1 #include "import_source.h"     3 #include <torch/csrc/jit/script/parser.h>    13       : module(std::move(module)) {}
    14   std::string kind()
 const override {
    18   std::shared_ptr<SugaredValue> attr(
    21       const std::string& field)
 override {
    23       return std::make_shared<ModuleAccessorValue>(v->module);
    24     } 
else if (
NamedIValue* v = module->find_parameter(field)) {
    25       return std::make_shared<SimpleValue>(m.get_or_add_parameter(v->slot()));
    26     } 
else if (
NamedIValue* v = module->find_buffer(field)) {
    27       return std::make_shared<SimpleValue>(m.get_or_add_parameter(v->slot()));
    29       return std::make_shared<script::SimpleValue>(
    30           m.get_or_add_attribute(v->type, v->slot()));
    31     } 
else if (
Method* m = module->find_method(field)) {
    32       return std::make_shared<MethodValue>(shared_from_this(), *m);
    34       throw ErrorReport(loc) << 
"unknown attr: " << field;
    39   std::shared_ptr<Module> module;
    43   OpsValue(
size_t version) : version_(version) {}
    44   std::string kind()
 const override {
    47   std::shared_ptr<SugaredValue> attr(
    50       const std::string& field)
 override {
    51     return std::make_shared<BuiltinModule>(field, version_);
    59   std::string kind()
 const override {
    63     return m.graph()->insertConstant(value_);
    72   std::string kind()
 const override {
    76   std::shared_ptr<SugaredValue> attr(
    79       const std::string& field)
 override {
    80     const char* field_s = field.c_str();
    82     int64_t offset = std::strtoll(field_s + 1, &end, 10);
    83     if (field.size() < 2 || *end != 0)
    84       throw ErrorReport(loc) << 
"invalid constant specifier: " << field;
    85     if (offset < 0 || 
size_t(offset) >= constants_.size()) {
    86       throw ErrorReport(loc) << 
"constant index " << offset
    87                              << 
" is out of bounds (constant table has "    88                              << constants_.size() << 
" entries).";
    90     Value* value = m.graph()->insertConstant(constants_[offset], 
nullptr, loc);
    91     return std::make_shared<SimpleValue>(value);
   101       const std::string& src,
   102       const std::vector<at::Tensor>& constant_table)
   103       : parser_(src), constant_table_(constant_table) {
   104     const auto version = parseVersionNumber();
   106         {
"torch", std::make_shared<BuiltinModule>(
"aten", version)},
   107         {
"ops", std::make_shared<OpsValue>(version)},
   108         {
"CONSTANTS", std::make_shared<ConstantTableValue>(constant_table)},
   109         {
"fork", std::make_shared<ForkValue>()},
   110         {
"annotate", std::make_shared<AnnotateValue>()},
   112          std::make_shared<ConstantValue>(
   113              std::numeric_limits<double>::infinity())},
   115          std::make_shared<ConstantValue>(
   116              std::numeric_limits<double>::quiet_NaN())},
   119     resolver_ = [&](
const std::string& name,
   121                     const SourceRange& loc) -> std::shared_ptr<SugaredValue> {
   122       auto it = env_.find(name);
   123       if (it == env_.end()) {
   133   const std::vector<at::Tensor>& constant_table_;
   134   std::unordered_map<std::string, std::shared_ptr<SugaredValue>> env_;
   135   std::function<std::shared_ptr<
   139   size_t parseVersionNumber() {
   140     auto& L = parser_.lexer();
   141     auto range = L.cur().range;
   142     auto name = L.expect(TK_IDENT).text();
   144     std::string version_text = L.expect(TK_NUMBER).text();
   145     L.expect(TK_NEWLINE);
   146     auto version = Const::create(L.cur().range, version_text);
   147     if (name != 
"op_version_set")
   148       throw ErrorReport(range) << 
"expected an assignment to op_version_set";
   149     if (!version.isIntegral())
   151           << 
"expected an integral version but found " << version.text();
   152     return size_t(version.asIntegral());
   157     const std::shared_ptr<Module>& mod,
   158     const std::string& src,
   159     const std::vector<at::Tensor>& constant_table) {
   161   auto& p = importer.parser_;
   163   std::vector<Def> definitions;
   164   std::vector<Resolver> resolvers;
   165   while (p.lexer().cur().kind != TK_EOF) {
   166     auto def = 
Def(p.parseFunction(
true));
   167     definitions.emplace_back(def);
   168     resolvers.emplace_back(importer.resolver_);
   170   auto self = std::make_shared<ModuleAccessorValue>(mod);
   171   defineMethodsInModule(mod, definitions, resolvers, 
Self(
self));
   175     const std::string& src,
   176     const std::vector<at::Tensor>& constant_table) {
   178   auto& p = importer.parser_;
   180   while (p.lexer().cur().kind != TK_EOF) {
   181     std::vector<Def> definitions;
   182     std::vector<Resolver> resolvers;
   183     auto class_def = 
ClassDef(p.parseClass());
   184     for (
const auto& method_def : class_def.defs()) {
   185       definitions.emplace_back(method_def);
   186       resolvers.emplace_back(importer.resolver_);
   189     auto mod = std::make_shared<Module>();
   190     Self self(ClassType::create(class_def.name().name(), mod));
   191     defineMethodsInModule(mod, definitions, resolvers, 
self);
 
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...