Caffe2 - C++ API
A deep learning, cross platform ML framework
half_float_ops.h
1 #ifndef CAFFE2_OPERATORS_HALF_FLOAT_OPS_H_
2 #define CAFFE2_OPERATORS_HALF_FLOAT_OPS_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 
7 namespace caffe2 {
8 
9 template <class Context>
10 class FloatToHalfOp : public Operator<Context> {
11  public:
12  USE_OPERATOR_CONTEXT_FUNCTIONS;
13  USE_SIMPLE_CTOR_DTOR(FloatToHalfOp);
14 
15  bool RunOnDevice() override;
16 };
17 
18 template <class Context>
19 class HalfToFloatOp : public Operator<Context> {
20  public:
21  USE_OPERATOR_CONTEXT_FUNCTIONS;
22  USE_SIMPLE_CTOR_DTOR(HalfToFloatOp);
23 
24  bool RunOnDevice() override;
25 };
26 
27 class Float16ConstantFillOp : public Operator<CPUContext> {
28  public:
29  template <class... Args>
30  explicit Float16ConstantFillOp(Args&&... args)
31  : Operator<CPUContext>(std::forward<Args>(args)...),
32  shape_(this->template GetRepeatedArgument<int64_t>("shape")) {}
33 
34  USE_OPERATOR_FUNCTIONS(CPUContext);
35  virtual ~Float16ConstantFillOp() {}
36 
37  bool RunOnDevice() override;
38 
39  private:
40  vector<int64_t> shape_;
41 };
42 
43 class Float16UniformFillOp : public Operator<CPUContext> {
44  public:
45  template <class... Args>
46  explicit Float16UniformFillOp(Args&&... args)
47  : Operator<CPUContext>(std::forward<Args>(args)...),
48  shape_(this->template GetRepeatedArgument<int64_t>("shape")),
49  min_(this->template GetSingleArgument<float>("min", 0)),
50  max_(this->template GetSingleArgument<float>("max", 1)) {
51  if (InputSize() == 3) {
52  CAFFE_ENFORCE(
53  !this->template HasSingleArgumentOfType<float>("min"),
54  "Cannot set both min arg and min input blob");
55  CAFFE_ENFORCE(
56  !this->template HasSingleArgumentOfType<float>("max"),
57  "Cannot set both max arg and max input blob");
58  } else {
59  CAFFE_ENFORCE_LT(
60  min_, max_, "Max value should be bigger than min value.");
61  }
62  }
63 
64  USE_OPERATOR_FUNCTIONS(CPUContext);
65  virtual ~Float16UniformFillOp() {}
66 
67  bool RunOnDevice() override;
68 
69  private:
70  vector<int64_t> shape_;
71  float min_;
72  float max_;
73 };
74 
75 inline std::vector<TensorShape> Float16FillerTensorInference(
76  const OperatorDef& def,
77  const vector<TensorShape>& in) {
78  vector<TensorShape> out(1);
79  ArgumentHelper helper(def);
80  out[0].set_data_type(static_cast<TensorProto_DataType>(
81  helper.GetSingleArgument<int>("dtype", TensorProto_DataType_FLOAT16)));
82  auto shape = helper.GetRepeatedArgument<int>("shape");
83  for (int d : shape) {
84  out[0].add_dims(d);
85  }
86  return out;
87 }
88 
89 } // namespace caffe2
90 
91 #endif // CAFFE2_OPERATORS_HALF_FLOAT_OPS_H_
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
Definition: context.h:40
A helper class to index into arguments.
Definition: proto_utils.h:200
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13