Caffe2 - C++ API
A deep learning, cross platform ML framework
symbolic_variable.h
1 #pragma once
2 
3 #include <torch/csrc/jit/constants.h>
4 #include <torch/csrc/jit/ir.h>
5 
6 namespace torch {
7 namespace jit {
8 
10  SymbolicVariable() : v(nullptr) {}
11  /* implicit */ SymbolicVariable(Value* v) : v(v) {}
12  // we allow implicit conversions to/from Value since
13  // this type truly just provides more methods for value
14  operator Value*() const {
15  return v;
16  }
17  static SymbolicVariable asNewInput(Graph& g, std::string name = "") {
18  return g.addInput(std::move(name));
19  }
20  static SymbolicVariable asNewInput(Graph& g, TypePtr type) {
21  return g.addInput()->setType(std::move(type));
22  }
23  const std::vector<int64_t>& sizes() const {
24  return v->type()->expect<CompleteTensorType>()->sizes();
25  }
26  void addAsOutput() const {
27  v->owningGraph()->registerOutput(v);
28  }
29  static std::vector<SymbolicVariable> create(
30  Symbol kind,
32  int num_outputs = 1,
33  Node** created_node = nullptr,
34  Graph* g = nullptr) {
35  if (g == nullptr) {
36  g = inputs.at(0).value()->owningGraph();
37  }
38  Node* n = g->insertNode(g->create(kind, num_outputs));
39  size_t max_depth = 0;
40  ScopePtr s;
41  for (auto n : inputs) {
42  size_t d = n.value()->node()->scope()->getDepth();
43  if (d > max_depth) {
44  max_depth = d;
45  s = n.value()->node()->scope();
46  }
47  }
48  n->setScope(s);
49 
50  for (auto i : inputs) {
51  n->addInput(i.value());
52  }
53  if (created_node) {
54  *created_node = n;
55  }
56  std::vector<SymbolicVariable> out;
57  for (auto v : n->outputs()) {
58  out.emplace_back(v);
59  }
60  return out;
61  }
62  static bool isConstInt(at::Scalar s, int32_t i) {
63  // int32_t is safely convertible to both double and int64_t
64  if (s.isFloatingPoint()) {
65  return (double)i == s.toDouble();
66  } else {
67  return (int64_t)i == s.toLong();
68  }
69  }
70  SymbolicVariable operator*(const SymbolicVariable rhs) const {
71  return create(aten::mul, {*this, rhs})[0].typeLike(*this);
72  }
73  SymbolicVariable operator/(const SymbolicVariable rhs) const {
74  return create(aten::div, {*this, rhs})[0].typeLike(*this);
75  }
76  SymbolicVariable operator*(at::Scalar rhs) const {
77  if (isConstInt(rhs, 1))
78  return *this;
79  return (*this) * insertConstant(rhs);
80  }
81  SymbolicVariable operator>(at::Scalar rhs) const {
82  return create(aten::gt, {*this, insertConstant(rhs)})[0]
83  .typeLikeWithScalarType(*this, at::kByte);
84  }
85  SymbolicVariable operator>(const SymbolicVariable rhs) const {
86  return create(aten::gt, {*this, rhs})[0].typeLikeWithScalarType(
87  *this, at::kByte);
88  }
89  SymbolicVariable operator<(at::Scalar rhs) const {
90  return create(aten::lt, {*this, insertConstant(rhs)})[0]
91  .typeLikeWithScalarType(*this, at::kByte);
92  }
93  SymbolicVariable operator<(const SymbolicVariable rhs) const {
94  return create(aten::lt, {*this, rhs})[0].typeLikeWithScalarType(
95  *this, at::kByte);
96  }
97  SymbolicVariable operator>=(at::Scalar rhs) const {
98  return create(aten::ge, {*this, insertConstant(rhs)})[0]
99  .typeLikeWithScalarType(*this, at::kByte);
100  }
101  SymbolicVariable operator>=(const SymbolicVariable rhs) const {
102  return create(aten::ge, {*this, rhs})[0].typeLikeWithScalarType(
103  *this, at::kByte);
104  }
105  SymbolicVariable operator<=(at::Scalar rhs) const {
106  return create(aten::le, {*this, insertConstant(rhs)})[0]
107  .typeLikeWithScalarType(*this, at::kByte);
108  }
109  SymbolicVariable operator<=(const SymbolicVariable rhs) const {
110  return create(aten::le, {*this, rhs})[0].typeLikeWithScalarType(
111  *this, at::kByte);
112  }
113  SymbolicVariable operator==(at::Scalar rhs) const {
114  return create(aten::eq, {*this, insertConstant(rhs)})[0]
115  .typeLikeWithScalarType(*this, at::kByte);
116  }
117  SymbolicVariable operator!=(at::Scalar rhs) const {
118  return create(aten::ne, {*this, insertConstant(rhs)})[0]
119  .typeLikeWithScalarType(*this, at::kByte);
120  }
121  SymbolicVariable operator+(const SymbolicVariable rhs) const {
122  return create(aten::add, {*this, rhs, insertConstant(1)})[0].typeLike(
123  *this);
124  }
125  SymbolicVariable operator+(at::Scalar rhs) const {
126  return (*this) + insertConstant(rhs);
127  }
128  SymbolicVariable operator-() const {
129  return create(aten::neg, {*this})[0].typeLike(*this);
130  }
131  SymbolicVariable operator-(const SymbolicVariable rhs) const {
132  return create(aten::sub, {*this, rhs, insertConstant(1)})[0].typeLike(
133  *this);
134  }
135  SymbolicVariable operator/(at::Scalar rhs) const {
136  return create(aten::div, {*this, insertConstant(rhs)})[0].typeLike(*this);
137  }
138  SymbolicVariable operator%(at::Scalar rhs) const {
139  return create(aten::remainder, {*this, insertConstant(rhs)})[0].typeLike(
140  *this);
141  }
142  Value* size() const {
143  return v->owningGraph()->insert(aten::size, {v});
144  }
145  SymbolicVariable gradSumToSize(Value* size) const {
146  return create(aten::_grad_sum_to_size, {*this, size})[0];
147  }
148  SymbolicVariable expand(Value* size) const {
149  return v->owningGraph()->insert(aten::expand, {v, size});
150  }
151  SymbolicVariable isnan() const {
152  return create(aten::ne, {*this, *this})[0].typeLikeWithScalarType(
153  *this, at::kByte);
154  }
155  SymbolicVariable mm(const SymbolicVariable rhs) const {
156  return create(t("mm"), {*this, rhs})[0];
157  }
158  SymbolicVariable t() const {
159  return create(t("t"), {*this})[0];
160  }
161  SymbolicVariable sigmoid() const {
162  return create(aten::sigmoid, {*this})[0].typeLike(*this);
163  }
164  SymbolicVariable tanh() const {
165  return create(aten::tanh, {*this})[0].typeLike(*this);
166  }
167  std::vector<SymbolicVariable> chunk(int64_t chunks, int dim) const {
168  Node* chunk;
169  auto outputs = create(prim::ConstantChunk, {value()}, chunks, &chunk);
170  chunk->i_(attr::chunks, chunks)->i_(attr::dim, dim);
171  return outputs;
172  }
173  SymbolicVariable type_as(const SymbolicVariable rhs) const {
174  return create(aten::type_as, {*this, rhs})[0].typeLikeWithRhsScalarType(
175  *this, rhs);
176  }
177  SymbolicVariable narrow(int dim, int64_t start, int64_t length) const {
178  return create(
179  t("narrow"),
180  {*this,
181  insertConstant(dim),
182  insertConstant(start),
183  insertConstant(length)},
184  1)[0];
185  }
186  static SymbolicVariable cat(ArrayRef<SymbolicVariable> inputs, Value* dim) {
187  Graph* g = dim->owningGraph();
188  Value* input_list;
189  if (inputs.size() == 1 &&
190  inputs[0].value()->type()->isSubtypeOf(ListType::ofTensors())) {
191  input_list = inputs[0];
192  } else {
193  auto value_inputs =
194  fmap(inputs, [](const SymbolicVariable& v) { return v.value(); });
195  input_list =
196  g->insertNode(g->createList(TensorType::get(), value_inputs))
197  ->output();
198  }
199  return create(aten::cat, {input_list, dim})[0];
200  }
201  static SymbolicVariable cat(ArrayRef<SymbolicVariable> inputs, int dim) {
202  AT_ASSERT(inputs.size() > 0);
203  return SymbolicVariable::cat(inputs, inputs[0].insertConstant(dim));
204  }
205  static SymbolicVariable stack(ArrayRef<SymbolicVariable> inputs, Value* dim) {
206  Graph* g = dim->owningGraph();
207  auto value_inputs =
208  fmap(inputs, [](const SymbolicVariable& v) { return v.value(); });
209  Value* input_list =
210  g->insertNode(g->createList(TensorType::get(), value_inputs))
211  ->output();
212  return create(aten::stack, {input_list, dim})[0];
213  }
214  static SymbolicVariable stack(ArrayRef<SymbolicVariable> inputs, int dim) {
215  AT_ASSERT(inputs.size() > 0);
216  return SymbolicVariable::stack(inputs, inputs[0].insertConstant(dim));
217  }
218  static std::vector<SymbolicVariable> broadcast_tensors(
220  AT_ASSERT(inputs.size() > 0);
221  Graph* g = inputs[0].value()->owningGraph();
222  auto value_inputs =
223  fmap(inputs, [](const SymbolicVariable& v) { return v.value(); });
224  Value* input_list =
225  g->insertNode(g->createList(TensorType::get(), value_inputs))
226  ->output();
227  Value* output_list = g->insert(aten::broadcast_tensors, {input_list});
228  Node* unpack = g->insertNode(
229  g->create(prim::ListUnpack, {output_list}, inputs.size()));
230  return fmap<SymbolicVariable>(unpack->outputs());
231  }
232  static SymbolicVariable zeros_like(const SymbolicVariable input) {
233  return create(t("zeros_like"), {input})[0];
234  }
235  SymbolicVariable cos() const {
236  return create(t("cos"), {*this})[0];
237  }
238  SymbolicVariable cosh() const {
239  return create(t("cosh"), {*this})[0];
240  }
241  SymbolicVariable exp() const {
242  return create(t("exp"), {*this})[0];
243  }
244  SymbolicVariable pow(at::Scalar other) const {
245  return create(t("pow"), {*this, insertConstant(other)})[0];
246  }
247  SymbolicVariable rsqrt() const {
248  return create(t("rsqrt"), {*this})[0];
249  }
250  SymbolicVariable sign() const {
251  return create(t("sign"), {*this})[0];
252  }
253  SymbolicVariable sin() const {
254  return create(t("sin"), {*this})[0];
255  }
256  SymbolicVariable sinh() const {
257  return create(t("sinh"), {*this})[0];
258  }
259  SymbolicVariable sum() const {
260  return create(t("sum"), {*this})[0];
261  }
262  SymbolicVariable sum(int dim, bool keepdim) const {
263  return create(
264  t("sum"),
265  {*this, insertConstant(at::IntArrayRef{dim}), insertConstant(keepdim)})[0];
266  }
267  SymbolicVariable squeeze(Value* dim) const {
268  return create(t("squeeze"), {*this, dim})[0];
269  }
270  SymbolicVariable squeeze(int dim) const {
271  return squeeze(insertConstant(dim));
272  }
273  SymbolicVariable unsqueeze(Value* dim) const {
274  return create(t("unsqueeze"), {*this, dim})[0];
275  }
276  SymbolicVariable unsqueeze(int dim) const {
277  return unsqueeze(insertConstant(dim));
278  }
279  SymbolicVariable view(Value* sizes) const {
280  return create(aten::view, {*this, sizes})[0];
281  }
282  SymbolicVariable view(std::vector<std::int64_t> sizes) const {
283  return view(insertConstant(std::move(sizes)));
284  }
285  SymbolicVariable reshape(Value* sizes) const {
286  return create(aten::reshape, {*this, sizes})[0];
287  }
288  SymbolicVariable reshape(std::vector<std::int64_t> sizes) const {
289  return reshape(insertConstant(std::move(sizes)));
290  }
291  SymbolicVariable addmm(SymbolicVariable mat1, SymbolicVariable mat2) const {
292  return create(
293  aten::addmm,
294  {*this, mat1, mat2, insertConstant(1), insertConstant(1)})[0];
295  }
296  Value* value() const {
297  return v;
298  }
299 
300  private:
301  Value* insertConstant(IValue value) const {
302  return v->owningGraph()->insertConstant(std::move(value));
303  }
304  SymbolicVariable typeLike(SymbolicVariable other) const {
305  if (auto other_type = other.v->type()->cast<CompleteTensorType>())
306  v->setType(other_type->contiguous());
307  return *this;
308  }
309  SymbolicVariable typeLikeWithScalarType(
310  SymbolicVariable other,
311  at::ScalarType type) const {
312  if (auto other_type = other.v->type()->cast<CompleteTensorType>()) {
313  auto new_type = other_type->toScalarType(type)->contiguous();
314  v->setType(new_type);
315  }
316  return *this;
317  }
318  SymbolicVariable typeLikeWithRhsScalarType(
319  SymbolicVariable other,
320  SymbolicVariable rhs) const {
321  auto other_type = other.v->type()->cast<CompleteTensorType>();
322  auto rhs_type = rhs.v->type()->cast<CompleteTensorType>();
323  if (other_type && rhs_type) {
324  auto new_type =
325  other_type->toScalarType(rhs_type->scalarType())->contiguous();
326  v->setType(new_type);
327  }
328  return *this;
329  }
330  static Symbol a(const char* s_) {
331  return Symbol::attr(s_);
332  }
333  static Symbol t(const char* s_) {
334  return Symbol::aten(s_);
335  }
336  Value* v;
337 };
338 
339 // shorter method so that toVar(v) + toVar(c) is short.
340 static inline SymbolicVariable toVar(Value* v) {
341  return {v};
342 }
343 
344 template <
345  typename T,
346  typename = typename std::enable_if<std::is_arithmetic<T>::value>::type>
347 inline SymbolicVariable operator+(T lhs, SymbolicVariable rhs) {
348  return rhs + at::Scalar(lhs);
349 }
350 
351 inline SymbolicVariable operator+(at::Scalar lhs, SymbolicVariable rhs) {
352  return rhs + lhs;
353 }
354 
355 inline SymbolicVariable operator-(at::Scalar lhs, SymbolicVariable rhs) {
356  return (lhs + (-rhs));
357 }
358 
359 } // namespace jit
360 } // namespace torch
Scalar represents a 0-dimensional tensor which contains a single element.
Definition: Scalar.h:22
constexpr size_t size() const
size - Get the array size.
Definition: ArrayRef.h:138
Definition: jit_type.h:17
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: ArrayRef.h:41
AT_CPP14_CONSTEXPR const T & at(size_t Index) const
Vector compatibility.
Definition: ArrayRef.h:186