5 #include <unordered_map> 8 #include <ATen/core/aten_interned_strings.h> 9 #include <c10/macros/Macros.h> 14 #define FORALL_NS_SYMBOLS(_) \ 19 _(namespaces, scope) \ 21 _(namespaces, namespaces) \ 23 _(prim, BroadcastingChunk) \ 24 _(prim, BroadcastSizes) \ 30 _(prim, FusionGroup) \ 31 _(prim, DifferentiableGraph) \ 41 _(prim, Placeholder) \ 44 _(prim, IgnoredPythonOp) \ 48 _(prim, AutogradZero) \ 49 _(prim, AutogradAnyNonZero) \ 51 _(prim, TupleConstruct) \ 52 _(prim, TupleUnpack) \ 55 _(prim, ListConstruct) \ 57 _(prim, DictConstruct) \ 59 _(prim, NumToTensor) \ 60 _(prim, ImplicitTensorToNum) \ 67 _(prim, requires_grad) \ 68 _(prim, AutogradAdd) \ 70 _(prim, FusedConcat) \ 71 _(prim, ConstantChunk) \ 72 _(prim, MMTreeReduce) \ 73 _(prim, MMBatchSide) \ 77 _(aten, _grad_sum_to_size) \ 78 _(aten, _ncf_unsqueeze) \ 81 _(aten, __round_to_zero_floordiv)\ 82 _(aten, _unwrap_optional) \ 84 _(prim, RaiseException) \ 86 _(prim, CreateObject) \ 102 _(prim, unchecked_unwrap_optional)\ 103 FORALL_ATEN_BASE_SYMBOLS(_) \ 107 _(onnx, ConstantFill) \ 132 _(onnx, ConstantOfShape) \ 133 FORALL_ATTR_BASE_SYMBOLS(_) \ 135 _(attr, ReverseSubgraph) \ 136 _(attr, f_real_outputs) \ 137 _(attr, df_input_vjps) \ 138 _(attr, df_input_captured_inputs) \ 139 _(attr, df_input_captured_outputs) \ 140 _(attr, df_output_vjps) \ 147 _(attr, input_as_shape) \ 162 #define FORALL_NS_SYMBOLS(_) \ 163 _(namespaces, prim) \ 164 _(namespaces, aten) \ 165 _(namespaces, onnx) \ 166 _(namespaces, attr) \ 167 _(namespaces, scope) \ 168 _(namespaces, user) \ 169 _(namespaces, namespaces) 208 using unique_t = uint32_t;
210 const std::string& domain_prefix();
216 explicit constexpr
Symbol() : value(0) {};
217 explicit constexpr Symbol(unique_t uniq)
221 static Symbol fromQualString(
const std::string & s);
224 static Symbol fromDomainAndUnqualString(
const std::string & d,
const std::string & s);
231 static Symbol attr(
const std::string & s);
232 static Symbol aten(
const std::string & s);
233 static Symbol
onnx(
const std::string & s);
234 static Symbol prim(
const std::string & s);
235 static Symbol user(
const std::string & s);
237 static Symbol scope(
const std::string & s);
239 bool is_attr()
const;
240 bool is_aten()
const;
241 bool is_prim()
const;
242 bool is_onnx()
const;
243 bool is_user()
const;
246 constexpr
operator unique_t()
const {
255 const char * toUnqualString()
const;
260 const char * toQualString()
const;
265 const char * toDisplayString()
const;
269 std::string domainString()
const;
272 explicit Symbol(Symbol ns,
const std::string & s);
277 return static_cast<unique_t
>(lhs) == static_cast<unique_t>(rhs);
280 enum class _keys : unique_t {
281 #define DEFINE_KEY(ns, s) ns##_##s, 282 FORALL_NS_SYMBOLS(DEFINE_KEY)
287 #define DEFINE_SYMBOL(s) \ 288 constexpr Symbol s(static_cast<unique_t>(_keys::s)); 292 #define DEFINE_SYMBOL(ns, s) \ 293 namespace ns { constexpr Symbol s(static_cast<unique_t>(_keys::ns##_##s)); } 294 FORALL_NS_SYMBOLS(DEFINE_SYMBOL)
297 inline Symbol Symbol::attr(
const std::string & s) {
return Symbol::fromQualString(
"attr::" + s); }
298 inline Symbol Symbol::aten(
const std::string & s) {
return Symbol::fromQualString(
"aten::" + s); }
299 inline Symbol Symbol::onnx(
const std::string & s) {
return Symbol::fromQualString(
"onnx::" + s); }
300 inline Symbol Symbol::prim(
const std::string & s) {
return Symbol::fromQualString(
"prim::" + s); }
301 inline Symbol Symbol::scope(
const std::string & s) {
return Symbol::fromQualString(
"scope::" + s); }
302 inline Symbol Symbol::user(
const std::string & s) {
return Symbol::fromQualString(
"user::" + s); }
303 inline bool Symbol::is_attr()
const {
return ns() == namespaces::attr; }
304 inline bool Symbol::is_aten()
const {
return ns() == namespaces::aten; }
305 inline bool Symbol::is_prim()
const {
return ns() == namespaces::prim; }
306 inline bool Symbol::is_onnx()
const {
return ns() == namespaces::onnx; }
307 inline bool Symbol::is_user()
const {
return ns() == namespaces::user; }
314 struct hash<
c10::Symbol> {
316 return std::hash<uint32_t>()(static_cast<uint32_t>(s));
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...