3 #include <torch/data/datasets/base.h> 4 #include <torch/types.h> 6 #include <c10/util/ArrayRef.h> 16 template <
bool C,
typename T>
17 using optional_if_t =
typename std::conditional<C, torch::optional<T>,
T>::type;
21 template <
typename SourceDataset,
typename AppliedTransform>
22 class MapDataset :
public BatchDataset<
23 MapDataset<SourceDataset, AppliedTransform>,
24 detail::optional_if_t<
25 SourceDataset::is_stateful,
26 typename AppliedTransform::OutputBatchType>,
27 typename SourceDataset::BatchRequestType> {
29 using DatasetType = SourceDataset;
30 using TransformType = AppliedTransform;
31 using BatchRequestType =
typename SourceDataset::BatchRequestType;
32 using OutputBatchType = detail::optional_if_t<
33 SourceDataset::is_stateful,
34 typename AppliedTransform::OutputBatchType>;
36 MapDataset(DatasetType dataset, TransformType transform)
37 : dataset_(
std::move(dataset)), transform_(
std::move(transform)) {}
41 OutputBatchType
get_batch(BatchRequestType indices)
override {
42 return get_batch_impl(std::move(indices));
47 return dataset_.size();
72 typename D = SourceDataset,
73 typename = torch::disable_if_t<D::is_stateful>>
74 OutputBatchType get_batch_impl(BatchRequestType indices) {
75 return transform_.apply_batch(dataset_.get_batch(std::move(indices)));
83 template <
typename D = SourceDataset>
84 torch::enable_if_t<D::is_stateful, OutputBatchType> get_batch_impl(
85 BatchRequestType indices) {
86 if (
auto batch = dataset_.get_batch(std::move(indices))) {
87 return transform_.apply_batch(std::move(*batch));
93 SourceDataset dataset_;
96 AppliedTransform transform_;
100 template <
typename DatasetType,
typename TransformType>
103 TransformType transform) {
106 typename std::conditional<
107 DatasetType::is_stateful,
108 typename DatasetType::BatchType::value_type,
109 typename DatasetType::BatchType>::type,
110 typename TransformType::InputBatchType>::value,
111 "BatchType type of dataset does not match input type of transform");
112 return {std::move(dataset), std::move(transform)};
const AppliedTransform & transform() noexcept
Returns the transform being applied.
OutputBatchType get_batch(BatchRequestType indices) override
Gets a batch from the source dataset and applies the transform to it, returning the result...
optional< size_t > size() const noexceptoverride
Returns the size of the source dataset.
void reset()
Calls reset() on the underlying dataset.
const SourceDataset & dataset() noexcept
Returns the underlying dataset.
A MapDataset is a dataset that applies a transform to a source dataset.