Caffe2 - C++ API
A deep learning, cross platform ML framework
cast_op.cc
1 #include "caffe2/operators/cast_op.h"
2 
3 namespace caffe2 {
4 
5 template <typename DstType, typename SrcType>
6 struct CastHelper {
7  static DstType call(SrcType data) {
8  return static_cast<DstType>(data);
9  }
10 };
11 
12 template <typename SrcType>
13 struct CastHelper<std::string, SrcType> {
14  static std::string call(SrcType data) {
15  return caffe2::to_string(data);
16  }
17 };
18 
19 template <>
20 template <typename DstType, typename SrcType>
22  auto& input = Input(0);
23 
24  auto* output = Output(0, input.sizes(), at::dtype<DstType>());
25  const auto* data = input.template data<SrcType>();
26  auto* out = output->template mutable_data<DstType>();
27  auto N = input.numel();
28  for (int64_t i = 0; i < N; ++i) {
29  out[i] = CastHelper<DstType, SrcType>::call(data[i]);
30  }
31  return true;
32 }
33 
34 template <>
35 void CastOp<CPUContext>::SetBody(TensorProto_DataType to) {
36  switch (to) {
37  case TensorProto_DataType_FLOAT:
38  // body_ = &CastOp::DoRunIncFp16WithDstType<float>;
40  break;
41  case TensorProto_DataType_INT32:
43  break;
44  case TensorProto_DataType_BYTE:
45  LOG(FATAL) << "BYTE is deprecated";
46  break;
47  case TensorProto_DataType_STRING:
48  body_ = &CastOp<CPUContext>::DoRunWithDstType<std::string>;
49  break;
50  case TensorProto_DataType_BOOL:
52  break;
53  case TensorProto_DataType_UINT8:
54  body_ = &CastOp<CPUContext>::DoRunWithDstType<uint8_t>;
55  break;
56  case TensorProto_DataType_INT8:
57  body_ = &CastOp<CPUContext>::DoRunWithDstType<int8_t>;
58  break;
59  case TensorProto_DataType_UINT16:
60  body_ = &CastOp<CPUContext>::DoRunWithDstType<uint16_t>;
61  break;
62  case TensorProto_DataType_INT16:
63  body_ = &CastOp<CPUContext>::DoRunWithDstType<int16_t>;
64  break;
65  case TensorProto_DataType_INT64:
66  body_ = &CastOp<CPUContext>::DoRunWithDstType<int64_t>;
67  break;
68  case TensorProto_DataType_FLOAT16:
69  CAFFE_THROW("Casting to and from at::Half on CPU is not supported yet");
70  // break;
71  case TensorProto_DataType_DOUBLE:
72  // body_ = &CastOp::DoRunIncFp16WithDstType<double>;
73  body_ = &CastOp<CPUContext>::DoRunWithDstType<double>;
74  break;
75  case TensorProto_DataType_UNDEFINED:
76  CAFFE_THROW("Cast op must have 'to' argument of type DataType");
77  // break;
78  default:
79  CAFFE_THROW("Unexpected 'to' argument value: ", to);
80  }
81 }
82 
83 template <>
84 template <typename DstType>
86  return DispatchHelper<
88  float,
89  int32_t,
90  bool,
91  uint8_t,
92  int8_t,
93  uint16_t,
94  int16_t,
95  int64_t,
96  double>,
97  DstType>::call(this, Input(0));
98 }
99 
100 REGISTER_CPU_OPERATOR(Cast, CastOp<CPUContext>);
101 
102 OPERATOR_SCHEMA(Cast)
103  .NumInputs(1)
104  .NumOutputs(1)
105  .TensorInferenceFunction([](const OperatorDef& def,
106  const vector<TensorShape>& in) {
107  ArgumentHelper helper(def);
108  vector<TensorShape> out;
109  out.push_back(in[0]);
110  out[0].set_data_type(cast::GetCastDataType(helper, "to"));
111  return out;
112  })
113  .SetDoc(R"DOC(
114 Casts the elements of a given input tensor to a data type specified by the `to`
115 argument and returns an output tensor of the same size in the converted type.
116 The `to` argument must be one of the data types specified in the *DataType*
117 enum field in the TensorProto message (see below). If the `to` argument is not
118 provided or is not one of the enumerated types in *DataType*, Caffe2 throws an
119 Enforce error.
120 
121 NOTE: Casting from strings is not supported, and casting to strings is only
122 supported on CPU.
123 
124 TensorProto *DataType* field:
125 ```
126 message TensorProto {
127  ...
128  enum DataType {
129  UNDEFINED = 0;
130  FLOAT = 1; // float
131  INT32 = 2; // int
132  BYTE = 3; // BYTE, when deserialized, is going to be restored as uint8.
133  STRING = 4; // string
134  BOOL = 5; // bool
135  UINT8 = 6; // uint8_t
136  INT8 = 7; // int8_t
137  UINT16 = 8; // uint16_t
138  INT16 = 9; // int16_t
139  INT64 = 10; // int64_t
140  FLOAT16 = 12; // at::Half
141  DOUBLE = 13; // double
142  }
143 ```
144 
145 Github Links:
146 
147 - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/cast_op.cc
148 
149 <details>
150 
151 <summary> <b>Example</b> </summary>
152 
153 **Code**
154 
155 ```
156 workspace.ResetWorkspace()
157 
158 op = core.CreateOperator(
159  "Cast",
160  ["X"],
161  ["Y"],
162  to=2
163 )
164 
165 workspace.FeedBlob("X", (np.random.rand(3,3)).astype(np.float32)*10)
166 print("X:", workspace.FetchBlob("X"))
167 workspace.RunOperatorOnce(op)
168 print("Y:", workspace.FetchBlob("Y"))
169 ```
170 
171 **Result**
172 
173 ```
174 X: [[9.436466 5.8529844 0.54932857]
175  [1.1583444 2.9936118 0.22950427]
176  [3.9143739 3.4040766 8.905341 ]]
177 Y: [[9 5 0]
178  [1 2 0]
179  [3 3 8]]
180 ```
181 
182 </details>
183 
184 )DOC")
185  .Arg(
186  "to",
187  "*(type: int)* Data type to which the elements of the input tensor are "
188  "cast. Strictly must be one of the types from *DataType* enum in "
189  "TensorProto.")
190  .Input(0, "X", "*(type: Tensor)* Input tensor to be cast.")
191  .Output(
192  0,
193  "Y",
194  "*(type: Tensor`<'to' type>`)* Output tensor with the same shape as "
195  "input with type specified by the `to` argument.")
196  .InheritOnnxSchema();
197 
198 // Some Casts are compatible with gradients, but for now we don't support it
199 // GRADIENT_NOT_IMPLEMENTED_YET(Cast);
200 
201 class GetCastGradient : public GradientMakerBase {
202  using GradientMakerBase::GradientMakerBase;
203  vector<OperatorDef> GetGradientDefs() override {
204 
205  vector<OperatorDef> defs = SingleGradientDef("Cast", "", vector<string>{GO(0)}, vector<string>{GI(0)});
206 
207  // now modify the arguments in defs[0]
208  ArgumentHelper argsHelper(def_);
209 
210  auto to_name = cast::GetCastDataType(argsHelper, "to");
211 
212  CAFFE_ENFORCE(
213  argsHelper.HasSingleArgumentOfType<string>("from_type") ||
214  argsHelper.HasSingleArgumentOfType<int>("from_type"),
215  "Argument 'from_type' of type int or string"
216  " is required to get the gradient of CastOp");
217 
218  auto from_name = cast::GetCastDataType(argsHelper, "from_type");
219  Argument *to = defs[0].add_arg();
220  to->set_name("to");
221  to->set_i(from_name);
222 
223  Argument *from = defs[0].add_arg();
224  from->set_name("from_type");
225  from->set_i(to_name);
226 
227  return defs;
228  }
229 
230  bool CopyArguments() const override {
231  return false;
232  }
233 };
234 
235 REGISTER_GRADIENT(Cast, GetCastGradient);
236 
237 
238 
239 
240 } // namespace caffe2
A helper class to index into arguments.
Definition: proto_utils.h:200
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13