Caffe2 - C++ API
A deep learning, cross platform ML framework
import_source.cpp
1 #include "import_source.h"
2 
3 #include <torch/csrc/jit/script/parser.h>
4 
5 namespace torch {
6 namespace jit {
7 namespace script {
8 
9 // this is a much simpler accessor that only handles modules, parameters, and
10 // and methods. It does not depend on python to work.
12  ModuleAccessorValue(std::shared_ptr<Module> module)
13  : module(std::move(module)) {}
14  std::string kind() const override {
15  return "module";
16  }
17  // select an attribute on it, e.g. `this.field`
18  std::shared_ptr<SugaredValue> attr(
19  const SourceRange& loc,
20  Method& m,
21  const std::string& field) override {
22  if (NamedModule* v = module->find_module(field)) {
23  return std::make_shared<ModuleAccessorValue>(v->module);
24  } else if (NamedIValue* v = module->find_parameter(field)) {
25  return std::make_shared<SimpleValue>(m.get_or_add_parameter(v->slot()));
26  } else if (NamedIValue* v = module->find_buffer(field)) {
27  return std::make_shared<SimpleValue>(m.get_or_add_parameter(v->slot()));
28  } else if (script::NamedIValue* v = module->find_attribute(field)) {
29  return std::make_shared<script::SimpleValue>(
30  m.get_or_add_attribute(v->type, v->slot()));
31  } else if (Method* m = module->find_method(field)) {
32  return std::make_shared<MethodValue>(shared_from_this(), *m);
33  } else {
34  throw ErrorReport(loc) << "unknown attr: " << field;
35  }
36  }
37 
38  private:
39  std::shared_ptr<Module> module;
40 };
41 
42 struct OpsValue : public SugaredValue {
43  OpsValue(size_t version) : version_(version) {}
44  std::string kind() const override {
45  return "ops";
46  }
47  std::shared_ptr<SugaredValue> attr(
48  const SourceRange& loc,
49  Method& m,
50  const std::string& field) override {
51  return std::make_shared<BuiltinModule>(field, version_);
52  }
53  size_t version_;
54 };
55 
56 struct ConstantValue : public SugaredValue {
57  ConstantValue(IValue value) : value_(std::move(value)) {}
58  IValue value_;
59  std::string kind() const override {
60  return "constant";
61  }
62  Value* asValue(const SourceRange& loc, Method& m) override {
63  return m.graph()->insertConstant(value_);
64  }
65 };
66 
67 // This value maps attributes CONSTANTS.c0 CONSTANTS.c1 to entries
68 // in the 'constants' vector. This table is will be stored in a container format
69 // and given to the import_method when restoring the code.
71  ConstantTableValue(ArrayRef<at::Tensor> constants) : constants_(constants) {}
72  std::string kind() const override {
73  return "CONSTANTS";
74  }
75  // select an attribute on it, e.g. `this.field`
76  std::shared_ptr<SugaredValue> attr(
77  const SourceRange& loc,
78  Method& m,
79  const std::string& field) override {
80  const char* field_s = field.c_str();
81  char* end;
82  int64_t offset = std::strtoll(field_s + 1, &end, 10);
83  if (field.size() < 2 || *end != 0)
84  throw ErrorReport(loc) << "invalid constant specifier: " << field;
85  if (offset < 0 || size_t(offset) >= constants_.size()) {
86  throw ErrorReport(loc) << "constant index " << offset
87  << " is out of bounds (constant table has "
88  << constants_.size() << " entries).";
89  }
90  Value* value = m.graph()->insertConstant(constants_[offset], nullptr, loc);
91  return std::make_shared<SimpleValue>(value);
92  }
93 
94  private:
95  ArrayRef<at::Tensor> constants_;
96 };
97 
98 // Helper that contains the state for a parsing a TorchScript source string.
101  const std::string& src,
102  const std::vector<at::Tensor>& constant_table)
103  : parser_(src), constant_table_(constant_table) {
104  const auto version = parseVersionNumber();
105  env_ = {
106  {"torch", std::make_shared<BuiltinModule>("aten", version)},
107  {"ops", std::make_shared<OpsValue>(version)},
108  {"CONSTANTS", std::make_shared<ConstantTableValue>(constant_table)},
109  {"fork", std::make_shared<ForkValue>()},
110  {"annotate", std::make_shared<AnnotateValue>()},
111  {"inf",
112  std::make_shared<ConstantValue>(
113  std::numeric_limits<double>::infinity())},
114  {"nan",
115  std::make_shared<ConstantValue>(
116  std::numeric_limits<double>::quiet_NaN())},
117  };
118 
119  resolver_ = [&](const std::string& name,
120  Method& m,
121  const SourceRange& loc) -> std::shared_ptr<SugaredValue> {
122  auto it = env_.find(name);
123  if (it == env_.end()) {
124  return nullptr;
125  }
126  return it->second;
127  };
128  }
129 
130  Parser parser_;
131  // Constants present in the model. Used to resolve "CONSTANTS.n" to the actual
132  // value
133  const std::vector<at::Tensor>& constant_table_;
134  std::unordered_map<std::string, std::shared_ptr<SugaredValue>> env_;
135  std::function<std::shared_ptr<
136  SugaredValue>(const std::string& name, Method& m, const SourceRange& loc)>
137  resolver_;
138 
139  size_t parseVersionNumber() {
140  auto& L = parser_.lexer();
141  auto range = L.cur().range;
142  auto name = L.expect(TK_IDENT).text();
143  L.expect('=');
144  std::string version_text = L.expect(TK_NUMBER).text();
145  L.expect(TK_NEWLINE);
146  auto version = Const::create(L.cur().range, version_text);
147  if (name != "op_version_set")
148  throw ErrorReport(range) << "expected an assignment to op_version_set";
149  if (!version.isIntegral())
150  throw ErrorReport(range)
151  << "expected an integral version but found " << version.text();
152  return size_t(version.asIntegral());
153  }
154 };
155 
156 void import_methods(
157  const std::shared_ptr<Module>& mod,
158  const std::string& src,
159  const std::vector<at::Tensor>& constant_table) {
160  SourceImporter importer(src, constant_table);
161  auto& p = importer.parser_;
162 
163  std::vector<Def> definitions;
164  std::vector<Resolver> resolvers;
165  while (p.lexer().cur().kind != TK_EOF) {
166  auto def = Def(p.parseFunction(/*is_method=*/true));
167  definitions.emplace_back(def);
168  resolvers.emplace_back(importer.resolver_);
169  }
170  auto self = std::make_shared<ModuleAccessorValue>(mod);
171  defineMethodsInModule(mod, definitions, resolvers, Self(self));
172 }
173 
174 void import_libs(
175  const std::string& src,
176  const std::vector<at::Tensor>& constant_table) {
177  SourceImporter importer(src, constant_table);
178  auto& p = importer.parser_;
179 
180  while (p.lexer().cur().kind != TK_EOF) {
181  std::vector<Def> definitions;
182  std::vector<Resolver> resolvers;
183  auto class_def = ClassDef(p.parseClass());
184  for (const auto& method_def : class_def.defs()) {
185  definitions.emplace_back(method_def);
186  resolvers.emplace_back(importer.resolver_);
187  }
188 
189  auto mod = std::make_shared<Module>();
190  Self self(ClassType::create(class_def.name().name(), mod));
191  defineMethodsInModule(mod, definitions, resolvers, self);
192  }
193 }
194 
195 } // namespace script
196 } // namespace jit
197 } // namespace torch
Definition: jit_type.h:17
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: ArrayRef.h:41