1 #include <caffe2/ideep/ideep_utils.h> 5 USE_IDEEP_DEF_ALIASES();
7 static inline itensor::dims CanonicalDims(itensor::dims adims, int32_t axis) {
8 CAFFE_ENFORCE(axis < (int32_t)adims.size(),
"Invalid axis!");
9 CAFFE_ENFORCE(axis > (int32_t)-adims.size(),
"Invalid axis!");
10 if (adims.size() == 2 || axis == 1) {
14 axis += (int32_t)adims.size();
17 auto dim0 = std::accumulate(
21 std::multiplies<itensor::dim_t>());
22 auto dim1 = std::accumulate(
23 adims.begin() + axis, adims.end(), 1, std::multiplies<itensor::dim_t>());
24 return itensor::dims({dim0, dim1});
29 USE_IDEEP_DEF_ALIASES();
30 USE_IDEEP_OPERATOR_FUNCTIONS();
34 axis_(OperatorBase::GetSingleArgument<int32_t>(
"axis", 1)),
35 axis_w_(OperatorBase::GetSingleArgument<int32_t>(
"axis_w", 1)) {}
38 bool RunOnDevice()
override {
39 const auto& X = Input(INPUT);
40 const auto& filter = Input(FILTER);
41 auto* Y = Output(OUTPUT);
44 auto X_dims = CanonicalDims(X_in.get_dims(), axis_);
45 if (X_in.get_dims() != X_dims) {
49 itensor filter_in = filter;
50 auto filter_dims = CanonicalDims(filter_in.get_dims(), axis_w_);
51 if (filter_in.get_dims() != filter_dims) {
52 filter_in.reshape(filter_dims);
55 if (InputSize() > BIAS) {
56 ideep::inner_product_forward::compute(X_in, filter_in, Input(BIAS), *Y);
58 ideep::inner_product_forward::compute(X_in, filter_in, *Y);
68 INPUT_TAGS(INPUT, FILTER, BIAS);
74 USE_IDEEP_DEF_ALIASES();
75 USE_IDEEP_OPERATOR_FUNCTIONS();
79 axis_(OperatorBase::GetSingleArgument<int32_t>(
"axis", 1)),
80 axis_w_(OperatorBase::GetSingleArgument<int32_t>(
"axis_w", 1)) {}
84 const auto& X = Input(INPUT);
85 const auto& filter = Input(FILTER);
86 const auto& dY = Input(OUTPUT_GRAD);
87 auto* dfilter = Output(FILTER_GRAD);
88 auto* dbias = Output(BIAS_GRAD);
91 auto X_dims = CanonicalDims(X_in.get_dims(), axis_);
92 if (X_in.get_dims() != X_dims) {
96 itensor filter_in = filter;
97 auto filter_dims = CanonicalDims(filter_in.get_dims(), axis_w_);
98 if (filter_in.get_dims() != filter_dims) {
99 filter_in.reshape(filter_dims);
102 ideep::inner_product_backward_weights::compute(X_in, dY, *dfilter, *dbias);
108 if (dfilter->get_dims() != filter.get_dims()) {
109 dfilter->reshape(filter.get_dims());
112 if (OutputSize() > INPUT_GRAD) {
113 ideep::inner_product_backward_data::compute(
114 dY, filter_in, X.get_dims(), *Output(INPUT_GRAD));
124 INPUT_TAGS(INPUT, FILTER, OUTPUT_GRAD);
125 OUTPUT_TAGS(FILTER_GRAD, BIAS_GRAD, INPUT_GRAD);
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool RunOnDevice() override