1 #include <torch/data/datasets/mnist.h> 3 #include <torch/data/example.h> 4 #include <torch/types.h> 6 #include <c10/util/Exception.h> 17 constexpr uint32_t kTrainSize = 60000;
18 constexpr uint32_t kTestSize = 10000;
19 constexpr uint32_t kImageMagicNumber = 2051;
20 constexpr uint32_t kTargetMagicNumber = 2049;
21 constexpr uint32_t kImageRows = 28;
22 constexpr uint32_t kImageColumns = 28;
23 constexpr
const char* kTrainImagesFilename =
"train-images-idx3-ubyte";
24 constexpr
const char* kTrainTargetsFilename =
"train-labels-idx1-ubyte";
25 constexpr
const char* kTestImagesFilename =
"t10k-images-idx3-ubyte";
26 constexpr
const char* kTestTargetsFilename =
"t10k-labels-idx1-ubyte";
28 bool check_is_little_endian() {
29 const uint32_t word = 1;
30 return reinterpret_cast<const uint8_t*
>(&word)[0] == 1;
33 constexpr uint32_t flip_endianness(uint32_t value) {
34 return ((value & 0xffu) << 24u) | ((value & 0xff00u) << 8u) |
35 ((value & 0xff0000u) >> 8u) | ((value & 0xff000000u) >> 24u);
38 uint32_t read_int32(std::ifstream& stream) {
39 static const bool is_little_endian = check_is_little_endian();
41 AT_ASSERT(stream.read(reinterpret_cast<char*>(&value),
sizeof value));
42 return is_little_endian ? flip_endianness(value) : value;
45 uint32_t expect_int32(std::ifstream& stream, uint32_t expected) {
46 const auto value = read_int32(stream);
48 AT_CHECK(value == expected,
49 "Expected to read number ", expected,
" but found ", value,
" instead");
54 std::string join_paths(std::string head,
const std::string& tail) {
55 if (head.back() !=
'/') {
62 Tensor read_images(
const std::string& root,
bool train) {
64 join_paths(root, train ? kTrainImagesFilename : kTestImagesFilename);
65 std::ifstream images(path, std::ios::binary);
66 AT_CHECK(images,
"Error opening images file at ", path);
68 const auto count = train ? kTrainSize : kTestSize;
71 expect_int32(images, kImageMagicNumber);
72 expect_int32(images, count);
73 expect_int32(images, kImageRows);
74 expect_int32(images, kImageColumns);
77 torch::empty({count, 1, kImageRows, kImageColumns}, torch::kByte);
78 images.read(reinterpret_cast<char*>(tensor.data_ptr()), tensor.numel());
79 return tensor.to(torch::kFloat32).div_(255);
82 Tensor read_targets(
const std::string& root,
bool train) {
84 join_paths(root, train ? kTrainTargetsFilename : kTestTargetsFilename);
85 std::ifstream targets(path, std::ios::binary);
86 AT_CHECK(targets,
"Error opening targets file at ", path);
88 const auto count = train ? kTrainSize : kTestSize;
90 expect_int32(targets, kTargetMagicNumber);
91 expect_int32(targets, count);
93 auto tensor = torch::empty(count, torch::kByte);
94 targets.read(reinterpret_cast<char*>(tensor.data_ptr()), count);
95 return tensor.to(torch::kInt64);
100 : images_(read_images(root, mode ==
Mode::kTrain)),
101 targets_(read_targets(root, mode ==
Mode::kTrain)) {}
104 return {images_[index], targets_[index]};
108 return images_.size(0);
112 return images_.size(0) == kTrainSize;
optional< size_t > size() const override
Returns the size of the dataset.
An Example from a dataset.
const Tensor & targets() const
Returns all targets stacked into a single tensor.
Example get(size_t index) override
Returns the Example at the given index.
bool is_train() const noexcept
Returns true if this is the training subset of MNIST.
const Tensor & images() const
Returns all images stacked into a single tensor.
Mode
The mode in which the dataset is loaded.
MNIST(const std::string &root, Mode mode=Mode::kTrain)
Loads the MNIST dataset from the root path.