1 #include <torch/data/samplers/random.h> 2 #include <torch/serialize/archive.h> 3 #include <torch/types.h> 13 : indices_(
torch::randperm(size, index_dtype)) {}
18 const auto size = new_size.value_or(static_cast<size_t>(indices_.numel()));
19 indices_ = torch::randperm(size, indices_.
options());
24 AT_ASSERT(index_ <= indices_.numel());
25 const size_t remaining_indices = indices_.numel() - index_;
26 if (remaining_indices == 0) {
29 std::vector<size_t> index_batch(std::min(batch_size, remaining_indices));
30 auto slice = indices_.slice(0, index_, index_ + index_batch.size());
36 slice = slice.to(torch::kInt64);
37 const auto* data = slice.data<int64_t>();
38 std::copy(data, data + index_batch.size(), index_batch.begin());
39 index_ += index_batch.size();
46 torch::tensor(static_cast<int64_t>(index_), torch::kInt64),
55 auto tensor = torch::empty(1, torch::kInt64);
60 index_ = tensor.item<int64_t>();
TORCH_API void reset(optional< size_t > new_size=nullopt) override
Resets the RandomSampler to a new set of indices.
TensorOptions options() const
Returns the TensorOptions corresponding to this Tensor.
TORCH_API RandomSampler(int64_t size, Dtype index_dtype=torch::kInt64)
Constructs a RandomSampler with a size and dtype for the stored indices.
TORCH_API size_t index() const noexcept
Returns the current index of the RandomSampler.
TORCH_API void load(serialize::InputArchive &archive) override
Deserializes the RandomSampler from the archive.
void write(const std::string &key, const Tensor &tensor, bool is_buffer=false)
Writes a (key, tensor) pair to the OutputArchive, and marks it as being or not being a buffer (non-di...
TORCH_API void save(serialize::OutputArchive &archive) const override
Serializes the RandomSampler to the archive.
TORCH_API optional< std::vector< size_t > > next(size_t batch_size) override
Returns the next batch of indices.