Caffe2 - C++ API
A deep learning, cross platform ML framework
mnist.cpp
1 #include <torch/data/datasets/mnist.h>
2 
3 #include <torch/data/example.h>
4 #include <torch/types.h>
5 
6 #include <c10/util/Exception.h>
7 
8 #include <cstddef>
9 #include <fstream>
10 #include <string>
11 #include <vector>
12 
13 namespace torch {
14 namespace data {
15 namespace datasets {
16 namespace {
17 constexpr uint32_t kTrainSize = 60000;
18 constexpr uint32_t kTestSize = 10000;
19 constexpr uint32_t kImageMagicNumber = 2051;
20 constexpr uint32_t kTargetMagicNumber = 2049;
21 constexpr uint32_t kImageRows = 28;
22 constexpr uint32_t kImageColumns = 28;
23 constexpr const char* kTrainImagesFilename = "train-images-idx3-ubyte";
24 constexpr const char* kTrainTargetsFilename = "train-labels-idx1-ubyte";
25 constexpr const char* kTestImagesFilename = "t10k-images-idx3-ubyte";
26 constexpr const char* kTestTargetsFilename = "t10k-labels-idx1-ubyte";
27 
28 bool check_is_little_endian() {
29  const uint32_t word = 1;
30  return reinterpret_cast<const uint8_t*>(&word)[0] == 1;
31 }
32 
33 constexpr uint32_t flip_endianness(uint32_t value) {
34  return ((value & 0xffu) << 24u) | ((value & 0xff00u) << 8u) |
35  ((value & 0xff0000u) >> 8u) | ((value & 0xff000000u) >> 24u);
36 }
37 
38 uint32_t read_int32(std::ifstream& stream) {
39  static const bool is_little_endian = check_is_little_endian();
40  uint32_t value;
41  AT_ASSERT(stream.read(reinterpret_cast<char*>(&value), sizeof value));
42  return is_little_endian ? flip_endianness(value) : value;
43 }
44 
45 uint32_t expect_int32(std::ifstream& stream, uint32_t expected) {
46  const auto value = read_int32(stream);
47  // clang-format off
48  AT_CHECK(value == expected,
49  "Expected to read number ", expected, " but found ", value, " instead");
50  // clang-format on
51  return value;
52 }
53 
54 std::string join_paths(std::string head, const std::string& tail) {
55  if (head.back() != '/') {
56  head.push_back('/');
57  }
58  head += tail;
59  return head;
60 }
61 
62 Tensor read_images(const std::string& root, bool train) {
63  const auto path =
64  join_paths(root, train ? kTrainImagesFilename : kTestImagesFilename);
65  std::ifstream images(path, std::ios::binary);
66  AT_CHECK(images, "Error opening images file at ", path);
67 
68  const auto count = train ? kTrainSize : kTestSize;
69 
70  // From http://yann.lecun.com/exdb/mnist/
71  expect_int32(images, kImageMagicNumber);
72  expect_int32(images, count);
73  expect_int32(images, kImageRows);
74  expect_int32(images, kImageColumns);
75 
76  auto tensor =
77  torch::empty({count, 1, kImageRows, kImageColumns}, torch::kByte);
78  images.read(reinterpret_cast<char*>(tensor.data_ptr()), tensor.numel());
79  return tensor.to(torch::kFloat32).div_(255);
80 }
81 
82 Tensor read_targets(const std::string& root, bool train) {
83  const auto path =
84  join_paths(root, train ? kTrainTargetsFilename : kTestTargetsFilename);
85  std::ifstream targets(path, std::ios::binary);
86  AT_CHECK(targets, "Error opening targets file at ", path);
87 
88  const auto count = train ? kTrainSize : kTestSize;
89 
90  expect_int32(targets, kTargetMagicNumber);
91  expect_int32(targets, count);
92 
93  auto tensor = torch::empty(count, torch::kByte);
94  targets.read(reinterpret_cast<char*>(tensor.data_ptr()), count);
95  return tensor.to(torch::kInt64);
96 }
97 } // namespace
98 
99 MNIST::MNIST(const std::string& root, Mode mode)
100  : images_(read_images(root, mode == Mode::kTrain)),
101  targets_(read_targets(root, mode == Mode::kTrain)) {}
102 
103 Example<> MNIST::get(size_t index) {
104  return {images_[index], targets_[index]};
105 }
106 
108  return images_.size(0);
109 }
110 
111 bool MNIST::is_train() const noexcept {
112  return images_.size(0) == kTrainSize;
113 }
114 
115 const Tensor& MNIST::images() const {
116  return images_;
117 }
118 
119 const Tensor& MNIST::targets() const {
120  return targets_;
121 }
122 
123 } // namespace datasets
124 } // namespace data
125 } // namespace torch
optional< size_t > size() const override
Returns the size of the dataset.
Definition: mnist.cpp:107
An Example from a dataset.
Definition: example.h:12
const Tensor & targets() const
Returns all targets stacked into a single tensor.
Definition: mnist.cpp:119
Example get(size_t index) override
Returns the Example at the given index.
Definition: mnist.cpp:103
bool is_train() const noexcept
Returns true if this is the training subset of MNIST.
Definition: mnist.cpp:111
Definition: jit_type.h:17
const Tensor & images() const
Returns all images stacked into a single tensor.
Definition: mnist.cpp:115
Mode
The mode in which the dataset is loaded.
Definition: mnist.h:19
MNIST(const std::string &root, Mode mode=Mode::kTrain)
Loads the MNIST dataset from the root path.
Definition: mnist.cpp:99