1 #ifndef CAFFE2_OPERATORS_ROW_MUL_H_ 2 #define CAFFE2_OPERATORS_ROW_MUL_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/logging.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/utils/math.h" 13 template <
typename T,
class Context>
16 USE_OPERATOR_CONTEXT_FUNCTIONS;
19 bool RunOnDevice()
override {
23 auto* output = Output(0, mat.sizes(), at::dtype<T>());
24 T* output_data = output->template mutable_data<T>();
25 const T* mat_data = mat.template data<T>();
26 const T* w_data = w.template data<T>();
32 "Length of w should be equal to the first dim of mat");
34 auto block_size = mat.size_from_dim(1);
35 for (
int i = 0; i < w.numel(); i++) {
36 size_t offset = i * block_size;
37 for (
int j = 0; j < block_size; j++) {
38 output_data[offset + j] = mat_data[offset + j] * w_data[i];
47 template <
typename T,
class Context>
50 USE_OPERATOR_CONTEXT_FUNCTIONS;
53 bool RunOnDevice()
override {
57 int block_size = mat.size_from_dim(1);
59 auto* output = Output(0, {N}, at::dtype<T>());
60 T* output_data = output->template mutable_data<T>();
61 const T* mat_data = mat.template data<T>();
63 for (
int i = 0; i < N; i++) {
65 size_t offset = i * block_size;
66 for (
int j = 0; j < block_size; j++) {
67 output_data[i] += mat_data[offset + j];
76 #endif // CAFFE2_OPERATORS_ROW_MUL_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...