Caffe2 - C++ API
A deep learning, cross platform ML framework
doubler.h
1 #include <torch/extension.h>
2 
3 struct Doubler {
4  Doubler(int A, int B) {
5  tensor_ =
6  torch::ones({A, B}, torch::dtype(torch::kFloat64).requires_grad(true));
7  }
8  torch::Tensor forward() {
9  return tensor_ * 2;
10  }
11  torch::Tensor get() const {
12  return tensor_;
13  }
14 
15  private:
16  torch::Tensor tensor_;
17 };
Definition: doubler.h:3
Definition: static.cpp:52
Definition: static.cpp:58