1 #include <ATen/core/interned_strings.h> 8 #include <unordered_map> 10 #include <ATen/core/interned_strings_class.h> 11 #include <c10/util/Exception.h> 12 #include <c10/util/Optional.h> 16 const std::string& domain_prefix() {
17 static const std::string _domain_prefix =
"org.pytorch.";
18 return _domain_prefix;
21 Symbol InternedStrings::symbol(
const std::string& s) {
22 std::lock_guard<std::mutex> guard(mutex_);
26 std::pair<const char*, const char*> InternedStrings::string(Symbol sym) {
32 #define DEFINE_CASE(ns, s) \ 33 case static_cast<unique_t>(ns::s): \ 34 return {#ns "::" #s, #s}; 35 FORALL_NS_SYMBOLS(DEFINE_CASE)
38 return customString(sym);
42 Symbol InternedStrings::ns(Symbol sym) {
44 #define DEFINE_CASE(ns, s) \ 45 case static_cast<unique_t>(ns::s): \ 46 return namespaces::ns; 47 FORALL_NS_SYMBOLS(DEFINE_CASE)
50 std::lock_guard<std::mutex> guard(mutex_);
51 return sym_to_info_.at(sym).ns;
56 Symbol InternedStrings::_symbol(
const std::string& s) {
57 auto it = string_to_sym_.find(s);
58 if (it != string_to_sym_.end())
61 auto pos = s.find(
"::");
62 if (pos == std::string::npos) {
64 ss <<
"all symbols must have a namespace, <namespace>::<string>, but found: " << s;
65 throw std::runtime_error(ss.str());
67 Symbol ns = _symbol(
"namespaces::" + s.substr(0, pos));
69 Symbol sym(sym_to_info_.size());
70 string_to_sym_[s] = sym;
71 sym_to_info_.push_back({ns, s, s.substr(pos + strlen(
"::"))});
75 std::pair<const char*, const char*> InternedStrings::customString(Symbol sym) {
76 std::lock_guard<std::mutex> guard(mutex_);
77 SymbolInfo& s = sym_to_info_.at(sym);
78 return {s.qual_name.c_str(), s.unqual_name.c_str()};
81 static InternedStrings & globalStrings() {
82 static InternedStrings s;
86 Symbol Symbol::fromQualString(
const std::string & s) {
87 return globalStrings().symbol(s);
90 const char * Symbol::toUnqualString()
const {
91 return globalStrings().string(*this).second;
94 const char * Symbol::toQualString()
const {
95 return globalStrings().string(*this).first;
98 const char * Symbol::toDisplayString()
const {
103 return toQualString();
106 Symbol Symbol::ns()
const {
107 return globalStrings().ns(*
this);
110 std::string Symbol::domainString()
const {
111 return domain_prefix() + ns().toUnqualString();
114 Symbol Symbol::fromDomainAndUnqualString(
const std::string & d,
const std::string & s) {
115 if (d.compare(0, domain_prefix().size(), domain_prefix()) != 0) {
116 std::ostringstream ss;
117 ss <<
"Symbol: domain string is expected to be prefixed with '" 118 << domain_prefix() <<
"', e.g. 'org.pytorch.aten'";
119 throw std::runtime_error(ss.str());
121 std::string qualString = d.substr(domain_prefix().size()) +
"::" + s;
122 return fromQualString(qualString);
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...