Caffe2 - C++ API
A deep learning, cross platform ML framework
register_prim_ops.cpp
1 #include <aten/src/ATen/Context.h>
2 #include <torch/csrc/autograd/edge.h>
3 #include <torch/csrc/autograd/function.h>
4 #include <torch/csrc/autograd/generated/variable_factories.h>
5 #include <torch/csrc/autograd/profiler.h>
6 #include <torch/csrc/autograd/variable.h>
7 #include <torch/csrc/jit/custom_operator.h>
8 #include <torch/csrc/jit/fuser/interface.h>
9 #include <torch/csrc/jit/graph_executor.h>
10 #include <torch/csrc/jit/ir.h>
11 #include <torch/csrc/jit/operator.h>
12 #include <torch/csrc/jit/script/jit_exception.h>
13 
14 #include <ATen/ExpandUtils.h>
15 #include <ATen/WrapDimUtils.h>
16 #include <ATen/core/ivalue.h>
17 #include <c10/core/thread_pool.h>
18 #include <c10/util/SmallVector.h>
19 
20 #include <algorithm>
21 #include <exception>
22 #include <iostream>
23 #include <limits>
24 #include <memory>
25 #include <mutex>
26 #include <ostream>
27 #include <stdexcept>
28 #include <string>
29 #include <typeinfo>
30 #include <unordered_map>
31 #include <unordered_set>
32 #include <utility>
33 #include <vector>
34 
35 namespace torch {
36 namespace jit {
37 
38 namespace {
39 
40 Operation noop(const Node* n) {
41  return [](Stack& stack) { return 0; };
42 }
43 
44 // using the rules from python_arg_parser FunctionParameter::check
45 // tensor cannot have grad set, tensor must be 0 dim,
46 // and if the dest is an int the source must be integral type
47 void checkImplicitTensorToNum(at::Tensor t, bool toInt) {
48  if (autograd::as_variable_ref(t).requires_grad()) {
49  throw std::runtime_error(
50  "Cannot input a tensor that requires grad as a scalar argument");
51  }
52  if (t.sizes().size() != 0) {
53  throw std::runtime_error(
54  "Cannot input a tensor of dimension other than 0 as a scalar argument");
55  }
56  if (toInt &&
57  !isIntegralType(autograd::as_variable_ref(t).data().scalar_type())) {
58  std::stringstream ss;
59  ss << "Cannot input a tensor of type " << t.scalar_type()
60  << " as an integral argument";
61  throw std::runtime_error(ss.str());
62  }
63 }
64 
65 template <typename dtype> // int64_t, bool, double
66 Operation listConstruct(int64_t num_inputs) {
67  return [=](Stack& stack) {
68  auto inputs = peekSlice(stack, 0, num_inputs, num_inputs);
69  std::vector<dtype> vals =
70  fmap(inputs, [](const IValue& v) { return v.to<dtype>(); });
71  drop(stack, num_inputs);
72  push(stack, std::move(vals));
73  return 0;
74  };
75 }
76 
77 static int64_t floordiv(int64_t a, int64_t b) {
78  if (b == 0) {
79  throw std::runtime_error("division by 0");
80  }
81  if ((a > 0) == (b > 0)) {
82  // simple case, both have same sign
83  return a / b;
84  } else {
85  // in python division rounds down, it doesnt not truncate like in c++
86  auto r = lldiv(a, b);
87  return (r.rem) ? r.quot - 1 : r.quot;
88  }
89 }
90 
91 // reference function THPVariable_to in python_variable_methods.cpp
92 static at::Tensor to_dispatch(
93  at::Tensor self,
96  bool non_blocking,
97  bool copy) {
98  if (device && device->is_cuda()) {
99  at::globalContext().lazyInitCUDA();
100  }
101  if (!device && !scalarType && !copy) {
102  return self;
103  } else if (!device) {
104  return self.to(*scalarType, non_blocking, copy);
105  } else if (!scalarType) {
106  return self.to(*device, non_blocking, copy);
107  } else {
108  return self.to(*device, *scalarType, non_blocking, copy);
109  }
110 }
111 
112 RegisterOperators reg(
113  {Operator(
114  prim::FusionGroup,
115  [](const Node* node) {
116  const auto key = registerFusion(node);
117  return [key](Stack& stack) {
118  autograd::profiler::RecordFunction record("FusionGroup");
119  runFusion(key, stack);
120  return 0;
121  };
122  }),
123  Operator(
124  "prim::rangelist(int n) -> int[]",
125  [](Stack& stack) {
126  int64_t n;
127  pop(stack, n);
128  std::vector<int64_t> elems(n);
129  for (int i = 0; i < n; i++) {
130  elems[i] = i;
131  }
132  push(stack, jit::IntList::create(elems));
133  return 0;
134  }),
135  Operator(
136  "prim::Bool(Tensor a) -> bool",
137  [](Stack& stack) {
138  at::Tensor a;
139  pop(stack, a);
140  push(stack, a.item<int64_t>() != 0);
141  return 0;
142  }),
143  Operator(
144  "prim::Bool(int a) -> bool",
145  [](Stack& stack) {
146  int64_t i;
147  pop(stack, i);
148  push(stack, (bool)i);
149  return 0;
150  }),
151  Operator(
152  "prim::Bool(float a) -> bool",
153  [](Stack& stack) {
154  double d;
155  pop(stack, d);
156  push(stack, (bool)d);
157  return 0;
158  }),
159  Operator(
160  "prim::Int(Tensor a) -> int",
161  [](Stack& stack) {
162  at::Tensor a;
163  pop(stack, a);
164  push(stack, a.item<int64_t>());
165  return 0;
166  }),
167  Operator(
168  "prim::Float(Tensor a) -> float",
169  [](Stack& stack) {
170  at::Tensor a;
171  pop(stack, a);
172  push(stack, a.item<double>());
173  return 0;
174  }),
175  Operator(
176  "prim::ImplicitTensorToNum(Tensor a) -> Scalar",
177  [](const Node* node) -> Operation {
178  if (node->output()->type() == IntType::get()) {
179  return [](Stack& stack) {
180  at::Tensor a;
181  pop(stack, a);
182  checkImplicitTensorToNum(a, /*to int*/ true);
183  push(stack, a.item<int64_t>());
184  return 0;
185  };
186  } else {
187  return [](Stack& stack) {
188  at::Tensor a;
189  pop(stack, a);
190  checkImplicitTensorToNum(a, /*to int*/ false);
191  push(stack, a.item<double>());
192  return 0;
193  };
194  }
195  }),
196  Operator(
197  "prim::NumToTensor(Scalar a) -> Tensor",
198  [](Stack& stack) {
199  at::Scalar s;
200  pop(stack, s);
201  push(stack, autograd::make_variable(at::scalar_to_tensor(s)));
202  return 0;
203  }),
204  // note: this op needs to share a name with the Scalar -> Tensor conversion
205  // because all _to_tensor conversion have to have the same operator namet
206  Operator(
207  "prim::NumToTensor(bool a) -> Tensor",
208  [](Stack& stack) {
209  bool b;
210  pop(stack, b);
211  push(stack, autograd::make_variable(at::scalar_to_tensor(b)));
212  return 0;
213  }),
214  Operator(
215  "prim::Float(Scalar a) -> float",
216  [](Stack& stack) {
217  IValue scalar;
218  pop(stack, scalar);
219  if (scalar.isDouble()) {
220  push(stack, scalar);
221  } else {
222  push(stack, static_cast<double>(scalar.toInt()));
223  }
224  return 0;
225  }),
226  Operator(
227  "prim::Float(int a) -> float",
228  [](Stack& stack) {
229  int64_t i;
230  pop(stack, i);
231  push(stack, (float)i);
232  return 0;
233  }),
234  Operator(
235  "prim::Int(float a) -> int",
236  [](Stack& stack) {
237  double d;
238  pop(stack, d);
239  push(stack, (int64_t)d);
240  return 0;
241  }),
242  Operator(
243  "prim::Float(bool a) -> float",
244  [](Stack& stack) {
245  bool b;
246  pop(stack, b);
247  push(stack, (float)b);
248  return 0;
249  }),
250  Operator(
251  "prim::Int(bool a) -> int",
252  [](Stack& stack) {
253  bool b;
254  pop(stack, b);
255  push(stack, (int)b);
256  return 0;
257  }),
258  Operator(
259  "prim::Int(Scalar a) -> float",
260  [](Stack& stack) {
261  IValue scalar;
262  pop(stack, scalar);
263  if (scalar.isInt()) {
264  push(stack, scalar);
265  } else {
266  push(stack, static_cast<int64_t>(scalar.toDouble()));
267  }
268  return 0;
269  }),
270  Operator(
271  "prim::Float(str a) -> float",
272  [](Stack& stack) {
273  auto s = pop(stack).toString();
274  if (s->string() == "inf")
275  push(stack, std::numeric_limits<double>::infinity());
276  else if (s->string() == "-inf")
277  push(stack, -std::numeric_limits<double>::infinity());
278  else
279  AT_ERROR(
280  "Only 'inf' or '-inf' can be cast to a float, but got '",
281  s->string(),
282  "'");
283  return 0;
284  }),
285  Operator(
286  "aten::device(str a) -> Device",
287  [](Stack& stack) {
288  push(stack, c10::Device(pop(stack).toStringRef()));
289  return 0;
290  }),
291  // reference function parse_to_conversion in python_arg_parsing.h
292  Operator(
293  "aten::to(Tensor(a) self, Device? device, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(a|b)",
294  [](Stack& stack) {
295  bool non_blocking;
296  bool copy;
297  pop(stack, non_blocking, copy);
298  c10::optional<at::ScalarType> scalarType =
299  pop(stack).toOptional<at::ScalarType>();
301  pop(stack).toOptional<c10::Device>();
302  at::Tensor self = pop(stack).toTensor();
303  push(
304  stack,
305  to_dispatch(self, device, scalarType, non_blocking, copy));
306  return 0;
307  }),
308  Operator(
309  "aten::to(Tensor(a) self, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(a|b)",
310  [](Stack& stack) {
311  bool non_blocking;
312  bool copy;
313  pop(stack, non_blocking, copy);
314  c10::optional<at::ScalarType> scalarType =
315  pop(stack).toOptional<at::ScalarType>();
316  c10::optional<c10::Device> device = c10::nullopt;
317  at::Tensor self = pop(stack).toTensor();
318  push(
319  stack,
320  to_dispatch(self, device, scalarType, non_blocking, copy));
321  return 0;
322  }),
323  Operator(
324  "aten::to(Tensor(a) self, bool non_blocking=False, bool copy=False) -> Tensor(a|b)",
325  [](Stack& stack) {
326  at::Tensor self;
327  bool non_blocking;
328  bool copy;
329  pop(stack, self, non_blocking, copy);
330  c10::optional<c10::Device> device = c10::nullopt;
331  c10::optional<at::ScalarType> scalarType = c10::nullopt;
332  push(
333  stack,
334  to_dispatch(self, device, scalarType, non_blocking, copy));
335  return 0;
336  }),
337  Operator(
338  "aten::eq(Device a, Device b) -> bool",
339  [](Stack& stack) {
340  auto a = pop(stack).toDevice();
341  auto b = pop(stack).toDevice();
342  push(stack, a == b);
343  return 0;
344  }),
345  Operator(
346  "prim::device(Tensor a) -> Device",
347  [](Stack& stack) {
348  push(stack, pop(stack).toTensor().device());
349  return 0;
350  }),
351  Operator(
352  "prim::dtype(Tensor a) -> int",
353  [](Stack& stack) {
354  at::Tensor a;
355  pop(stack, a);
356  push(stack, static_cast<int64_t>(a.scalar_type()));
357  return 0;
358  }),
359  Operator(
360  "prim::requires_grad(Tensor a) -> bool",
361  [](Stack& stack) {
362  at::Tensor a;
363  pop(stack, a);
364  push(stack, a.requires_grad());
365  return 0;
366  }),
367  Operator(
368  "prim::shape(Tensor a) -> int[]",
369  [](Stack& stack) {
370  at::Tensor a;
371  pop(stack, a);
372  push(stack, a.sizes());
373  return 0;
374  }),
375  Operator(
376  "prim::is_cuda(Tensor a) -> bool",
377  [](Stack& stack) {
378  at::Tensor a;
379  pop(stack, a);
380  push(stack, a.is_cuda());
381  return 0;
382  }),
383  Operator(
384  "aten::cpu(Tensor(a) self) -> Tensor(a|b)",
385  [](Stack& stack) {
386  at::Tensor a;
387  pop(stack, a);
388  push(stack, a.cpu());
389  return 0;
390  }),
391  Operator(
392  "aten::cuda(Tensor(a) self) -> Tensor(a|b)",
393  [](Stack& stack) {
394  at::Tensor a;
395  pop(stack, a);
396  push(stack, a.cuda());
397  return 0;
398  }),
399  Operator(
400  "prim::AutogradZero() -> Tensor",
401  [](const Node* node) {
402  return [](Stack& stack) {
403  stack.emplace_back(at::Tensor());
404  return 0;
405  };
406  }),
407  Operator(
408  prim::Print,
409  [](const Node* node) {
410  size_t num_inputs = node->inputs().size();
411  return [num_inputs](Stack& stack) {
412  bool first = true;
413  for (const IValue& i : last(stack, num_inputs)) {
414  if (!first)
415  std::cout << " ";
416  first = false;
417  std::cout << i;
418  }
419  drop(stack, num_inputs);
420  std::cout << std::endl;
421  return 0;
422  };
423  }),
424  Operator(
425  prim::BroadcastSizes,
426  [](const Node* node) -> Operation {
427  size_t num_inputs = node->inputs().size();
428  return [num_inputs](Stack& stack) {
429  std::vector<int64_t> size;
430  size.reserve(8);
431  for (size_t i = 0; i < num_inputs; ++i) {
432  size = at::infer_size(
433  size, peek(stack, i, num_inputs).toIntList()->elements());
434  }
435  drop(stack, num_inputs);
436  push(stack, std::move(size));
437  return 0;
438  };
439  }),
440  Operator(
441  prim::ChunkSizes,
442  [](const Node* node) -> Operation {
443  int64_t raw_dim = node->i(attr::dim);
444  int64_t chunks = node->i(attr::chunks);
445  return [raw_dim, chunks](Stack& stack) {
446  Shared<IntList> sizes_l;
447  pop(stack, sizes_l);
448  const auto& shape = sizes_l->elements();
449  std::vector<int64_t> regular_shape = shape;
450  std::vector<int64_t> last_shape = shape;
451  int64_t dim = at::maybe_wrap_dim(raw_dim, shape.size());
452  AT_CHECK(
453  dim < (int64_t)regular_shape.size(),
454  "Dimension out of range for chunk");
455  int64_t split_size = (regular_shape[dim] + chunks - 1) / chunks;
456  regular_shape[dim] = split_size;
457  if (shape[dim] % chunks == 0) {
458  last_shape[dim] = split_size;
459  } else {
460  int64_t num_splits = std::max<int64_t>(
461  (shape[dim] + split_size - 1) / split_size, 1);
462  last_shape[dim] =
463  split_size - (split_size * num_splits - shape[dim]);
464  AT_ASSERT(last_shape[dim] >= 0);
465  }
466  push(stack, std::move(regular_shape));
467  push(stack, std::move(last_shape));
468  return 0;
469  };
470  }),
471  Operator(
472  FunctionSchema(
473  "aten::warn",
474  "",
475  {Argument("message", StringType::get()),
476  Argument("stacklevel", IntType::get(), c10::nullopt, 2, true)},
477  {}),
478  [](const Node* node) {
479  return [](Stack& stack) {
480  drop(stack, 1);
481  AT_WARN(pop(stack).toStringRef());
482  return 0;
483  };
484  }),
485  Operator(
486  "prim::RaiseException(str msg) -> ()",
487  [](Stack& stack) {
488  throw JITException(pop(stack).toStringRef());
489  return 0;
490  }),
491 
492  Operator(
493  "prim::IgnoredPythonOp(...) -> ()",
494  [](Stack& stack) {
495  throw JITException(
496  "This Python function is annotated to be ignored"
497  " and cannot be and has not been included in the exported"
498  " binary, meaning that it cannot be executed now."
499  " Make sure that ignored operations are never executed after"
500  " import");
501  return 0;
502  }),
503 
504  // Load x, y
505  // loads values from registers onto the stack, the actual callback does
506  // nothing since the stack manipulation is already encoded in inst.inputs
507  // and inst.outputs
508  Operator(prim::Load, noop),
509  // x, y = Store
510  // stores vales from stack into registers, the actual callback does
511  // nothing since the stack manipulation is already encoded in inst.inputs
512  // and inst.outputs
513  Operator(prim::Store, noop),
514  Operator(
515  prim::Drop,
516  [](const Node* node) {
517  auto N = node->inputs().size();
518  return [=](Stack& stack) {
519  drop(stack, N);
520  return 0;
521  };
522  }),
523  Operator(
524  c10::onnx::Reshape,
525  [](const Node* node) {
526  return [=](Stack& stack) {
527  at::Tensor input, shape;
528  pop(stack, input, shape);
529  shape = shape.contiguous();
530  AT_ASSERT(shape.ndimension() == 1);
531  at::IntArrayRef shape_list(shape.data<int64_t>(), shape.size(0));
532  push(stack, input.reshape(shape_list));
533  return 0;
534  };
535  }),
536  Operator(
537  c10::onnx::Shape,
538  [](const Node* node) {
539  return [=](Stack& stack) {
540  auto t = pop(stack).toTensor();
541  at::IntArrayRef sizes = t.sizes();
542  auto sizes_tensor = torch::empty(
543  {static_cast<int64_t>(sizes.size())}, at::dtype(at::kLong));
544  auto accessor = sizes_tensor.accessor<int64_t, 1>();
545  for (size_t i = 0; i < sizes.size(); ++i) {
546  accessor[i] = sizes[i];
547  }
548  stack.emplace_back(sizes_tensor);
549  return 0;
550  };
551  }),
552  Operator(
553  prim::AutogradAnyNonZero,
554  [](const Node* node) {
555  size_t num_inputs = node->inputs().size();
556  return [=](Stack& stack) {
557  bool result = false;
558  for (const IValue& t : last(stack, num_inputs)) {
559  if (t.toTensor().defined()) {
560  result = true;
561  break;
562  }
563  }
564  drop(stack, num_inputs);
565  stack.emplace_back(result);
566  return 0;
567  };
568  }),
569  Operator(
570  prim::AutogradAdd,
571  [](const Node* node) {
572  return [=](Stack& stack) {
573  at::Tensor a, b;
574  pop(stack, a, b);
575  if (!a.defined())
576  stack.emplace_back(b);
577  else if (!b.defined())
578  stack.emplace_back(a);
579  else
580  stack.emplace_back(a + b);
581  return 0;
582  };
583  }),
584  Operator(
585  "aten::_grad_sum_to_size(Tensor(a) self, int[] size) -> Tensor(a)",
586  [](Stack& stack) {
587  at::Tensor self;
588  Shared<IntList> desired_sizes;
589  pop(stack, self, desired_sizes);
590  push(stack, at::sum_to(std::move(self), desired_sizes->elements()));
591  return 0;
592  }),
593  Operator(
594  prim::TupleUnpack,
595  [](const Node* node) {
596  size_t num_elems = node->outputs().size();
597  return [=](Stack& stack) {
598  auto t = pop(stack).toTuple();
599  const auto& elems = t->elements();
600  if (elems.size() != num_elems) {
601  AT_ERROR(
602  "Expected a tuple of ",
603  num_elems,
604  " elements, but got ",
605  elems.size());
606  }
607  stack.insert(stack.end(), elems.begin(), elems.end());
608  return 0;
609  };
610  }),
611  Operator(
612  prim::TupleSlice,
613  [](const Node* node) {
614  int64_t beg_ind = node->i(attr::beg);
615  int64_t end_ind = node->i(attr::end);
616  return [=](Stack& stack) {
617  auto t = pop(stack).toTuple();
618  const auto& elems = t->elements();
619  std::vector<IValue> output_elems;
620  for (int64_t i = beg_ind; i < end_ind; ++i) {
621  output_elems.emplace_back(elems.at(i));
622  }
623  push(stack, Tuple::create(std::move(output_elems)));
624  return 0;
625  };
626  }),
627  Operator(
628  prim::TupleIndex,
629  [](const Node* node) {
630  auto index = node->i(attr::index);
631  return [=](Stack& stack) {
632  auto tup = pop(stack).toTuple();
633  const auto& elems = tup->elements();
634  // index is normalized to be positive at compile time
635  stack.emplace_back(elems.at(index));
636  return 0;
637  };
638  }),
639  Operator(
640  prim::TupleConstruct,
641  [](const Node* node) {
642  size_t num_inputs = node->inputs().size();
643  return [=](Stack& stack) {
644  std::vector<IValue> elems{
645  std::make_move_iterator(stack.end() - num_inputs),
646  std::make_move_iterator(stack.end())};
647  drop(stack, num_inputs);
648  push(stack, Tuple::create(std::move(elems)));
649  return 0;
650  };
651  }),
652  Operator(
653  prim::ConstantChunk,
654  [](const Node* node) {
655  int64_t chunks = node->i(attr::chunks);
656  int64_t dim = node->i(attr::dim);
657  auto outputs_used = fmap(node->outputs(), [](const Value* v) {
658  return v->uses().size() > 0;
659  });
660  return [=](Stack& stack) {
661  autograd::profiler::RecordFunction record("chunk");
662  at::Tensor t;
663  pop(stack, t);
664  auto result = at::chunk(t, chunks, dim);
665  stack.insert(
666  stack.end(),
667  std::make_move_iterator(result.begin()),
668  std::make_move_iterator(result.end()));
669  // NB: Chunk can sometimes return a smaller number of outputs.
670  int64_t num_results = result.size();
671  if (num_results != chunks) {
672  if (num_results > chunks) {
673  AT_CHECK(
674  num_results == chunks,
675  "Expected chunk to return ",
676  chunks,
677  " outputs, but got ",
678  num_results);
679  }
680  for (int64_t i = num_results; i < chunks; ++i) {
681  AT_CHECK(
682  !outputs_used[i],
683  "Expected chunk to return at least ",
684  chunks,
685  " outputs, but got only ",
686  num_results);
687  // We know that the output is unused, so it's ok to push
688  // anything on the stack.
689  stack.emplace_back();
690  }
691  }
692  return 0;
693  };
694  }),
695  Operator(
696  prim::ListUnpack,
697  [](const Node* node) -> Operation {
698  const auto num_outputs = node->outputs().size();
699  ListTypePtr lt = node->input()->type()->expect<ListType>();
700  if (lt->getElementType() == IntType::get()) {
701  return [=](Stack& stack) {
702  auto ilist = pop(stack);
703  const auto& list = ilist.toIntList()->elements();
704  AT_CHECK(
705  list.size() == num_outputs,
706  "Expected ",
707  num_outputs,
708  " elements in a list but found ",
709  list.size());
710  stack.insert(stack.end(), list.begin(), list.end());
711  return 0;
712  };
713  } else if (lt->getElementType() == FloatType::get()) {
714  return [=](Stack& stack) {
715  auto ilist = pop(stack);
716  const auto& list = ilist.toDoubleList()->elements();
717  AT_CHECK(
718  list.size() == num_outputs,
719  "Expected ",
720  num_outputs,
721  " elements in a list but found ",
722  list.size());
723  stack.insert(stack.end(), list.begin(), list.end());
724  return 0;
725  };
726  } else if (lt->getElementType() == TensorType::get()) {
727  return [=](Stack& stack) {
728  auto ilist = pop(stack);
729  const auto& list = ilist.toTensorList()->elements();
730  AT_CHECK(
731  list.size() == num_outputs,
732  "Expected ",
733  num_outputs,
734  " elements in a list but found ",
735  list.size());
736  stack.insert(stack.end(), list.begin(), list.end());
737  return 0;
738  };
739  } else {
740  return [=](Stack& stack) {
741  auto glist = pop(stack);
742  const auto& list = glist.toGenericList()->elements();
743  AT_CHECK(
744  list.size() == num_outputs,
745  "Expected ",
746  num_outputs,
747  " elements in a list but found ",
748  list.size());
749  stack.insert(stack.end(), list.begin(), list.end());
750  return 0;
751  };
752  }
753  }),
754  Operator(
755  prim::ListConstruct,
756  [](const Node* node) -> Operation {
757  const auto num_inputs = node->inputs().size();
758  ListTypePtr lt = node->output()->type()->expect<ListType>();
759  if (IntType::get() == lt->getElementType()) {
760  return listConstruct<int64_t>(num_inputs);
761  } else if (FloatType::get() == lt->getElementType()) {
762  return listConstruct<double>(num_inputs);
763  } else if (lt->getElementType() == BoolType::get()) {
764  return listConstruct<bool>(num_inputs);
765  } else if (lt->getElementType()->isSubtypeOf(TensorType::get())) {
766  return [=](Stack& stack) {
767  const size_t stack_size = stack.size();
768  std::vector<at::Tensor> vals;
769  vals.reserve(num_inputs);
770  for (size_t i = stack_size - num_inputs; i < stack_size; ++i) {
771  vals.emplace_back(std::move(stack[i]).toTensor());
772  }
773  drop(stack, num_inputs);
774  push(stack, std::move(vals));
775  return 0;
776  };
777  } else {
778  return [=](Stack& stack) {
779  const size_t stack_size = stack.size();
780  std::vector<IValue> vals;
781  vals.reserve(num_inputs);
782  for (size_t i = stack_size - num_inputs; i < stack_size; ++i) {
783  vals.emplace_back(std::move(stack[i]));
784  }
785  drop(stack, num_inputs);
786  push(stack, std::move(vals));
787  return 0;
788  };
789  }
790  }),
791  Operator(
792  prim::DictConstruct,
793  [](const Node* node) -> Operation {
794  const auto num_inputs = node->inputs().size();
795  if (num_inputs % 2 != 0) {
796  throw std::runtime_error(
797  "DictConstruct must have an even number of inputs");
798  }
799  return [=](Stack& stack) {
800  c10::ivalue::UnorderedMap vals;
801  for (size_t i = 0; i < num_inputs; i += 2) {
802  auto val = pop(stack);
803  auto key = pop(stack);
804  vals[key] = val;
805  }
806  push(stack, std::move(vals));
807  return 0;
808  };
809  }),
810  Operator(
811  "aten::_unwrap_optional(t(a)? optional) -> t(a)",
812  [](Stack& stack) {
813  auto val = pop(stack);
814  AT_CHECK(!val.isNone(), "Unwrapping null optional");
815  push(stack, val);
816  return 0;
817  }),
818  // This op can be removed in preprocessing before being run in the
819  // interpreter (but is currently not removed), even when it is removed it
820  // needs to remain a registered op so that constant prop can run.
821  Operator("prim::unchecked_unwrap_optional(t(a)? optional) -> t(a)", noop),
822  Operator(
823  prim::fork,
824  [](const Node* node) {
825  Code code(node->g(attr::Subgraph));
826  int n_inputs = node->inputs().size();
827  AT_ASSERT(node->blocks().size() == 0);
828  AT_ASSERT(node->hasAttribute(attr::Subgraph));
829  return [=](Stack& stack) {
830  // Move inputs to a separate stack
831  InterpreterState forked_interprester(code);
832  InterpreterContinuation continuation(
833  forked_interprester,
834  Stack(stack.end() - n_inputs, stack.end()),
835  autograd::GradMode::is_enabled());
836  drop(stack, n_inputs);
837 
838  push(stack, forked_interprester.getFuture());
839 
840  c10::global_work_queue().run(std::move(continuation));
841  return 0;
842  };
843  }),
844  Operator(
845  "aten::wait(Future(t) self) -> t",
846  [](Stack& stack) {
847  auto future = pop(stack).toFuture();
848  if (future->completed()) {
849  push(stack, future->value());
850  } else {
851  throw Suspend(future);
852  }
853  return 0;
854  }),
855  Operator(
856  prim::CreateObject,
857  [](const Node* node) {
858  const auto type = node->output()->type()->expect<ClassType>();
859  const auto name = Symbol::user(type->name());
860  const size_t numAttrs = type->numAttributes();
861  return [name, numAttrs](Stack& stack) {
862  auto userObj = c10::ivalue::Object::create(name, numAttrs);
863  push(stack, std::move(userObj));
864  return 0;
865  };
866  }),
867  Operator(
868  prim::GetAttr,
869  [](const Node* node) {
870  const auto type = node->input()->type()->expect<ClassType>();
871  const auto& field = node->s(attr::name);
872  const auto slot = type->getAttributeSlot(field);
873  return [slot](Stack& stack) {
874  auto userObj = pop(stack).toObject();
875  auto value = userObj->getSlot(slot);
876  push(stack, std::move(value));
877  return 0;
878  };
879  }),
880  Operator(prim::SetAttr, [](const Node* node) {
881  const auto type = node->inputs().at(0)->type()->expect<ClassType>();
882  const auto& field = node->s(attr::name);
883  const auto slot = type->getAttributeSlot(field);
884  return [slot](Stack& stack) {
885  auto v = pop(stack);
886  auto userObj = pop(stack).toObject();
887  userObj->setSlot(slot, std::move(v));
888  return 0;
889  };
890  })});
891 
892 // define implementations for primitive number ops
893 #define DEFINE_GENERIC_OP(aten_op, int_op, float_op, int_result, float_result) \
894  Operator( \
895  #aten_op "(int a, int b) -> " #int_result, \
896  [](Stack& stack) { \
897  int64_t a, b; \
898  pop(stack, a, b); \
899  push(stack, int_op); \
900  return 0; \
901  }), \
902  Operator( \
903  #aten_op "(float a, float b) -> " #float_result, [](Stack& stack) { \
904  double a, b; \
905  pop(stack, a, b); \
906  push(stack, float_op); \
907  return 0; \
908  })
909 
910 #define DEFINE_INT_FLOAT_OP(aten_op, op, result) \
911  Operator( \
912  #aten_op "(int a, float b) -> " #result, \
913  [](Stack& stack) { \
914  int64_t a; \
915  double b; \
916  pop(stack, a, b); \
917  push(stack, op); \
918  return 0; \
919  }), \
920  Operator(#aten_op "(float a, int b) -> " #result, [](Stack& stack) { \
921  double a; \
922  int64_t b; \
923  pop(stack, a, b); \
924  push(stack, op); \
925  return 0; \
926  })
927 
928 #define DEFINE_INT_OP(aten_op, op) \
929  Operator(#aten_op "(int a, int b) -> int", [](Stack& stack) { \
930  int64_t a, b; \
931  pop(stack, a, b); \
932  push(stack, op); /* NOLINT(hicpp-signed-bitwise) */ \
933  return 0; \
934  })
935 
936 #define DEFINE_BINARY_OP(aten_op, op) \
937  DEFINE_GENERIC_OP(aten_op, op, op, int, float), \
938  DEFINE_INT_FLOAT_OP(aten_op, op, float)
939 #define DEFINE_COMPARISON_OP(aten_op, op) \
940  DEFINE_GENERIC_OP(aten_op, op, op, bool, bool), \
941  DEFINE_INT_FLOAT_OP(aten_op, op, bool)
942 #define DEFINE_BOOL_OP(aten_op, op) \
943  Operator(#aten_op "(bool a, bool b) -> bool", [](Stack& stack) { \
944  bool a, b; \
945  pop(stack, a, b); \
946  push(stack, op); \
947  return 0; \
948  })
949 
950 // Convert an python index (which may be negative) into an index usable for a
951 // C++ container
952 int64_t normalizeIndex(int64_t idx, int64_t list_size) {
953  if (idx < 0) {
954  // Handle negative indexing
955  idx = list_size + idx;
956  }
957  return idx;
958 }
959 
960 // Equivalent to list.at(idx)
961 template <typename TList> // something like Shared<IntList>
962 typename TList::element_type::ElemType& getItem(TList& list, int64_t idx) {
963  const int64_t list_size = list->elements().size();
964  const int64_t normalized_idx = normalizeIndex(idx, list_size);
965  if (normalized_idx < 0 || normalized_idx >= list_size) {
966  throw std::out_of_range("list index out of range");
967  }
968  return list->elements()[normalized_idx];
969 }
970 
971 // cannot return a reference to an element in a bool vector
972 bool getBoolItem(const std::vector<bool>& list, int64_t idx) {
973  const int64_t list_size = list.size();
974  const int64_t normalized_idx = normalizeIndex(idx, list_size);
975  if (normalized_idx < 0 || normalized_idx >= list_size) {
976  throw std::out_of_range("list index out of range");
977  }
978  return list[normalized_idx];
979 }
980 
981 template <typename TList, typename TElement>
982 int listAppend(Stack& stack) {
983  TList a;
984  TElement el;
985  pop(stack, a, el);
986 
987  a->elements().push_back(el);
988  push(stack, a);
989 
990  return 0;
991 }
992 
993 template <typename TList>
994 int listReverse(Stack& stack) {
995  TList a;
996  pop(stack, a);
997 
998  auto& elements = a->elements();
999  std::reverse(elements.begin(), elements.end());
1000 
1001  return 0;
1002 }
1003 
1004 template <typename TList>
1005 int listPop(Stack& stack) {
1006  TList list;
1007  int64_t idx;
1008  pop(stack, list, idx);
1009 
1010  auto& elements = list->elements();
1011  const int64_t list_size = elements.size();
1012  const int64_t normalized_idx = normalizeIndex(idx, list_size);
1013 
1014  if (list_size == 0) {
1015  AT_ERROR("pop from empty list");
1016  }
1017 
1018  push(stack, std::move(getItem(list, idx)));
1019  elements.erase(elements.begin() + normalized_idx);
1020 
1021  return 0;
1022 }
1023 
1024 template <>
1025 int listPop<Shared<BoolList>>(Stack& stack) {
1026  Shared<BoolList> list;
1027  int64_t idx;
1028  pop(stack, list, idx);
1029 
1030  auto& elements = list->elements();
1031  const int64_t list_size = elements.size();
1032  const int64_t normalized_idx = normalizeIndex(idx, list_size);
1033 
1034  if (list_size == 0) {
1035  AT_ERROR("pop from empty list");
1036  }
1037 
1038  push(stack, getBoolItem(elements, idx));
1039  elements.erase(elements.begin() + normalized_idx);
1040 
1041  return 0;
1042 }
1043 
1044 template <typename TList>
1045 int listClear(Stack& stack) {
1046  TList a;
1047  pop(stack, a);
1048 
1049  a->elements().clear();
1050  return 0;
1051 }
1052 
1053 template <typename TList, typename TElement>
1054 int listInsert(Stack& stack) {
1055  TList list;
1056  int64_t idx;
1057  TElement elem;
1058  pop(stack, list, idx, elem);
1059 
1060  auto& elements = list->elements();
1061  const int64_t list_size = elements.size();
1062  const int64_t normalized_idx = normalizeIndex(idx, list_size);
1063 
1064  if (normalized_idx < 0 || normalized_idx >= list_size) {
1065  if (normalized_idx < 0) {
1066  elements.insert(elements.begin(), elem);
1067  } else {
1068  elements.push_back(elem);
1069  }
1070  } else {
1071  elements.insert(elements.begin() + normalized_idx, elem);
1072  }
1073 
1074  return 0;
1075 }
1076 
1077 template <typename TList, typename TElement>
1078 int listRemove(Stack& stack) {
1079  TList list;
1080  TElement elem;
1081  pop(stack, list, elem);
1082 
1083  auto& elements = list->elements();
1084  auto pos = std::find(elements.begin(), elements.end(), elem);
1085 
1086  if (pos != elements.end()) {
1087  elements.erase(pos);
1088  } else {
1089  AT_ERROR("list.remove(x): x not in list");
1090  }
1091 
1092  return 0;
1093 }
1094 
1095 template <>
1096 int listRemove<Shared<TensorList>, at::Tensor>(Stack& stack) {
1097  Shared<TensorList> list;
1098  at::Tensor elem;
1099  pop(stack, list, elem);
1100 
1101  auto& elements = list->elements();
1102  auto pos = std::find_if(
1103  elements.begin(), elements.end(), [elem](const at::Tensor& b) {
1104  const auto cmp_result = elem.eq(b);
1105  return cmp_result.is_nonzero();
1106  });
1107 
1108  if (pos != elements.end()) {
1109  elements.erase(pos);
1110  } else {
1111  AT_ERROR("list.remove(x): x not in list");
1112  }
1113 
1114  return 0;
1115 }
1116 
1117 template <typename TList, typename TElement>
1118 int listIndex(Stack& stack) {
1119  TList list;
1120  TElement elem;
1121  pop(stack, list, elem);
1122 
1123  auto& elements = list->elements();
1124  auto pos = std::find(elements.begin(), elements.end(), elem);
1125 
1126  if (pos != elements.end()) {
1127  push(stack, static_cast<int64_t>(std::distance(elements.begin(), pos)));
1128  } else {
1129  AT_ERROR("'", elem, "' is not in list");
1130  }
1131 
1132  return 0;
1133 }
1134 
1135 template <>
1136 int listIndex<Shared<TensorList>, at::Tensor>(Stack& stack) {
1137  Shared<TensorList> list;
1138  at::Tensor elem;
1139  pop(stack, list, elem);
1140 
1141  auto& elements = list->elements();
1142  auto pos = std::find_if(
1143  elements.begin(), elements.end(), [elem](const at::Tensor& b) {
1144  const auto cmp_result = elem.eq(b);
1145  return cmp_result.is_nonzero();
1146  });
1147 
1148  if (pos != elements.end()) {
1149  push(stack, static_cast<int64_t>(std::distance(elements.begin(), pos)));
1150  } else {
1151  AT_ERROR("'", elem, "' is not in list");
1152  }
1153 
1154  return 0;
1155 }
1156 
1157 template <typename TList, typename TElement>
1158 int listCount(Stack& stack) {
1159  TList list;
1160  TElement elem;
1161  pop(stack, list, elem);
1162 
1163  auto& elements = list->elements();
1164  const int64_t count = std::count(elements.begin(), elements.end(), elem);
1165  push(stack, count);
1166 
1167  return 0;
1168 }
1169 
1170 template <>
1171 int listCount<Shared<TensorList>, at::Tensor>(Stack& stack) {
1172  Shared<TensorList> list;
1173  at::Tensor elem;
1174  pop(stack, list, elem);
1175 
1176  auto& elements = list->elements();
1177  const int64_t count = std::count_if(
1178  elements.begin(), elements.end(), [elem](const at::Tensor& b) {
1179  const auto cmp_result = elem.eq(b);
1180  return cmp_result.is_nonzero();
1181  });
1182  push(stack, count);
1183 
1184  return 0;
1185 }
1186 
1187 template <typename TList>
1188 Operation listExtend(const Node* node) {
1189  return [](Stack& stack) {
1190  TList a;
1191  TList b;
1192  pop(stack, a, b);
1193 
1194  auto& vec_a = a->elements();
1195  const auto& vec_b = b->elements();
1196  vec_a.insert(vec_a.end(), vec_b.cbegin(), vec_b.cend());
1197  return 0;
1198  };
1199 }
1200 
1201 template <typename TList>
1202 Operation listCopy(const Node* node) {
1203  return [](Stack& stack) {
1204  TList list;
1205  pop(stack, list);
1206 
1207  const auto& vec = list->elements();
1208  auto out = vec;
1209  push(stack, out);
1210  return 0;
1211  };
1212 }
1213 
1214 template <typename T>
1215 int listSelect(Stack& stack) {
1216  T list;
1217  int64_t idx;
1218  pop(stack, list, idx);
1219 
1220  auto element = getItem(list, idx);
1221  push(stack, std::move(element));
1222  return 0;
1223 }
1224 
1225 // needs specialization because cannot return a pointer to a bool in an array
1226 template <>
1227 int listSelect<Shared<BoolList>>(Stack& stack) {
1228  Shared<BoolList> list;
1229  int64_t idx;
1230  pop(stack, list, idx);
1231 
1232  auto element = getBoolItem(list->elements(), idx);
1233  push(stack, element);
1234  return 0;
1235 }
1236 
1237 template <typename T>
1238 int listLen(Stack& stack) {
1239  T a;
1240  pop(stack, a);
1241 
1242  const int64_t size = a->elements().size();
1243  push(stack, size);
1244  return 0;
1245 }
1246 
1247 template <typename T>
1248 int listEq(Stack& stack) {
1249  T a;
1250  T b;
1251  pop(stack, a, b);
1252  push(stack, a->elements() == b->elements() ? true : false);
1253  return 0;
1254 }
1255 
1256 template <typename T>
1257 int listNe(Stack& stack) {
1258  T a;
1259  T b;
1260  pop(stack, a, b);
1261  push(stack, !(a->elements() == b->elements()));
1262  return 0;
1263 }
1264 
1265 inline bool tensor_list_equal(Shared<TensorList> a, Shared<TensorList> b) {
1266  if (a->elements().size() != b->elements().size()) {
1267  return false;
1268  }
1269 
1270  for (size_t i = 0; i < a->elements().size(); ++i) {
1271  const auto& a_element = a->elements()[i];
1272  const auto& b_element = b->elements()[i];
1273  // This preserves Python's semantics, which uses eq() to compare two
1274  // elements, then passes the result to bool().
1275  // see: https://docs.python.org/3.4/reference/datamodel.html#object.__ge__
1276  const auto cmp_result = a_element.eq(b_element);
1277  if (!cmp_result.is_nonzero()) {
1278  return false;
1279  }
1280  }
1281 
1282  return true;
1283 }
1284 
1285 // Specialization for at::Tensor, since it doesn't define operator==
1286 template <>
1287 int listEq<Shared<TensorList>>(Stack& stack) {
1288  Shared<TensorList> a;
1289  Shared<TensorList> b;
1290  pop(stack, a, b);
1291  push(stack, tensor_list_equal(a, b));
1292  return 0;
1293 }
1294 
1295 // Specialization for at::Tensor, since it doesn't define operator==
1296 template <>
1297 int listNe<Shared<TensorList>>(Stack& stack) {
1298  Shared<TensorList> a;
1299  Shared<TensorList> b;
1300  pop(stack, a, b);
1301  push(stack, !tensor_list_equal(a, b));
1302  return 0;
1303 }
1304 
1305 Operation listList(const Node* node) {
1306  return [=](Stack& stack) {
1307  // Intentional no-op, needed to match Python semantics for list(iterable),
1308  // but in JIT these will already be lists
1309  return 0;
1310  };
1311 }
1312 
1313 template <class TList, class TElement>
1314 int listAdd(Stack& stack) {
1315  TList a;
1316  TList b;
1317  pop(stack, a, b);
1318 
1319  std::vector<TElement> ret;
1320  const auto total_size = a->elements().size() + b->elements().size();
1321  ret.reserve(total_size);
1322  for (const auto& a_element : a->elements()) {
1323  ret.push_back(a_element);
1324  }
1325  for (const auto& b_element : b->elements()) {
1326  ret.push_back(b_element);
1327  }
1328 
1329  push(stack, ret);
1330  return 0;
1331 }
1332 
1333 template <class TList, class TElement>
1334 int listMulIntLeft(Stack& stack) {
1335  TList list;
1336  int64_t n;
1337  pop(stack, list, n);
1338 
1339  std::vector<TElement> ret;
1340  const auto size = list->elements().size() * n;
1341  ret.reserve(size);
1342 
1343  for (auto i = 0; i < n; i++) {
1344  for (const auto& e : list->elements()) {
1345  ret.push_back(e);
1346  }
1347  }
1348 
1349  push(stack, ret);
1350  return 0;
1351 }
1352 
1353 template <class TList, class TElement>
1354 int listMulIntRight(Stack& stack) {
1355  TList list;
1356  int64_t n;
1357  pop(stack, n, list);
1358 
1359  std::vector<TElement> ret;
1360  const auto size = list->elements().size() * n;
1361  ret.reserve(size);
1362 
1363  for (auto i = 0; i < n; i++) {
1364  for (const auto& e : list->elements()) {
1365  ret.push_back(e);
1366  }
1367  }
1368 
1369  push(stack, ret);
1370  return 0;
1371 }
1372 
1373 template <typename TList, typename TElement>
1374 int listSlice(Stack& stack) {
1375  TList list;
1376  int64_t start;
1377  int64_t end;
1378  int64_t step;
1379 
1380  pop(stack, list, start, end, step);
1381  const int64_t list_size = list->elements().size();
1382 
1383  // clamp start and end to the bounds of the list
1384  const auto normalized_start =
1385  std::max((int64_t)0, normalizeIndex(start, list_size));
1386  const auto normalized_end =
1387  std::min(list_size, normalizeIndex(end, list_size));
1388 
1389  std::vector<TElement> sliced_list;
1390  if (normalized_end <= normalized_start) {
1391  // early exit if the slice is trivially empty
1392  push(stack, sliced_list);
1393  return 0;
1394  }
1395 
1396  sliced_list.reserve(normalized_end - normalized_start);
1397 
1398  for (auto i = normalized_start; i < normalized_end;) {
1399  sliced_list.push_back(list->elements()[i]);
1400  i += step;
1401  }
1402 
1403  push(stack, sliced_list);
1404  return 0;
1405 }
1406 
1407 template <typename TList, typename TElement>
1408 int listSetItem(Stack& stack) {
1409  TList list;
1410  int64_t idx;
1411  TElement value;
1412 
1413  pop(stack, list, idx, value);
1414  getItem(list, idx) = value;
1415 
1416  push(stack, list);
1417  return 0;
1418 }
1419 
1420 template <>
1421 int listSetItem<Shared<BoolList>, bool>(Stack& stack) {
1422  Shared<BoolList> list;
1423  int64_t idx;
1424  bool value;
1425 
1426  pop(stack, list, idx, value);
1427 
1428  int64_t list_size = list->elements().size();
1429  auto normalized_idx = normalizeIndex(idx, list_size);
1430  if (normalized_idx < 0 || normalized_idx >= list_size) {
1431  throw std::out_of_range("list index out of range");
1432  }
1433  list->elements()[normalized_idx] = value;
1434 
1435  push(stack, list);
1436  return 0;
1437 }
1438 
1439 int dictSetItem(Stack& stack) {
1440  auto value = pop(stack);
1441  auto idx = pop(stack);
1442  auto& dict = pop(stack).toGenericDict()->elements();
1443  dict[idx] = value;
1444  push(stack, dict);
1445  return 0;
1446 }
1447 
1448 int dictLen(Stack& stack) {
1449  auto dict = pop(stack).toGenericDictRef();
1450  push(stack, int64_t(dict.size()));
1451  return 0;
1452 }
1453 
1454 int dictKeys(Stack& stack) {
1455  auto dict = pop(stack).toGenericDictRef();
1456  std::vector<IValue> keys;
1457  keys.reserve(dict.size());
1458  for (auto item : dict) {
1459  keys.push_back(item.first);
1460  }
1461  push(stack, IValue(keys));
1462  return 0;
1463 }
1464 
1465 int dictValues(Stack& stack) {
1466  auto dict = pop(stack).toGenericDictRef();
1467  std::vector<IValue> values;
1468  values.reserve(dict.size());
1469  for (auto item : dict) {
1470  values.push_back(item.second);
1471  }
1472  push(stack, IValue(values));
1473  return 0;
1474 }
1475 
1476 int dictIndex(Stack& stack) {
1477  auto index = pop(stack);
1478  auto dict = pop(stack).toGenericDict();
1479  const auto& elems = dict->elements();
1480  auto value = elems.find(index);
1481  if (value == elems.end()) {
1482  AT_ERROR("KeyError: '", index, "'");
1483  }
1484  push(stack, value->second);
1485  return 0;
1486 }
1487 
1488 int dictGet(Stack& stack) {
1489  auto index = pop(stack);
1490  auto dict = pop(stack).toGenericDict();
1491  const auto& elems = dict->elements();
1492  auto value = elems.find(index);
1493  if (value == elems.end()) {
1494  push(stack, IValue());
1495  } else {
1496  push(stack, value->second);
1497  }
1498  return 0;
1499 }
1500 
1501 int dictGetDefault(Stack& stack) {
1502  auto default_value = pop(stack);
1503  auto index = pop(stack);
1504  auto dict = pop(stack).toGenericDict();
1505  const auto& elems = dict->elements();
1506  auto value = elems.find(index);
1507  if (value == elems.end()) {
1508  push(stack, default_value);
1509  } else {
1510  push(stack, value->second);
1511  }
1512  return 0;
1513 }
1514 
1515 RegisterOperators reg2({
1516 
1517 #define DEFINE_STRING_OP(op_name, string_op, result) \
1518  Operator(#op_name "(str a, str b) ->" #result, [](Stack& stack) { \
1519  auto b = pop(stack).toStringRef(); \
1520  auto a = pop(stack).toStringRef(); \
1521  push(stack, string_op); \
1522  return 0; \
1523  })
1524 
1525  DEFINE_STRING_OP(aten::eq, a == b, bool),
1526  DEFINE_STRING_OP(aten::ne, a != b, bool),
1527  DEFINE_STRING_OP(aten::add, a + b, str),
1528 #undef DEFINE_STRING_OP
1529 
1530  // tensor length op (size of 1st dimension)
1531  Operator(
1532  "aten::len(Tensor t) -> int",
1533  [](Stack& stack) {
1534  at::Tensor t = pop(stack).toTensor();
1535  if (t.dim() == 0) {
1536  AT_ERROR("len() of a 0-d tensor");
1537  }
1538  push(stack, t.sizes()[0]);
1539  return 0;
1540  }),
1541 // Mutable ops for lists containing mutable types.
1542 #define CREATE_MUTABLE_LIST_OPS(decl_type, c_type) \
1543  Operator( \
1544  "aten::select(" decl_type "[](a) list, int idx) -> " decl_type "(*)", \
1545  listSelect<Shared<c_type>>), \
1546  Operator( \
1547  "aten::append( " decl_type "[](a!) self, " decl_type \
1548  "(c) el) -> " decl_type "[](a!)", \
1549  listAppend<Shared<c_type>, c_type::ElemType>), \
1550  Operator( \
1551  "aten::reverse( " decl_type "[](a!) self) -> ()", \
1552  listReverse<Shared<c_type>>), \
1553  Operator( \
1554  "aten::extend(" decl_type "[](a!) self, " decl_type \
1555  " [] other) -> ()", \
1556  listExtend<Shared<c_type>>), \
1557  Operator( \
1558  "aten::copy(" decl_type \
1559  "[](a) self)" \
1560  " -> " decl_type "[]", \
1561  listCopy<Shared<c_type>>), \
1562  Operator( \
1563  "aten::_set_item(" decl_type "[](a!) l, int idx, " decl_type \
1564  " el) -> " decl_type "[](a!)", \
1565  listSetItem<Shared<c_type>, c_type::ElemType>), \
1566  Operator( \
1567  "aten::clear( " decl_type "[](a!) self) -> ()", \
1568  listClear<Shared<c_type>>), \
1569  Operator( \
1570  "aten::insert( " decl_type \
1571  "[](a!) self, int idx, \
1572  " decl_type " el) -> ()", \
1573  listInsert<Shared<c_type>, c_type::ElemType>), \
1574  Operator( \
1575  "aten::pop(" decl_type \
1576  "[](a!) self, int idx=-1) \
1577  -> " decl_type "(*)", \
1578  listPop<Shared<c_type>>)
1579 
1580  CREATE_MUTABLE_LIST_OPS("Tensor", TensorList),
1581 
1582  Operator(
1583  "aten::remove(Tensor[](a!) self, Tensor el) -> ()",
1584  listRemove<Shared<TensorList>, at::Tensor>),
1585  Operator(
1586  "aten::index(Tensor[] self, Tensor el) -> int",
1587  listIndex<Shared<TensorList>, at::Tensor>),
1588  Operator(
1589  "aten::count(Tensor[] self, Tensor el) -> int",
1590  listCount<Shared<TensorList>, at::Tensor>),
1591 
1592 // Mutable ops for lists containing immutable types.
1593 #define CREATE_IMMUTABLE_LIST_OPS(decl_type, c_type) \
1594  Operator( \
1595  "aten::select(" decl_type "[] a, int b) -> " decl_type, \
1596  listSelect<Shared<c_type>>), \
1597  Operator( \
1598  "aten::append(" decl_type "[](a!) self, " decl_type \
1599  " el) -> " decl_type "[](a!)", \
1600  listAppend<Shared<c_type>, c_type::ElemType>), \
1601  Operator( \
1602  "aten::reverse(" decl_type "[](a!) self) -> ()", \
1603  listReverse<Shared<c_type>>), \
1604  Operator( \
1605  "aten::extend(" decl_type "[](a!) self, " decl_type \
1606  " [] other) -> ()", \
1607  listExtend<Shared<c_type>>), \
1608  Operator( \
1609  "aten::copy(" decl_type \
1610  "[](a) self)" \
1611  " -> " decl_type "[]", \
1612  listCopy<Shared<c_type>>), \
1613  Operator( \
1614  "aten::_set_item(" decl_type "[](a!) l, int idx, " decl_type \
1615  " el) -> " decl_type "[](a!)", \
1616  listSetItem<Shared<c_type>, c_type::ElemType>), \
1617  Operator( \
1618  "aten::clear( " decl_type "[](a!) self) -> ()", \
1619  listClear<Shared<c_type>>), \
1620  Operator( \
1621  "aten::insert( " decl_type \
1622  "[](a!) self, int idx, \
1623  " decl_type " el) -> ()", \
1624  listInsert<Shared<c_type>, c_type::ElemType>), \
1625  Operator( \
1626  "aten::remove(" decl_type \
1627  "[](a!) self, \
1628  " decl_type " el) -> ()", \
1629  listRemove<Shared<c_type>, c_type::ElemType>), \
1630  Operator( \
1631  "aten::index(" decl_type \
1632  "[] self, \
1633  " decl_type " el) -> int", \
1634  listIndex<Shared<c_type>, c_type::ElemType>), \
1635  Operator( \
1636  "aten::count(" decl_type \
1637  "[] self, \
1638  " decl_type " el) -> int", \
1639  listCount<Shared<c_type>, c_type::ElemType>), \
1640  Operator( \
1641  "aten::pop(" decl_type \
1642  "[](a!) self, int idx=-1) \
1643  -> " decl_type, \
1644  listPop<Shared<c_type>>)
1645 
1646  CREATE_IMMUTABLE_LIST_OPS("int", IntList),
1647  CREATE_IMMUTABLE_LIST_OPS("float", DoubleList),
1648  CREATE_IMMUTABLE_LIST_OPS("bool", BoolList),
1649 
1650  // NOTE: this must be after the other list specializations so that operator
1651  // resolution doesn't pick this up first
1652  CREATE_MUTABLE_LIST_OPS("t", GenericList),
1653 #undef CREATE_IMMUTABLE_LIST_OPS
1654 #undef CREATE_MUTABLE_LIST_OPS
1655 
1656 #define CREATE_LIST_OPS(decl_type, c_type) \
1657  Operator("aten::len(" decl_type "[] a) -> int", listLen<Shared<c_type>>), \
1658  Operator( \
1659  "aten::add(" decl_type "[] a, " decl_type "[] b) -> " decl_type \
1660  "[]", \
1661  listAdd<Shared<c_type>, c_type::ElemType>), \
1662  Operator( \
1663  "aten::slice(" decl_type \
1664  "[] l, int start, int end=9223372036854775807, int step=1) -> " decl_type \
1665  "[]", \
1666  listSlice<Shared<c_type>, c_type::ElemType>), \
1667  Operator("aten::list(" decl_type "[] l) -> " decl_type "[]", listList), \
1668  Operator( \
1669  "aten::mul(" decl_type "[] l, int n) -> " decl_type "[]", \
1670  listMulIntLeft<Shared<c_type>, c_type::ElemType>), \
1671  Operator( \
1672  "aten::mul(int n, " decl_type "[] l) -> " decl_type "[]", \
1673  listMulIntRight<Shared<c_type>, c_type::ElemType>)
1674 
1675  CREATE_LIST_OPS("int", IntList),
1676  CREATE_LIST_OPS("float", DoubleList),
1677  CREATE_LIST_OPS("bool", BoolList),
1678  CREATE_LIST_OPS("Tensor", TensorList),
1679  CREATE_LIST_OPS("t", GenericList),
1680 #undef CREATE_LIST_OPS
1681 
1682  Operator("aten::eq(int[] a, int[] b) -> bool", listEq<Shared<IntList>>),
1683  Operator(
1684  "aten::eq(float[] a, float[] b) -> bool",
1685  listEq<Shared<DoubleList>>),
1686  Operator(
1687  "aten::eq(Tensor[] a, Tensor[] b) -> bool",
1688  listEq<Shared<TensorList>>),
1689  Operator("aten::eq(bool[] a, bool[] b) -> bool", listEq<Shared<BoolList>>),
1690  Operator("aten::ne(int[] a, int[] b) -> bool", listNe<Shared<IntList>>),
1691  Operator(
1692  "aten::ne(float[] a, float[] b) -> bool",
1693  listNe<Shared<DoubleList>>),
1694  Operator(
1695  "aten::ne(Tensor[] a, Tensor[] b) -> bool",
1696  listNe<Shared<TensorList>>),
1697  Operator("aten::ne(bool[] a, bool[] b) -> bool", listNe<Shared<BoolList>>),
1698 
1699 #define CREATE_COPY_OP(other_type, c_type) \
1700  Operator( \
1701  "aten::copy_(Tensor(a!) self, " #other_type " other) -> Tensor(a!)", \
1702  [](Stack& stack) { \
1703  at::Tensor t; \
1704  c_type other; \
1705  pop(stack, t, other); \
1706  std::move(t) = other; /* NOLINT(bugprone-use-after-move) */ \
1707  push(stack, std::move(t)); /* NOLINT(bugprone-use-after-move) */ \
1708  return 0; \
1709  })
1710 
1711  CREATE_COPY_OP(Tensor, at::Tensor),
1712  CREATE_COPY_OP(int, int64_t),
1713  CREATE_COPY_OP(float, double),
1714 #undef CREATE_COPY_OP
1715 
1716  DEFINE_BINARY_OP(aten::add, a + b),
1717  DEFINE_BINARY_OP(aten::sub, a - b),
1718  DEFINE_BINARY_OP(aten::mul, a* b),
1719  DEFINE_BINARY_OP(aten::pow, static_cast<decltype(a)>(pow(a, b))),
1720  // min and max are in prim:: because there is a difference between
1721  // the python builtin 'min' and 'torch.min'
1722  DEFINE_BINARY_OP(prim::min, a < b ? a : b),
1723  DEFINE_BINARY_OP(prim::max, a > b ? a : b),
1724  // Pass in two ops for handling int and float separately as % in C++ only
1725  // works for int The modulus calculation is different between C++ and Python
1726  // (on negative), we preserve the python behavior as it's more common and
1727  // match python syntax, hence the conversion.
1728  DEFINE_GENERIC_OP(
1729  aten::remainder,
1730  (b + (a % b)) % b,
1731  fmod((b + fmod(a, b)), b),
1732  int,
1733  float),
1734  DEFINE_INT_FLOAT_OP(aten::remainder, fmod((b + fmod(a, b)), b), float),
1735 
1736  DEFINE_GENERIC_OP(
1737  aten::floordiv,
1738  floordiv(a, b),
1739  std::floor(a / b),
1740  int,
1741  float),
1742  DEFINE_INT_FLOAT_OP(aten::floordiv, std::floor(a / b), float),
1743 
1744  // only used in loop unrolling, not exposed to end users
1745  DEFINE_INT_OP(aten::__round_to_zero_floordiv, a / b),
1746 
1747  DEFINE_INT_OP(aten::__and__, a& b),
1748  DEFINE_INT_OP(aten::__or__, a | b),
1749  DEFINE_INT_OP(aten::__xor__, a ^ b),
1750 
1751  // NB: This is the python truediv operation
1752  Operator(
1753  "aten::div(int a, int b) -> float",
1754  [](Stack& stack) {
1755  int64_t a, b;
1756  pop(stack, a, b);
1757  push(stack, static_cast<double>(a) / static_cast<double>(b));
1758  return 0;
1759  }),
1760  Operator(
1761  "aten::div(float a, float b) -> float",
1762  [](Stack& stack) {
1763  double a, b;
1764  pop(stack, a, b);
1765  push(stack, a / b);
1766  return 0;
1767  }),
1768 
1769  Operator(
1770  "aten::floor(float a) -> int",
1771  [](Stack& stack) {
1772  double a;
1773  pop(stack, a);
1774  push(stack, static_cast<int64_t>(std::floor(a)));
1775  return 0;
1776  }),
1777 
1778  DEFINE_COMPARISON_OP(aten::ne, a != b),
1779  DEFINE_COMPARISON_OP(aten::eq, a == b),
1780  DEFINE_COMPARISON_OP(aten::lt, a < b),
1781  DEFINE_COMPARISON_OP(aten::gt, a > b),
1782  DEFINE_COMPARISON_OP(aten::le, a <= b),
1783  DEFINE_COMPARISON_OP(aten::ge, a >= b),
1784 
1785  DEFINE_BOOL_OP(aten::__and__, a&& b),
1786  DEFINE_BOOL_OP(aten::__or__, a || b),
1787  DEFINE_BOOL_OP(aten::__xor__, a != b),
1788 
1789  Operator(
1790  "aten::neg(int self) -> int",
1791  [](Stack& stack) {
1792  push(stack, -pop(stack).toInt());
1793  return 0;
1794  }),
1795  Operator(
1796  "aten::neg(float self) -> float",
1797  [](Stack& stack) {
1798  push(stack, -pop(stack).toDouble());
1799  return 0;
1800  }),
1801  Operator(
1802  "aten::__not__(bool self) -> bool",
1803  [](Stack& stack) {
1804  push(stack, !pop(stack).toBool());
1805  return 0;
1806  }),
1807  Operator(
1808  "aten::__is__(t1 self, t2 obj) -> bool",
1809  [](Stack& stack) {
1810  IValue self, obj;
1811  pop(stack, self, obj);
1812  push(stack, self.isSameIdentity(obj));
1813  return 0;
1814  }),
1815  Operator(
1816  "aten::__isnot__(t1 self, t2 obj) -> bool",
1817  [](Stack& stack) {
1818  IValue self, obj;
1819  pop(stack, self, obj);
1820  push(stack, !self.isSameIdentity(obj));
1821  return 0;
1822  }),
1823  Operator(
1824  "aten::_tensor_to_list(Tensor self) -> int[]",
1825  [](Stack& stack) {
1826  at::Tensor t;
1827  pop(stack, t);
1828  std::vector<int64_t> elems;
1829  elems.reserve(t.size(0));
1830  for (int i = 0; i < t.size(0); i++) {
1831  elems.push_back(*t[i].data<int32_t>());
1832  }
1833  push(stack, jit::IntList::create(elems));
1834  return 0;
1835  }),
1836  Operator(
1837  "aten::_list_to_tensor(int[] self) -> Tensor",
1838  [](Stack& stack) {
1839  std::vector<int64_t> l;
1840  pop(stack, l);
1841  auto t = torch::empty(
1842  {static_cast<int64_t>(l.size())}, at::dtype(at::kInt));
1843  for (size_t i = 0; i < l.size(); i++) {
1844  t[i] = l[i];
1845  }
1846  push(stack, t);
1847  return 0;
1848  }),
1849 #define CREATE_DICT_OPS(key_type) \
1850  Operator("aten::len(Dict(" key_type ", t) self) -> int", dictLen), \
1851  Operator( \
1852  "aten::keys(Dict(" key_type ", t) self) -> " key_type "[](*)", \
1853  dictKeys), \
1854  Operator( \
1855  "aten::values(Dict(" key_type ", t) self) -> t[](*)", dictValues), \
1856  Operator( \
1857  "prim::DictIndex(Dict(" key_type ", t) self, " key_type \
1858  " key) -> t(*)", \
1859  dictIndex), \
1860  Operator( \
1861  "aten::get(Dict(" key_type ", t) self, " key_type \
1862  " key) -> t(*)?", \
1863  dictGet), \
1864  Operator( \
1865  "aten::get(Dict(" key_type ", t) self, " key_type \
1866  " key, t default_value) -> t(*)", \
1867  dictGetDefault), \
1868  Operator( \
1869  "aten::_set_item(Dict(" key_type ", t)(a!) l, " key_type \
1870  " idx, t v) -> ()", \
1871  dictSetItem)
1872 
1873  CREATE_DICT_OPS("str"),
1874  CREATE_DICT_OPS("int"),
1875  CREATE_DICT_OPS("float"),
1876 #undef CREATE_DICT_OPS
1877 });
1878 
1879 // reference: _output_size in torch/nn/functional.py
1880 // size can be none, int or intlist
1881 // scale_factors can be none, float, or floatlist
1882 std::vector<int64_t> _output_size(
1883  const at::Tensor& input,
1884  size_t dim,
1885  const IValue& size,
1886  const IValue& scale_factors) {
1887  if (!size.isNone()) {
1888  if (size.isInt()) {
1889  std::vector<int64_t> repeated(dim, size.toInt());
1890  return repeated;
1891  } else {
1892  return size.toIntListRef();
1893  }
1894  }
1895  std::vector<double> scale_repeated;
1896  if (scale_factors.isDouble()) {
1897  scale_repeated = std::vector<double>(dim, scale_factors.toDouble());
1898  } else {
1899  scale_repeated = scale_factors.toDoubleListRef();
1900  }
1901  std::vector<int64_t> ret;
1902  for (size_t i = 0; i < dim; ++i) {
1903  ret.push_back(std::floor(input.size(i + 2) * scale_repeated[i]));
1904  }
1905  return ret;
1906 }
1907 
1908 // reference: interpolate in torch/nn/functional.py
1909 // size can be none, int or intlist
1910 // scale_factors can be none, float, or floatlist
1911 at::Tensor interpolate(
1912  const at::Tensor& input,
1913  const IValue& size,
1914  const IValue& scale_factors,
1915  const std::string& mode,
1916  c10::optional<bool> align_corners) {
1917  if ((mode == "nearest" || mode == "area")) {
1918  if (align_corners != c10::nullopt) {
1919  throw std::runtime_error(
1920  "align_corners option can only be set with the "
1921  "interpolating modes: linear | bilinear | bicubic | trilinear");
1922  }
1923  } else {
1924  if (align_corners == c10::nullopt) {
1925  AT_WARN(
1926  "Default upsampling behavior when mode=",
1927  mode,
1928  " is changed "
1929  "to align_corners=False since 0.4.0. Please specify align_corners=True "
1930  "if the old behavior is desired. See the documentation of nn.Upsample for details");
1931  align_corners = false;
1932  }
1933  }
1934 
1935  auto input_dim = input.dim();
1936  if (input_dim == 3 && mode == "nearest")
1937  return at::upsample_nearest1d(
1938  input, _output_size(input, 1, size, scale_factors));
1939  if (input_dim == 4 && mode == "nearest")
1940  return at::upsample_nearest2d(
1941  input, _output_size(input, 2, size, scale_factors));
1942  if (input_dim == 5 && mode == "nearest")
1943  return at::upsample_nearest3d(
1944  input, _output_size(input, 3, size, scale_factors));
1945  if (input_dim == 3 && mode == "area")
1946  return at::adaptive_avg_pool1d(
1947  input, _output_size(input, 1, size, scale_factors));
1948  if (input_dim == 4 && mode == "area")
1949  return at::adaptive_avg_pool2d(
1950  input, _output_size(input, 2, size, scale_factors));
1951  if (input_dim == 5 && mode == "area")
1952  return at::adaptive_avg_pool3d(
1953  input, _output_size(input, 3, size, scale_factors));
1954  if (input_dim == 3 && mode == "linear")
1955  return at::upsample_linear1d(
1956  input, _output_size(input, 1, size, scale_factors), *align_corners);
1957  if (input_dim == 3 && mode == "bilinear")
1958  throw std::runtime_error("Got 3D input, but bilinear mode needs 4D input");
1959  if (input_dim == 3 && mode == "bicubic")
1960  throw std::runtime_error("Got 3D input, but bicubic mode needs 4D input");
1961  if (input_dim == 3 && mode == "trilinear")
1962  throw std::runtime_error("Got 3D input, but trilinear mode needs 5D input");
1963  if (input_dim == 4 && mode == "linear")
1964  throw std::runtime_error("Got 4D input, but linear mode needs 3D input");
1965  if (input_dim == 4 && mode == "bilinear")
1966  return at::upsample_bilinear2d(
1967  input, _output_size(input, 2, size, scale_factors), *align_corners);
1968  if (input_dim == 4 && mode == "bicubic")
1969  return at::upsample_bicubic2d(
1970  input, _output_size(input, 2, size, scale_factors), *align_corners);
1971  if (input_dim == 4 && mode == "trilinear")
1972  throw std::runtime_error("Got 4D input, but trilinear mode needs 5D input");
1973  if (input_dim == 5 && mode == "linear")
1974  throw std::runtime_error("Got 5D input, but linear mode needs 3D input");
1975  if (input_dim == 5 && mode == "bilinear")
1976  throw std::runtime_error("Got 5D input, but bilinear mode needs 4D input");
1977  if (input_dim == 5 && mode == "bicubic")
1978  throw std::runtime_error("Got 5D input, but bicubic mode needs 4D input");
1979  if (input_dim == 5 && mode == "trilinear")
1980  return at::upsample_trilinear3d(
1981  input, _output_size(input, 3, size, scale_factors), *align_corners);
1982 
1983  AT_ERROR(
1984  "Input Error: Only 3D, 4D and 5D input Tensors supported",
1985  " (got ",
1986  input_dim,
1987  "D) for the modes: nearest | linear | bilinear | trilinear",
1988  " (got ",
1989  mode,
1990  ") ");
1991 }
1992 
1993 Operation interpolate_op(const Node* n) {
1994  return [](Stack& stack) {
1995  at::Tensor input;
1996  IValue size;
1997  IValue scale_factors;
1998  std::string mode;
1999  IValue align_corners;
2000  pop(stack, input, size, scale_factors, mode, align_corners);
2001  at::Tensor res = interpolate(
2002  input, size, scale_factors, mode, align_corners.toOptional<bool>());
2003  push(stack, res);
2004  return 0;
2005  };
2006 }
2007 
2008 // interpolate takes in float & float[] for scale factor
2009 // upsample takes in int & int[], so convert the ints to floats before
2010 // passing on to the interpolate op
2011 IValue convert_scale_factor_to_double(const IValue& int_ivalue) {
2012  IValue scale_factor_double;
2013  if (int_ivalue.isInt()) {
2014  scale_factor_double = static_cast<double>(int_ivalue.toInt());
2015  } else if (int_ivalue.isIntList()) {
2016  auto int_list = int_ivalue.toIntListRef();
2017  std::vector<double> double_vec(int_list.begin(), int_list.end());
2018  scale_factor_double = double_vec;
2019  } else if (int_ivalue.isNone()) {
2020  return IValue();
2021  } else {
2022  std::stringstream ss;
2023  ss << "Expecting optional int or int list arg for scale factor, got"
2024  << int_ivalue;
2025  throw std::runtime_error(ss.str());
2026  }
2027  return scale_factor_double;
2028 }
2029 
2030 Operation upsample_nearest_op(const Node* n) {
2031  return [](Stack& stack) {
2032  at::Tensor input;
2033  IValue size;
2034  IValue scale_factor_int;
2035  pop(stack, input, size, scale_factor_int);
2036  IValue scale_factor_double =
2037  convert_scale_factor_to_double(scale_factor_int);
2038  at::Tensor res =
2039  interpolate(input, size, scale_factor_double, "nearest", c10::nullopt);
2040  push(stack, res);
2041  return 0;
2042  };
2043 }
2044 
2045 Operation upsample_op(const Node* n) {
2046  return [](Stack& stack) {
2047  at::Tensor input;
2048  IValue size;
2049  IValue scale_factor_int;
2050  std::string mode;
2051  IValue align_corners;
2052  pop(stack, input, size, scale_factor_int, mode, align_corners);
2053  IValue scale_factor_double =
2054  convert_scale_factor_to_double(scale_factor_int);
2055  at::Tensor res = interpolate(
2056  input,
2057  size,
2058  scale_factor_double,
2059  mode,
2060  align_corners.toOptional<bool>());
2061  push(stack, res);
2062  return 0;
2063  };
2064 }
2065 
2066 Operation upsample_bilinear_op(const Node* n) {
2067  return [](Stack& stack) {
2068  at::Tensor input;
2069  IValue size;
2070  IValue scale_factor_int;
2071  pop(stack, input, size, scale_factor_int);
2072  IValue scale_factor_double =
2073  convert_scale_factor_to_double(scale_factor_int);
2074  at::Tensor res =
2075  interpolate(input, size, scale_factor_double, "bilinear", true);
2076  push(stack, res);
2077  return 0;
2078  };
2079 }
2080 
2081 RegisterOperators reg3({
2082  Operator(
2083  "aten::__interpolate(Tensor input, int? size = None, float[]? scale_factor = None, str mode = 'nearest', bool? align_corners = None) -> Tensor",
2084  interpolate_op),
2085  Operator(
2086  "aten::__interpolate(Tensor input, int[]? size = None, float[]? scale_factor = None, str mode = 'nearest', bool? align_corners = None) -> Tensor",
2087  interpolate_op),
2088  Operator(
2089  "aten::__interpolate(Tensor input, int? size = None, float? scale_factor = None, str mode = 'nearest', bool? align_corners = None) -> Tensor",
2090  interpolate_op),
2091  Operator(
2092  "aten::__interpolate(Tensor input, int[]? size = None, float? scale_factor = None, str mode = 'nearest', bool? align_corners = None) -> Tensor",
2093  interpolate_op),
2094 
2095  Operator(
2096  "aten::__upsample_nearest(Tensor input, int? size = None, int? scale_factor = None) -> Tensor",
2097  upsample_nearest_op),
2098  Operator(
2099  "aten::__upsample_nearest(Tensor input, int[]? size = None, int? scale_factor = None) -> Tensor",
2100  upsample_nearest_op),
2101 
2102  Operator(
2103  "aten::__upsample(Tensor input, int? size = None, int? scale_factor = None, str mode = 'nearest', bool? align_corners = None) -> Tensor",
2104  upsample_op),
2105  Operator(
2106  "aten::__upsample(Tensor input, int[]? size = None, int? scale_factor = None, str mode = 'nearest', bool? align_corners = None) -> Tensor",
2107  upsample_op),
2108 
2109  Operator(
2110  "aten::__upsample_bilinear(Tensor input, int? size = None, int? scale_factor = None) -> Tensor",
2111  upsample_bilinear_op),
2112  Operator(
2113  "aten::__upsample_bilinear(Tensor input, int[]? size = None, int? scale_factor = None) -> Tensor",
2114  upsample_bilinear_op),
2115  Operator(
2116  "aten::__upsample_bilinear(Tensor input, int? size = None, int[]? scale_factor = None) -> Tensor",
2117  upsample_bilinear_op),
2118  Operator(
2119  "aten::__upsample_bilinear(Tensor input, int[]? size = None, int[]? scale_factor = None) -> Tensor",
2120  upsample_bilinear_op),
2121 
2122 });
2123 
2124 at::Tensor leaky_relu(const at::Tensor& tensor, double scalar) {
2125  return at::leaky_relu(tensor, scalar);
2126 }
2127 at::Tensor cat(const std::vector<at::Tensor>& tensors) {
2128  return at::cat(tensors);
2129 }
2130 
2131 std::string get_first(const std::vector<std::vector<std::string>>& strings) {
2132  return strings[0][0];
2133 }
2134 
2135 static auto reg4 =
2137  .op("_test::leaky_relu(Tensor self, float v=0.01) -> Tensor",
2138  &leaky_relu)
2139  .op("_test::cat(Tensor[] inputs) -> Tensor", &cat)
2140  .op("_test::get_first", &get_first);
2141 } // namespace
2142 } // namespace jit
2143 } // namespace torch
RegisterOperators & op(const std::string &name, Implementation &&implementation)
Creates a new operator from a name and implementation function (function pointer or function object/l...
Scalar represents a 0-dimensional tensor which contains a single element.
Definition: Scalar.h:22
bool is_cuda() const noexcept
Return true if the device is of CUDA type.
Definition: Device.h:80
Represents a a compute device on which a tensor is located.
Definition: Device.h:30
Registration class for new operators.
constexpr size_t size() const
size - Get the array size.
Definition: ArrayRef.h:138
bool is_cuda() const
Returns if a Tensor has CUDA backend.
Definition: jit_type.h:17
TensorOptions requires_grad(bool requires_grad=true)
Convenience function that returns a TensorOptions object with the requires_grad set to the given one...
Flush-To-Zero and Denormals-Are-Zero mode.