1 #include "caffe2/operators/cast_op.h" 5 template <
typename DstType,
typename SrcType>
7 static DstType call(SrcType data) {
8 return static_cast<DstType
>(data);
12 template <
typename SrcType>
14 static std::string call(SrcType data) {
15 return caffe2::to_string(data);
20 template <
typename DstType,
typename SrcType>
22 auto& input = Input(0);
24 auto* output = Output(0, input.sizes(), at::dtype<DstType>());
25 const auto* data = input.template data<SrcType>();
26 auto* out = output->template mutable_data<DstType>();
27 auto N = input.numel();
28 for (int64_t i = 0; i < N; ++i) {
37 case TensorProto_DataType_FLOAT:
41 case TensorProto_DataType_INT32:
44 case TensorProto_DataType_BYTE:
45 LOG(FATAL) <<
"BYTE is deprecated";
47 case TensorProto_DataType_STRING:
50 case TensorProto_DataType_BOOL:
53 case TensorProto_DataType_UINT8:
56 case TensorProto_DataType_INT8:
59 case TensorProto_DataType_UINT16:
62 case TensorProto_DataType_INT16:
65 case TensorProto_DataType_INT64:
68 case TensorProto_DataType_FLOAT16:
69 CAFFE_THROW(
"Casting to and from at::Half on CPU is not supported yet");
71 case TensorProto_DataType_DOUBLE:
75 case TensorProto_DataType_UNDEFINED:
76 CAFFE_THROW(
"Cast op must have 'to' argument of type DataType");
79 CAFFE_THROW(
"Unexpected 'to' argument value: ", to);
84 template <
typename DstType>
97 DstType>::call(
this, Input(0));
102 OPERATOR_SCHEMA(Cast)
105 .TensorInferenceFunction([](
const OperatorDef& def,
106 const vector<TensorShape>& in) {
108 vector<TensorShape> out;
109 out.push_back(in[0]);
110 out[0].set_data_type(cast::GetCastDataType(helper,
"to"));
114 Casts the elements of a given input tensor to a data type specified by the `to` 115 argument and returns an output tensor of the same size in the converted type. 116 The `to` argument must be one of the data types specified in the *DataType* 117 enum field in the TensorProto message (see below). If the `to` argument is not 118 provided or is not one of the enumerated types in *DataType*, Caffe2 throws an 121 NOTE: Casting from strings is not supported, and casting to strings is only 124 TensorProto *DataType* field: 126 message TensorProto { 132 BYTE = 3; // BYTE, when deserialized, is going to be restored as uint8. 133 STRING = 4; // string 135 UINT8 = 6; // uint8_t 137 UINT16 = 8; // uint16_t 138 INT16 = 9; // int16_t 139 INT64 = 10; // int64_t 140 FLOAT16 = 12; // at::Half 141 DOUBLE = 13; // double 147 - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/cast_op.cc 151 <summary> <b>Example</b> </summary> 156 workspace.ResetWorkspace() 158 op = core.CreateOperator( 165 workspace.FeedBlob("X", (np.random.rand(3,3)).astype(np.float32)*10) 166 print("X:", workspace.FetchBlob("X")) 167 workspace.RunOperatorOnce(op) 168 print("Y:", workspace.FetchBlob("Y")) 174 X: [[9.436466 5.8529844 0.54932857] 175 [1.1583444 2.9936118 0.22950427] 176 [3.9143739 3.4040766 8.905341 ]] 187 "*(type: int)* Data type to which the elements of the input tensor are " 188 "cast. Strictly must be one of the types from *DataType* enum in " 190 .Input(0,
"X",
"*(type: Tensor)* Input tensor to be cast.")
194 "*(type: Tensor`<'to' type>`)* Output tensor with the same shape as " 195 "input with type specified by the `to` argument.")
196 .InheritOnnxSchema();
202 using GradientMakerBase::GradientMakerBase;
203 vector<OperatorDef> GetGradientDefs()
override {
205 vector<OperatorDef> defs = SingleGradientDef(
"Cast",
"", vector<string>{GO(0)}, vector<string>{GI(0)});
210 auto to_name = cast::GetCastDataType(argsHelper,
"to");
213 argsHelper.HasSingleArgumentOfType<
string>(
"from_type") ||
214 argsHelper.HasSingleArgumentOfType<
int>(
"from_type"),
215 "Argument 'from_type' of type int or string" 216 " is required to get the gradient of CastOp");
218 auto from_name = cast::GetCastDataType(argsHelper,
"from_type");
221 to->set_i(from_name);
224 from->set_name(
"from_type");
225 from->set_i(to_name);
230 bool CopyArguments()
const override {
A helper class to index into arguments.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...