Caffe2 - C++ API
A deep learning, cross platform ML framework
conv_transpose_op.h
1 
17 #ifndef CAFFE2_OPERATORS_CONV_TRANSPOSE_OP_H_
18 #define CAFFE2_OPERATORS_CONV_TRANSPOSE_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/operator.h"
22 #include "caffe2/operators/conv_transpose_unpool_op_base.h"
23 
24 namespace caffe2 {
25 
26 template <typename T, class Context>
27 class ConvTransposeOp final : public ConvTransposeUnpoolBase<Context> {
28  public:
29  USE_CONV_TRANSPOSE_UNPOOL_BASE_FUNCTIONS(Context);
30  ConvTransposeOp(const OperatorDef& operator_def, Workspace* ws)
31  : ConvTransposeUnpoolBase<Context>(operator_def, ws) {}
32 
33  bool RunOnDeviceWithOrderNCHW() override;
34  bool RunOnDeviceWithOrderNHWC() override;
35 
36  private:
37  Tensor<Context> col_buffer_;
38  Tensor<Context> bias_multiplier_;
39  // Input: X, W, b
40  // Output: Y
41  INPUT_TAGS(INPUT, FILTER, BIAS);
42 };
43 
44 template <typename T, class Context>
45 class ConvTransposeGradientOp final : public ConvTransposeUnpoolBase<Context> {
46  public:
47  USE_CONV_TRANSPOSE_UNPOOL_BASE_FUNCTIONS(Context);
48  ConvTransposeGradientOp(const OperatorDef& operator_def, Workspace* ws)
49  : ConvTransposeUnpoolBase<Context>(operator_def, ws),
50  no_bias_(OperatorBase::GetSingleArgument<bool>("no_bias", false)) {
51  CAFFE_ENFORCE(
52  !(no_bias_ && OutputSize() == 3),
53  "If bias is not present, you should not have 3 grad output.");
54  }
55 
56  bool RunOnDeviceWithOrderNCHW() override;
57  bool RunOnDeviceWithOrderNHWC() override;
58 
59  private:
60  Tensor<Context> col_buffer_;
61  Tensor<Context> bias_multiplier_;
62  const bool no_bias_;
63  // input: X, W, dY
64  // output: dW, optionally db and dX
65  INPUT_TAGS(INPUT, FILTER, OUTPUT_GRAD);
66  OUTPUT_TAGS(FILTER_GRAD, BIAS_OR_INPUT_GRAD, INPUT_GRAD);
67 };
68 
69 } // namespace caffe2
70 
71 #endif // CAFFE2_OPERATORS_CONV_TRANSPOSE_OP_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:109
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.