Caffe2 - C++ API
A deep learning, cross platform ML framework
channel_shuffle_op.cc
1 #include <caffe2/ideep/operators/conv_pool_base_op.h>
2 
3 namespace caffe2 {
4 
5 class ChannelShuffleOp final : public IDEEPConvPoolOpBase {
6  public:
7  USE_IDEEP_DEF_ALIASES();
8  USE_IDEEP_CONV_POOL_BASE_FUNCTIONS();
9 
10  ChannelShuffleOp(const OperatorDef& operator_def, Workspace* ws)
11  : IDEEPConvPoolOpBase(operator_def, ws) {}
12 
13  bool RunOnDeviceWithOrderNCHW() override {
14  const auto& X = Input(INPUT);
15  auto* Y = Output(OUTPUT);
16 
17  ideep::channel_shuffle_forward::compute(X, *Y, group_);
18 
19  return true;
20  }
21 
22  private:
23  INPUT_TAGS(INPUT);
24  OUTPUT_TAGS(OUTPUT);
25 };
26 
28  public:
29  USE_IDEEP_DEF_ALIASES();
30  USE_IDEEP_CONV_POOL_BASE_FUNCTIONS();
31 
32  ChannelShuffleGradientOp(const OperatorDef& operator_def, Workspace* ws)
33  : IDEEPConvPoolOpBase(operator_def, ws) {}
34 
35  bool RunOnDeviceWithOrderNCHW() override {
36  const auto& dY = Input(OUTPUT_GRAD);
37  auto* dX = Output(INPUT_GRAD);
38 
39  ideep::channel_shuffle_backward::compute(dY, *dX, group_);
40 
41  return true;
42  }
43 
44  private:
45  INPUT_TAGS(OUTPUT_GRAD);
46  OUTPUT_TAGS(INPUT_GRAD);
47 };
48 
49 
50 REGISTER_IDEEP_OPERATOR(ChannelShuffle, ChannelShuffleOp);
51 REGISTER_IDEEP_OPERATOR(ChannelShuffleGradient, ChannelShuffleGradientOp);
52 
53 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13