Caffe2 - C++ API
A deep learning, cross platform ML framework
clip_tensor_op.h
1 #ifndef CAFFE2_OPERATORS_CLIP_TENSOR_OP_H_
2 #define CAFFE2_OPERATORS_CLIP_TENSOR_OP_H_
3 
4 #include <vector>
5 #include "caffe2/core/context.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/core/tensor.h"
8 #include "caffe2/utils/math.h"
9 
10 namespace caffe2 {
11 
12 template <typename Context>
13 class ClipTensorByScalingOp final : public Operator<Context> {
14  public:
15  USE_OPERATOR_CONTEXT_FUNCTIONS;
16 
17  ClipTensorByScalingOp(const OperatorDef& operator_def, Workspace* ws)
18  : Operator<Context>(operator_def, ws) {
19  threshold_ = this->template GetSingleArgument<float>("threshold", 0.0);
20  CAFFE_ENFORCE_GT(threshold_, 0, "Threshold must be greater than 0");
21  }
22 
23  bool RunOnDevice() override {
24  const auto& input_tensor = Input(0);
25  CAFFE_ENFORCE_GT(input_tensor.numel(), 0);
26  const auto& val = Input(1);
27  CAFFE_ENFORCE_EQ(val.numel(), 1);
28 
29  const auto* input_tensor_data = input_tensor.template data<float>();
30  const auto* val_data = val.template data<float>();
31 
32  auto* clipped = Output(0, input_tensor.sizes(), at::dtype<float>());
33  float* clipped_tensor_data = clipped->template mutable_data<float>();
34 
35  if (InputSize() > 2) {
36  const auto& additional_threshold = Input(2);
37  CAFFE_ENFORCE_EQ(additional_threshold.numel(), 1);
38 
39  threshold_ *= *(additional_threshold.template data<float>());
40  }
41 
42  if (*val_data > threshold_) {
43  float ratio = threshold_ / *val_data;
44 
45  math::Scale<float, float, Context>(
46  clipped->numel(),
47  ratio,
48  input_tensor_data,
49  clipped_tensor_data,
50  &context_);
51  } else {
52  if (input_tensor_data != clipped_tensor_data) {
53  clipped->CopyFrom(input_tensor, /*async*/ true);
54  }
55  }
56 
57  return true;
58  }
59 
60  private:
61  float threshold_;
62 };
63 
64 } // namespace caffe2
65 
66 #endif // CAFFE2_OPERATORS_CLIP_TENSOR_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13