Caffe2 - C++ API
A deep learning, cross platform ML framework
stream.cpp
1 #include <torch/data/samplers/stream.h>
2 #include <torch/serialize/archive.h>
3 #include <torch/types.h>
4 
5 #include <c10/util/Exception.h>
6 
7 #include <cstddef>
8 
9 namespace torch {
10 namespace data {
11 namespace samplers {
12 
13 BatchSize::BatchSize(size_t size) : size_(size) {}
14 size_t BatchSize::size() const noexcept {
15  return size_;
16 }
17 BatchSize::operator size_t() const noexcept {
18  return size_;
19 }
20 
21 StreamSampler::StreamSampler(size_t epoch_size) : epoch_size_(epoch_size) {}
22 
24  if (new_size.has_value()) {
25  epoch_size_ = *new_size;
26  }
27  examples_retrieved_so_far_ = 0;
28 }
29 
31  AT_ASSERT(examples_retrieved_so_far_ <= epoch_size_);
32  if (examples_retrieved_so_far_ == epoch_size_) {
33  return nullopt;
34  }
35  if (examples_retrieved_so_far_ + batch_size > epoch_size_) {
36  batch_size = epoch_size_ - examples_retrieved_so_far_;
37  }
38  examples_retrieved_so_far_ += batch_size;
39  return BatchSize(batch_size);
40 }
41 
43  archive.write(
44  "examples_retrieved_so_far",
45  torch::tensor(
46  static_cast<int64_t>(examples_retrieved_so_far_), torch::kInt64),
47  /*is_buffer=*/true);
48 }
49 
51  auto tensor = torch::empty(1, torch::kInt64);
52  archive.read(
53  "examples_retrieved_so_far",
54  tensor,
55  /*is_buffer=*/true);
56  examples_retrieved_so_far_ = tensor.item<int64_t>();
57 }
58 
59 } // namespace samplers
60 } // namespace data
61 } // namespace torch
TORCH_API optional< BatchSize > next(size_t batch_size) override
Returns a BatchSize object with the number of elements to fetch in the next batch.
Definition: stream.cpp:30
void read(const std::string &key, Tensor &tensor, bool is_buffer=false)
Reads a tensor associated with a given key.
TORCH_API StreamSampler(size_t epoch_size)
Constructs the StreamSampler with the number of individual examples that should be fetched until the ...
Definition: stream.cpp:21
TORCH_API void load(serialize::InputArchive &archive) override
Deserializes the StreamSampler from the archive.
Definition: stream.cpp:50
TORCH_API void reset(optional< size_t > new_size=nullopt) override
Resets the internal state of the sampler.
Definition: stream.cpp:23
Definition: jit_type.h:17
size_t size() const noexceptoverride
The number of elements accessed by this index.
Definition: stream.cpp:14
A wrapper around a batch size value, which implements the CustomBatchRequest interface.
Definition: stream.h:23
void write(const std::string &key, const Tensor &tensor, bool is_buffer=false)
Writes a (key, tensor) pair to the OutputArchive, and marks it as being or not being a buffer (non-di...
A recursive representation of tensors that can be deserialized from a file or stream.
Definition: input-archive.h:32
TORCH_API void save(serialize::OutputArchive &archive) const override
Serializes the StreamSampler to the archive.
Definition: stream.cpp:42