Caffe2 - C++ API
A deep learning, cross platform ML framework
rowmul_op.h
1 
17 #ifndef CAFFE2_OPERATORS_ROW_MUL_H_
18 #define CAFFE2_OPERATORS_ROW_MUL_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/logging.h"
22 #include "caffe2/core/operator.h"
23 #include "caffe2/utils/math.h"
24 
25 namespace caffe2 {
26 
27 // A hacky version of Mul with broadcast
28 // RowMul([mat, w], [output])
29 template <typename T, class Context>
30 class RowMulOp : public Operator<Context> {
31  public:
32  USE_OPERATOR_CONTEXT_FUNCTIONS;
33  USE_SIMPLE_CTOR_DTOR(RowMulOp);
34 
35  bool RunOnDevice() override {
36  auto& mat = Input(0);
37  auto& w = Input(1);
38  auto* output = Output(0);
39 
40  output->ResizeLike(mat);
41  T* output_data = output->template mutable_data<T>();
42  const T* mat_data = mat.template data<T>();
43  const T* w_data = w.template data<T>();
44 
45  // Dimension checking
46  CAFFE_ENFORCE_EQ(
47  w.size(),
48  mat.dim32(0),
49  "Length of w should be equal to the first dim of mat");
50 
51  auto block_size = mat.size_from_dim(1);
52  for (int i = 0; i < w.size(); i++) {
53  size_t offset = i * block_size;
54  for (int j = 0; j < block_size; j++) {
55  output_data[offset + j] = mat_data[offset + j] * w_data[i];
56  }
57  }
58 
59  return true;
60  }
61 };
62 
63 // A hacky version
64 template <typename T, class Context>
65 class ReduceTailSumOp : public Operator<Context> {
66  public:
67  USE_OPERATOR_CONTEXT_FUNCTIONS;
68  USE_SIMPLE_CTOR_DTOR(ReduceTailSumOp);
69 
70  bool RunOnDevice() override {
71  auto& mat = Input(0);
72  auto* output = Output(0);
73 
74  int N = mat.dim32(0);
75  int block_size = mat.size_from_dim(1);
76 
77  output->Resize(N);
78  T* output_data = output->template mutable_data<T>();
79  const T* mat_data = mat.template data<T>();
80 
81  for (int i = 0; i < N; i++) {
82  output_data[i] = 0;
83  size_t offset = i * block_size;
84  for (int j = 0; j < block_size; j++) {
85  output_data[i] += mat_data[offset + j];
86  }
87  }
88  return true;
89  }
90 };
91 
92 } // namespace caffe2
93 
94 #endif // CAFFE2_OPERATORS_ROW_MUL_H_
Copyright (c) 2016-present, Facebook, Inc.