Caffe2 - C++ API
A deep learning, cross platform ML framework
sequential.cpp
1 #include <torch/data/samplers/sequential.h>
2 #include <torch/serialize/archive.h>
3 #include <torch/types.h>
4 
5 #include <algorithm>
6 #include <cstddef>
7 #include <vector>
8 
9 namespace torch {
10 namespace data {
11 namespace samplers {
12 SequentialSampler::SequentialSampler(size_t size) : size_(size) {}
13 
15  if (new_size.has_value()) {
16  size_ = *new_size;
17  }
18  index_ = 0;
19 }
20 
22  const auto remaining_indices = size_ - index_;
23  if (remaining_indices == 0) {
24  return nullopt;
25  }
26  std::vector<size_t> index_batch(std::min(batch_size, remaining_indices));
27  for (auto& i : index_batch) {
28  i = index_++;
29  }
30  return index_batch;
31 }
32 
34  archive.write(
35  "index",
36  torch::tensor(static_cast<int64_t>(index_), torch::kInt64),
37  /*is_buffer=*/true);
38 }
39 
41  auto tensor = torch::empty(1, torch::kInt64);
42  archive.read(
43  "index",
44  tensor,
45  /*is_buffer=*/true);
46  index_ = tensor.item<int64_t>();
47 }
48 
49 size_t SequentialSampler::index() const noexcept {
50  return index_;
51 }
52 
53 } // namespace samplers
54 } // namespace data
55 } // namespace torch
TORCH_API SequentialSampler(size_t size)
Creates a SequentialSampler that will return indices in the range 0...size - 1.
Definition: sequential.cpp:12
TORCH_API optional< std::vector< size_t > > next(size_t batch_size) override
Returns the next batch of indices.
Definition: sequential.cpp:21
void read(const std::string &key, Tensor &tensor, bool is_buffer=false)
Reads a tensor associated with a given key.
Definition: jit_type.h:17
TORCH_API void load(serialize::InputArchive &archive) override
Deserializes the SequentialSampler from the archive.
Definition: sequential.cpp:40
TORCH_API void reset(optional< size_t > new_size=nullopt) override
Resets the SequentialSampler to zero.
Definition: sequential.cpp:14
TORCH_API void save(serialize::OutputArchive &archive) const override
Serializes the SequentialSampler to the archive.
Definition: sequential.cpp:33
void write(const std::string &key, const Tensor &tensor, bool is_buffer=false)
Writes a (key, tensor) pair to the OutputArchive, and marks it as being or not being a buffer (non-di...
A recursive representation of tensors that can be deserialized from a file or stream.
Definition: input-archive.h:32
TORCH_API size_t index() const noexcept
Returns the current index of the SequentialSampler.
Definition: sequential.cpp:49