Caffe2 - C++ API
A deep learning, cross platform ML framework
fully_connected_op.cc
1 #include <caffe2/ideep/ideep_utils.h>
2 
3 namespace caffe2 {
4 
5 USE_IDEEP_DEF_ALIASES();
6 
7 static inline itensor::dims CanonicalDims(itensor::dims adims, int32_t axis) {
8  CAFFE_ENFORCE(axis < (int32_t)adims.size(), "Invalid axis!");
9  CAFFE_ENFORCE(axis > (int32_t)-adims.size(), "Invalid axis!");
10  if (adims.size() == 2 || axis == 1) {
11  return adims;
12  }
13  if (axis < 0) {
14  axis += (int32_t)adims.size();
15  }
16 
17  auto dim0 = std::accumulate(
18  adims.begin(),
19  adims.begin() + axis,
20  1,
21  std::multiplies<itensor::dim_t>());
22  auto dim1 = std::accumulate(
23  adims.begin() + axis, adims.end(), 1, std::multiplies<itensor::dim_t>());
24  return itensor::dims({dim0, dim1});
25 }
26 
27 class IDEEPFullyConnectedOp final : public IDEEPOperator {
28  public:
29  USE_IDEEP_DEF_ALIASES();
30  USE_IDEEP_OPERATOR_FUNCTIONS();
31 
32  IDEEPFullyConnectedOp(const OperatorDef& operator_def, Workspace* ws)
33  : IDEEPOperator(operator_def, ws),
34  axis_(OperatorBase::GetSingleArgument<int32_t>("axis", 1)),
35  axis_w_(OperatorBase::GetSingleArgument<int32_t>("axis_w", 1)) {}
36  ~IDEEPFullyConnectedOp() override {}
37 
38  bool RunOnDevice() override {
39  const auto& X = Input(INPUT);
40  const auto& filter = Input(FILTER);
41  auto* Y = Output(OUTPUT);
42 
43  itensor X_in = X;
44  auto X_dims = CanonicalDims(X_in.get_dims(), axis_);
45  if (X_in.get_dims() != X_dims) {
46  X_in.reshape(X_dims);
47  }
48 
49  itensor filter_in = filter;
50  auto filter_dims = CanonicalDims(filter_in.get_dims(), axis_w_);
51  if (filter_in.get_dims() != filter_dims) {
52  filter_in.reshape(filter_dims);
53  }
54 
55  if (InputSize() > BIAS) {
56  ideep::inner_product_forward::compute(X_in, filter_in, Input(BIAS), *Y);
57  } else {
58  ideep::inner_product_forward::compute(X_in, filter_in, *Y);
59  }
60 
61  return true;
62  }
63 
64  private:
65  size_t axis_{1};
66  size_t axis_w_{1};
67 
68  INPUT_TAGS(INPUT, FILTER, BIAS);
69  OUTPUT_TAGS(OUTPUT);
70 };
71 
73  public:
74  USE_IDEEP_DEF_ALIASES();
75  USE_IDEEP_OPERATOR_FUNCTIONS();
76 
77  IDEEPFullyConnectedGradientOp(const OperatorDef& operator_def, Workspace* ws)
78  : IDEEPOperator(operator_def, ws),
79  axis_(OperatorBase::GetSingleArgument<int32_t>("axis", 1)),
80  axis_w_(OperatorBase::GetSingleArgument<int32_t>("axis_w", 1)) {}
81  ~IDEEPFullyConnectedGradientOp() override {}
82 
83  bool RunOnDevice() override {
84  const auto& X = Input(INPUT);
85  const auto& filter = Input(FILTER);
86  const auto& dY = Input(OUTPUT_GRAD);
87  auto* dfilter = Output(FILTER_GRAD);
88  auto* dbias = Output(BIAS_GRAD);
89 
90  itensor X_in = X;
91  auto X_dims = CanonicalDims(X_in.get_dims(), axis_);
92  if (X_in.get_dims() != X_dims) {
93  X_in.reshape(X_dims);
94  }
95 
96  itensor filter_in = filter;
97  auto filter_dims = CanonicalDims(filter_in.get_dims(), axis_w_);
98  if (filter_in.get_dims() != filter_dims) {
99  filter_in.reshape(filter_dims);
100  }
101 
102  ideep::inner_product_backward_weights::compute(X_in, dY, *dfilter, *dbias);
103 
108  if (dfilter->get_dims() != filter.get_dims()) {
109  dfilter->reshape(filter.get_dims());
110  }
111 
112  if (OutputSize() > INPUT_GRAD) {
113  ideep::inner_product_backward_data::compute(
114  dY, filter_in, X.get_dims(), *Output(INPUT_GRAD));
115  }
116 
117  return true;
118  }
119 
120  private:
121  size_t axis_{1};
122  size_t axis_w_{1};
123 
124  INPUT_TAGS(INPUT, FILTER, OUTPUT_GRAD);
125  OUTPUT_TAGS(FILTER_GRAD, BIAS_GRAD, INPUT_GRAD);
126 };
127 
128 REGISTER_IDEEP_OPERATOR(FC, IDEEPFullyConnectedOp);
129 REGISTER_IDEEP_OPERATOR(FCGradient, IDEEPFullyConnectedGradientOp);
130 
131 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: OpClasses.h:566