Caffe2 - C++ API
A deep learning, cross platform ML framework
accumulate_op.h
1 #ifndef CAFFE2_OPERATORS_ACCUMULATE_OP_H_
2 #define CAFFE2_OPERATORS_ACCUMULATE_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/utils/math.h"
7 
8 namespace caffe2 {
9 
10 template <typename T, class Context>
11 class AccumulateOp final : public Operator<Context> {
12  public:
13  template <class... Args>
14  explicit AccumulateOp(Args&&... args)
15  : Operator<Context>(std::forward<Args>(args)...),
16  gamma_(static_cast<T>(
17  this->template GetSingleArgument<float>("gamma", 1.0))) {}
18  USE_OPERATOR_CONTEXT_FUNCTIONS;
19 
20  bool RunOnDevice() override {
21  auto& input = Input(0);
22 
23  // TODO: the operator depends on output being set to 0 before the run
24  auto* output = Output(0, input.sizes(), at::dtype<T>());
25  math::Axpby<T, T, Context>(
26  input.numel(),
27  static_cast<T>(1),
28  input.template data<T>(),
29  gamma_,
30  output->template mutable_data<T>(),
31  &context_);
32  return true;
33  }
34 
35  protected:
36  T gamma_;
37 };
38 
39 } // namespace caffe2
40 
41 #endif // CAFFE2_OPERATORS_ACCUMULATE_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13