1 #include "import_source.h" 3 #include <torch/csrc/jit/script/parser.h> 13 : module(std::move(module)) {}
14 std::string kind()
const override {
18 std::shared_ptr<SugaredValue> attr(
21 const std::string& field)
override {
23 return std::make_shared<ModuleAccessorValue>(v->module);
24 }
else if (
NamedIValue* v = module->find_parameter(field)) {
25 return std::make_shared<SimpleValue>(m.get_or_add_parameter(v->slot()));
26 }
else if (
NamedIValue* v = module->find_buffer(field)) {
27 return std::make_shared<SimpleValue>(m.get_or_add_parameter(v->slot()));
29 return std::make_shared<script::SimpleValue>(
30 m.get_or_add_attribute(v->type, v->slot()));
31 }
else if (
Method* m = module->find_method(field)) {
32 return std::make_shared<MethodValue>(shared_from_this(), *m);
34 throw ErrorReport(loc) <<
"unknown attr: " << field;
39 std::shared_ptr<Module> module;
43 OpsValue(
size_t version) : version_(version) {}
44 std::string kind()
const override {
47 std::shared_ptr<SugaredValue> attr(
50 const std::string& field)
override {
51 return std::make_shared<BuiltinModule>(field, version_);
59 std::string kind()
const override {
63 return m.graph()->insertConstant(value_);
72 std::string kind()
const override {
76 std::shared_ptr<SugaredValue> attr(
79 const std::string& field)
override {
80 const char* field_s = field.c_str();
82 int64_t offset = std::strtoll(field_s + 1, &end, 10);
83 if (field.size() < 2 || *end != 0)
84 throw ErrorReport(loc) <<
"invalid constant specifier: " << field;
85 if (offset < 0 ||
size_t(offset) >= constants_.size()) {
86 throw ErrorReport(loc) <<
"constant index " << offset
87 <<
" is out of bounds (constant table has " 88 << constants_.size() <<
" entries).";
90 Value* value = m.graph()->insertConstant(constants_[offset],
nullptr, loc);
91 return std::make_shared<SimpleValue>(value);
101 const std::string& src,
102 const std::vector<at::Tensor>& constant_table)
103 : parser_(src), constant_table_(constant_table) {
104 const auto version = parseVersionNumber();
106 {
"torch", std::make_shared<BuiltinModule>(
"aten", version)},
107 {
"ops", std::make_shared<OpsValue>(version)},
108 {
"CONSTANTS", std::make_shared<ConstantTableValue>(constant_table)},
109 {
"fork", std::make_shared<ForkValue>()},
110 {
"annotate", std::make_shared<AnnotateValue>()},
112 std::make_shared<ConstantValue>(
113 std::numeric_limits<double>::infinity())},
115 std::make_shared<ConstantValue>(
116 std::numeric_limits<double>::quiet_NaN())},
119 resolver_ = [&](
const std::string& name,
121 const SourceRange& loc) -> std::shared_ptr<SugaredValue> {
122 auto it = env_.find(name);
123 if (it == env_.end()) {
133 const std::vector<at::Tensor>& constant_table_;
134 std::unordered_map<std::string, std::shared_ptr<SugaredValue>> env_;
135 std::function<std::shared_ptr<
139 size_t parseVersionNumber() {
140 auto& L = parser_.lexer();
141 auto range = L.cur().range;
142 auto name = L.expect(TK_IDENT).text();
144 std::string version_text = L.expect(TK_NUMBER).text();
145 L.expect(TK_NEWLINE);
146 auto version = Const::create(L.cur().range, version_text);
147 if (name !=
"op_version_set")
148 throw ErrorReport(range) <<
"expected an assignment to op_version_set";
149 if (!version.isIntegral())
151 <<
"expected an integral version but found " << version.text();
152 return size_t(version.asIntegral());
157 const std::shared_ptr<Module>& mod,
158 const std::string& src,
159 const std::vector<at::Tensor>& constant_table) {
161 auto& p = importer.parser_;
163 std::vector<Def> definitions;
164 std::vector<Resolver> resolvers;
165 while (p.lexer().cur().kind != TK_EOF) {
166 auto def =
Def(p.parseFunction(
true));
167 definitions.emplace_back(def);
168 resolvers.emplace_back(importer.resolver_);
170 auto self = std::make_shared<ModuleAccessorValue>(mod);
171 defineMethodsInModule(mod, definitions, resolvers,
Self(
self));
175 const std::string& src,
176 const std::vector<at::Tensor>& constant_table) {
178 auto& p = importer.parser_;
180 while (p.lexer().cur().kind != TK_EOF) {
181 std::vector<Def> definitions;
182 std::vector<Resolver> resolvers;
183 auto class_def =
ClassDef(p.parseClass());
184 for (
const auto& method_def : class_def.defs()) {
185 definitions.emplace_back(method_def);
186 resolvers.emplace_back(importer.resolver_);
189 auto mod = std::make_shared<Module>();
190 Self self(ClassType::create(class_def.name().name(), mod));
191 defineMethodsInModule(mod, definitions, resolvers,
self);
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...