Caffe2 - C++ API
A deep learning, cross platform ML framework
cast_op.cc
1 
17 #include "caffe2/operators/cast_op.h"
18 
19 namespace caffe2 {
20 
21 template <>
22 template <typename DstType, typename SrcType>
23 bool CastOp<CPUContext>::DoRunWithType() {
24  auto& input = Input(0);
25  auto* output = Output(0);
26  output->ResizeLike(input);
27  const auto* data = input.template data<SrcType>();
28  auto* out = output->template mutable_data<DstType>();
29  auto N = input.size();
30  for (TIndex i = 0; i < N; ++i) {
31  out[i] = static_cast<DstType>(data[i]);
32  }
33  return true;
34 }
35 
36 template <>
37 void CastOp<CPUContext>::SetBody(TensorProto_DataType to) {
38  switch (to) {
39  case TensorProto_DataType_FLOAT:
40  // body_ = &CastOp::DoRunIncFp16WithDstType<float>;
41  body_ = &CastOp<CPUContext>::DoRunWithDstType<float>;
42  break;
43  case TensorProto_DataType_INT32:
44  body_ = &CastOp<CPUContext>::DoRunWithDstType<int>;
45  break;
46  case TensorProto_DataType_BYTE:
47  LOG(FATAL) << "BYTE is deprecated";
48  break;
49  case TensorProto_DataType_STRING:
50  CAFFE_THROW("Casting to and from strings is not supported yet");
51  // break;
52  case TensorProto_DataType_BOOL:
53  body_ = &CastOp<CPUContext>::DoRunWithDstType<bool>;
54  break;
55  case TensorProto_DataType_UINT8:
56  body_ = &CastOp<CPUContext>::DoRunWithDstType<uint8_t>;
57  break;
58  case TensorProto_DataType_INT8:
59  body_ = &CastOp<CPUContext>::DoRunWithDstType<int8_t>;
60  break;
61  case TensorProto_DataType_UINT16:
62  body_ = &CastOp<CPUContext>::DoRunWithDstType<uint16_t>;
63  break;
64  case TensorProto_DataType_INT16:
65  body_ = &CastOp<CPUContext>::DoRunWithDstType<int16_t>;
66  break;
67  case TensorProto_DataType_INT64:
68  body_ = &CastOp<CPUContext>::DoRunWithDstType<int64_t>;
69  break;
70  case TensorProto_DataType_FLOAT16:
71  CAFFE_THROW("Casting to and from float16 on CPU is not supported yet");
72  // break;
73  case TensorProto_DataType_DOUBLE:
74  //body_ = &CastOp::DoRunIncFp16WithDstType<double>;
75  body_ = &CastOp<CPUContext>::DoRunWithDstType<double>;
76  break;
77  case TensorProto_DataType_UNDEFINED:
78  CAFFE_THROW("Cast op must have 'to' argument of type DataType");
79  // break;
80  default:
81  CAFFE_THROW("Unexpected 'to' argument value: ", to);
82  }
83 }
84 
85 template <>
86 template <typename DstType>
87 bool CastOp<CPUContext>::DoRunWithDstType() {
88  return DispatchHelper<
89  TensorTypes<
90  float,
91  int32_t,
92  bool,
93  uint8_t,
94  int8_t,
95  uint16_t,
96  int16_t,
97  int64_t,
98  double>,
99  DstType>::call(this, Input(0));
100 }
101 
102 REGISTER_CPU_OPERATOR(Cast, CastOp<CPUContext>);
103 
104 OPERATOR_SCHEMA(Cast)
105  .NumInputs(1)
106  .NumOutputs(1)
107  .TensorInferenceFunction(
108  [](const OperatorDef& def, const vector<TensorShape>& in) {
109  ArgumentHelper helper(def);
110  vector<TensorShape> out;
111  out.push_back(in[0]);
112  out[0].set_data_type(cast::GetCastDataType(helper, "to"));
113  return out;
114  })
115  .SetDoc(R"DOC(
116 The operator casts the elements of a given input tensor to a data type
117 specified by the 'to' argument and returns an output tensor of the same size in
118 the converted type. The 'to' argument must be one of the data types specified
119 in the 'DataType' enum field in the TensorProto message. If the 'to' argument
120 is not provided or is not one of the enumerated types in DataType, Caffe2
121 throws an Enforce error.
122 
123 NOTE: Casting to and from strings is not supported yet.
124 )DOC")
125  .Arg(
126  "to",
127  "The data type to which the elements of the input tensor are cast."
128  "Strictly must be one of the types from DataType enum in TensorProto")
129  .Input(0, "input", "Input tensor to be cast.")
130  .Output(
131  0,
132  "output",
133  "Output tensor with the same shape as input with type "
134  "specified by the 'to' argument");
135 
136 // Some Casts are compatible with gradients, but for now we don't support it
137 // GRADIENT_NOT_IMPLEMENTED_YET(Cast);
138 
139 class GetCastGradient : public GradientMakerBase {
140  using GradientMakerBase::GradientMakerBase;
141  vector<OperatorDef> GetGradientDefs() override {
142 
143  vector<OperatorDef> defs = SingleGradientDef("Cast", "", vector<string>{GO(0)}, vector<string>{GI(0)});
144 
145  // now modify the arguments in defs[0]
146  ArgumentHelper argsHelper(def_);
147 
148  auto to_name = cast::GetCastDataType(argsHelper, "to");
149 
150  CAFFE_ENFORCE(
151  argsHelper.HasSingleArgumentOfType<string>("from_type") ||
152  argsHelper.HasSingleArgumentOfType<int>("from_type"),
153  "Argument 'from_type' of type int or string"
154  " is required to get the gradient of CastOp");
155 
156  auto from_name = cast::GetCastDataType(argsHelper, "from_type");
157  Argument *to = defs[0].add_arg();
158  to->set_name("to");
159  to->set_i(from_name);
160 
161  Argument *from = defs[0].add_arg();
162  from->set_name("from_type");
163  from->set_i(to_name);
164 
165  return defs;
166  }
167 
168  bool CopyArguments() const override {
169  return false;
170  }
171 };
172 
173 REGISTER_GRADIENT(Cast, GetCastGradient);
174 
175 
176 
177 
178 } // namespace caffe2
A helper class to index into arguments.
Definition: proto_utils.h:198
Copyright (c) 2016-present, Facebook, Inc.
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...