Caffe2 - C++ API
A deep learning, cross platform ML framework
tensor.h
1 #pragma once
2 
3 #include <torch/data/example.h>
4 #include <torch/data/transforms/base.h>
5 #include <torch/types.h>
6 
7 #include <functional>
8 #include <utility>
9 
10 namespace torch {
11 namespace data {
12 namespace transforms {
13 
17 template <typename Target = Tensor>
19  : public Transform<Example<Tensor, Target>, Example<Tensor, Target>> {
20  public:
21  using E = Example<Tensor, Target>;
22  using typename Transform<E, E>::InputType;
23  using typename Transform<E, E>::OutputType;
24 
26  virtual Tensor operator()(Tensor input) = 0;
27 
29  OutputType apply(InputType input) override {
30  input.data = (*this)(std::move(input.data));
31  return input;
32  }
33 };
34 
36 template <typename Target = Tensor>
37 class TensorLambda : public TensorTransform<Target> {
38  public:
39  using FunctionType = std::function<Tensor(Tensor)>;
40 
42  explicit TensorLambda(FunctionType function)
43  : function_(std::move(function)) {}
44 
46  Tensor operator()(Tensor input) override {
47  return function_(std::move(input));
48  }
49 
50  private:
51  FunctionType function_;
52 };
53 
56 template <typename Target = Tensor>
57 struct Normalize : public TensorTransform<Target> {
62  : mean(torch::tensor(mean, torch::kFloat32)
63  .unsqueeze(/*dim=*/1)
64  .unsqueeze(/*dim=*/2)),
65  stddev(torch::tensor(stddev, torch::kFloat32)
66  .unsqueeze(/*dim=*/1)
67  .unsqueeze(/*dim=*/2)) {}
68 
70  return input.sub(mean).div(stddev);
71  }
72 
73  torch::Tensor mean, stddev;
74 };
75 } // namespace transforms
76 } // namespace data
77 } // namespace torch
virtual Tensor operator()(Tensor input)=0
Transforms a single input tensor to an output tensor.
Definition: static.cpp:76
An Example from a dataset.
Definition: example.h:12
Normalize(ArrayRef< double > mean, ArrayRef< double > stddev)
Constructs a Normalize transform.
Definition: tensor.h:61
torch::Tensor operator()(Tensor input)
Transforms a single input tensor to an output tensor.
Definition: tensor.h:69
Normalizes input tensors by subtracting the supplied mean and dividing by the given standard deviatio...
Definition: tensor.h:57
Definition: jit_type.h:17
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: ArrayRef.h:41
Tensor operator()(Tensor input) override
Applies the user-provided functor to the input tensor.
Definition: tensor.h:46
A Lambda specialized for the typical Example<Tensor, Tensor> input type.
Definition: tensor.h:37
A transformation of individual input examples to individual output examples.
Definition: base.h:32
TensorLambda(FunctionType function)
Creates a TensorLambda from the given function.
Definition: tensor.h:42
A Transform that is specialized for the typical Example<Tensor, Tensor> combination.
Definition: tensor.h:18
OutputType apply(InputType input) override
Implementation of Transform::apply that calls operator().
Definition: tensor.h:29