Caffe2 - C++ API
A deep learning, cross platform ML framework
sugared_value.cpp
1 #include <torch/csrc/jit/script/sugared_value.h>
2 #include <torch/csrc/jit/ir.h>
3 #include <torch/csrc/jit/script/schema_matching.h>
4 #include <torch/csrc/jit/script/tree_views.h>
5 
6 namespace torch {
7 namespace jit {
8 namespace script {
9 
11  NoneValue() = default;
12  std::string kind() const override {
13  return "None";
14  }
15 };
16 
17 std::shared_ptr<SugaredValue> PrintValue::call(
18  const SourceRange& loc,
19  Method& m,
21  at::ArrayRef<NamedValue> attributes,
22  size_t n_binders) {
23  auto& g = *m.graph();
24  if (!attributes.empty())
25  throw ErrorReport(loc) << "print doesn't accept any keyword arguments";
26 
27  // temporary hack to allow print statements to work in python 2, where
28  // print(a, b) is treated as a (a, b) tuple input.
29 
30  std::vector<Value*> lowered_inputs = toValues(*m.graph(), inputs);
31  if (lowered_inputs.size() == 1 &&
32  lowered_inputs.at(0)->node()->kind() == prim::TupleConstruct) {
33  auto input = lowered_inputs[0];
34  for (size_t j = 0; j < input->node()->inputs().size(); ++j) {
35  lowered_inputs.insert(
36  lowered_inputs.begin() + 1 + j, input->node()->inputs().at(j));
37  }
38  lowered_inputs.erase(lowered_inputs.begin());
39  }
40  g.insertNode(g.create(prim::Print, lowered_inputs, 0)
41  ->setSourceLocation(std::make_shared<SourceRange>(loc)));
42  return std::make_shared<NoneValue>();
43 }
44 
45 static const std::unordered_map<std::string, std::string>&
46 builtin_cast_methods() {
47  static std::unordered_map<std::string, std::string> builtin_cast_methods = {
48  {"byte", "_cast_Byte"},
49  {"char", "_cast_Char"},
50  {"double", "_cast_Double"},
51  {"float", "_cast_Float"},
52  {"int", "_cast_Int"},
53  {"long", "_cast_Long"},
54  {"short", "_cast_Short"},
55  {"half", "_cast_Half"}};
56  return builtin_cast_methods;
57 }
58 
59 std::shared_ptr<SugaredValue> BuiltinFunction::call(
60  const SourceRange& loc,
61  Method& m,
63  at::ArrayRef<NamedValue> attributes,
64  size_t n_binders) {
65  return std::make_shared<SimpleValue>(
66  emitBuiltinCall(loc, *m.graph(), symbol, self, inputs, attributes, true));
67 }
68 
69 // support syntax sugar for x.foo(y, z) by allowing x.foo to return a
70 // callable value that will resolve to foo(x, y, z) when called.
71 std::shared_ptr<SugaredValue> SimpleValue::attr(
72  const SourceRange& loc,
73  Method& m,
74  const std::string& field) {
75  // Allow method-style casts on Tensor types. e.g. x.int()
76  if (value_->type()->isSubtypeOf(TensorType::get())) {
77  if (builtin_cast_methods().count(field)) {
78  return std::make_shared<BuiltinFunction>(
79  Symbol::aten(builtin_cast_methods().at(field)),
80  NamedValue(loc, "self", value_));
81  }
82  // functions that are just direct property lookups on tensor
83  // must be registered as prim::<name>(Tensor t) -> <return_type>
84  static const std::unordered_set<std::string> fields = {
85  "dtype",
86  "device",
87  "shape",
88  "is_cuda",
89  "requires_grad",
90  };
91  if (fields.count(field)) {
92  auto r =
93  m.graph()->insert(Symbol::fromQualString("prim::" + field), {value_});
94  return std::make_shared<SimpleValue>(r);
95  }
96  }
97  if (value_->type()->isSubtypeOf(NumberType::get())) {
98  throw ErrorReport(loc) << "Cannot call methods on numbers";
99  }
100  if (auto tuple_type = value_->type()->cast<TupleType>()) {
101  if (!tuple_type->hasNames()) {
102  throw ErrorReport(loc) << "Getting attributes of tuples is not supported";
103  }
104  auto names = tuple_type->names();
105  for (size_t i = 0; i < names.size(); i++) {
106  if (names[i] == field) {
107  auto r = m.graph()
108  ->insertNode(m.graph()->createTupleIndex(value_, i))
109  ->output();
110  return std::make_shared<SimpleValue>(r);
111  }
112  }
113  throw ErrorReport(loc) << "Unknown attribute to named tuple";
114  }
115 
116  if (auto classType = value_->type()->cast<ClassType>()) {
117  // This is a class, emit the proper attribute lookup
118  if (auto method = classType->getMethod(field)) {
119  return std::make_shared<MethodValue>(shared_from_this(), *method);
120  }
121 
122  if (!classType->hasAttribute(field)) {
123  throw ErrorReport(loc)
124  << "Tried to access to nonexistent attribute " << field
125  << ". Did you forget to initialize it in __init__()?";
126  }
127  auto& g = *m.graph();
128  auto n = g.insertNode(g.createGetAttr(value_, field));
129  return std::make_shared<SimpleValue>(n->output());
130  }
131 
132  return std::make_shared<BuiltinFunction>(
133  Symbol::aten(field), NamedValue(loc, "self", value_));
134 }
135 
136 std::vector<std::shared_ptr<SugaredValue>> SimpleValue::asTuple(
137  const SourceRange& loc,
138  Method& m,
139  const c10::optional<size_t>& size_hint) {
140  static const auto make_simple_value =
141  [](Value* v) -> std::shared_ptr<SugaredValue> {
142  return std::make_shared<SimpleValue>(v);
143  };
144  if (value_->type()->kind() == TypeKind::TupleType) {
145  auto outputs = createTupleUnpack(value_);
146  return fmap(outputs, make_simple_value);
147  } else if (value_->type()->kind() == TypeKind::ListType) {
148  if (!size_hint) {
149  throw ErrorReport(loc)
150  << "cannot statically infer the expected size of a "
151  << "list in this context";
152  }
153  auto graph = value_->owningGraph();
154  Node* unpack =
155  graph->insertNode(graph->createListUnpack(value_, *size_hint));
156  return fmap(unpack->outputs(), make_simple_value);
157  }
158  throw ErrorReport(loc) << value_->type()->str()
159  << " cannot be used as a tuple";
160 }
161 
162 void SimpleValue::setAttr(
163  const SourceRange& loc,
164  Method& m,
165  const std::string& field,
166  Value* newValue) {
167  const auto classType = value_->type()->cast<ClassType>();
168  if (!classType) {
169  throw ErrorReport(loc) << "Tried to set an attribute: " << field
170  << " on a non-class: " << value_->type()->str();
171  }
172  auto expectedType = classType->getAttribute(field);
173  if (!expectedType) {
174  // If we are still compiling the __init__ method for this class, then
175  // setting an unknown attribute adds it to the class's definition.
176 
177  // We are initializing if:
178  const auto isInitializing =
179  // 1. The method we're currently inserting into is an init method
180  m.name() == "__init__" &&
181  // 2. The `self` arg matches this value's type (i.e. we are in the init
182  // method for this class, not some other class)
183  !m.graph()->inputs().empty() &&
184  m.graph()->inputs().at(0)->type() == classType;
185 
186  if (isInitializing) {
187  classType->addAttribute(field, newValue->type());
188  expectedType = newValue->type();
189 
190  const auto insertPoint = m.graph()->insertPoint();
191  const auto topLevelBlock = m.graph()->block();
192  if (insertPoint->owningBlock() != topLevelBlock) {
193  throw ErrorReport(loc)
194  << "First assignment cannot be in a control-flow block. "
195  << "Initialize the field at the top level first.";
196  }
197  } else {
198  throw ErrorReport(loc)
199  << "Tried to set nonexistent attribute: " << field
200  << ". Did you forget to initialize it in __init__()?";
201  }
202  }
203 
204  AT_ASSERT(expectedType);
205 
206  // Check type correctness
207  const auto newType = newValue->type();
208  if (!newType->isSubtypeOf(expectedType)) {
209  throw ErrorReport(loc) << "Wrong type for attribute assignment. Expected "
210  << expectedType->str() << " but got "
211  << newType->str();
212  }
213 
214  auto& g = *m.graph();
215  g.insertNode(g.createSetAttr(value_, field, newValue));
216 }
217 
218 std::shared_ptr<SugaredValue> ClassValue::call(
219  const SourceRange& loc,
220  Method& m,
221  // note: names for args will be 'argument 0', 'argument 1', etc..
223  at::ArrayRef<NamedValue> attributes,
224  size_t n_binders) {
225  AT_ASSERT(n_binders <= 1);
226 
227  // Generate a new object of the right type, then call `__init__` on it
228  auto& g = *m.graph();
229  auto createNode = g.insertNode(g.createObject(type_));
230  auto self = std::make_shared<SimpleValue>(createNode->output());
231 
232  auto initMethod = type_->getMethod("__init__");
233  AT_ASSERT(initMethod);
234 
235  // Call the init function
236  MethodValue(self, *initMethod).call(loc, m, inputs, attributes, n_binders);
237 
238  return self;
239 }
240 
241 std::shared_ptr<SugaredValue> ClassValue::attr(
242  const SourceRange& loc,
243  Method& m,
244  const std::string& field) {
245  if (field != "__new__") {
246  throw ErrorReport(loc) << "Tried to lookup unknown attribute on class";
247  }
248  return std::make_shared<ClassNewMethod>(type_);
249 }
250 } // namespace script
251 } // namespace jit
252 } // namespace torch
constexpr bool empty() const
empty - Check if the array is empty.
Definition: ArrayRef.h:129
Definition: jit_type.h:17
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: ArrayRef.h:41
Flush-To-Zero and Denormals-Are-Zero mode.