17 #ifndef CAFFE2_OPERATORS_TT_CONTRACTION_OP_H_ 18 #define CAFFE2_OPERATORS_TT_CONTRACTION_OP_H_ 20 #include "caffe2/core/context.h" 21 #include "caffe2/core/operator.h" 22 #include "caffe2/utils/math.h" 26 template <
typename T,
class Context,
class Engine = DefaultEngine>
29 USE_OPERATOR_CONTEXT_FUNCTIONS;
32 K_(OperatorBase::GetSingleArgument<int64_t>(
"K", 0)),
33 M_(OperatorBase::GetSingleArgument<int64_t>(
"M", 0)),
34 N_(OperatorBase::GetSingleArgument<int64_t>(
"N", 0)) {
40 bool RunOnDevice()
override {
44 CAFFE_ENFORCE(
A.dim() == 2,
A.dim());
46 int64_t A_size =
A.numel();
47 int64_t B_size =
B.numel();
51 "Argument `K` and `M` do not agree with the size of A.");
54 B_size % (K_ * N_) == 0,
55 "Argument `K` and `N` do not agree with the size of B.");
57 int64_t D_ = B_size / (K_ * N_);
59 int64_t C_size = D_ * M_ * N_;
60 auto*
C = Output(0, vector<int64_t>{C_size}, at::dtype<T>());
62 int64_t B_stride = K_ * N_;
63 int64_t C_stride = M_ * N_;
65 const T* A_data =
A.template data<T>();
66 const T* B_data =
B.template data<T>();
67 T* C_data =
C->template mutable_data<T>();
69 for (int64_t B_index = 0; B_index < B_size; B_index += B_stride) {
70 math::Gemm<T, Context, Engine>(
91 template <
typename T,
class Context,
class Engine = DefaultEngine>
94 USE_OPERATOR_CONTEXT_FUNCTIONS;
97 K_(OperatorBase::GetSingleArgument<int64_t>(
"K", 0)),
98 M_(OperatorBase::GetSingleArgument<int64_t>(
"M", 0)),
99 N_(OperatorBase::GetSingleArgument<int64_t>(
"N", 0)) {}
101 bool RunOnDevice()
override {
102 const auto& G =
Input(0);
106 int64_t G_size = G.numel();
107 int64_t D_ = G_size / (M_ * N_);
109 int64_t dB_size = D_ * K_ * N_;
111 auto* dA = Output(0,
A.sizes(), at::dtype<T>());
112 auto* dB = Output(1,
B.sizes(), at::dtype<T>());
114 int64_t B_stride = K_ * N_;
115 int64_t G_stride = M_ * N_;
117 const T* G_data = G.template data<T>();
118 const T* A_data =
A.template data<T>();
119 const T* B_data =
B.template data<T>();
121 T* dA_data = dA->template mutable_data<T>();
122 T* dB_data = dB->template mutable_data<T>();
124 const T* G_ptr = G_data;
125 for (int64_t B_index = 0; B_index < dB_size; B_index += B_stride) {
126 math::Gemm<T, Context, Engine>(
132 B_index == 0 ? 0 : 1,
139 for (int64_t B_index = 0; B_index < dB_size; B_index += B_stride) {
140 math::Gemm<T, Context, Engine>(
163 #endif // CAFFE2_OPERATORS_TT_CONTRACTION_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.