Caffe2 - C++ API
A deep learning, cross platform ML framework
resize_op.h
1 
18 #pragma once
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/operator.h"
22 
23 namespace caffe2 {
24 
25 template <typename T, class Context>
26 class ResizeNearestOp final : public Operator<Context> {
27  public:
28  ResizeNearestOp(const OperatorDef& operator_def, Workspace* ws)
29  : Operator<Context>(operator_def, ws), width_scale_(1), height_scale_(1) {
30  if (HasArgument("width_scale")) {
31  width_scale_ = static_cast<T>(
32  OperatorBase::GetSingleArgument<float>("width_scale", 1));
33  }
34  if (HasArgument("height_scale")) {
35  height_scale_ = static_cast<T>(
36  OperatorBase::GetSingleArgument<float>("height_scale", 1));
37  }
38  CAFFE_ENFORCE_GT(width_scale_, 0);
39  CAFFE_ENFORCE_GT(height_scale_, 0);
40  }
41  USE_OPERATOR_CONTEXT_FUNCTIONS;
42 
43  bool RunOnDevice() override;
44 
45  protected:
46  T width_scale_;
47  T height_scale_;
48 };
49 
50 template <typename T, class Context>
51 class ResizeNearestGradientOp final : public Operator<Context> {
52  public:
53  ResizeNearestGradientOp(const OperatorDef& operator_def, Workspace* ws)
54  : Operator<Context>(operator_def, ws), width_scale_(1), height_scale_(1) {
55  width_scale_ = static_cast<T>(
56  OperatorBase::GetSingleArgument<float>("width_scale", 1));
57  height_scale_ = static_cast<T>(
58  OperatorBase::GetSingleArgument<float>("height_scale", 1));
59  CAFFE_ENFORCE_GT(width_scale_, 0);
60  CAFFE_ENFORCE_GT(height_scale_, 0);
61  }
62  USE_OPERATOR_CONTEXT_FUNCTIONS;
63 
64  bool RunOnDevice() override;
65 
66  protected:
67  T width_scale_;
68  T height_scale_;
69 };
70 
71 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:52