Caffe2 - C++ API
A deep learning, cross platform ML framework
tensor.h
1 #pragma once
2 
3 #include <torch/data/datasets/base.h>
4 #include <torch/data/example.h>
5 #include <torch/types.h>
6 
7 #include <cstddef>
8 #include <vector>
9 
10 namespace torch {
11 namespace data {
12 namespace datasets {
13 
16 struct TensorDataset : public Dataset<TensorDataset, TensorExample> {
18  explicit TensorDataset(const std::vector<Tensor>& tensors)
19  : TensorDataset(torch::stack(tensors)) {}
20 
21  explicit TensorDataset(torch::Tensor tensor) : tensor(std::move(tensor)) {}
22 
24  TensorExample get(size_t index) override {
25  return tensor[index];
26  }
27 
29  optional<size_t> size() const override {
30  return tensor.size(0);
31  }
32 
33  Tensor tensor;
34 };
35 
36 } // namespace datasets
37 } // namespace data
38 } // namespace torch
optional< size_t > size() const override
Returns the number of tensors in the dataset.
Definition: tensor.h:29
An Example from a dataset.
Definition: example.h:12
A dataset of tensors.
Definition: tensor.h:16
A dataset that can yield data in batches, or as individual examples.
Definition: base.h:76
Definition: jit_type.h:17
TensorDataset(const std::vector< Tensor > &tensors)
Creates a TensorDataset from a vector of tensors.
Definition: tensor.h:18