Caffe2 - C++ API
A deep learning, cross platform ML framework
base.h
1 #pragma once
2 
3 #include <torch/data/example.h>
4 #include <torch/types.h>
5 
6 #include <c10/util/ArrayRef.h>
7 
8 #include <cstddef>
9 #include <cstdint>
10 #include <type_traits>
11 #include <utility>
12 #include <vector>
13 
14 namespace torch {
15 namespace data {
16 namespace datasets {
17 template <typename S, typename T>
18 class MapDataset;
19 template <typename D, typename T>
20 MapDataset<D, T> map(D, T); // NOLINT
21 } // namespace datasets
22 } // namespace data
23 } // namespace torch
24 
25 namespace torch {
26 namespace data {
27 namespace datasets {
28 namespace detail {
29 template <typename T>
30 struct is_optional : std::false_type {};
31 template <typename T>
32 struct is_optional<optional<T>> : std::true_type {};
33 } // namespace detail
34 
36 template <
37  typename Self,
38  typename Batch = std::vector<Example<>>,
39  typename BatchRequest = ArrayRef<size_t>>
40 class BatchDataset {
41  public:
42  using SelfType = Self;
43  using BatchType = Batch;
44  using BatchRequestType = BatchRequest;
45  constexpr static bool is_stateful = detail::is_optional<BatchType>::value;
46 
47  virtual ~BatchDataset() = default;
48 
50  virtual Batch get_batch(BatchRequest request) = 0;
51 
53  virtual optional<size_t> size() const = 0;
54 
56  template <typename TransformType>
57  MapDataset<Self, TransformType> map(TransformType transform) & {
58  return datasets::map(static_cast<Self&>(*this), std::move(transform));
59  }
60 
62  template <typename TransformType>
63  MapDataset<Self, TransformType> map(TransformType transform) && {
64  return datasets::map(
65  std::move(static_cast<Self&>(*this)), std::move(transform));
66  }
67 };
68 
75 template <typename Self, typename SingleExample = Example<>>
76 class Dataset : public BatchDataset<Self, std::vector<SingleExample>> {
77  public:
78  using ExampleType = SingleExample;
79 
81  virtual ExampleType get(size_t index) = 0;
82 
86  std::vector<ExampleType> get_batch(ArrayRef<size_t> indices) override {
87  std::vector<ExampleType> batch;
88  batch.reserve(indices.size());
89  for (const auto i : indices) {
90  batch.push_back(get(i));
91  }
92  return batch;
93  }
94 };
95 
99 template <typename Self, typename Batch = std::vector<Example<>>>
100 using StreamDataset = BatchDataset<Self, Batch, /*BatchRequest=*/size_t>;
101 } // namespace datasets
102 } // namespace data
103 } // namespace torch
MapDataset< Self, TransformType > map(TransformType transform)&
Creates a MapDataset that applies the given transform to this dataset.
Definition: base.h:57
std::vector< ExampleType > get_batch(ArrayRef< size_t > indices) override
Returns a batch of data.
Definition: base.h:86
A dataset that can yield data only in batches.
Definition: base.h:40
constexpr size_t size() const
size - Get the array size.
Definition: ArrayRef.h:138
A dataset that can yield data in batches, or as individual examples.
Definition: base.h:76
MapDataset< Self, TransformType > map(TransformType transform)&&
Creates a MapDataset that applies the given transform to this dataset.
Definition: base.h:63
Definition: jit_type.h:17
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: ArrayRef.h:41
Definition: static.cpp:70
A MapDataset is a dataset that applies a transform to a source dataset.
Definition: base.h:18