Caffe2 - C++ API
A deep learning, cross platform ML framework
Scalar.h
1 #pragma once
2 
3 #include <assert.h>
4 #include <stdint.h>
5 #include <stdexcept>
6 #include <string>
7 #include <utility>
8 
9 #include <c10/macros/Macros.h>
10 #include <c10/core/ScalarType.h>
11 #include <c10/util/Half.h>
12 
13 namespace c10 {
14 
22 class C10_API Scalar {
23  public:
24  Scalar() : Scalar(int64_t(0)) {}
25 
26 #define DEFINE_IMPLICIT_CTOR(type,name,member) \
27  Scalar(type vv) \
28  : tag(Tag::HAS_##member) { \
29  v . member = convert<decltype(v.member),type>(vv); \
30  }
31  // We can't set v in the initializer list using the
32  // syntax v{ .member = ... } because it doesn't work on MSVC
33 
34  AT_FORALL_SCALAR_TYPES(DEFINE_IMPLICIT_CTOR)
35 
36 #undef DEFINE_IMPLICIT_CTOR
37 
38 #define DEFINE_IMPLICIT_COMPLEX_CTOR(type, name, member) \
39  Scalar(type vv) : tag(Tag::HAS_##member) { \
40  v.member[0] = c10::convert<double>(vv.real()); \
41  v.member[1] = c10::convert<double>(vv.imag()); \
42  }
43 
44  DEFINE_IMPLICIT_COMPLEX_CTOR(at::ComplexHalf,ComplexHalf,z)
45  DEFINE_IMPLICIT_COMPLEX_CTOR(std::complex<float>,ComplexFloat,z)
46  DEFINE_IMPLICIT_COMPLEX_CTOR(std::complex<double>,ComplexDouble,z)
47 
48 #undef DEFINE_IMPLICIT_COMPLEX_CTOR
49 
50 #define DEFINE_ACCESSOR(type,name,member) \
51  type to##name () const { \
52  if (Tag::HAS_d == tag) { \
53  return checked_convert<type, double>(v.d, #type); \
54  } else if (Tag::HAS_z == tag) { \
55  return checked_convert<type, std::complex<double>>({v.z[0], v.z[1]}, #type); \
56  } else { \
57  return checked_convert<type, int64_t>(v.i, #type); \
58  } \
59  }
60 
61  // TODO: Support ComplexHalf accessor
62  AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_ACCESSOR)
63 
64  //also support scalar.to<int64_t>();
65  template<typename T>
66  T to();
67 
68 #undef DEFINE_ACCESSOR
69  bool isFloatingPoint() const {
70  return Tag::HAS_d == tag;
71  }
72  bool isIntegral() const {
73  return Tag::HAS_i == tag;
74  }
75  bool isComplex() const {
76  return Tag::HAS_z == tag;
77  }
78 
79  Scalar operator-() const;
80 
81 private:
82  enum class Tag { HAS_d, HAS_i, HAS_z };
83  Tag tag;
84  union {
85  double d;
86  int64_t i;
87  // Can't do put std::complex in the union, because it triggers
88  // an nvcc bug:
89  // error: designator may not specify a non-POD subobject
90  double z[2];
91  } v;
92 };
93 
94 // define the scalar.to<int64_t>() specializations
95 template<typename T>
96 inline T Scalar::to() {
97  throw std::runtime_error("to() cast to unexpected type.");
98 }
99 
100 #define DEFINE_TO(T,name,_) \
101 template<> \
102 inline T Scalar::to<T>() { \
103  return to##name(); \
104 }
105 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_TO)
106 #undef DEFINE_TO
107 }
Scalar represents a 0-dimensional tensor which contains a single element.
Definition: Scalar.h:22
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...
Definition: alias_info.h:7