Caffe2 - C++ API
A deep learning, cross platform ML framework
stack.h
1 #pragma once
2 
3 #include <torch/data/example.h>
4 #include <torch/data/transforms/collate.h>
5 #include <torch/types.h>
6 
7 #include <utility>
8 #include <vector>
9 
10 namespace torch {
11 namespace data {
12 namespace transforms {
13 
14 template <typename T = Example<>>
15 struct Stack;
16 
19 template <>
20 struct Stack<Example<>> : public Collation<Example<>> {
21  Example<> apply_batch(std::vector<Example<>> examples) override {
22  std::vector<torch::Tensor> data, targets;
23  data.reserve(examples.size());
24  targets.reserve(examples.size());
25  for (auto& example : examples) {
26  data.push_back(std::move(example.data));
27  targets.push_back(std::move(example.target));
28  }
29  return {torch::stack(data), torch::stack(targets)};
30  }
31 };
32 
35 template <>
37  : public Collation<Example<Tensor, example::NoTarget>> {
38  TensorExample apply_batch(std::vector<TensorExample> examples) override {
39  std::vector<torch::Tensor> data;
40  data.reserve(examples.size());
41  for (auto& example : examples) {
42  data.push_back(std::move(example.data));
43  }
44  return torch::stack(data);
45  }
46 };
47 } // namespace transforms
48 } // namespace data
49 } // namespace torch
An Example from a dataset.
Definition: example.h:12
Definition: jit_type.h:17
A transformation of a batch to a new batch.
Definition: base.h:14