Caffe2 - C++ API
A deep learning, cross platform ML framework
operator.h
1 // in memory description of all ATen Ops similar to Caffe2 schema
2 // once C10 exists this can be removed, or stubbed out, but we need
3 // it now to implement correct semantic checking for script
4 #pragma once
5 
6 #include <c10/util/Exception.h>
7 #include <torch/csrc/jit/ir.h>
8 #include <ATen/core/stack.h>
9 
10 #include <ATen/ATen.h>
11 #include <ATen/core/function_schema.h>
12 
13 #include <functional>
14 #include <initializer_list>
15 #include <memory>
16 #include <string>
17 #include <unordered_map>
18 #include <utility>
19 #include <vector>
20 
21 namespace torch {
22 namespace jit {
23 
24 using ::c10::FunctionSchema;
25 
26 TORCH_API FunctionSchema parseSchema(const std::string& schema);
27 
28 using OperationCreator = std::function<Operation(const Node*)>;
29 
30 /*
31  * Note: JIT relies on Operator instances having static lifetime, because
32  * it for example stores a non-owning FunctionSchema* pointer in the Node class,
33  * which points to the function shema stored in the Operator instance.
34  * Also, jit::Operator is meant to store more operator related information like
35  * symbolic derivatives, which also requires them to have static lifetime
36  * so that changes to symbolic derivatives are remembered.
37  *
38  * Now, currently, the c10 operator library doesn't store jit::Operator instances,
39  * but we use a listener pattern that notifies JIT about changes in the
40  * c10 operator library and then registers jit::Operator instances to the JIT
41  * operator registry, acting as wrappers to the c10 operators.
42  *
43  * However, that results in code duplication as JIT and c10 will likely get
44  * their own mechanisms for storing derivatives and other operator related
45  * information, and all of this would have to be wrapped from c10 into JIT.
46  *
47  * We should consider merging the JIT and c10 registries, moving jit::Operator
48  * to c10 and storing these jit::Operator instances in the c10 operator library
49  * instead, allowing us to have these mechanisms only implemented once.
50  * However, the current jit::Operator implementation has additional features
51  * like OperationCreator that aren't needed in c10 (they're only used for
52  * prim ops like If/Else or While which wouldn't be in the c10 operator library),
53  * and which depend on other JIT features which we don't want to move to c10
54  * (notably jit/ir.h). We might, however, be able, to split jit::Operator into
55  * a c10::Operator with the core features and a jit::Operator that adds the
56  * JIT-only features like OperationCreator, and then use c10::Operator in the
57  * c10 operator library.
58  */
59 
60 struct TORCH_API Operator {
61  Operator(FunctionSchema schema, OperationCreator op_creator)
62  : schema_(std::make_shared<FunctionSchema>(std::move(schema))),
63  op_creator_(std::move(op_creator)) {}
64 
65  Operator(const std::string& schema, OperationCreator op_creator)
66  : schema_string_(schema), op_creator_(std::move(op_creator)) {}
67 
68  // Helper constructor to register `op` to run
69  // run for _every_ IR Node where n.kind() == name, regardless of arguments.
70  // This is accomplished by marking the schema varargs and having no required
71  // arguments. This is used for things like prim::While or prim::If that can
72  // take a number of different valid input types and lengths.
73  Operator(Symbol name, OperationCreator op_creator)
74  : Operator(
75  FunctionSchema(
76  name,
77  "",
78  {},
79  {},
80  /*is_vararg*/ true,
81  /*is_varret*/ true),
82  std::move(op_creator)) {}
83 
84  Operator(FunctionSchema schema, Operation op)
85  : schema_(std::make_shared<FunctionSchema>(std::move(schema))),
86  op_(std::make_shared<Operation>(std::move(op))) {}
87 
88  Operator(const std::string& schema, Operation op)
89  : schema_string_(schema),
90  op_(std::make_shared<Operation>(std::move(op))) {}
91 
92  bool matches(const Node* node) const;
93 
94  Operation getOperation(const Node* node = nullptr) const {
95  if (op_) {
96  return *op_;
97  }
98  AT_ASSERT(node != nullptr);
99  return op_creator_(node);
100  }
101 
102  const FunctionSchema& schema() const {
103  // we lazily parse schema initialized from strings so that
104  // we do less work during static operator registration
105  if (!schema_) {
106  schema_ =
107  std::make_shared<FunctionSchema>(parseSchema(schema_string_.value()));
108  schema_string_ = c10::nullopt;
109  }
110  return *schema_;
111  }
112 
113  private:
114  mutable c10::optional<std::string> schema_string_;
115  // cannot use c10::optional because windows has issues that require an
116  // assignment operator to be generated cannot use std::unique_ptr because
117  // initializer lists of Operators end up copying the Operator
118  mutable std::shared_ptr<FunctionSchema> schema_;
119 
120  // Essentially a variant<Operation, OperationCreator>.
121  // NB: std::function has a default state (where it == nullptr).
122  std::shared_ptr<Operation> op_;
123  OperationCreator op_creator_;
124 };
125 
126 TORCH_API std::string canonicalSchemaString(const FunctionSchema& schema);
127 
128 TORCH_API const std::vector<std::shared_ptr<Operator>>& getAllOperatorsFor(
129  Symbol name);
130 
131 std::shared_ptr<Operator> findOperatorFor(const Node* node);
132 const Operator& getOperatorFor(const Node* node);
133 
134 inline Operation getOperation(const Node* node) {
135  // note: getOperatorFor ensures that getOperatorFor(node).matches(node) ==
136  // true so the call to selectVariant is always valid.
137  return getOperatorFor(node).getOperation(node);
138 }
139 
140 
141 TORCH_API std::vector<Symbol> findSimilarOperators(Symbol input_op);
142 
143 TORCH_API void registerOperator(Operator&& op);
144 
145 // XXX: this function is meant to be used with string literals only!
146 Operator& sig(const char* signature_literal);
147 
148 struct OperatorSet {
149  OperatorSet(std::initializer_list<const char*> sig_literals);
150  // XXX: Returns a nullptr if no Operator in the set matches n
151  Operator* find(const Node* n) const;
152 
153  private:
154  std::unordered_map<Symbol, std::vector<std::shared_ptr<Operator>>> ops;
155 };
156 
157 } // namespace jit
158 } // namespace torch
Definition: jit_type.h:17