Caffe2 - C++ API
A deep learning, cross platform ML framework
resize_op.h
1 
2 #pragma once
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 
7 namespace caffe2 {
8 
9 template <typename T, class Context>
10 class ResizeNearestOp final : public Operator<Context> {
11  public:
12  template <class... Args>
13  explicit ResizeNearestOp(Args&&... args)
14  : Operator<Context>(std::forward<Args>(args)...),
15  width_scale_(1),
16  height_scale_(1),
17  order_(StringToStorageOrder(
18  this->template GetSingleArgument<std::string>("order", "NCHW"))) {
19  if (HasArgument("width_scale")) {
20  width_scale_ = static_cast<T>(
21  this->template GetSingleArgument<float>("width_scale", 1));
22  }
23  if (HasArgument("height_scale")) {
24  height_scale_ = static_cast<T>(
25  this->template GetSingleArgument<float>("height_scale", 1));
26  }
27 
28  CAFFE_ENFORCE_GT(width_scale_, 0);
29  CAFFE_ENFORCE_GT(height_scale_, 0);
30 
31  CAFFE_ENFORCE(order_ == StorageOrder::NCHW || order_ == StorageOrder::NHWC);
32  }
33 
34  USE_OPERATOR_CONTEXT_FUNCTIONS;
35 
36  bool RunOnDevice() override;
37  bool RunOnDeviceWithOrderNCHW();
38  bool RunOnDeviceWithOrderNHWC();
39 
40  protected:
41  T width_scale_;
42  T height_scale_;
43  StorageOrder order_;
44 };
45 
46 template <typename T, class Context>
48  public:
49  template <class... Args>
50  explicit ResizeNearestGradientOp(Args&&... args)
51  : Operator<Context>(std::forward<Args>(args)...),
52  width_scale_(1),
53  height_scale_(1),
54  order_(StringToStorageOrder(
55  this->template GetSingleArgument<std::string>("order", "NCHW"))) {
56  width_scale_ = static_cast<T>(
57  this->template GetSingleArgument<float>("width_scale", 1));
58  height_scale_ = static_cast<T>(
59  this->template GetSingleArgument<float>("height_scale", 1));
60 
61  CAFFE_ENFORCE_GT(width_scale_, 0);
62  CAFFE_ENFORCE_GT(height_scale_, 0);
63 
64  CAFFE_ENFORCE(order_ == StorageOrder::NCHW || order_ == StorageOrder::NHWC);
65  }
66 
67  USE_OPERATOR_CONTEXT_FUNCTIONS;
68 
69  bool RunOnDevice() override;
70  bool RunOnDeviceWithOrderNCHW();
71  bool RunOnDeviceWithOrderNHWC();
72 
73  protected:
74  T width_scale_;
75  T height_scale_;
76  StorageOrder order_;
77 };
78 
79 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:70