Caffe2 - C++ API
A deep learning, cross platform ML framework
scale_op.h
1 #ifndef CAFFE2_OPERATORS_SCALE_OP_H_
2 #define CAFFE2_OPERATORS_SCALE_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/utils/math.h"
7 
8 namespace caffe2 {
9 
10 template <class Context>
11 class ScaleOp final : public Operator<Context> {
12  public:
13  USE_OPERATOR_CONTEXT_FUNCTIONS;
14  template <class... Args>
15  explicit ScaleOp(Args&&... args)
16  : Operator<Context>(std::forward<Args>(args)...),
17  scale_(this->template GetSingleArgument<float>("scale", 1.0)) {}
18 
19  template <typename T>
20  bool DoRunWithType() {
21  auto& X = Input(0);
22 
23  auto* Y = Output(0, X.sizes(), at::dtype<T>());
24  math::Scale<float, T, Context>(
25  X.numel(),
26  scale_,
27  X.template data<T>(),
28  Y->template mutable_data<T>(),
29  &context_);
30  return true;
31  }
32 
33  bool RunOnDevice() override {
34  return DispatchHelper<TensorTypes<float>>::call(this, Input(0));
35  }
36 
37  protected:
38  float scale_;
39 };
40 
41 } // namespace caffe2
42 
43 #endif // CAFFE2_OPERATORS_SCALE_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13