Caffe2 - C++ API
A deep learning, cross platform ML framework
cast_cpu.cc
1 #include <ATen/core/dispatch/KernelRegistration.h>
2 #include "caffe2/operators/experimental/c10/schemas/cast.h"
3 #include "caffe2/utils/math.h"
4 #include "caffe2/core/tensor.h"
5 
7 using caffe2::Tensor;
8 using caffe2::TensorProto_DataType;
9 
10 namespace caffe2 {
11 namespace {
12 
13 template <typename DstType, typename SrcType>
14 void do_cast_(const Tensor& input, const Tensor& output) {
15  output.ResizeLike(input);
16  const auto* data = input.template data<SrcType>();
17  auto* out = output.template mutable_data<DstType>();
18  auto N = input.numel();
19  for (int64_t i = 0; i < N; ++i) {
20  out[i] = static_cast<DstType>(data[i]);
21  }
22 }
23 
24 template <class SrcType>
25 void cast_op_cpu_impl(
26  const at::Tensor& input_,
27  const at::Tensor& output_,
28  int64_t to_) {
29  Tensor input{C10Tensor(input_)};
30  Tensor output{C10Tensor(output_)};
31  TensorProto_DataType to = static_cast<TensorProto_DataType>(to_);
32 
33  switch (to) {
34  case caffe2::TensorProto_DataType_FLOAT:
35  do_cast_<float, SrcType>(input, output);
36  break;
37  case caffe2::TensorProto_DataType_INT32:
38  do_cast_<int32_t, SrcType>(input, output);
39  break;
40  case caffe2::TensorProto_DataType_BYTE:
41  LOG(FATAL) << "BYTE is deprecated";
42  break;
43  case caffe2::TensorProto_DataType_STRING:
44  CAFFE_THROW("Casting to and from strings is not supported yet");
45  // break;
46  case caffe2::TensorProto_DataType_BOOL:
47  do_cast_<bool, SrcType>(input, output);
48  break;
49  case caffe2::TensorProto_DataType_UINT8:
50  do_cast_<uint8_t, SrcType>(input, output);
51  break;
52  case caffe2::TensorProto_DataType_INT8:
53  do_cast_<int8_t, SrcType>(input, output);
54  break;
55  case caffe2::TensorProto_DataType_UINT16:
56  do_cast_<uint16_t, SrcType>(input, output);
57  break;
58  case caffe2::TensorProto_DataType_INT16:
59  do_cast_<int16_t, SrcType>(input, output);
60  break;
61  case caffe2::TensorProto_DataType_INT64:
62  do_cast_<int64_t, SrcType>(input, output);
63  break;
64  case caffe2::TensorProto_DataType_FLOAT16:
65  CAFFE_THROW("Casting to and from Half on CPU is not supported yet");
66  // break;
67  case caffe2::TensorProto_DataType_DOUBLE:
68  do_cast_<double, SrcType>(input, output);
69  break;
70  case caffe2::TensorProto_DataType_UNDEFINED:
71  CAFFE_THROW("Cast op must have 'to' argument of type DataType");
72  // break;
73  default:
74  CAFFE_THROW("Unexpected 'to' argument value: ", to);
75  }
76 }
77 void cast_op_cpu(
78  const at::Tensor& input,
79  const at::Tensor& output,
80  int64_t to) {
81  switch (input.scalar_type()) {
82 #define CASE(ctype,name,_2) case ScalarType:: name : return cast_op_cpu_impl<ctype>(input, output, to);
83  AT_FORALL_SCALAR_TYPES(CASE)
84 #undef CASE
85  default: throw std::runtime_error(string() + "Unsupported scalar type " + toString(input.scalar_type()));
86  }
87 }
88 } // namespace
89 } // namespace caffe2
90 
91 namespace c10 {
92 C10_REGISTER_KERNEL(caffe2::ops::Cast)
93  .kernel<decltype(caffe2::cast_op_cpu), &caffe2::cast_op_cpu>()
94  .dispatchKey(CPUTensorId());
95 } // namespace c10
int64_t numel() const
Returns the number of items of the tensor.
Definition: tensor.h:483
Tensor class holds a shared pointer to the implementation TensorImpl, redirects API calls to TensorIm...
Definition: tensor.h:25
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
Definition: context.h:40
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...
Definition: alias_info.h:7
void ResizeLike(const Tensor &src_tensor) const
Resize the tensor like the source tensor.
Definition: tensor.h:322