1 #ifndef CAFFE2_OPERATORS_HALF_FLOAT_OPS_H_ 2 #define CAFFE2_OPERATORS_HALF_FLOAT_OPS_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 9 template <
class Context>
12 USE_OPERATOR_CONTEXT_FUNCTIONS;
15 bool RunOnDevice()
override;
18 template <
class Context>
21 USE_OPERATOR_CONTEXT_FUNCTIONS;
24 bool RunOnDevice()
override;
29 template <
class... Args>
32 shape_(this->
template GetRepeatedArgument<int64_t>(
"shape")) {}
35 virtual ~Float16ConstantFillOp() {}
37 bool RunOnDevice()
override;
40 vector<int64_t> shape_;
45 template <
class... Args>
48 shape_(this->
template GetRepeatedArgument<int64_t>(
"shape")),
49 min_(this->
template GetSingleArgument<float>(
"min", 0)),
50 max_(this->
template GetSingleArgument<float>(
"max", 1)) {
51 if (InputSize() == 3) {
53 !this->
template HasSingleArgumentOfType<float>(
"min"),
54 "Cannot set both min arg and min input blob");
56 !this->
template HasSingleArgumentOfType<float>(
"max"),
57 "Cannot set both max arg and max input blob");
60 min_, max_,
"Max value should be bigger than min value.");
65 virtual ~Float16UniformFillOp() {}
67 bool RunOnDevice()
override;
70 vector<int64_t> shape_;
75 inline std::vector<TensorShape> Float16FillerTensorInference(
76 const OperatorDef& def,
77 const vector<TensorShape>& in) {
78 vector<TensorShape> out(1);
80 out[0].set_data_type(static_cast<TensorProto_DataType>(
81 helper.GetSingleArgument<
int>(
"dtype", TensorProto_DataType_FLOAT16)));
82 auto shape = helper.GetRepeatedArgument<
int>(
"shape");
91 #endif // CAFFE2_OPERATORS_HALF_FLOAT_OPS_H_
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
A helper class to index into arguments.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...