Caffe2 - C++ API
A deep learning, cross platform ML framework
channel_shuffle_op.h
1 
17 #pragma once
18 #include "caffe2/operators/conv_pool_op_base.h"
19 
20 namespace caffe2 {
21 
22 template <typename Context>
23 class ChannelShuffleOp final : public ConvPoolOpBase<Context> {
24  public:
25  USE_OPERATOR_FUNCTIONS(Context);
26  ChannelShuffleOp(const OperatorDef& operator_def, Workspace* ws)
27  : ConvPoolOpBase<Context>(operator_def, ws) {
28  OPERATOR_NEEDS_FEATURE(
29  this->order_ == StorageOrder::NCHW,
30  "ChannelShuffleOp only supports NCHW order");
31  }
32 
33  bool RunOnDeviceWithOrderNCHW() override {
34  const auto& X = Input(0);
35  auto* Y = Output(0);
36  Y->ResizeLike(X);
37  const auto C = X.dim32(1);
38  CAFFE_ENFORCE(C % this->group_ == 0, "");
39  const auto K = C / this->group_;
40  const auto S = X.dim32(2) * X.dim32(3);
41  const auto G = this->group_;
42  for (auto n = 0; n < X.dim32(0); ++n) {
43  for (auto g = 0; g < G; ++g) {
44  // Scatter the group g block (of size KxS) to output channels
45  // g + 0 * G, g + 1 * G, g + 2 * G, g + G * (K - 1) etc.
46  math::CopyMatrix<Context>(
47  X.itemsize(),
48  K,
49  S,
50  X.template data<float>() + g * K * S + n * C * S,
51  S,
52  Y->template mutable_data<float>() + g * S + n * C * S,
53  G * S,
54  &context_,
55  X.meta().copy());
56  }
57  }
58  return true;
59  }
60 };
61 
62 template <typename Context>
63 class ChannelShuffleGradientOp final : public ConvPoolOpBase<Context> {
64  public:
65  USE_OPERATOR_FUNCTIONS(Context);
66  ChannelShuffleGradientOp(const OperatorDef& operator_def, Workspace* ws)
67  : ConvPoolOpBase<Context>(operator_def, ws) {
68  OPERATOR_NEEDS_FEATURE(
69  this->order_ == StorageOrder::NCHW,
70  "ChannelShuffleOp only supports NCHW order");
71  }
72 
73  bool RunOnDeviceWithOrderNCHW() override {
74  const auto& dY = Input(0);
75  auto* dX = Output(0);
76  dX->ResizeLike(dY);
77  const auto C = dY.dim32(1);
78  CAFFE_ENFORCE(C % this->group_ == 0, "");
79  const auto K = C / this->group_;
80  const auto S = dY.dim32(2) * dY.dim32(3);
81  const auto G = this->group_;
82  for (auto n = 0; n < dY.dim32(0); ++n) {
83  for (auto g = 0; g < G; ++g) {
84  // Gather the group g block (of size KxS) from output channels
85  // g + 0 * G, g + 1 * G, g + 2 * G, g + G * (K - 1) etc.
86  math::CopyMatrix<Context>(
87  dY.itemsize(),
88  K,
89  S,
90  dY.template data<float>() + g * S + n * C * S,
91  G * S,
92  dX->template mutable_data<float>() + g * K * S + n * C * S,
93  S,
94  &context_,
95  dY.meta().copy());
96  }
97  }
98  return true;
99  }
100 };
101 }
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.