1 #include <torch/data/samplers/sequential.h>     2 #include <torch/serialize/archive.h>     3 #include <torch/types.h>    15   if (new_size.has_value()) {
    22   const auto remaining_indices = size_ - index_;
    23   if (remaining_indices == 0) {
    26   std::vector<size_t> index_batch(std::min(batch_size, remaining_indices));
    27   for (
auto& i : index_batch) {
    36       torch::tensor(static_cast<int64_t>(index_), torch::kInt64),
    41   auto tensor = torch::empty(1, torch::kInt64);
    46   index_ = tensor.item<int64_t>();
 
TORCH_API SequentialSampler(size_t size)
Creates a SequentialSampler that will return indices in the range 0...size - 1. 
TORCH_API optional< std::vector< size_t > > next(size_t batch_size) override
Returns the next batch of indices. 
TORCH_API void load(serialize::InputArchive &archive) override
Deserializes the SequentialSampler from the archive. 
TORCH_API void reset(optional< size_t > new_size=nullopt) override
Resets the SequentialSampler to zero. 
TORCH_API void save(serialize::OutputArchive &archive) const override
Serializes the SequentialSampler to the archive. 
void write(const std::string &key, const Tensor &tensor, bool is_buffer=false)
Writes a (key, tensor) pair to the OutputArchive, and marks it as being or not being a buffer (non-di...
TORCH_API size_t index() const noexcept
Returns the current index of the SequentialSampler.