1 #ifndef CAFFE2_OPERATORS_TT_LINEAR_OP_H_ 2 #define CAFFE2_OPERATORS_TT_LINEAR_OP_H_ 6 #endif // CAFFE2_USE_MKL 10 #include "caffe2/core/context.h" 11 #include "caffe2/core/operator.h" 12 #include "caffe2/utils/eigen_utils.h" 13 #include "caffe2/utils/math.h" 17 template <
typename T,
class Context,
class Engine = DefaultEngine>
20 USE_OPERATOR_CONTEXT_FUNCTIONS;
21 template <
class... Args>
24 inp_sizes_(this->
template GetRepeatedArgument<int>(
"inp_sizes")),
25 out_sizes_(this->
template GetRepeatedArgument<int>(
"out_sizes")),
26 tt_ranks_(this->
template GetRepeatedArgument<int>(
"tt_ranks")),
27 Y_temp_(unique_ptr<Blob>(
new Blob())) {}
30 bool RunOnDevice()
override {
31 const auto& X =
Input(0);
32 const auto& b =
Input(1);
33 const auto& cores =
Input(2);
35 CAFFE_ENFORCE(X.dim() > 1,
"Number of dimensions in X: ", X.dim());
36 CAFFE_ENFORCE(b.dim() == 1,
"Number of dimensions in b: ", b.dim());
38 inp_sizes_.size() == out_sizes_.size(),
39 "inp_sizes has size: ",
41 ", out_sizes has size: ",
44 cores.dim() == 1,
"Number of dimensions in cores: ", cores.dim());
46 const int batch_size = X.dim() > 1 ? X.dim32(0) : 1;
49 const int d = inp_sizes_.size();
55 auto Y_buf = BlobGetMutableTensor(Y_temp_.get(), Context::GetDeviceType());
63 for (
int i = (d - 1); i >= 0; --i) {
64 int curr_rows = inp_sizes_[i] * tt_ranks_[i + 1];
65 int curr_cols = tt_ranks_[i] * out_sizes_[i];
68 Y_buf->Resize(Y_buf->numel() / curr_rows, curr_rows);
70 0, {Y_buf->numel() / curr_rows, curr_cols}, at::dtype<float>());
73 CAFFE_ENFORCE(Y_buf->numel() % curr_rows == 0, Y_buf->numel(), curr_rows);
75 cores_idx + curr_rows * curr_cols <= cores.numel(),
76 cores_idx + curr_rows * curr_cols,
80 math::Gemm<float, Context, Engine>(
83 Y_buf->numel() / curr_rows,
87 Y_buf->template data<float>(),
88 cores.template data<float>() + cores_idx,
90 Y->template mutable_data<float>(),
93 CAFFE_ENFORCE(Y->numel() % out_sizes_[i] == 0, Y->numel(), out_sizes_[i]);
96 auto Y_mat = EigenMatrixMap<float>(
97 Y->template mutable_data<float>(),
98 Y->numel() / out_sizes_[i],
100 Y_mat = ConstEigenMatrixMap<float>(
101 Y->template data<float>(),
103 Y->numel() / out_sizes_[i])
108 Y_buf->Resize(Y->dim32(0), Y->dim32(1));
109 context_.template CopyFromCPU<float>(
111 Y->template data<float>(),
112 Y_buf->template mutable_data<float>());
114 cores_idx += curr_rows * curr_cols;
118 auto Y_mat = EigenMatrixMap<float>(
119 Y->template mutable_data<float>(), batch_size, Y->numel() / batch_size);
120 Y_mat = ConstEigenMatrixMap<float>(
121 Y->template data<float>(), Y->numel() / batch_size, batch_size)
125 Y = Output(0, {batch_size, Y->numel() / batch_size}, at::dtype<float>());
128 int prod_out_sizes = 1;
129 for (
int i = 0; i < out_sizes_.size(); i++) {
130 prod_out_sizes *= out_sizes_[i];
133 Y->dim32(1) == prod_out_sizes,
134 "Output dimension of Y: ",
136 ", product of out_sizes: ",
140 if (bias_multiplier_.numel() != batch_size) {
145 at::dtype<T>().device(Context::GetDeviceType()));
146 math::Set<T, Context>(
149 bias_multiplier_.template mutable_data<T>(),
152 math::Gemm<T, Context, Engine>(
159 bias_multiplier_.template data<T>(),
160 b.template data<T>(),
162 Y->template mutable_data<T>(),
169 std::vector<int> inp_sizes_;
170 std::vector<int> out_sizes_;
171 std::vector<int> tt_ranks_;
172 std::unique_ptr<Blob> Y_temp_;
176 template <
typename T,
class Context,
class Engine = DefaultEngine>
179 USE_OPERATOR_CONTEXT_FUNCTIONS;
180 template <
class... Args>
183 ~TTLinearGradientOp() {}
185 bool RunOnDevice()
override {
190 Tensor bias_multiplier_{Context::GetDeviceType()};
195 #endif // CAFFE2_OPERATORS_TT_LINEAR_OP_H_ Blob is a general container that hosts a typed pointer.
void ReinitializeTensor(Tensor *tensor, at::IntArrayRef dims, at::TensorOptions options)
Reinitialize a Tensor to given dims and options if necessary, note that this will not do anything if ...
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...