1 #include <ATen/core/dispatch/KernelRegistration.h> 2 #include "caffe2/operators/experimental/c10/schemas/cast.h" 3 #include "caffe2/utils/math.h" 4 #include "caffe2/core/tensor.h" 8 using caffe2::TensorProto_DataType;
13 template <
typename DstType,
typename SrcType>
14 void do_cast_(
const Tensor& input,
const Tensor& output) {
16 const auto* data = input.template data<SrcType>();
17 auto* out = output.template mutable_data<DstType>();
18 auto N = input.
numel();
19 for (int64_t i = 0; i < N; ++i) {
20 out[i] =
static_cast<DstType
>(data[i]);
24 template <
class SrcType>
25 void cast_op_cpu_impl(
29 Tensor input{C10Tensor(input_)};
30 Tensor output{C10Tensor(output_)};
31 TensorProto_DataType to =
static_cast<TensorProto_DataType
>(to_);
34 case caffe2::TensorProto_DataType_FLOAT:
35 do_cast_<float, SrcType>(input, output);
37 case caffe2::TensorProto_DataType_INT32:
38 do_cast_<int32_t, SrcType>(input, output);
40 case caffe2::TensorProto_DataType_BYTE:
41 LOG(FATAL) <<
"BYTE is deprecated";
43 case caffe2::TensorProto_DataType_STRING:
44 CAFFE_THROW(
"Casting to and from strings is not supported yet");
46 case caffe2::TensorProto_DataType_BOOL:
47 do_cast_<bool, SrcType>(input, output);
49 case caffe2::TensorProto_DataType_UINT8:
50 do_cast_<uint8_t, SrcType>(input, output);
52 case caffe2::TensorProto_DataType_INT8:
53 do_cast_<int8_t, SrcType>(input, output);
55 case caffe2::TensorProto_DataType_UINT16:
56 do_cast_<uint16_t, SrcType>(input, output);
58 case caffe2::TensorProto_DataType_INT16:
59 do_cast_<int16_t, SrcType>(input, output);
61 case caffe2::TensorProto_DataType_INT64:
62 do_cast_<int64_t, SrcType>(input, output);
64 case caffe2::TensorProto_DataType_FLOAT16:
65 CAFFE_THROW(
"Casting to and from Half on CPU is not supported yet");
67 case caffe2::TensorProto_DataType_DOUBLE:
68 do_cast_<double, SrcType>(input, output);
70 case caffe2::TensorProto_DataType_UNDEFINED:
71 CAFFE_THROW(
"Cast op must have 'to' argument of type DataType");
74 CAFFE_THROW(
"Unexpected 'to' argument value: ", to);
81 switch (input.scalar_type()) {
82 #define CASE(ctype,name,_2) case ScalarType:: name : return cast_op_cpu_impl<ctype>(input, output, to); 83 AT_FORALL_SCALAR_TYPES(CASE)
85 default:
throw std::runtime_error(
string() +
"Unsupported scalar type " + toString(input.scalar_type()));
92 C10_REGISTER_KERNEL(caffe2::ops::Cast)
93 .kernel<decltype(caffe2::cast_op_cpu), &caffe2::cast_op_cpu>()
94 .dispatchKey(CPUTensorId());
int64_t numel() const
Returns the number of items of the tensor.
Tensor class holds a shared pointer to the implementation TensorImpl, redirects API calls to TensorIm...
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...
void ResizeLike(const Tensor &src_tensor) const
Resize the tensor like the source tensor.