Caffe2 - C++ API
A deep learning, cross platform ML framework
tt_contraction_op.h
1 
17 #ifndef CAFFE2_OPERATORS_TT_CONTRACTION_OP_H_
18 #define CAFFE2_OPERATORS_TT_CONTRACTION_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/operator.h"
22 #include "caffe2/utils/math.h"
23 
24 namespace caffe2 {
25 
26 template <typename T, class Context, class Engine = DefaultEngine>
27 class TTContractionOp final : public Operator<Context> {
28  public:
29  USE_OPERATOR_CONTEXT_FUNCTIONS;
30  TTContractionOp(const OperatorDef& operator_def, Workspace* ws)
31  : Operator<Context>(operator_def, ws),
32  K_(OperatorBase::GetSingleArgument<TIndex>("K", 0)),
33  M_(OperatorBase::GetSingleArgument<TIndex>("M", 0)),
34  N_(OperatorBase::GetSingleArgument<TIndex>("N", 0)) {
35  CAFFE_ENFORCE(OperatorBase::HasArgument("K"), "Argument `K` is missing.");
36  CAFFE_ENFORCE(OperatorBase::HasArgument("M"), "Argument `M` is missing.");
37  CAFFE_ENFORCE(OperatorBase::HasArgument("N"), "Argument `N` is missing.");
38  }
39 
40  bool RunOnDevice() override {
41  const auto& A = Input(0);
42  const auto& B = Input(1);
43  auto* C = Output(0);
44 
45  CAFFE_ENFORCE(A.ndim() == 2, A.ndim());
46 
47  TIndex A_size = A.size_from_dim(0);
48  TIndex B_size = B.size_from_dim(0);
49 
50  CAFFE_ENFORCE(
51  K_ * M_ == A_size,
52  "Argument `K` and `M` do not agree with the size of A.");
53 
54  CAFFE_ENFORCE(
55  B_size % (K_ * N_) == 0,
56  "Argument `K` and `N` do not agree with the size of B.");
57 
58  TIndex D_ = B_size / (K_ * N_);
59 
60  TIndex C_size = D_ * M_ * N_;
61  C->Resize(vector<TIndex>{C_size});
62 
63  TIndex B_stride = K_ * N_;
64  TIndex C_stride = M_ * N_;
65 
66  const T* A_data = A.template data<T>();
67  const T* B_data = B.template data<T>();
68  T* C_data = C->template mutable_data<T>();
69 
70  for (TIndex B_index = 0; B_index < B_size; B_index += B_stride) {
71  math::Gemm<T, Context, Engine>(
72  CblasTrans,
73  CblasNoTrans,
74  M_, N_, K_, 1,
75  A_data,
76  B_data + B_index,
77  0,
78  C_data,
79  &context_);
80  C_data += C_stride;
81  }
82 
83  return true;
84  }
85 
86  protected:
87  TIndex K_;
88  TIndex M_;
89  TIndex N_;
90 };
91 
92 template <typename T, class Context, class Engine = DefaultEngine>
93 class TTContractionGradientOp final : public Operator<Context> {
94  public:
95  USE_OPERATOR_CONTEXT_FUNCTIONS;
96  TTContractionGradientOp(const OperatorDef& operator_def, Workspace* ws)
97  : Operator<Context>(operator_def, ws),
98  K_(OperatorBase::GetSingleArgument<TIndex>("K", 0)),
99  M_(OperatorBase::GetSingleArgument<TIndex>("M", 0)),
100  N_(OperatorBase::GetSingleArgument<TIndex>("N", 0)) {}
101 
102  bool RunOnDevice() override {
103  const auto& G = Input(0);
104  const auto& A = Input(1);
105  const auto& B = Input(2);
106  auto* dA = Output(0);
107  auto* dB = Output(1);
108 
109  TIndex G_size = G.size_from_dim(0);
110  TIndex D_ = G_size / (M_ * N_);
111 
112  TIndex dB_size = D_ * K_ * N_;
113 
114  dA->Resize(A.dims());
115  dB->Resize(B.dims());
116 
117  TIndex B_stride = K_ * N_;
118  TIndex G_stride = M_ * N_;
119 
120  const T* G_data = G.template data<T>();
121  const T* A_data = A.template data<T>();
122  const T* B_data = B.template data<T>();
123 
124  T* dA_data = dA->template mutable_data<T>();
125  T* dB_data = dB->template mutable_data<T>();
126 
127  const T* G_ptr = G_data;
128  for (TIndex B_index = 0; B_index < dB_size; B_index += B_stride) {
129  math::Gemm<T, Context, Engine>(
130  CblasNoTrans,
131  CblasTrans,
132  K_, M_, N_, 1,
133  B_data + B_index,
134  G_ptr,
135  B_index == 0 ? 0 : 1,
136  dA_data,
137  &context_);
138  G_ptr += G_stride;
139  }
140 
141  G_ptr = G_data;
142  for (TIndex B_index = 0; B_index < dB_size; B_index += B_stride) {
143  math::Gemm<T, Context, Engine>(
144  CblasNoTrans,
145  CblasNoTrans,
146  K_, N_, M_, 1,
147  A_data,
148  G_ptr,
149  0,
150  dB_data + B_index,
151  &context_);
152  G_ptr += G_stride;
153  }
154 
155  return true;
156  }
157 
158  protected:
159  TIndex K_;
160  TIndex M_;
161  TIndex N_;
162 };
163 
164 } // namespace caffe2
165 
166 #endif // CAFFE2_OPERATORS_TT_CONTRACTION_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:52