Caffe2 - C++ API
A deep learning, cross platform ML framework
upsample_op.h
1 
17 #pragma once
18 
19 #include "caffe2/core/context.h"
20 #include "caffe2/core/operator.h"
21 
22 namespace caffe2 {
23 
24 template <typename T, class Context>
25 class UpsampleBilinearOp final : public Operator<Context> {
26  public:
27  template <class... Args>
28  explicit UpsampleBilinearOp(Args&&... args)
29  : Operator<Context>(std::forward<Args>(args)...),
30  width_scale_(1),
31  height_scale_(1) {
32  if (HasArgument("width_scale")) {
33  width_scale_ = static_cast<T>(
34  this->template GetSingleArgument<float>("width_scale", 1));
35  }
36  if (HasArgument("height_scale")) {
37  height_scale_ = static_cast<T>(
38  this->template GetSingleArgument<float>("height_scale", 1));
39  }
40  CAFFE_ENFORCE_GT(width_scale_, 0);
41  CAFFE_ENFORCE_GT(height_scale_, 0);
42  }
43  USE_OPERATOR_CONTEXT_FUNCTIONS;
44 
45  bool RunOnDevice() override;
46 
47  protected:
48  T width_scale_;
49  T height_scale_;
50 };
51 
52 template <typename T, class Context>
54  public:
55  template <class... Args>
56  explicit UpsampleBilinearGradientOp(Args&&... args)
57  : Operator<Context>(std::forward<Args>(args)...),
58  width_scale_(1),
59  height_scale_(1) {
60  width_scale_ = static_cast<T>(
61  this->template GetSingleArgument<float>("width_scale", 1));
62  height_scale_ = static_cast<T>(
63  this->template GetSingleArgument<float>("height_scale", 1));
64  CAFFE_ENFORCE_GT(width_scale_, 0);
65  CAFFE_ENFORCE_GT(height_scale_, 0);
66  }
67  USE_OPERATOR_CONTEXT_FUNCTIONS;
68 
69  bool RunOnDevice() override;
70 
71  protected:
72  T width_scale_;
73  T height_scale_;
74 };
75 
76 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:70