Caffe2 - C++ API
A deep learning, cross platform ML framework
random.cpp
1 #include <torch/data/samplers/random.h>
2 #include <torch/serialize/archive.h>
3 #include <torch/types.h>
4 
5 #include <algorithm>
6 #include <cstddef>
7 #include <vector>
8 
9 namespace torch {
10 namespace data {
11 namespace samplers {
12 RandomSampler::RandomSampler(int64_t size, Dtype index_dtype)
13  : indices_(torch::randperm(size, index_dtype)) {}
14 
16  // This allocates a new chunk of memory every time (just FYI). It should be
17  // amortized over the entire epoch hopefully.
18  const auto size = new_size.value_or(static_cast<size_t>(indices_.numel()));
19  indices_ = torch::randperm(size, indices_.options());
20  index_ = 0;
21 }
22 
24  AT_ASSERT(index_ <= indices_.numel());
25  const size_t remaining_indices = indices_.numel() - index_;
26  if (remaining_indices == 0) {
27  return nullopt;
28  }
29  std::vector<size_t> index_batch(std::min(batch_size, remaining_indices));
30  auto slice = indices_.slice(/*dim=*/0, index_, index_ + index_batch.size());
31  // You may want to store your indices with 32-bit or less, but here we need
32  // to upcast to 64-bit. A batch itself won't hold too many indices, so that
33  // should be ok. Note that if this indeed results in a type promotion, there
34  // will be two allocations: one for the upcast slice, and one for the
35  // returned `index_batch` vector.
36  slice = slice.to(torch::kInt64);
37  const auto* data = slice.data<int64_t>();
38  std::copy(data, data + index_batch.size(), index_batch.begin());
39  index_ += index_batch.size();
40  return index_batch;
41 }
42 
44  archive.write(
45  "index",
46  torch::tensor(static_cast<int64_t>(index_), torch::kInt64),
47  /*is_buffer=*/true);
48  archive.write(
49  "indices",
50  indices_,
51  /*is_buffer=*/true);
52 }
53 
55  auto tensor = torch::empty(1, torch::kInt64);
56  archive.read(
57  "index",
58  tensor,
59  /*is_buffer=*/true);
60  index_ = tensor.item<int64_t>();
61  archive.read(
62  "indices",
63  indices_,
64  /*is_buffer=*/true);
65 }
66 
67 size_t RandomSampler::index() const noexcept {
68  return index_;
69 }
70 
71 } // namespace samplers
72 } // namespace data
73 } // namespace torch
TORCH_API void reset(optional< size_t > new_size=nullopt) override
Resets the RandomSampler to a new set of indices.
Definition: random.cpp:15
TensorOptions options() const
Returns the TensorOptions corresponding to this Tensor.
Definition: TensorMethods.h:42
TORCH_API RandomSampler(int64_t size, Dtype index_dtype=torch::kInt64)
Constructs a RandomSampler with a size and dtype for the stored indices.
Definition: random.cpp:12
void read(const std::string &key, Tensor &tensor, bool is_buffer=false)
Reads a tensor associated with a given key.
Definition: jit_type.h:17
TORCH_API size_t index() const noexcept
Returns the current index of the RandomSampler.
Definition: random.cpp:67
TORCH_API void load(serialize::InputArchive &archive) override
Deserializes the RandomSampler from the archive.
Definition: random.cpp:54
void write(const std::string &key, const Tensor &tensor, bool is_buffer=false)
Writes a (key, tensor) pair to the OutputArchive, and marks it as being or not being a buffer (non-di...
A recursive representation of tensors that can be deserialized from a file or stream.
Definition: input-archive.h:32
TORCH_API void save(serialize::OutputArchive &archive) const override
Serializes the RandomSampler to the archive.
Definition: random.cpp:43
TORCH_API optional< std::vector< size_t > > next(size_t batch_size) override
Returns the next batch of indices.
Definition: random.cpp:23