Caffe2 - C++ API
A deep learning, cross platform ML framework
swish_op.h
1 
17 #pragma once
18 
19 #include "caffe2/core/operator.h"
20 #include "caffe2/utils/math.h"
21 
22 namespace caffe2 {
23 template <class Context>
24 class SwishGradientOp final : public Operator<Context> {
25  public:
26  USE_SIMPLE_CTOR_DTOR(SwishGradientOp)
27  USE_OPERATOR_CONTEXT_FUNCTIONS;
28 
29  template <typename T>
30  bool DoRunWithType();
31 
32  bool RunOnDevice() override {
33  return DispatchHelper<TensorTypes<float, double>>::call(this, Input(X));
34  }
35 
36  protected:
37  INPUT_TAGS(X, Y, DY);
38  OUTPUT_TAGS(DX);
39 };
40 
42  using GradientMakerBase::GradientMakerBase;
43  vector<OperatorDef> GetGradientDefs() override {
44  return SingleGradientDef(
45  "SwishGradient",
46  "",
47  vector<string>{I(0), O(0), GO(0)},
48  vector<string>{GI(0)});
49  }
50 };
51 
52 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.