Caffe2 - C++ API
A deep learning, cross platform ML framework
shared.h
1 #pragma once
2 
3 #include <torch/data/datasets/base.h>
4 
5 #include <memory>
6 #include <utility>
7 
8 namespace torch {
9 namespace data {
10 namespace datasets {
11 
20 template <typename UnderlyingDataset>
22  SharedBatchDataset<UnderlyingDataset>,
23  typename UnderlyingDataset::BatchType,
24  typename UnderlyingDataset::BatchRequestType> {
25  public:
26  using BatchType = typename UnderlyingDataset::BatchType;
27  using BatchRequestType = typename UnderlyingDataset::BatchRequestType;
28 
31  /* implicit */ SharedBatchDataset(
32  std::shared_ptr<UnderlyingDataset> shared_dataset)
33  : dataset_(std::move(shared_dataset)) {}
34 
36  BatchType get_batch(BatchRequestType request) override {
37  return dataset_->get_batch(std::move(request));
38  }
39 
41  optional<size_t> size() const override {
42  return dataset_->size();
43  }
44 
46  UnderlyingDataset& operator*() {
47  return *dataset_;
48  }
49 
51  const UnderlyingDataset& operator*() const {
52  return *dataset_;
53  }
54 
56  UnderlyingDataset* operator->() {
57  return dataset_.get();
58  }
59 
61  const UnderlyingDataset* operator->() const {
62  return dataset_.get();
63  }
64 
66  void reset() {
67  dataset_->reset();
68  }
69 
70  private:
71  std::shared_ptr<UnderlyingDataset> dataset_;
72 };
73 
77 template <typename UnderlyingDataset, typename... Args>
78 SharedBatchDataset<UnderlyingDataset> make_shared_dataset(Args&&... args) {
79  return std::make_shared<UnderlyingDataset>(std::forward<Args>(args)...);
80 }
81 } // namespace datasets
82 } // namespace data
83 } // namespace torch
UnderlyingDataset * operator->()
Accesses the underlying dataset.
Definition: shared.h:56
optional< size_t > size() const override
Returns the size from the underlying dataset.
Definition: shared.h:41
BatchType get_batch(BatchRequestType request) override
Calls get_batch on the underlying dataset.
Definition: shared.h:36
const UnderlyingDataset & operator*() const
Accesses the underlying dataset.
Definition: shared.h:51
UnderlyingDataset & operator*()
Accesses the underlying dataset.
Definition: shared.h:46
A dataset that can yield data only in batches.
Definition: base.h:40
SharedBatchDataset(std::shared_ptr< UnderlyingDataset > shared_dataset)
Constructs a new SharedBatchDataset from a shared_ptr to the UnderlyingDataset.
Definition: shared.h:31
A dataset that wraps another dataset in a shared pointer and implements the BatchDataset API...
Definition: shared.h:21
Definition: jit_type.h:17
const UnderlyingDataset * operator->() const
Accesses the underlying dataset.
Definition: shared.h:61
void reset()
Calls reset() on the underlying dataset.
Definition: shared.h:66