Caffe2 - C++ API
A deep learning, cross platform ML framework
ensure_clipped_op.h
1 #pragma once
2 
3 #include "caffe2/core/operator.h"
4 #include "caffe2/utils/eigen_utils.h"
5 #include "caffe2/utils/math.h"
6 
7 namespace caffe2 {
8 
9 template <typename T, class Context>
10 class EnsureClippedOp final : public Operator<Context> {
11  public:
12  USE_OPERATOR_CONTEXT_FUNCTIONS;
13 
14  template <class... Args>
15  explicit EnsureClippedOp(Args&&... args)
16  : Operator<Context>(std::forward<Args>(args)...),
17  min_(std::numeric_limits<T>::lowest()),
18  max_(std::numeric_limits<T>::max()) {
19  if (HasArgument("min")) {
20  min_ = static_cast<T>(this->template GetSingleArgument<float>("min", 0));
21  }
22  if (HasArgument("max")) {
23  max_ = static_cast<T>(this->template GetSingleArgument<float>("max", 0));
24  }
25  }
26 
27  bool RunOnDevice() override {
28  if (InputSize() > INDICES) {
29  // spares gradient, selective checking clipping
30  CAFFE_ENFORCE_EQ(
31  Input(PARAM).size_from_dim(1),
32  Input(GRAD).size_from_dim(Input(INDICES).dim()));
34  this, Input(INDICES));
35  } else {
36  auto& X = Input(PARAM);
37 
38  auto* Y = Output(OUTPUT_PARAM, X.sizes(), at::dtype<float>());
39  EigenVectorMap<float>(Y->template mutable_data<float>(), Y->numel()) =
40  ConstEigenVectorMap<float>(X.template data<float>(), X.numel())
41  .cwiseMax(min_)
42  .cwiseMin(max_);
43  return true;
44  }
45  }
46 
47  template <typename SIndex>
48  bool DoRunWithType();
49 
50  protected:
51  T min_;
52  T max_;
53  INPUT_TAGS(PARAM, INDICES, GRAD);
54  OUTPUT_TAGS(OUTPUT_PARAM);
55 };
56 
57 } // namespace caffe2
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:70