1 #include <torch/csrc/api/include/torch/jit.h> 2 #include <torch/csrc/jit/code_template.h> 3 #include <torch/csrc/jit/script/builtin_functions.h> 9 auto scalar_operators_source = CodeTemplate(
11 def mul(a : ${Scalar}, b : Tensor) -> Tensor: 13 def add(a : ${Scalar}, b : Tensor) -> Tensor: 15 def ne(a : ${Scalar}, b : Tensor) -> Tensor: 17 def eq(a : ${Scalar}, b : Tensor) -> Tensor: 19 def lt(a : ${Scalar}, b : Tensor) -> Tensor: 21 def le(a : ${Scalar}, b : Tensor) -> Tensor: 23 def gt(a : ${Scalar}, b : Tensor) -> Tensor: 25 def ge(a : ${Scalar}, b : Tensor) -> Tensor: 27 def sub(a : ${Scalar}, b : Tensor) -> Tensor: 28 return torch.neg(b) + a 29 def div(a : ${Scalar}, b : Tensor) -> Tensor: 30 return torch.reciprocal(b) * a 33 auto _ntuple_ops = CodeTemplate(
35 def _${name}(x: BroadcastingList${Length}[${Scalar}]) -> List[${Scalar}]: 40 const std::vector<Method*>& getAllBuiltinFunctionsFor(
Symbol name) {
41 const static std::vector<Method*> empty;
48 std::lock_guard<std::recursive_mutex> guard(mutex);
49 if (state == INTIIALIZING) {
51 }
else if (state == UNINITIALIZED) {
53 loadBuiltinFunctions();
56 AT_ASSERT(state == INITIALIZED);
57 auto it = builtins_by_name.find(name);
58 if (it == builtins_by_name.end())
64 void loadSource(
const std::string& source) {
65 auto module = std::make_shared<script::Module>();
66 defineMethodsInModule(
67 module, source, script::nativeResolver, c10::nullopt);
68 modules.push_back(module);
69 for (
auto& method : module->get_methods()) {
70 builtins_by_name[Symbol::fromQualString(
"aten::" + method.key())]
71 .push_back(method->get());
74 void loadBuiltinFunctions() {
75 for (
auto scalar : {
"float",
"int"}) {
77 env.s(
"Scalar", scalar);
78 loadSource(scalar_operators_source.format(env));
81 using str_pair = std::pair<std::string, std::string>;
82 const std::vector<str_pair> name_len = {
83 str_pair(
"single",
"1"),
84 str_pair(
"pair",
"2"),
85 str_pair(
"triple",
"3"),
86 str_pair(
"quadruple",
"4"),
88 for (
auto scalar : {
"float",
"int"}) {
89 for (
auto pair : name_len) {
91 env.s(
"Scalar", scalar);
92 env.s(
"name", pair.first);
93 env.s(
"Length", pair.second);
94 loadSource(_ntuple_ops.format(env));
98 enum { UNINITIALIZED, INTIIALIZING, INITIALIZED } state = UNINITIALIZED;
99 std::recursive_mutex mutex;
100 std::vector<std::shared_ptr<Module>> modules;
101 std::unordered_map<Symbol, std::vector<Method*>> builtins_by_name;
104 TORCH_API
const std::vector<Method*>& getAllBuiltinFunctionsFor(
Symbol name) {
106 return registry.getAllBuiltinFunctionsFor(name);