Caffe2 - C++ API
A deep learning, cross platform ML framework
function_schema.h
1 #pragma once
2 
3 #include <ATen/core/jit_type.h>
4 #include <ATen/core/interned_strings.h>
5 #include <ATen/core/ivalue.h>
6 #include <ATen/core/alias_info.h>
7 
8 namespace c10 {
9 
10 // schema as used in the compiler for resolving function calls and reporting
11 // errors. These objects should be constructed from C10 schema once those
12 // are available.
13 
14 struct Argument {
15  Argument(
16  std::string name = "",
17  TypePtr type = nullptr,
18  c10::optional<int32_t> N = c10::nullopt,
19  c10::optional<IValue> default_value = c10::nullopt,
20  bool kwarg_only = false,
21  c10::optional<AliasInfo> alias_info = c10::nullopt)
22  : name_(std::move(name)),
23  type_(type ? type : TensorType::get()),
24  N_(std::move(N)),
25  default_value_(std::move(default_value)),
26  kwarg_only_(kwarg_only),
27  alias_info_(std::move(alias_info)) {
28  if (default_value_ && default_value_->isTensor()) {
29  auto t = default_value_->toTensor();
30  AT_ASSERT(!t.defined() || t.is_variable());
31  }
32  }
33  const std::string& name() const {
34  return name_;
35  }
36  TypePtr type() const {
37  return type_;
38  }
39  c10::optional<int32_t> N() const {
40  return N_;
41  }
42  const c10::optional<IValue>& default_value() const {
43  return default_value_;
44  }
45  bool kwarg_only() const {
46  return kwarg_only_;
47  }
48  const c10::optional<AliasInfo>& alias_info() const {
49  return alias_info_;
50  }
51 
52 private:
53  std::string name_;
54  TypePtr type_;
55  // for list types, an optional statically known length for the list
56  // e.g. for int[3]: type = ListType::ofInts(), N = 3
57  // If present, this will allow scalars to be broadcast to this length to
58  // become a list.
60 
61  c10::optional<IValue> default_value_;
62  // is this only specifyable as a keyword argument?
63  bool kwarg_only_;
64  c10::optional<AliasInfo> alias_info_;
65 };
66 
69  std::string name,
70  std::string overload_name,
71  std::vector<Argument> arguments,
72  std::vector<Argument> returns,
73  bool is_vararg = false,
74  bool is_varret = false)
75  : name_(std::move(name)),
76  overload_name_(std::move(overload_name)),
77  arguments_(std::move(arguments)),
78  returns_(std::move(returns)),
79  is_vararg_(is_vararg),
80  is_varret_(is_varret) {}
81 
83  Symbol name,
84  std::string overload_name,
85  std::vector<Argument> arguments,
86  std::vector<Argument> returns,
87  bool is_vararg = false,
88  bool is_varret = false,
89  std::vector<std::string> writes = {})
91  name.toQualString(),
92  std::move(overload_name),
93  std::move(std::move(arguments)),
94  std::move(std::move(returns)),
95  is_vararg,
96  is_varret) {}
97 
98 private:
99  const std::string name_;
100  const std::string overload_name_;
101  const std::vector<Argument> arguments_;
102  const std::vector<Argument> returns_;
103  // if true then this schema takes an arbitrary number of additional arguments
104  // after the argument specified in arguments
105  // currently this is used primarily to represent 'primtive' operators whose
106  // arguments are not checked by schema
107  const bool is_vararg_;
108  const bool is_varret_;
109 
110 public:
111  const std::string& name() const {
112  return name_;
113  }
114  const std::string& overload_name() const {
115  return overload_name_;
116  }
117  const std::vector<Argument>& arguments() const {
118  return arguments_;
119  }
120  const std::vector<Argument>& returns() const {
121  return returns_;
122  }
123  bool is_vararg() const {
124  return is_vararg_;
125  }
126  bool is_varret() const {
127  return is_varret_;
128  }
129  bool is_mutable() const {
130  return std::any_of(
131  arguments_.cbegin(), arguments_.cend(), [](const Argument& arg) {
132  const auto& aliasInfo = arg.alias_info();
133  return aliasInfo && aliasInfo.value().isWrite();
134  });
135  }
136  c10::optional<int> argumentIndexWithName(const std::string& name) const {
137  for(size_t i = 0; i < arguments().size(); ++i) {
138  if(name == arguments()[i].name())
139  return i;
140  }
141  return c10::nullopt;
142  }
143 };
144 
145 // for debugging, make sure we can describe the call site
146 inline std::ostream& operator<<(std::ostream& out, const Argument& arg) {
147  return out << arg.type()->str() << " " << arg.name() << (arg.default_value() ? "=<default>" : "");
148 }
149 
150 inline std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) {
151  // eventually this should look almost identical to python arg parser, but
152  // it is simpler for now to work directly on this schema
153 
154  out << schema.name();
155  out << "(";
156 
157  bool seen_kwarg_only = false;
158  for(size_t i = 0; i < schema.arguments().size(); ++i) {
159  if (i > 0) out << ", ";
160  if (schema.arguments()[i].kwarg_only() && !seen_kwarg_only) {
161  out << "*, ";
162  seen_kwarg_only = true;
163  }
164  out << schema.arguments()[i];
165  }
166 
167  if(schema.is_vararg()) {
168  if(schema.arguments().size() > 0)
169  out << ", ";
170  out << "...";
171  }
172 
173  out << ") -> ";
174  if (schema.returns().size() == 1) {
175  out << schema.returns().at(0).type()->str();
176  } else if (schema.returns().size() > 1) {
177  out << "(";
178  for (size_t i = 0; i < schema.returns().size(); ++i) {
179  if (i > 0) out << ", ";
180  out << schema.returns()[i].type()->str();
181  }
182  out << ")";
183  }
184  return out;
185 }
186 
187 } // namespace c10
bool is_variable() const noexcept
Returns true if the Tensor is actually a torch::autograd::Variable.
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...
Definition: alias_info.h:7