1 #include <torch/csrc/jit/script/sugared_value.h> 2 #include <torch/csrc/jit/ir.h> 3 #include <torch/csrc/jit/script/schema_matching.h> 4 #include <torch/csrc/jit/script/tree_views.h> 12 std::string kind()
const override {
17 std::shared_ptr<SugaredValue> PrintValue::call(
24 if (!attributes.
empty())
25 throw ErrorReport(loc) <<
"print doesn't accept any keyword arguments";
30 std::vector<Value*> lowered_inputs = toValues(*m.graph(), inputs);
31 if (lowered_inputs.size() == 1 &&
32 lowered_inputs.at(0)->node()->kind() == prim::TupleConstruct) {
33 auto input = lowered_inputs[0];
34 for (
size_t j = 0; j < input->node()->inputs().size(); ++j) {
35 lowered_inputs.insert(
36 lowered_inputs.begin() + 1 + j, input->node()->inputs().at(j));
38 lowered_inputs.erase(lowered_inputs.begin());
40 g.insertNode(g.create(prim::Print, lowered_inputs, 0)
41 ->setSourceLocation(std::make_shared<SourceRange>(loc)));
42 return std::make_shared<NoneValue>();
45 static const std::unordered_map<std::string, std::string>&
46 builtin_cast_methods() {
47 static std::unordered_map<std::string, std::string> builtin_cast_methods = {
48 {
"byte",
"_cast_Byte"},
49 {
"char",
"_cast_Char"},
50 {
"double",
"_cast_Double"},
51 {
"float",
"_cast_Float"},
53 {
"long",
"_cast_Long"},
54 {
"short",
"_cast_Short"},
55 {
"half",
"_cast_Half"}};
56 return builtin_cast_methods;
59 std::shared_ptr<SugaredValue> BuiltinFunction::call(
65 return std::make_shared<SimpleValue>(
66 emitBuiltinCall(loc, *m.graph(), symbol,
self, inputs, attributes,
true));
71 std::shared_ptr<SugaredValue> SimpleValue::attr(
74 const std::string& field) {
76 if (value_->type()->isSubtypeOf(TensorType::get())) {
77 if (builtin_cast_methods().count(field)) {
78 return std::make_shared<BuiltinFunction>(
79 Symbol::aten(builtin_cast_methods().
at(field)),
84 static const std::unordered_set<std::string> fields = {
91 if (fields.count(field)) {
93 m.graph()->insert(Symbol::fromQualString(
"prim::" + field), {value_});
94 return std::make_shared<SimpleValue>(r);
97 if (value_->type()->isSubtypeOf(NumberType::get())) {
98 throw ErrorReport(loc) <<
"Cannot call methods on numbers";
100 if (
auto tuple_type = value_->type()->cast<
TupleType>()) {
101 if (!tuple_type->hasNames()) {
102 throw ErrorReport(loc) <<
"Getting attributes of tuples is not supported";
104 auto names = tuple_type->names();
105 for (
size_t i = 0; i < names.size(); i++) {
106 if (names[i] == field) {
108 ->insertNode(m.graph()->createTupleIndex(value_, i))
110 return std::make_shared<SimpleValue>(r);
113 throw ErrorReport(loc) <<
"Unknown attribute to named tuple";
116 if (
auto classType = value_->type()->cast<
ClassType>()) {
118 if (
auto method = classType->getMethod(field)) {
119 return std::make_shared<MethodValue>(shared_from_this(), *method);
122 if (!classType->hasAttribute(field)) {
124 <<
"Tried to access to nonexistent attribute " << field
125 <<
". Did you forget to initialize it in __init__()?";
127 auto& g = *m.graph();
128 auto n = g.insertNode(g.createGetAttr(value_, field));
129 return std::make_shared<SimpleValue>(n->output());
132 return std::make_shared<BuiltinFunction>(
133 Symbol::aten(field),
NamedValue(loc,
"self", value_));
136 std::vector<std::shared_ptr<SugaredValue>> SimpleValue::asTuple(
140 static const auto make_simple_value =
141 [](
Value* v) -> std::shared_ptr<SugaredValue> {
142 return std::make_shared<SimpleValue>(v);
144 if (value_->type()->kind() == TypeKind::TupleType) {
145 auto outputs = createTupleUnpack(value_);
146 return fmap(outputs, make_simple_value);
147 }
else if (value_->type()->kind() == TypeKind::ListType) {
150 <<
"cannot statically infer the expected size of a " 151 <<
"list in this context";
153 auto graph = value_->owningGraph();
155 graph->insertNode(graph->createListUnpack(value_, *size_hint));
156 return fmap(unpack->outputs(), make_simple_value);
159 <<
" cannot be used as a tuple";
162 void SimpleValue::setAttr(
165 const std::string& field,
167 const auto classType = value_->type()->cast<
ClassType>();
169 throw ErrorReport(loc) <<
"Tried to set an attribute: " << field
170 <<
" on a non-class: " << value_->type()->str();
172 auto expectedType = classType->getAttribute(field);
178 const auto isInitializing =
180 m.name() ==
"__init__" &&
183 !m.graph()->inputs().empty() &&
184 m.graph()->inputs().at(0)->type() == classType;
186 if (isInitializing) {
187 classType->addAttribute(field, newValue->type());
188 expectedType = newValue->type();
190 const auto insertPoint = m.graph()->insertPoint();
191 const auto topLevelBlock = m.graph()->block();
192 if (insertPoint->owningBlock() != topLevelBlock) {
194 <<
"First assignment cannot be in a control-flow block. " 195 <<
"Initialize the field at the top level first.";
199 <<
"Tried to set nonexistent attribute: " << field
200 <<
". Did you forget to initialize it in __init__()?";
204 AT_ASSERT(expectedType);
207 const auto newType = newValue->type();
208 if (!newType->isSubtypeOf(expectedType)) {
209 throw ErrorReport(loc) <<
"Wrong type for attribute assignment. Expected " 210 << expectedType->str() <<
" but got " 214 auto& g = *m.graph();
215 g.insertNode(g.createSetAttr(value_, field, newValue));
218 std::shared_ptr<SugaredValue> ClassValue::call(
225 AT_ASSERT(n_binders <= 1);
228 auto& g = *m.graph();
229 auto createNode = g.insertNode(g.createObject(type_));
230 auto self = std::make_shared<SimpleValue>(createNode->output());
232 auto initMethod = type_->getMethod(
"__init__");
233 AT_ASSERT(initMethod);
236 MethodValue(
self, *initMethod).call(loc, m, inputs, attributes, n_binders);
241 std::shared_ptr<SugaredValue> ClassValue::attr(
244 const std::string& field) {
245 if (field !=
"__new__") {
246 throw ErrorReport(loc) <<
"Tried to lookup unknown attribute on class";
248 return std::make_shared<ClassNewMethod>(type_);
constexpr bool empty() const
empty - Check if the array is empty.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Flush-To-Zero and Denormals-Are-Zero mode.