1 #include <torch/data/samplers/stream.h> 2 #include <torch/serialize/archive.h> 3 #include <torch/types.h> 5 #include <c10/util/Exception.h> 13 BatchSize::BatchSize(
size_t size) : size_(size) {}
17 BatchSize::operator size_t()
const noexcept {
24 if (new_size.has_value()) {
25 epoch_size_ = *new_size;
27 examples_retrieved_so_far_ = 0;
31 AT_ASSERT(examples_retrieved_so_far_ <= epoch_size_);
32 if (examples_retrieved_so_far_ == epoch_size_) {
35 if (examples_retrieved_so_far_ + batch_size > epoch_size_) {
36 batch_size = epoch_size_ - examples_retrieved_so_far_;
38 examples_retrieved_so_far_ += batch_size;
44 "examples_retrieved_so_far",
46 static_cast<int64_t>(examples_retrieved_so_far_), torch::kInt64),
51 auto tensor = torch::empty(1, torch::kInt64);
53 "examples_retrieved_so_far",
56 examples_retrieved_so_far_ = tensor.item<int64_t>();
TORCH_API optional< BatchSize > next(size_t batch_size) override
Returns a BatchSize object with the number of elements to fetch in the next batch.
TORCH_API StreamSampler(size_t epoch_size)
Constructs the StreamSampler with the number of individual examples that should be fetched until the ...
TORCH_API void load(serialize::InputArchive &archive) override
Deserializes the StreamSampler from the archive.
TORCH_API void reset(optional< size_t > new_size=nullopt) override
Resets the internal state of the sampler.
size_t size() const noexceptoverride
The number of elements accessed by this index.
A wrapper around a batch size value, which implements the CustomBatchRequest interface.
void write(const std::string &key, const Tensor &tensor, bool is_buffer=false)
Writes a (key, tensor) pair to the OutputArchive, and marks it as being or not being a buffer (non-di...
TORCH_API void save(serialize::OutputArchive &archive) const override
Serializes the StreamSampler to the archive.