Caffe2 - C++ API
A deep learning, cross platform ML framework
attributes.h
1 #pragma once
2 #include <ATen/ATen.h>
3 #include <string>
4 #include <vector>
5 
6 #include <ATen/core/interned_strings.h>
7 
8 #include <torch/csrc/WindowsTorchApiMacro.h>
9 
10 namespace torch {
11 namespace jit {
12 
13 using ::c10::Symbol;
14 
15 constexpr int max_tensor_display_size = 10;
16 
17 enum class AttributeKind { f, fs, i, is, s, ss, t, ts, g, gs };
18 static inline const char* toString(AttributeKind kind) {
19  static const char* names[] = {
20  "f", "fs", "i", "is", "s", "ss", "t", "ts", "g", "gs"};
21  AT_ASSERT(size_t(kind) < sizeof(names) / sizeof(AttributeKind));
22  return names[int(kind)];
23 }
24 
26  AttributeValue(Symbol name) : name(name) {}
27  using Ptr = std::unique_ptr<AttributeValue>;
28  Symbol name;
29  virtual AttributeKind kind() const = 0;
30  virtual Ptr clone() const = 0;
31  virtual ~AttributeValue() = default;
32 };
33 
34 template <typename T, AttributeKind Kind>
36  using ConstructorType = T;
37  using ValueType = T;
39  : AttributeValue(name), value_(std::move(value_)) {}
40  ValueType& value() {
41  return value_;
42  }
43  Ptr clone() const override {
44  return Ptr(new ScalarAttributeValue(name, value_));
45  }
46  AttributeKind kind() const override {
47  return Kind;
48  }
49 
50  private:
51  ValueType value_;
52 };
53 
54 template <typename T, AttributeKind Kind>
56  using ConstructorType = std::vector<T>;
57  using ValueType = std::vector<T>;
58  VectorAttributeValue(Symbol name, ConstructorType value_)
59  : AttributeValue(name), value_(std::move(value_)) {}
60  ValueType& value() {
61  return value_;
62  }
63  AttributeKind kind() const override {
64  return Kind;
65  }
66  std::unique_ptr<AttributeValue> clone() const override {
67  auto copy = value_;
68  return Ptr(new VectorAttributeValue(name, std::move(copy)));
69  }
70 
71  private:
72  ValueType value_;
73 };
74 
83 struct Graph;
84 
85 // We special case Graph attributes like this because we want to ensure that
86 // Graph::copy() is called when we clone() these attributes.
87 struct GraphAttr : public AttributeValue {
88  using ConstructorType = std::shared_ptr<Graph>;
89  using ValueType = std::shared_ptr<Graph>;
90  GraphAttr(Symbol name, ConstructorType value_)
91  : AttributeValue(name), value_(value_) {}
92  ValueType& value() {
93  return value_;
94  }
95  TORCH_API Ptr clone() const override;
96  AttributeKind kind() const override {
97  return AttributeKind::g;
98  }
99 
100  private:
101  std::shared_ptr<Graph> value_;
102 };
103 
104 struct GraphsAttr : public AttributeValue {
105  using ConstructorType = std::vector<std::shared_ptr<Graph>>;
106  using ValueType = std::vector<std::shared_ptr<Graph>>;
107  GraphsAttr(Symbol name, ConstructorType value_)
108  : AttributeValue(name), value_(std::move(value_)) {}
109  ValueType& value() {
110  return value_;
111  }
112  AttributeKind kind() const override {
113  return AttributeKind::gs;
114  }
115  TORCH_API std::unique_ptr<AttributeValue> clone() const override;
116 
117  private:
118  ValueType value_;
119 };
120 
121 struct AttributeError : public std::exception {
122  AttributeError(Symbol name, bool defined) {
123  std::stringstream ss;
124  if (!defined) {
125  ss << "required keyword attribute '" << name.toUnqualString()
126  << "' is undefined.";
127  } else {
128  ss << "required keyword attribute '" << name.toUnqualString()
129  << "' has the wrong type";
130  }
131  msg = ss.str();
132  }
133  const char* what() const noexcept override {
134  return msg.c_str();
135  }
136 
137  private:
138  std::string msg;
139 };
140 } // namespace jit
141 } // namespace torch
Definition: jit_type.h:17