3 #include <torch/csrc/jit/operator.h> 4 #include <ATen/core/stack.h> 5 #include <torch/csrc/jit/tracer.h> 6 #include <torch/csrc/utils/variadic.h> 8 #include <ATen/core/function_schema.h> 9 #include <c10/util/Metaprogramming.h> 10 #include <c10/util/TypeList.h> 16 using ::c10::Argument;
17 using ::c10::FunctionSchema;
21 void checkStaticTypes() {
25 !std::is_integral<T>::value || std::is_same<T, int64_t>::value,
26 "INVALID TYPE: Only int64_t is supported as an integral argument type");
28 !std::is_same<T, float>::value,
29 "INVALID TYPE: float is not supported as an argument type, use double instead");
32 template <
typename First,
typename Second,
typename... Rest>
33 void checkStaticTypes() {
34 checkStaticTypes<First>();
35 checkStaticTypes<Second, Rest...>();
38 template <
typename... Ts,
size_t... Is>
39 ::std::vector<Argument> createArgumentVectorFromTypes(Indices<Is...> indices) {
40 checkStaticTypes<decay_t<Ts>...>();
42 return {Argument(
"_" + std::to_string(Is), getTypePtr<decay_t<Ts>>())...};
45 template <
typename... Ts,
size_t... Is>
46 ::std::vector<Argument> createReturns(Indices<Is...> indices) {
47 return createArgumentVectorFromTypes<Ts..., Is...>();
52 template <
typename... Ts>
53 ::std::vector<Argument> createReturns(std::tuple<Ts...>* tuple) {
55 return createReturns<Ts...>(
typename MakeIndices<
sizeof...(Ts)>::indices{});
59 template <
typename ReturnType>
60 ::std::vector<Argument> createReturns(ReturnType*) {
61 checkStaticTypes<decay_t<ReturnType>>();
62 return {Argument(
"_1", getTypePtr<decay_t<ReturnType>>())};
67 template <
typename FunctionTraits,
size_t... Is>
68 ::std::vector<Argument> createArgumentVectorFromTraits(Indices<Is...> indices) {
69 using ArgumentTypes =
typename FunctionTraits::parameter_types;
70 return createArgumentVectorFromTypes<
71 c10::guts::typelist::element_t<Is, ArgumentTypes>...>(indices);
76 template <
typename FunctionTraits>
77 FunctionSchema createFunctionSchemaFromTraits(
const std::string& name) {
78 using ReturnType =
typename FunctionTraits::return_type;
80 auto arguments = createArgumentVectorFromTraits<FunctionTraits>(
81 typename MakeIndices<FunctionTraits::number_of_parameters>::indices{});
82 auto returns = createReturns(static_cast<ReturnType*>(
nullptr));
84 return {name,
"", arguments, returns};
88 template <
size_t... Is,
typename... Types>
90 const FunctionSchema& schema,
91 const std::tuple<Types...>& tuple) {
92 auto symbol = Symbol::fromQualString(schema.name());
93 const auto& graph = tracer::getTracingState()->graph;
94 Node* node = graph->create(symbol, 0);
95 tracer::recordSourceLocation(node);
101 node, schema.arguments()[Is].name().c_str(), std::get<Is>(tuple)),
105 graph->insertNode(node);
115 template <
typename Implementation,
typename... Types,
size_t... Is>
116 void callOperatorWithTuple(
117 const FunctionSchema& schema,
118 Implementation&& implementation,
120 std::tuple<Types...>& arguments,
122 AT_ASSERT(stack.size() ==
sizeof...(Is));
125 pop(stack, std::get<Is>(arguments)...);
127 Node* node =
nullptr;
128 if (jit::tracer::isTracing()) {
129 node = getTracedNode<Is...>(schema, arguments);
134 std::forward<Implementation>(implementation)(std::get<Is>(arguments)...);
136 if (jit::tracer::isTracing()) {
137 jit::tracer::addOutput(node, return_value);
141 push(stack, IValue(std::move(return_value)));
144 inline void checkArgumentVector(
146 const std::vector<Argument>& inferred,
147 const std::vector<Argument>& provided,
148 const FunctionSchema& inferredSchema,
149 const FunctionSchema& providedSchema) {
152 inferred.size() == provided.size(),
153 "Inferred ", inferred.size(),
" ", what,
154 "(s) for operator implementation, but the provided schema specified ",
155 provided.size(),
" ", what,
"(s). Inferred schema: ", inferredSchema,
156 " | Provided schema: ", providedSchema);
158 for (
size_t i = 0; i < provided.size(); ++i) {
161 provided[i].type()->isSubtypeOf(inferred[i].type()),
162 "Inferred type for ", what,
" #", i,
" was ", *inferred[i].type(),
163 ", but the provided schema specified type ", *provided[i].type(),
164 " for the ", what,
" in that position. Inferred schema: ",
165 inferredSchema,
" | Provided schema: ", providedSchema);
174 template <
typename Traits>
175 FunctionSchema inferAndCheckSchema(
const std::string& schemaOrName) {
178 const auto bracketIndex = schemaOrName.find(
'(');
179 if (bracketIndex == std::string::npos) {
181 return torch::jit::detail::createFunctionSchemaFromTraits<Traits>(
189 auto providedSchema = parseSchema(schemaOrName);
191 const auto inferredSchema =
192 torch::jit::detail::createFunctionSchemaFromTraits<Traits>(
193 providedSchema.name());
196 inferredSchema.arguments(),
197 providedSchema.arguments(),
202 inferredSchema.returns(),
203 providedSchema.returns(),
206 return providedSchema;
235 template <
typename Implementation>
236 Operator createOperator(
237 const std::string& schemaOrName,
238 Implementation&& implementation) {
239 using Traits = c10::guts::infer_function_traits_t<Implementation>;
240 using ArgumentTypes =
241 c10::guts::typelist::map_t<decay_t, typename Traits::parameter_types>;
242 using ArgumentTuple =
244 static constexpr
auto kNumberOfArguments =
245 std::tuple_size<ArgumentTuple>::value;
247 auto schema = torch::jit::detail::inferAndCheckSchema<Traits>(schemaOrName);
249 return Operator(schema, [implementation, schema](Stack& stack) {
251 torch::jit::detail::callOperatorWithTuple(
253 std::move(implementation),
256 typename MakeIndices<kNumberOfArguments>::indices{});
272 registerOperator(std::move(o));
277 template <
typename Implementation>
279 op(name, std::forward<Implementation>(implementation));
285 template <
typename Implementation>
287 const std::string& name,
288 Implementation&& implementation) {
290 createOperator(name, std::forward<Implementation>(implementation)));
RegisterOperators & op(const std::string &name, Implementation &&implementation)
Creates a new operator from a name and implementation function (function pointer or function object/l...
RegisterOperators(const std::string &name, Implementation &&implementation)
Calls op(...) with the given operator name and implementation.
Registration class for new operators.
RegisterOperators(std::vector< Operator > operators)
Registers a vector of already created Operators.
Transforms a list of types into a tuple holding these types.