Caffe2 - C++ API
A deep learning, cross platform ML framework
sugared_value.h
1 #pragma once
2 #include <functional>
3 #include <memory>
4 #include <string>
5 
6 #include <torch/csrc/jit/ir.h>
7 #include <torch/csrc/jit/script/error_report.h>
8 #include <torch/csrc/jit/script/module.h>
9 #include <torch/csrc/jit/script/tree_views.h>
10 
11 namespace torch {
12 namespace jit {
13 namespace script {
14 
15 // The AST can contain nodes like `self`, `self.b` or `python_fn` that
16 // are not first-class values in the graph representation, but instead
17 // will be desugared based on how they are used in the AST.
18 
19 // SugaredValue is used to temporarily represent these values in a way
20 // that separates their behavior from the AST -> IR converter itself.
21 // This allows us to keep dependencies on python minimal.
22 
23 enum NoneStatus { ALWAYS, MAYBE, NEVER };
24 
25 struct SugaredValue : public std::enable_shared_from_this<SugaredValue> {
26  // what is this node? for error reporting (e.g. Module, python function)
27  virtual std::string kind() const = 0;
28 
29  // what can we do with this thing?
30  // use it as a value e.g. `this + 4`
31  virtual Value* asValue(const SourceRange& loc, Method& m) {
32  throw ErrorReport(loc) << kind() << " cannot be used as a value";
33  }
34 
35  // select an attribute on it, e.g. `this.field`
36  virtual std::shared_ptr<SugaredValue> attr(
37  const SourceRange& loc,
38  Method& m,
39  const std::string& field) {
40  throw ErrorReport(loc) << "attribute lookup is not defined on " << kind();
41  }
42 
43  // assign an attribute on it, e.g. `this.field = newValue`
44  virtual void setAttr(
45  const SourceRange& loc,
46  Method& m,
47  const std::string& field,
48  Value* newValue) {
49  throw ErrorReport(loc) << "attribute assignment is not defined on "
50  << kind();
51  }
52  virtual NoneStatus isNone() {
53  return NEVER;
54  }
55 
56  // use it as a vector of values, e.g. a tuple of values as return value from
57  // a method invocation
58  virtual std::vector<std::shared_ptr<SugaredValue>> asTuple(
59  const SourceRange& loc,
60  Method& m,
61  const c10::optional<size_t>& size_hint = {}) {
62  throw ErrorReport(loc) << kind() << " cannot be used as a tuple";
63  }
64 
65  // call it like a function, e.g. `outputs = this(inputs)`
66  virtual std::shared_ptr<SugaredValue> call(
67  const SourceRange& loc,
68  Method& m,
69  // note: names for args will be 'argument 0', 'argument 1', etc..
71  at::ArrayRef<NamedValue> attributes,
72  size_t n_binders) {
73  // n_binders is always set to the number of variables an expression is
74  // syntactically bound to:
75  // a = foo() # 1 binder (note in this case the single binder might be a
76  // tuple) a, * b = foo() # 1 binder a, b = foo() # 2 binders foo() # 0
77  // binders
78  //
79  // In subexpressions, like bar() in foo(bar()), n_binders is always set to
80  // 1. n_binders is used as a hint to subexpressions to determine how many
81  // values they should return when that number is ambiguous statically. In
82  // particular it is currently used to decide how many tensors a call to a
83  // python function will return. It is only a hint, functions do not have to
84  // check that n_binders match the number of things they are returning, the
85  // assignment logic will do that anyway.
86 
87  throw ErrorReport(loc) << "cannot call a " << kind();
88  }
89 
90  virtual ~SugaredValue() = default;
91 };
92 
93 // most things in the environment are just simple value types
94 // and not special python syntax sugar types
95 struct TORCH_API SimpleValue : public SugaredValue {
96  SimpleValue(Value* value) : value_(value) {}
97  std::string kind() const override {
98  return "value";
99  }
100  Value* asValue(const SourceRange& range, Method& m) override {
101  return value_;
102  }
103  NoneStatus isNone() override {
104  if (value_->mustBeNone())
105  return ALWAYS;
106  else if (value_->type()->cast<OptionalType>())
107  return MAYBE;
108  else
109  return NEVER;
110  }
111  std::vector<std::shared_ptr<SugaredValue>> asTuple(
112  const SourceRange& loc,
113  Method& m,
114  const c10::optional<size_t>& size_hint = {}) override;
115  std::shared_ptr<SugaredValue> attr(
116  const SourceRange& loc,
117  Method& m,
118  const std::string& field) override;
119 
120  void setAttr(
121  const SourceRange& loc,
122  Method& m,
123  const std::string& field,
124  Value* newValue) override;
125 
126  Value* getValue() const {
127  return value_;
128  }
129 
130  private:
131  Value* value_;
132 };
133 
134 struct TORCH_API BuiltinFunction : public SugaredValue {
136  : symbol(symbol), self(std::move(self)) {}
137 
138  // The symbol of the function (e.g. `aten::relu`).
139  Symbol symbol;
140 
141  // if this is method, then this is the self argument.
143 
144  std::string kind() const override {
145  return "builtin";
146  }
147  std::shared_ptr<SugaredValue> call(
148  const SourceRange& loc,
149  Method& m,
150  at::ArrayRef<NamedValue> attributes,
152  size_t n_binders) override;
153 };
154 
155 struct TORCH_API BuiltinModule : public SugaredValue {
156  BuiltinModule(std::string name, c10::optional<int64_t> version = at::nullopt)
157  : name(std::move(name)), version(std::move(version)) {}
158 
159  std::string kind() const override {
160  return "builtin module";
161  }
162  std::shared_ptr<SugaredValue> attr(
163  const SourceRange& loc,
164  Method& m,
165  const std::string& field) override {
166  return std::make_shared<BuiltinFunction>(
167  Symbol::fromQualString(name + "::" + field), c10::nullopt);
168  }
169 
170  private:
171  std::string name;
172  // when we add operator versioning, emit this op as it exising at 'version'
173  // if not set, use the latest version
174  c10::optional<int64_t> version;
175 };
176 
177 // Represents a class, analagous to `int` or `dict`. Instances of classes,
178 // like `1` or `{"foo": 5}`, are represented as SimpleValues
179 struct TORCH_API ClassValue : public SugaredValue {
180  explicit ClassValue(ClassTypePtr type) : type_(std::move(type)) {}
181 
182  // Call the type's constructor, as in:
183  // n = Foo(constructor_arg)
184  std::shared_ptr<SugaredValue> call(
185  const SourceRange& loc,
186  Method& m,
188  at::ArrayRef<NamedValue> attributes,
189  size_t n_binders) override;
190 
191  std::shared_ptr<SugaredValue> attr(
192  const SourceRange& loc,
193  Method& m,
194  const std::string& field) override;
195 
196  std::string kind() const override {
197  return type_->str();
198  }
199 
200  ClassTypePtr type_;
201 };
202 
203 // defines how a method obtained from a module behaves in script
204 struct MethodValue : public SugaredValue {
205  MethodValue(std::shared_ptr<SugaredValue> self, Method& method)
206  : self_(std::move(self)), method(method) {}
207  std::string kind() const override {
208  return "method";
209  }
210  std::shared_ptr<SugaredValue> call(
211  const SourceRange& loc,
212  Method& caller,
214  at::ArrayRef<NamedValue> attributes,
215  size_t n_binders) override {
216  if (auto classType = dynamic_cast<SimpleValue*>(self_.get())) {
217  // If self_ is a class, then it will be expected as part of
218  // the schema. Add it to the front of the inputs.
219  std::vector<NamedValue> inputsWithSelf;
220  inputsWithSelf.emplace_back(loc, classType->getValue());
221  inputsWithSelf.insert(inputsWithSelf.end(), inputs.begin(), inputs.end());
222  return std::make_shared<SimpleValue>(
223  caller.emit_call_to(loc, method, inputsWithSelf, attributes));
224  }
225 
226  return std::make_shared<SimpleValue>(
227  caller.emit_call_to(loc, method, inputs, attributes));
228  }
229 
230  private:
231  std::shared_ptr<SugaredValue> self_;
232  Method& method;
233 };
234 
235 struct TORCH_API PrintValue : public SugaredValue {
236  std::string kind() const override {
237  return "print";
238  }
239  std::shared_ptr<SugaredValue> call(
240  const SourceRange& loc,
241  Method& m,
243  at::ArrayRef<NamedValue> attributes,
244  size_t n_binders) override;
245 };
246 
247 // expressions like int(x)
248 // these are the same as call prim::Int or equivalent except it
249 // is a noop when the input is a subtype of 'type'
250 struct TORCH_API CastValue : public BuiltinFunction {
251  CastValue(TypePtr type, c10::Symbol method)
252  : BuiltinFunction(method, c10::nullopt), type_(std::move(type)) {}
253  std::shared_ptr<SugaredValue> call(
254  const SourceRange& loc,
255  Method& m,
257  at::ArrayRef<NamedValue> attributes,
258  size_t n_binders) override {
259  if (inputs.size() == 1 && attributes.size() == 0) {
260  auto v = inputs[0].value(*m.graph());
261  if (v->type()->isSubtypeOf(type_)) {
262  return std::make_shared<SimpleValue>(v);
263  }
264  }
265  return BuiltinFunction::call(loc, m, inputs, attributes, n_binders);
266  }
267 
268  private:
269  TypePtr type_;
270 };
271 
272 // These SugaredValues have special handling in the compiler because they
273 // change the normal evalution order of the expression they participate in.
274 // They are exposed here so that the python frontend can inject them
275 // when it sees the equivalent thing in python
276 
277 struct TORCH_API ForkValue : public SugaredValue {
278  ForkValue() = default;
279  std::string kind() const override {
280  return "fork";
281  }
282 };
283 struct TORCH_API AnnotateValue : public SugaredValue {
284  AnnotateValue() = default;
285  std::string kind() const override {
286  return "annotate";
287  }
288 };
289 
290 // matched against for special handling of getattr expressions
291 struct TORCH_API GetAttrValue : SugaredValue {
292  GetAttrValue() = default;
293  std::string kind() const override {
294  return "getattr";
295  }
296 };
297 
298 // matched against for special handling of isinstance expressions
299 struct TORCH_API IsInstanceValue : SugaredValue {
300  IsInstanceValue() = default;
301  std::string kind() const override {
302  return "isinstance";
303  }
304 };
305 
306 // This represents the "__new__" method on classes, which can't be a MethodValue
307 // because it takes a ClassValue as input.
308 // So if we see:
309 // Foo.__new__(Foo)
310 // Foo is a ClassValue, calling `attr("__new__")` will return a ClassNewMethod.
311 struct TORCH_API ClassNewMethod : public SugaredValue {
312  ClassNewMethod(ClassTypePtr type) : type_(type) {}
313  std::string kind() const override {
314  return "class.__new__";
315  }
316 
317  std::shared_ptr<SugaredValue> createObject(
318  const SourceRange& loc,
319  Method& m,
320  const std::string& classname) {
321  if (classname != type_->name()) {
322  throw ErrorReport(loc)
323  << "Argument to __new__() must match the class "
324  << "you are calling __new__() on. "
325  << "Got: " << classname << ", expected: " << type_->name();
326  }
327  auto& g = *m.graph();
328  auto createNode = g.insertNode(g.createObject(type_));
329  return std::make_shared<SimpleValue>(createNode->output());
330  }
331 
332  ClassTypePtr type_;
333 };
334 
335 static inline std::vector<Value*> toValues(
336  Graph& g,
338  return fmap(nvs, [&](const NamedValue& v) { return v.value(g); });
339 }
340 
341 } // namespace script
342 } // namespace jit
343 } // namespace torch
constexpr size_t size() const
size - Get the array size.
Definition: ArrayRef.h:138
Definition: jit_type.h:17
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: ArrayRef.h:41