Caffe2 - C++ API
A deep learning, cross platform ML framework
tt_contraction_op.h
1 
17 #ifndef CAFFE2_OPERATORS_TT_CONTRACTION_OP_H_
18 #define CAFFE2_OPERATORS_TT_CONTRACTION_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/operator.h"
22 #include "caffe2/utils/math.h"
23 
24 namespace caffe2 {
25 
26 template <typename T, class Context, class Engine = DefaultEngine>
27 class TTContractionOp final : public Operator<Context> {
28  public:
29  USE_OPERATOR_CONTEXT_FUNCTIONS;
30  TTContractionOp(const OperatorDef& operator_def, Workspace* ws)
31  : Operator<Context>(operator_def, ws),
32  K_(OperatorBase::GetSingleArgument<int64_t>("K", 0)),
33  M_(OperatorBase::GetSingleArgument<int64_t>("M", 0)),
34  N_(OperatorBase::GetSingleArgument<int64_t>("N", 0)) {
35  CAFFE_ENFORCE(OperatorBase::HasArgument("K"), "Argument `K` is missing.");
36  CAFFE_ENFORCE(OperatorBase::HasArgument("M"), "Argument `M` is missing.");
37  CAFFE_ENFORCE(OperatorBase::HasArgument("N"), "Argument `N` is missing.");
38  }
39 
40  bool RunOnDevice() override {
41  const auto& A = Input(0);
42  const auto& B = Input(1);
43 
44  CAFFE_ENFORCE(A.dim() == 2, A.dim());
45 
46  int64_t A_size = A.numel();
47  int64_t B_size = B.numel();
48 
49  CAFFE_ENFORCE(
50  K_ * M_ == A_size,
51  "Argument `K` and `M` do not agree with the size of A.");
52 
53  CAFFE_ENFORCE(
54  B_size % (K_ * N_) == 0,
55  "Argument `K` and `N` do not agree with the size of B.");
56 
57  int64_t D_ = B_size / (K_ * N_);
58 
59  int64_t C_size = D_ * M_ * N_;
60  auto* C = Output(0, vector<int64_t>{C_size}, at::dtype<T>());
61 
62  int64_t B_stride = K_ * N_;
63  int64_t C_stride = M_ * N_;
64 
65  const T* A_data = A.template data<T>();
66  const T* B_data = B.template data<T>();
67  T* C_data = C->template mutable_data<T>();
68 
69  for (int64_t B_index = 0; B_index < B_size; B_index += B_stride) {
70  math::Gemm<T, Context, Engine>(
71  CblasTrans,
72  CblasNoTrans,
73  M_, N_, K_, 1,
74  A_data,
75  B_data + B_index,
76  0,
77  C_data,
78  &context_);
79  C_data += C_stride;
80  }
81 
82  return true;
83  }
84 
85  protected:
86  int64_t K_;
87  int64_t M_;
88  int64_t N_;
89 };
90 
91 template <typename T, class Context, class Engine = DefaultEngine>
92 class TTContractionGradientOp final : public Operator<Context> {
93  public:
94  USE_OPERATOR_CONTEXT_FUNCTIONS;
95  TTContractionGradientOp(const OperatorDef& operator_def, Workspace* ws)
96  : Operator<Context>(operator_def, ws),
97  K_(OperatorBase::GetSingleArgument<int64_t>("K", 0)),
98  M_(OperatorBase::GetSingleArgument<int64_t>("M", 0)),
99  N_(OperatorBase::GetSingleArgument<int64_t>("N", 0)) {}
100 
101  bool RunOnDevice() override {
102  const auto& G = Input(0);
103  const auto& A = Input(1);
104  const auto& B = Input(2);
105 
106  int64_t G_size = G.numel();
107  int64_t D_ = G_size / (M_ * N_);
108 
109  int64_t dB_size = D_ * K_ * N_;
110 
111  auto* dA = Output(0, A.sizes(), at::dtype<T>());
112  auto* dB = Output(1, B.sizes(), at::dtype<T>());
113 
114  int64_t B_stride = K_ * N_;
115  int64_t G_stride = M_ * N_;
116 
117  const T* G_data = G.template data<T>();
118  const T* A_data = A.template data<T>();
119  const T* B_data = B.template data<T>();
120 
121  T* dA_data = dA->template mutable_data<T>();
122  T* dB_data = dB->template mutable_data<T>();
123 
124  const T* G_ptr = G_data;
125  for (int64_t B_index = 0; B_index < dB_size; B_index += B_stride) {
126  math::Gemm<T, Context, Engine>(
127  CblasNoTrans,
128  CblasTrans,
129  K_, M_, N_, 1,
130  B_data + B_index,
131  G_ptr,
132  B_index == 0 ? 0 : 1,
133  dA_data,
134  &context_);
135  G_ptr += G_stride;
136  }
137 
138  G_ptr = G_data;
139  for (int64_t B_index = 0; B_index < dB_size; B_index += B_stride) {
140  math::Gemm<T, Context, Engine>(
141  CblasNoTrans,
142  CblasNoTrans,
143  K_, N_, M_, 1,
144  A_data,
145  G_ptr,
146  0,
147  dB_data + B_index,
148  &context_);
149  G_ptr += G_stride;
150  }
151 
152  return true;
153  }
154 
155  protected:
156  int64_t K_;
157  int64_t M_;
158  int64_t N_;
159 };
160 
161 } // namespace caffe2
162 
163 #endif // CAFFE2_OPERATORS_TT_CONTRACTION_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
Definition: static.cpp:52
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:64
Definition: static.cpp:58
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:70