Caffe2 - C++ API
A deep learning, cross platform ML framework
half_float_ops.h
1 
17 #ifndef CAFFE2_OPERATORS_HALF_FLOAT_OPS_H_
18 #define CAFFE2_OPERATORS_HALF_FLOAT_OPS_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/operator.h"
22 
23 namespace caffe2 {
24 
25 template <class Context>
26 class FloatToHalfOp : public Operator<Context> {
27  public:
28  USE_OPERATOR_CONTEXT_FUNCTIONS;
29  USE_SIMPLE_CTOR_DTOR(FloatToHalfOp);
30 
31  bool RunOnDevice() override;
32 };
33 
34 template <class Context>
35 class HalfToFloatOp : public Operator<Context> {
36  public:
37  USE_OPERATOR_CONTEXT_FUNCTIONS;
38  USE_SIMPLE_CTOR_DTOR(HalfToFloatOp);
39 
40  bool RunOnDevice() override;
41 };
42 
43 class Float16ConstantFillOp : public Operator<CPUContext> {
44  public:
45  Float16ConstantFillOp(const OperatorDef& operator_def, Workspace* ws)
46  : Operator<CPUContext>(operator_def, ws),
47  shape_(
48  ToVectorTIndex(OperatorBase::GetRepeatedArgument<int>("shape"))) {}
49 
50  USE_OPERATOR_FUNCTIONS(CPUContext);
51  virtual ~Float16ConstantFillOp() {}
52 
53  bool RunOnDevice() override;
54 
55  private:
56  vector<TIndex> shape_;
57 };
58 
59 inline std::vector<TensorShape> Float16FillerTensorInference(
60  const OperatorDef& def,
61  const vector<TensorShape>& in) {
62  vector<TensorShape> out(1);
63  ArgumentHelper helper(def);
64  out[0].set_data_type(static_cast<TensorProto_DataType>(
65  helper.GetSingleArgument<int>("dtype", TensorProto_DataType_FLOAT)));
66  auto shape = helper.GetRepeatedArgument<int>("shape");
67  for (int d : shape) {
68  out[0].add_dims(d);
69  }
70  return out;
71 }
72 
73 } // namespace caffe2
74 
75 #endif // CAFFE2_OPERATORS_HALF_FLOAT_OPS_H_
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
Definition: context.h:82
A helper class to index into arguments.
Definition: proto_utils.h:198
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.
vector< TIndex > ToVectorTIndex(const std::vector< int > &src)
A utility function to convert vector<int> to vector<TIndex>.
Definition: tensor.h:49