Caffe2 - C++ API
A deep learning, cross platform ML framework
mnist.h
1 #pragma once
2 
3 #include <torch/data/datasets/base.h>
4 #include <torch/data/example.h>
5 #include <torch/types.h>
6 
7 #include <torch/csrc/WindowsTorchApiMacro.h>
8 
9 #include <cstddef>
10 #include <string>
11 
12 namespace torch {
13 namespace data {
14 namespace datasets {
16 class TORCH_API MNIST : public Dataset<MNIST> {
17  public:
19  enum class Mode { kTrain, kTest };
20 
25  explicit MNIST(const std::string& root, Mode mode = Mode::kTrain);
26 
28  Example<> get(size_t index) override;
29 
31  optional<size_t> size() const override;
32 
34  bool is_train() const noexcept;
35 
37  const Tensor& images() const;
38 
40  const Tensor& targets() const;
41 
42  private:
43  Tensor images_, targets_;
44 };
45 } // namespace datasets
46 } // namespace data
47 } // namespace torch
An Example from a dataset.
Definition: example.h:12
The MNIST dataset.
Definition: mnist.h:16
A dataset that can yield data in batches, or as individual examples.
Definition: base.h:76
Definition: jit_type.h:17
Mode
The mode in which the dataset is loaded.
Definition: mnist.h:19