6 #include <c10/util/Exception.h> 7 #include <torch/csrc/jit/ir.h> 8 #include <ATen/core/stack.h> 10 #include <ATen/ATen.h> 11 #include <ATen/core/function_schema.h> 14 #include <initializer_list> 17 #include <unordered_map> 24 using ::c10::FunctionSchema;
26 TORCH_API FunctionSchema parseSchema(
const std::string& schema);
28 using OperationCreator = std::function<Operation(const Node*)>;
61 Operator(FunctionSchema schema, OperationCreator op_creator)
62 : schema_(std::make_shared<FunctionSchema>(std::move(schema))),
63 op_creator_(std::move(op_creator)) {}
65 Operator(
const std::string& schema, OperationCreator op_creator)
66 : schema_string_(schema), op_creator_(std::move(op_creator)) {}
73 Operator(
Symbol name, OperationCreator op_creator)
82 std::move(op_creator)) {}
84 Operator(FunctionSchema schema, Operation op)
85 : schema_(std::make_shared<FunctionSchema>(std::move(schema))),
86 op_(std::make_shared<Operation>(std::move(op))) {}
88 Operator(
const std::string& schema, Operation op)
89 : schema_string_(schema),
90 op_(std::make_shared<Operation>(std::move(op))) {}
92 bool matches(
const Node* node)
const;
94 Operation getOperation(
const Node* node =
nullptr)
const {
98 AT_ASSERT(node !=
nullptr);
99 return op_creator_(node);
102 const FunctionSchema& schema()
const {
107 std::make_shared<FunctionSchema>(parseSchema(schema_string_.value()));
108 schema_string_ = c10::nullopt;
118 mutable std::shared_ptr<FunctionSchema> schema_;
122 std::shared_ptr<Operation> op_;
123 OperationCreator op_creator_;
126 TORCH_API std::string canonicalSchemaString(
const FunctionSchema& schema);
128 TORCH_API
const std::vector<std::shared_ptr<Operator>>& getAllOperatorsFor(
131 std::shared_ptr<Operator> findOperatorFor(
const Node* node);
134 inline Operation getOperation(
const Node* node) {
137 return getOperatorFor(node).getOperation(node);
141 TORCH_API std::vector<Symbol> findSimilarOperators(
Symbol input_op);
143 TORCH_API
void registerOperator(
Operator&& op);
146 Operator& sig(
const char* signature_literal);
149 OperatorSet(std::initializer_list<const char*> sig_literals);
154 std::unordered_map<Symbol, std::vector<std::shared_ptr<Operator>>> ops;