Caffe2 - C++ API
A deep learning, cross platform ML framework
builtin_functions.cpp
1 #include <torch/csrc/api/include/torch/jit.h>
2 #include <torch/csrc/jit/code_template.h>
3 #include <torch/csrc/jit/script/builtin_functions.h>
4 
5 namespace torch {
6 namespace jit {
7 namespace script {
8 
9 auto scalar_operators_source = CodeTemplate(
10  R"SCRIPT(
11 def mul(a : ${Scalar}, b : Tensor) -> Tensor:
12  return b * a
13 def add(a : ${Scalar}, b : Tensor) -> Tensor:
14  return b + a
15 def ne(a : ${Scalar}, b : Tensor) -> Tensor:
16  return b != a
17 def eq(a : ${Scalar}, b : Tensor) -> Tensor:
18  return b == a
19 def lt(a : ${Scalar}, b : Tensor) -> Tensor:
20  return b > a
21 def le(a : ${Scalar}, b : Tensor) -> Tensor:
22  return b >= a
23 def gt(a : ${Scalar}, b : Tensor) -> Tensor:
24  return b < a
25 def ge(a : ${Scalar}, b : Tensor) -> Tensor:
26  return b <= a
27 def sub(a : ${Scalar}, b : Tensor) -> Tensor:
28  return torch.neg(b) + a
29 def div(a : ${Scalar}, b : Tensor) -> Tensor:
30  return torch.reciprocal(b) * a
31 )SCRIPT");
32 
33 auto _ntuple_ops = CodeTemplate(
34  R"SCRIPT(
35 def _${name}(x: BroadcastingList${Length}[${Scalar}]) -> List[${Scalar}]:
36  return x
37 )SCRIPT");
38 
40  const std::vector<Method*>& getAllBuiltinFunctionsFor(Symbol name) {
41  const static std::vector<Method*> empty;
42  // when initializing the builtin function library, we will re-enter
43  // getAllBuiltinFunctionsFor since it is called in the compiler to
44  // lookup builtins and initializing the builtin functions calls the
45  // compiler. To avoid deadlocking, we use a recursive mutex (same thread can
46  // re-lock, the mutex without waiting), and report no loaded builtins during
47  // init.
48  std::lock_guard<std::recursive_mutex> guard(mutex);
49  if (state == INTIIALIZING) {
50  return empty;
51  } else if (state == UNINITIALIZED) {
52  state = INTIIALIZING;
53  loadBuiltinFunctions();
54  state = INITIALIZED;
55  }
56  AT_ASSERT(state == INITIALIZED);
57  auto it = builtins_by_name.find(name);
58  if (it == builtins_by_name.end())
59  return empty;
60  return it->second;
61  }
62 
63  private:
64  void loadSource(const std::string& source) {
65  auto module = std::make_shared<script::Module>();
66  defineMethodsInModule(
67  module, source, script::nativeResolver, /*self=*/c10::nullopt);
68  modules.push_back(module);
69  for (auto& method : module->get_methods()) {
70  builtins_by_name[Symbol::fromQualString("aten::" + method.key())]
71  .push_back(method->get());
72  }
73  }
74  void loadBuiltinFunctions() {
75  for (auto scalar : {"float", "int"}) {
76  TemplateEnv env;
77  env.s("Scalar", scalar);
78  loadSource(scalar_operators_source.format(env));
79  }
80 
81  using str_pair = std::pair<std::string, std::string>;
82  const std::vector<str_pair> name_len = {
83  str_pair("single", "1"),
84  str_pair("pair", "2"),
85  str_pair("triple", "3"),
86  str_pair("quadruple", "4"),
87  };
88  for (auto scalar : {"float", "int"}) {
89  for (auto pair : name_len) {
90  TemplateEnv env;
91  env.s("Scalar", scalar);
92  env.s("name", pair.first);
93  env.s("Length", pair.second);
94  loadSource(_ntuple_ops.format(env));
95  }
96  }
97  }
98  enum { UNINITIALIZED, INTIIALIZING, INITIALIZED } state = UNINITIALIZED;
99  std::recursive_mutex mutex;
100  std::vector<std::shared_ptr<Module>> modules;
101  std::unordered_map<Symbol, std::vector<Method*>> builtins_by_name;
102 };
103 
104 TORCH_API const std::vector<Method*>& getAllBuiltinFunctionsFor(Symbol name) {
105  static BuiltinFunctionRegistry registry;
106  return registry.getAllBuiltinFunctionsFor(name);
107 }
108 
109 } // namespace script
110 } // namespace jit
111 } // namespace torch
Definition: jit_type.h:17