Caffe2 - C++ API
A deep learning, cross platform ML framework
custom_operator.h
1 #pragma once
2 
3 #include <torch/csrc/jit/operator.h>
4 #include <ATen/core/stack.h>
5 #include <torch/csrc/jit/tracer.h>
6 #include <torch/csrc/utils/variadic.h>
7 
8 #include <ATen/core/function_schema.h>
9 #include <c10/util/Metaprogramming.h>
10 #include <c10/util/TypeList.h>
11 
12 namespace torch {
13 namespace jit {
14 namespace detail {
15 
16 using ::c10::Argument;
17 using ::c10::FunctionSchema;
18 
20 template <typename T>
21 void checkStaticTypes() {
22  // Give nice error messages for some of the common error cases.
23  // Use a LOUD ERROR MESSAGE SO USERS SEE THE STATIC_ASSERT
24  static_assert(
25  !std::is_integral<T>::value || std::is_same<T, int64_t>::value,
26  "INVALID TYPE: Only int64_t is supported as an integral argument type");
27  static_assert(
28  !std::is_same<T, float>::value,
29  "INVALID TYPE: float is not supported as an argument type, use double instead");
30 }
31 
32 template <typename First, typename Second, typename... Rest>
33 void checkStaticTypes() {
34  checkStaticTypes<First>();
35  checkStaticTypes<Second, Rest...>();
36 }
37 
38 template <typename... Ts, size_t... Is>
39 ::std::vector<Argument> createArgumentVectorFromTypes(Indices<Is...> indices) {
40  checkStaticTypes<decay_t<Ts>...>();
41  // Arguments are named "_<index>"
42  return {Argument("_" + std::to_string(Is), getTypePtr<decay_t<Ts>>())...};
43 }
44 
45 template <typename... Ts, size_t... Is>
46 ::std::vector<Argument> createReturns(Indices<Is...> indices) {
47  return createArgumentVectorFromTypes<Ts..., Is...>();
48 }
49 
52 template <typename... Ts>
53 ::std::vector<Argument> createReturns(std::tuple<Ts...>* tuple) {
54  // Create an index pack so we can call `get<Indices>` on the tuple next.
55  return createReturns<Ts...>(typename MakeIndices<sizeof...(Ts)>::indices{});
56 }
57 
59 template <typename ReturnType>
60 ::std::vector<Argument> createReturns(ReturnType*) {
61  checkStaticTypes<decay_t<ReturnType>>();
62  return {Argument("_1", getTypePtr<decay_t<ReturnType>>())};
63 }
64 
67 template <typename FunctionTraits, size_t... Is>
68 ::std::vector<Argument> createArgumentVectorFromTraits(Indices<Is...> indices) {
69  using ArgumentTypes = typename FunctionTraits::parameter_types;
70  return createArgumentVectorFromTypes<
71  c10::guts::typelist::element_t<Is, ArgumentTypes>...>(indices);
72 }
73 
76 template <typename FunctionTraits>
77 FunctionSchema createFunctionSchemaFromTraits(const std::string& name) {
78  using ReturnType = typename FunctionTraits::return_type;
79 
80  auto arguments = createArgumentVectorFromTraits<FunctionTraits>(
81  typename MakeIndices<FunctionTraits::number_of_parameters>::indices{});
82  auto returns = createReturns(static_cast<ReturnType*>(nullptr));
83 
84  return {name, "", arguments, returns};
85 }
86 
88 template <size_t... Is, typename... Types>
89 Node* getTracedNode(
90  const FunctionSchema& schema,
91  const std::tuple<Types...>& tuple) {
92  auto symbol = Symbol::fromQualString(schema.name());
93  const auto& graph = tracer::getTracingState()->graph;
94  Node* node = graph->create(symbol, /*num_outputs=*/0);
95  tracer::recordSourceLocation(node);
96 
97  // Hack to call addInputs for the parameter pack in a sequenced fashion.
98  // https://stackoverflow.com/questions/12030538/calling-a-function-for-each-variadic-template-argument-and-an-array
99  int _[] = {
100  (tracer::addInputs(
101  node, schema.arguments()[Is].name().c_str(), std::get<Is>(tuple)),
102  0)...};
103  (void)_; // ignore
104 
105  graph->insertNode(node);
106 
107  return node;
108 }
109 
115 template <typename Implementation, typename... Types, size_t... Is>
116 void callOperatorWithTuple(
117  const FunctionSchema& schema,
118  Implementation&& implementation,
119  Stack& stack,
120  std::tuple<Types...>& arguments,
121  Indices<Is...>) {
122  AT_ASSERT(stack.size() == sizeof...(Is));
123 
124  // Pop values from the stack into the elements of the tuple.
125  pop(stack, std::get<Is>(arguments)...);
126 
127  Node* node = nullptr;
128  if (jit::tracer::isTracing()) {
129  node = getTracedNode<Is...>(schema, arguments);
130  }
131 
132  // Call into the actual, original, user-supplied function.
133  auto return_value =
134  std::forward<Implementation>(implementation)(std::get<Is>(arguments)...);
135 
136  if (jit::tracer::isTracing()) {
137  jit::tracer::addOutput(node, return_value);
138  }
139 
140  // Push the return value back onto the stack.
141  push(stack, IValue(std::move(return_value)));
142 }
143 
144 inline void checkArgumentVector(
145  const char* what,
146  const std::vector<Argument>& inferred,
147  const std::vector<Argument>& provided,
148  const FunctionSchema& inferredSchema,
149  const FunctionSchema& providedSchema) {
150  // clang-format off
151  AT_CHECK(
152  inferred.size() == provided.size(),
153  "Inferred ", inferred.size(), " ", what,
154  "(s) for operator implementation, but the provided schema specified ",
155  provided.size(), " ", what, "(s). Inferred schema: ", inferredSchema,
156  " | Provided schema: ", providedSchema);
157  // clang-format on
158  for (size_t i = 0; i < provided.size(); ++i) {
159  // clang-format off
160  AT_CHECK(
161  provided[i].type()->isSubtypeOf(inferred[i].type()),
162  "Inferred type for ", what, " #", i, " was ", *inferred[i].type(),
163  ", but the provided schema specified type ", *provided[i].type(),
164  " for the ", what, " in that position. Inferred schema: ",
165  inferredSchema, " | Provided schema: ", providedSchema);
166  // clang-format on
167  }
168 }
169 
174 template <typename Traits>
175 FunctionSchema inferAndCheckSchema(const std::string& schemaOrName) {
176  // If there is no '(' in the schema, we assume this is only the name (e.g.
177  // "foo::bar").
178  const auto bracketIndex = schemaOrName.find('(');
179  if (bracketIndex == std::string::npos) {
180  // Infer the full schema and we're good.
181  return torch::jit::detail::createFunctionSchemaFromTraits<Traits>(
182  /*name=*/schemaOrName);
183  }
184 
185  // If the user provided her own schema, we need to infer it nevertheless and
186  // check that it's correct. We return the user provided schema in the end
187  // because it has proper argument names.
188 
189  auto providedSchema = parseSchema(schemaOrName);
190 
191  const auto inferredSchema =
192  torch::jit::detail::createFunctionSchemaFromTraits<Traits>(
193  providedSchema.name());
194  checkArgumentVector(
195  "argument",
196  inferredSchema.arguments(),
197  providedSchema.arguments(),
198  inferredSchema,
199  providedSchema);
200  checkArgumentVector(
201  "return value",
202  inferredSchema.returns(),
203  providedSchema.returns(),
204  inferredSchema,
205  providedSchema);
206  return providedSchema;
207 }
208 } // namespace detail
209 
235 template <typename Implementation>
236 Operator createOperator(
237  const std::string& schemaOrName,
238  Implementation&& implementation) {
239  using Traits = c10::guts::infer_function_traits_t<Implementation>;
240  using ArgumentTypes =
241  c10::guts::typelist::map_t<decay_t, typename Traits::parameter_types>;
242  using ArgumentTuple =
244  static constexpr auto kNumberOfArguments =
245  std::tuple_size<ArgumentTuple>::value;
246 
247  auto schema = torch::jit::detail::inferAndCheckSchema<Traits>(schemaOrName);
248 
249  return Operator(schema, [implementation, schema](Stack& stack) {
250  ArgumentTuple tuple;
251  torch::jit::detail::callOperatorWithTuple(
252  schema,
253  std::move(implementation), // NOLINT(bugprone-move-forwarding-reference)
254  stack,
255  tuple,
256  typename MakeIndices<kNumberOfArguments>::indices{});
257  return 0;
258  });
259 }
260 
266 struct TORCH_API RegisterOperators {
267  RegisterOperators() = default;
268 
270  RegisterOperators(std::vector<Operator> operators) {
271  for (Operator& o : operators) {
272  registerOperator(std::move(o));
273  }
274  }
275 
277  template <typename Implementation>
278  RegisterOperators(const std::string& name, Implementation&& implementation) {
279  op(name, std::forward<Implementation>(implementation));
280  }
281 
285  template <typename Implementation>
287  const std::string& name,
288  Implementation&& implementation) {
289  registerOperator(
290  createOperator(name, std::forward<Implementation>(implementation)));
291  return *this;
292  }
293 };
294 
295 } // namespace jit
296 } // namespace torch
RegisterOperators & op(const std::string &name, Implementation &&implementation)
Creates a new operator from a name and implementation function (function pointer or function object/l...
RegisterOperators(const std::string &name, Implementation &&implementation)
Calls op(...) with the given operator name and implementation.
Registration class for new operators.
RegisterOperators(std::vector< Operator > operators)
Registers a vector of already created Operators.
Definition: jit_type.h:17
Transforms a list of types into a tuple holding these types.
Definition: TypeList.h:44