1 #ifndef CAFFE2_OPERATORS_CLIP_TENSOR_OP_H_ 2 #define CAFFE2_OPERATORS_CLIP_TENSOR_OP_H_ 5 #include "caffe2/core/context.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/core/tensor.h" 8 #include "caffe2/utils/math.h" 12 template <
typename Context>
15 USE_OPERATOR_CONTEXT_FUNCTIONS;
19 threshold_ = this->
template GetSingleArgument<float>(
"threshold", 0.0);
20 CAFFE_ENFORCE_GT(threshold_, 0,
"Threshold must be greater than 0");
23 bool RunOnDevice()
override {
24 const auto& input_tensor =
Input(0);
25 CAFFE_ENFORCE_GT(input_tensor.numel(), 0);
26 const auto& val =
Input(1);
27 CAFFE_ENFORCE_EQ(val.numel(), 1);
29 const auto* input_tensor_data = input_tensor.template data<float>();
30 const auto* val_data = val.template data<float>();
32 auto* clipped = Output(0, input_tensor.sizes(), at::dtype<float>());
33 float* clipped_tensor_data = clipped->template mutable_data<float>();
35 if (InputSize() > 2) {
36 const auto& additional_threshold =
Input(2);
37 CAFFE_ENFORCE_EQ(additional_threshold.numel(), 1);
39 threshold_ *= *(additional_threshold.template data<float>());
42 if (*val_data > threshold_) {
43 float ratio = threshold_ / *val_data;
45 math::Scale<float, float, Context>(
52 if (input_tensor_data != clipped_tensor_data) {
53 clipped->CopyFrom(input_tensor,
true);
66 #endif // CAFFE2_OPERATORS_CLIP_TENSOR_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...