1 #ifndef CAFFE2_OPERATORS_CHANNEL_SHUFFLE_OP_H_ 2 #define CAFFE2_OPERATORS_CHANNEL_SHUFFLE_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/logging.h" 6 #include "caffe2/core/operator.h" 10 template <
typename T,
class Context>
11 class ChannelShuffleOp final :
public Operator<Context> {
13 USE_OPERATOR_CONTEXT_FUNCTIONS;
15 template <
class... Args>
16 explicit ChannelShuffleOp(Args&&... args)
17 : Operator<Context>(
std::forward<Args>(args)...),
18 order_(StringToStorageOrder(
19 this->template GetSingleArgument<
std::string>(
"order",
"NCHW"))),
20 OP_SINGLE_ARG(int,
"group", group_, 1) {
21 CAFFE_ENFORCE_NE(order_, StorageOrder::UNKNOWN);
24 bool RunOnDevice()
override {
25 return order_ == StorageOrder::NCHW ? RunOnDeviceWithOrderNCHW()
26 : RunOnDeviceWithOrderNHWC();
29 bool RunOnDeviceWithOrderNCHW();
31 bool RunOnDeviceWithOrderNHWC();
34 const StorageOrder order_;
38 template <
typename T,
class Context>
39 class ChannelShuffleGradientOp final :
public Operator<Context> {
41 USE_OPERATOR_CONTEXT_FUNCTIONS;
43 template <
class... Args>
44 explicit ChannelShuffleGradientOp(Args&&... args)
45 : Operator<Context>(
std::forward<Args>(args)...),
46 order_(StringToStorageOrder(
47 this->template GetSingleArgument<
std::string>(
"order",
"NCHW"))),
48 OP_SINGLE_ARG(int,
"group", group_, 1) {
49 CAFFE_ENFORCE_NE(order_, StorageOrder::UNKNOWN);
52 bool RunOnDevice()
override {
53 return order_ == StorageOrder::NCHW ? RunOnDeviceWithOrderNCHW()
54 : RunOnDeviceWithOrderNHWC();
57 bool RunOnDeviceWithOrderNCHW();
59 bool RunOnDeviceWithOrderNHWC();
62 const StorageOrder order_;
68 #endif // CAFFE2_OPERATORS_CHANNEL_SHUFFLE_OP_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...