3 #include <torch/data/datasets/base.h> 20 template <
typename UnderlyingDataset>
22 SharedBatchDataset<UnderlyingDataset>,
23 typename UnderlyingDataset::BatchType,
24 typename UnderlyingDataset::BatchRequestType> {
26 using BatchType =
typename UnderlyingDataset::BatchType;
27 using BatchRequestType =
typename UnderlyingDataset::BatchRequestType;
32 std::shared_ptr<UnderlyingDataset> shared_dataset)
33 : dataset_(
std::move(shared_dataset)) {}
36 BatchType
get_batch(BatchRequestType request)
override {
37 return dataset_->get_batch(std::move(request));
42 return dataset_->size();
57 return dataset_.get();
62 return dataset_.get();
71 std::shared_ptr<UnderlyingDataset> dataset_;
77 template <
typename UnderlyingDataset,
typename... Args>
79 return std::make_shared<UnderlyingDataset>(std::forward<Args>(args)...);
UnderlyingDataset * operator->()
Accesses the underlying dataset.
optional< size_t > size() const override
Returns the size from the underlying dataset.
BatchType get_batch(BatchRequestType request) override
Calls get_batch on the underlying dataset.
const UnderlyingDataset & operator*() const
Accesses the underlying dataset.
UnderlyingDataset & operator*()
Accesses the underlying dataset.
A dataset that can yield data only in batches.
SharedBatchDataset(std::shared_ptr< UnderlyingDataset > shared_dataset)
Constructs a new SharedBatchDataset from a shared_ptr to the UnderlyingDataset.
A dataset that wraps another dataset in a shared pointer and implements the BatchDataset API...
const UnderlyingDataset * operator->() const
Accesses the underlying dataset.
void reset()
Calls reset() on the underlying dataset.