Caffe2 - C++ API
A deep learning, cross platform ML framework
map.h
1 #pragma once
2 
3 #include <torch/data/datasets/base.h>
4 #include <torch/types.h>
5 
6 #include <c10/util/ArrayRef.h>
7 
8 #include <cstddef>
9 #include <type_traits>
10 #include <utility>
11 
12 namespace torch {
13 namespace data {
14 namespace datasets {
15 namespace detail {
16 template <bool C, typename T>
17 using optional_if_t = typename std::conditional<C, torch::optional<T>, T>::type;
18 } // namespace detail
19 
21 template <typename SourceDataset, typename AppliedTransform>
22 class MapDataset : public BatchDataset<
23  MapDataset<SourceDataset, AppliedTransform>,
24  detail::optional_if_t<
25  SourceDataset::is_stateful,
26  typename AppliedTransform::OutputBatchType>,
27  typename SourceDataset::BatchRequestType> {
28  public:
29  using DatasetType = SourceDataset;
30  using TransformType = AppliedTransform;
31  using BatchRequestType = typename SourceDataset::BatchRequestType;
32  using OutputBatchType = detail::optional_if_t<
33  SourceDataset::is_stateful,
34  typename AppliedTransform::OutputBatchType>;
35 
36  MapDataset(DatasetType dataset, TransformType transform)
37  : dataset_(std::move(dataset)), transform_(std::move(transform)) {}
38 
41  OutputBatchType get_batch(BatchRequestType indices) override {
42  return get_batch_impl(std::move(indices));
43  }
44 
46  optional<size_t> size() const noexcept override {
47  return dataset_.size();
48  }
49 
54  void reset() {
55  dataset_.reset();
56  }
57 
59  const SourceDataset& dataset() noexcept {
60  return dataset_;
61  }
62 
64  const AppliedTransform& transform() noexcept {
65  return transform_;
66  }
67 
68  private:
71  template <
72  typename D = SourceDataset,
73  typename = torch::disable_if_t<D::is_stateful>>
74  OutputBatchType get_batch_impl(BatchRequestType indices) {
75  return transform_.apply_batch(dataset_.get_batch(std::move(indices)));
76  }
77 
83  template <typename D = SourceDataset>
84  torch::enable_if_t<D::is_stateful, OutputBatchType> get_batch_impl(
85  BatchRequestType indices) {
86  if (auto batch = dataset_.get_batch(std::move(indices))) {
87  return transform_.apply_batch(std::move(*batch));
88  }
89  return nullopt;
90  }
91 
93  SourceDataset dataset_;
94 
95  // The transformation that is applied to batches received from the dataset.
96  AppliedTransform transform_;
97 };
98 
100 template <typename DatasetType, typename TransformType>
102  DatasetType dataset,
103  TransformType transform) {
104  static_assert(
105  std::is_same<
106  typename std::conditional<
107  DatasetType::is_stateful,
108  typename DatasetType::BatchType::value_type,
109  typename DatasetType::BatchType>::type,
110  typename TransformType::InputBatchType>::value,
111  "BatchType type of dataset does not match input type of transform");
112  return {std::move(dataset), std::move(transform)};
113 }
114 
115 } // namespace datasets
116 } // namespace data
117 } // namespace torch
const AppliedTransform & transform() noexcept
Returns the transform being applied.
Definition: map.h:64
OutputBatchType get_batch(BatchRequestType indices) override
Gets a batch from the source dataset and applies the transform to it, returning the result...
Definition: map.h:41
optional< size_t > size() const noexceptoverride
Returns the size of the source dataset.
Definition: map.h:46
Definition: jit_type.h:17
void reset()
Calls reset() on the underlying dataset.
Definition: map.h:54
Definition: static.cpp:70
const SourceDataset & dataset() noexcept
Returns the underlying dataset.
Definition: map.h:59
A MapDataset is a dataset that applies a transform to a source dataset.
Definition: base.h:18