Caffe2 - C++ API
A deep learning, cross platform ML framework
shape_analysis.cpp
1 #include <torch/csrc/jit/passes/shape_analysis.h>
2 
3 #include <c10/util/Exception.h>
4 #include <torch/csrc/jit/argument_spec.h>
5 #include <torch/csrc/jit/constants.h>
6 #include <torch/csrc/jit/ir.h>
7 #include <torch/csrc/jit/operator.h>
8 #include <torch/csrc/jit/passes/alias_analysis.h>
9 
10 #include <torch/csrc/autograd/variable.h>
11 
12 #include <ATen/DeviceGuard.h>
13 #include <ATen/ExpandUtils.h>
14 
15 #include <exception>
16 #include <iostream>
17 #include <memory>
18 #include <utility>
19 #include <vector>
20 
21 namespace torch {
22 namespace jit {
23 
24 namespace prim {
25 using namespace ::c10::prim;
26 }
27 
28 struct propagation_error : std::exception {};
29 
30 #define SHAPE_ASSERT(cond) \
31  if (!(cond)) \
32  throw propagation_error()
33 
34 namespace {
35 
36 bool isValidArgumentForRunning(Value* v) {
37  // allow constants
38  if (toIValue(v))
39  return true;
40  if (CompleteTensorTypePtr tt = v->type()->cast<CompleteTensorType>()) {
41  return !at::isIntegralType(tt->scalarType());
42  }
43  return v->type()->isSubtypeOf(FloatType::get());
44 }
45 
46 bool isValidReturnForRunning(Value* v) {
47  return v->type()->isSubtypeOf(TensorType::get()) ||
48  v->type()->isSubtypeOf(NumberType::get());
49 }
50 
51 bool containsTensorType(const TypePtr& t) {
52  auto n_contained = t->containedTypes().size();
53  if (n_contained == 1) {
54  return t->containedTypes().at(0)->isSubtypeOf(TensorType::get());
55  } else if (n_contained > 1) {
56  return std::any_of(
57  t->containedTypes().begin(),
58  t->containedTypes().end(),
59  containsTensorType);
60  }
61  return false;
62 }
63 
64 class ShapePropagator {
65  public:
66  explicit ShapePropagator(std::shared_ptr<Graph> graph) : aliasDb_(graph) {
67  collectResizeSet(std::move(graph)->block());
68  }
69 
70  void PropagateShapeOnBlock(Block* block, bool insert_expands = true) {
71  for (Node* node : block->nodes()) {
72  try {
73  PropagateShapeOnNode(node, insert_expands);
74  } catch (propagation_error& e) {
75  setUnshapedType(node);
76  } catch (std::exception& e) {
77  if (auto sl = node->getSourceLocation()) {
78  sl->wrapAndRethrowException(e, "operation failed shape propagation");
79  } else {
80  throw;
81  }
82  }
83  }
84  }
85 
86  private:
87  ValueSet resized_alias_set;
88  const AliasDb aliasDb_;
89 
90  bool resizesInput(Node* n) {
91  static std::unordered_set<Symbol> resize_ops{
92  aten::resize_,
93  aten::resize_as_,
94  };
95 
96  if (resize_ops.count(n->kind()))
97  return true;
98 
99  if (!n->maybeSchema())
100  return false;
101 
102  // ops which take the result and write to input "out"
103  if (auto out_arg_index = n->schema().argumentIndexWithName("out")) {
104  auto arg = n->schema().arguments().at(*out_arg_index);
105  return arg.kwarg_only() && arg.type()->isSubtypeOf(TensorType::get());
106  }
107  return false;
108  }
109 
110  void collectResizeSet(Block* block) {
111  for (Node* n : block->nodes()) {
112  for (Block* b : n->blocks()) {
113  collectResizeSet(b);
114  }
115  if (resizesInput(n)) {
116  for (const auto input : n->inputs()) {
117  if (aliasDb_.writesToAlias(n, {input}, /*recurseBlocks*/ false)) {
118  resized_alias_set.insert(input);
119  }
120  }
121  }
122  }
123  }
124 
125  void setUnshapedType(Value* o) {
126  o->setType(unshapedType(o->type()));
127  }
128 
129  void setUnshapedType(Node* node) {
130  for (auto o : node->outputs()) {
131  setUnshapedType(o);
132  }
133  }
134 
135  int64_t wrapDim(int64_t dim, at::IntArrayRef sizes) {
136  if (dim < 0) {
137  dim += sizes.size();
138  }
139  return dim;
140  }
141 
142  // TODO: Would be better to make JIT not assume that CUDA devices
143  // are the only thing that exist.
144  static at::Device jitDeviceIndexToDevice(int device) {
145  return device == -1 ? at::kCPU : at::Device(at::kCUDA, device);
146  }
147 
148  IValue representativeValue(Value* v) {
149  TypePtr type_ = v->type();
150  // if the value is actually constant, just use it!
151  if (auto iv = toIValue(v)) {
152  return *iv;
153  }
154  if (CompleteTensorTypePtr type = type_->cast<CompleteTensorType>()) {
155  auto backend =
156  type->device().is_cpu() ? at::Backend::CPU : at::Backend::CUDA;
157  at::DeviceGuard device_guard(type->device());
158  auto& attype = at::getNonVariableType(backend, type->scalarType());
159  auto t =
160  at::empty_strided(type->sizes(), type->strides(), attype.options())
161  .zero_();
162  return autograd::make_variable(t, /*requires_grad=*/false);
163  } else if (type_->isSubtypeOf(FloatType::get())) {
164  return 0.f;
165  }
166  // we should not get here because isValidArgumentForRunning should have
167  // prevented it
168  std::stringstream ss;
169  ss << "unable to create representative value for: " << type_->str()
170  << ". File a bug report.";
171  throw std::runtime_error(ss.str());
172  }
173 
174  // for each node in the schema with type Tensor, extract the T type
175  // returns c10::nullopt if any Tensor in the schema does not have a known
176  // shape ignores non-tensor in the list of inputs
177  template <typename T>
178  c10::optional<std::vector<std::shared_ptr<T>>> gatherTensorTypes(Node* node) {
179  std::vector<std::shared_ptr<T>> tensor_types;
180 
181  auto& schema = node->schema();
182  auto& args = schema.arguments();
183  // can't handle varargs primitives because we don't know what should be a
184  // Tensor
185  if (schema.is_vararg()) {
186  return c10::nullopt;
187  }
188  for (size_t i = 0; i < args.size(); ++i) {
189  if (args[i].type()->isSubtypeOf(ListType::ofTensors())) {
190  return c10::nullopt;
191  } else if (args[i].type()->isSubtypeOf(TensorType::get())) {
192  if (auto type = node->input(i)->type()->cast<T>()) {
193  tensor_types.push_back(type);
194  } else {
195  return c10::nullopt;
196  }
197  } else /* non-tensor type */ {
198  continue;
199  }
200  }
201 
202  return tensor_types;
203  }
204 
205  bool mergeTypes(
206  ArrayRef<Value*> lhs,
207  ArrayRef<Value*> rhs,
208  ArrayRef<Value*> outputs) {
209  AT_ASSERT(lhs.size() == rhs.size() && rhs.size() == outputs.size());
210  bool changed = false;
211  for (size_t i = 0; i < lhs.size(); ++i) {
212  auto old_output_type = outputs[i]->type();
213  auto new_type = unifyTypes(lhs[i]->type(), rhs[i]->type());
214  AT_ASSERT(new_type);
215  outputs[i]->setType(*new_type);
216  if (*old_output_type != *outputs[i]->type())
217  changed = true;
218  }
219  return changed;
220  }
221 
222  void broadcastBinary(
223  Node* node,
224  std::vector<CompleteTensorTypePtr>& types,
225  size_t idx1,
226  size_t idx2) {
227  auto expected_size =
228  at::infer_size(types[idx1]->sizes(), types[idx2]->sizes());
229  auto broadcast = [&](size_t input_idx) {
230  CompleteTensorTypePtr input_type = types.at(input_idx);
231  if (input_type->sizes() == expected_size)
232  return;
233  auto graph = node->owningGraph();
234  WithInsertPoint point_guard{node};
235  Node* expand = graph
236  ->create(
237  aten::expand,
238  {node->inputs().at(input_idx),
239  graph->insertConstant(expected_size),
240  graph->insertConstant(false)})
241  ->insertBefore(node);
242  PropagateShapeOnNode(expand);
243  node->replaceInput(input_idx, expand->output());
244  };
245  broadcast(idx1);
246  broadcast(idx2);
247  types[0] = node->inputs().at(idx1)->type()->expect<CompleteTensorType>();
248  types[1] = node->inputs().at(idx2)->type()->expect<CompleteTensorType>();
249  }
250 
251  OperatorSet cannot_propagate_shape_by_running_it = {
252  "aten::solve(Tensor self, Tensor A) -> (Tensor, Tensor)",
253  "aten::inverse(Tensor self) -> Tensor",
254  };
255 
256  // Check if this node depends on a value that has been mutated previously. If
257  // it has, then it's not safe to run this node in isolation, since we don't
258  // know whether the dependency has been executed.
259  std::unordered_map<Node*, bool> dependsOnMutationMemo_;
260  bool dependsOnMutation(Node* node) {
261  if (dependsOnMutationMemo_.count(node) != 0) {
262  return dependsOnMutationMemo_[node];
263  }
264 
265  if (aliasDb_.hasWriters(node)) {
266  // If something could have written to a value used by this node, we can't
267  // guarantee the result is the same when running it in isolation.
268  dependsOnMutationMemo_[node] = true;
269  return true;
270  }
271 
272  // recursively check the producers of its inputs. We need to do this if the
273  // mutable value has been laundered through a pure function:
274  // a += 1
275  // c = a + b
276  // d = c + 1
277  // In this case, `d` cares whether `a` has been mutated even though it's not
278  // a direct input.
279  auto depends = false;
280  for (auto input : node->inputs()) {
281  depends |= dependsOnMutation(input->node());
282  }
283 
284  dependsOnMutationMemo_[node] = depends;
285  return depends;
286  }
287 
288  bool canPropagateShapeByRunningIt(Node* node) {
289  if (cannot_propagate_shape_by_running_it.find(node)) {
290  return false;
291  }
292 
293  if (dependsOnMutation(node)) {
294  return false;
295  }
296 
297  bool valid_args = std::all_of(
298  node->inputs().begin(),
299  node->inputs().end(),
300  isValidArgumentForRunning);
301  if (!valid_args)
302  return false;
303 
304  bool valid_returns = std::all_of(
305  node->outputs().begin(),
306  node->outputs().end(),
307  isValidReturnForRunning);
308  if (!valid_returns)
309  return false;
310 
311  return true;
312  }
313 
314  // If there's no Tensor in outputs, e.g float / float,
315  // we don't need to propagate shape.
316  bool DoesntRefineOutputs(Node* node) {
317  auto outputs = node->outputs();
318  for (auto& out : outputs) {
319  if (containsTensorType(out->type())) {
320  return false;
321  }
322  }
323  return true;
324  }
325 
326  bool PropagateShapeOnNodeByRunningIt(Node* node) {
327  if (!canPropagateShapeByRunningIt(node))
328  return false;
329  auto op = getOperation(node);
330  Stack stack;
331 
332  for (auto input : node->inputs()) {
333  stack.push_back(representativeValue(input));
334  }
335 
336  // XXX: we're not catching any exceptions from the op for now. This
337  // is to uncover any mistakes we could make when editing this code,
338  // and eventually it shouldn't matter, because this phase should be
339  // preceded by schema checking.
340  op(stack);
341 
342  AT_ASSERT(stack.size() == node->outputs().size());
343  for (size_t i = 0; i < stack.size(); ++i) {
344  // some ops may have mixed tensor/primitive outputs
345  // for primitives, we don't need to change the type because it is already
346  // its most constrained form.
347  if (stack[i].isTensor())
348  node->outputs()[i]->inferTypeFrom(stack[i].toTensor());
349  }
350  return true;
351  }
352 
353  void PropagateCatShape(Node* cat_node) {
354  static const auto propagate_complete =
355  [this](Node* node, at::ArrayRef<Value*> tensors) -> bool {
356  auto input_types = fmap(tensors, [](Value* v) {
357  return v->type()->cast<CompleteTensorType>();
358  });
359  if (!std::all_of(
360  input_types.begin(),
361  input_types.end(),
362  [](const CompleteTensorTypePtr& tp) { return tp != nullptr; })) {
363  return false;
364  }
365  if (!node->is_constant(attr::dim))
366  return false;
367  std::vector<int64_t> sizes = input_types[0]->sizes();
368  const int64_t dim = wrapDim(node->get<int64_t>(attr::dim).value(), sizes);
369  const int64_t ndim = sizes.size();
370 
371  if (dim < 0 || dim >= ndim)
372  return false;
373 
374  sizes[dim] = 0;
375  for (auto& tp : input_types) {
376  auto& tp_sizes = tp->sizes();
377  if (sizes.size() != tp_sizes.size())
378  return false;
379  for (int64_t i = 0; i < ndim; ++i) {
380  if (sizes[i] != tp_sizes[i] && i != dim) {
381  return false;
382  }
383  }
384  sizes[dim] += tp_sizes[dim];
385  }
386  node->output()->setType(input_types[0]->withSizes(sizes));
387  return true;
388  };
389  static const auto propagate = [](Node* node,
390  at::ArrayRef<Value*> tensors) -> bool {
391  for (Value* v : tensors) {
392  if (auto type = v->type()->cast<DimensionedTensorType>()) {
393  node->output()->setType(type);
394  return true;
395  }
396  }
397  return false;
398  };
399  auto list_node =
400  ((cat_node->kind() == prim::FusedConcat)
401  ? cat_node
402  : cat_node->namedInput(attr::tensors)->node());
403  if (list_node->kind() == prim::ListConstruct ||
404  cat_node->kind() == prim::FusedConcat) {
405  auto tensors = list_node->inputs();
406  if (!tensors.empty()) {
407  if (propagate_complete(cat_node, tensors)) {
408  return;
409  } else if (propagate(cat_node, tensors)) {
410  return;
411  }
412  }
413  }
414  setUnshapedType(cat_node);
415  }
416 
417  bool mayAliasResizedSet(at::ArrayRef<Value*> vs) {
418  bool in_resize = false;
419  for (auto v : vs) {
420  if (aliasDb_.mayAlias(ValueSet{v}, resized_alias_set)) {
421  setUnshapedType(v);
422  in_resize = true;
423  }
424  }
425  return in_resize;
426  }
427 
428  void PropagateShapeOnNode(Node* node, bool insert_expands = true) {
429  // Certain ops like resize_ change the input tensors size. Because our
430  // analysis is flow invariant, we set any Tensor that can alias a resized
431  // Tensor to the base Tensor Type without size information.
432  if (mayAliasResizedSet(node->inputs())) {
433  return setUnshapedType(node);
434  }
435 
436  // These don't require the types, and have complicated schema. Return early
437  // after we process them.
438  switch (node->kind()) {
439  case prim::If: {
440  auto then_block = node->blocks().at(0);
441  auto else_block = node->blocks().at(1);
442  PropagateShapeOnBlock(then_block);
443  PropagateShapeOnBlock(else_block);
444  mergeTypes(
445  then_block->outputs(), else_block->outputs(), node->outputs());
446  return;
447  }
448  case prim::Loop: {
449  auto body_block = node->blocks().at(0);
450  // propagate counter type
451  body_block->inputs().at(0)->setType(node->inputs().at(0)->type());
452  // propagate loop-carried input types to block inputs
453  auto loop_carried_inputs = node->inputs().slice(2); // skip max, cond
454  auto loop_carried_block = body_block->inputs().slice(1); // skip trip
455  for (size_t i = 0; i < loop_carried_inputs.size(); ++i) {
456  loop_carried_block[i]->setType(loop_carried_inputs[i]->type());
457  }
458  auto loop_carried_outputs = body_block->outputs().slice(1); // skip cond
459 
460  do {
461  PropagateShapeOnBlock(body_block, /*insert_expands=*/false);
462  // note: inserting expands is unsafe at this point, we don't know
463  // if the types are stable yet, so the arguments to expand may change
464  } while (mergeTypes(
465  loop_carried_block, loop_carried_outputs, loop_carried_block));
466 
467  // now that the types are stable, we can insert the expands
468  PropagateShapeOnBlock(body_block, /*insert_expands=*/true);
469 
470  for (size_t i = 0; i < loop_carried_inputs.size(); ++i) {
471  node->outputs()[i]->setType(loop_carried_block[i]->type());
472  }
473  return;
474  }
475  case prim::ImplicitTensorToNum:
476  case prim::Bool:
477  case prim::Int:
478  case prim::Float:
479  return; // correct num type is already set
480  case prim::NumToTensor: {
481  TypePtr typ = node->input()->type();
482  if (typ->isSubtypeOf(IntType::get()) ||
483  typ->isSubtypeOf(BoolType::get())) {
484  node->output()->setType(
485  DimensionedTensorType::create(at::kLong, at::kCPU, 0));
486  } else if (node->input()->type()->isSubtypeOf(FloatType::get())) {
487  node->output()->setType(
488  DimensionedTensorType::create(at::kDouble, at::kCPU, 0));
489  }
490  return;
491  }
492  case prim::TupleConstruct: {
493  // We refresh the tuple type, because the input types could have been
494  // refined.
495  node->output()->setType(TupleType::create(
496  fmap(node->inputs(), [](Value* v) { return v->type(); })));
497  return;
498  }
499  case prim::TupleUnpack: {
500  auto tuple_type = node->input()->type()->cast<TupleType>();
501  AT_ASSERT(
502  tuple_type &&
503  tuple_type->elements().size() == node->outputs().size());
504  auto elems = tuple_type->elements();
505  for (size_t i = 0; i < node->outputs().size(); ++i) {
506  node->output(i)->setType(elems[i]);
507  }
508  return;
509  }
510  case prim::Constant: {
511  if (node->output()->type()->isSubtypeOf(TensorType::get())) {
512  node->output()->inferTypeFrom(node->t(attr::value));
513  }
514  return;
515  }
516  case prim::ConstantChunk: {
517  Value* tensor = node->input();
518  if (auto type = tensor->type()->cast<DimensionedTensorType>()) {
519  for (Value* output : node->outputs()) {
520  output->setType(type);
521  }
522  } else {
523  setUnshapedType(node);
524  }
525  return;
526  }
527  case prim::AutogradZero: {
528  setUnshapedType(node);
529  return;
530  }
531  case aten::_unwrap_optional: {
532  auto input_ivalue = toIValue(node->input());
533  if (input_ivalue && input_ivalue->isNone()) {
534  return;
535  }
536  }
537  default:
538  break; // fall-through
539  }
540 
541  if (node->hasSideEffects()) {
542  return;
543  }
544 
545  if (node->matches("aten::cat(Tensor[] tensors, int dim) -> Tensor") ||
546  node->kind() == prim::FusedConcat) {
547  return PropagateCatShape(node);
548  }
549 
550  if (auto maybe_complete_types =
551  gatherTensorTypes<CompleteTensorType>(node)) {
552  if (PropagateCompleteShapeOnNode(
553  node, insert_expands, std::move(*maybe_complete_types))) {
554  return;
555  }
556  }
557 
558  if (PropagateTensorShapeOnNode(node, insert_expands)) {
559  return;
560  }
561 
562  if (DoesntRefineOutputs(node)) {
563  return;
564  }
565 
566  if (PropagateShapeOnNodeByRunningIt(node)) {
567  return;
568  }
569  return setUnshapedType(node);
570  }
571 
572  static c10::optional<size_t> determineListSize(Value* list) {
573  AT_ASSERT(list->type()->cast<ListType>());
574  if (auto shape = constant_as<std::vector<int64_t>>(list)) {
575  return shape->size();
576  }
577  auto input_node = list->node();
578  if (input_node->kind() == prim::ListConstruct) {
579  return input_node->inputs().size();
580  }
581  return c10::nullopt;
582  }
583 
584  // is it ok to try to run the op
585  // If an input is a constant, then we assume that the input is valid
586  // and we can try to run it.
587  // Otherwise:
588  // Integral typed _inputs_ are often an indicator that we're indexing into
589  // a tensor, so we should special-case these ops in the shape propagation.
590  // Additionally, passing in a zero representative tensor into an integer
591  // division op causes divide-by-zero errors
592  // _Outputs_ must be tensors or primtives
593  // We will call inferTypeFrom on the tensors, and ignore the primitives.
594  // However, we allow primitive returns because we want to support mixed
595  // primitive/tensor outputs.
596 
597  bool PropagateTensorShapeOnNode(Node* node, bool insert_expands) {
598  static const auto broadcast =
599  [](std::vector<DimensionedTensorTypePtr>& tensor_types,
600  size_t arg_for_type) -> DimensionedTensorTypePtr {
601  if (tensor_types.size() == 1) {
602  return tensor_types[0];
603  }
604  AT_ASSERT(!tensor_types.empty());
605  auto any_type = tensor_types[arg_for_type];
606  auto max_dims = any_type->dim();
607  for (auto& type : tensor_types) {
608  max_dims = std::max(max_dims, type->dim());
609  }
610  return DimensionedTensorType::create(
611  any_type->scalarType(), any_type->device(), max_dims);
612  };
613 
614  using type_vec_t = std::vector<DimensionedTensorTypePtr>;
615  // Formula is expected to return a vector of length equal to the number of
616  // tensor outputs of the node, or an empty vector which implies that it
617  // failed to propagate.
618  using formula_t = std::function<type_vec_t(Node*)>;
619  static std::mutex shape_formulas_mutex;
620  static std::vector<std::pair<OperatorSet, formula_t>> shape_formulas;
621  struct register_formula_for {
622  register_formula_for(OperatorSet operators, formula_t formula) {
623  std::unique_lock<std::mutex> lock{shape_formulas_mutex};
624  shape_formulas.emplace_back(std::move(operators), std::move(formula));
625  }
626  };
627 
628  // Requirements:
629  // dims : preserved
630  // scalar type : preserved
631  // device : preserved
632  // tensor inputs : 1
633  // tensor outputs : 1
634  // Additionally:
635  // - First input should be the only tensor input
636  static const register_formula_for simple_unary_ops{
637  {
638  "aten::abs(Tensor self) -> Tensor",
639  "aten::acos(Tensor self) -> Tensor",
640  "aten::neg(Tensor self) -> Tensor",
641  "aten::t(Tensor self) -> Tensor",
642  "aten::sigmoid(Tensor self) -> Tensor",
643  "aten::tanh(Tensor self) -> Tensor",
644  "aten::relu(Tensor self) -> Tensor",
645  "aten::asin(Tensor self) -> Tensor",
646  "aten::atan(Tensor self) -> Tensor",
647  "aten::ceil(Tensor self) -> Tensor",
648  "aten::clone(Tensor self) -> Tensor",
649  "aten::contiguous(Tensor self) -> Tensor",
650  "aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor",
651  "aten::celu(Tensor self, Scalar alpha) -> Tensor",
652  "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor",
653  "aten::clamp_max(Tensor self, Scalar max) -> Tensor",
654  "aten::clamp_min(Tensor self, Scalar min) -> Tensor",
655  "aten::alpha_dropout(Tensor input, float p, bool train) -> Tensor",
656  "aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor",
657  "aten::cos(Tensor self) -> Tensor",
658  "aten::cosh(Tensor self) -> Tensor",
659  "aten::digamma(Tensor self) -> Tensor",
660  "aten::dropout(Tensor input, float p, bool train) -> Tensor",
661  "aten::elu(Tensor self, Scalar alpha, Scalar scale, Scalar input_scale) -> Tensor",
662  "aten::erf(Tensor self) -> Tensor",
663  "aten::erfc(Tensor self) -> Tensor",
664  "aten::erfinv(Tensor self) -> Tensor",
665  "aten::exp(Tensor self) -> Tensor",
666  "aten::expm1(Tensor self) -> Tensor",
667  "aten::log(Tensor self) -> Tensor",
668  "aten::log10(Tensor self) -> Tensor",
669  "aten::log1p(Tensor self) -> Tensor",
670  "aten::log2(Tensor self) -> Tensor",
671  "aten::log_sigmoid(Tensor self) -> Tensor",
672  "aten::log_softmax(Tensor self, int dim) -> Tensor",
673  "aten::floor(Tensor self) -> Tensor",
674  "aten::frac(Tensor self) -> Tensor",
675  "aten::flip(Tensor self, int[] dims) -> Tensor",
676  "aten::feature_alpha_dropout(Tensor input, float p, bool train) -> Tensor",
677  "aten::feature_dropout(Tensor input, float p, bool train) -> Tensor",
678  "aten::hardshrink(Tensor self, Scalar lambd) -> Tensor",
679  "aten::hardtanh(Tensor self, Scalar min_val, Scalar max_val) -> Tensor",
680  "aten::glu(Tensor self, int dim) -> Tensor",
681  "aten::inverse(Tensor self) -> Tensor",
682  "aten::leaky_relu(Tensor self, Scalar negative_slope) -> Tensor",
683  "aten::lgamma(Tensor self) -> Tensor",
684  "aten::mvlgamma(Tensor self, int p) -> Tensor",
685  "aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor",
686  "aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor",
687  "aten::permute(Tensor self, int[] dims) -> Tensor",
688  "aten::pin_memory(Tensor self) -> Tensor",
689  "aten::pinverse(Tensor self, float rcond) -> Tensor",
690  "aten::reciprocal(Tensor self) -> Tensor",
691  "aten::relu(Tensor self) -> Tensor",
692  "aten::round(Tensor self) -> Tensor",
693  "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
694  "aten::rsqrt(Tensor self) -> Tensor",
695  "aten::selu(Tensor self) -> Tensor",
696  "aten::sigmoid(Tensor self) -> Tensor",
697  "aten::sign(Tensor self) -> Tensor",
698  "aten::sin(Tensor self) -> Tensor",
699  "aten::sinh(Tensor self) -> Tensor",
700  "aten::softmax(Tensor self, int dim) -> Tensor",
701  "aten::softplus(Tensor self, Scalar beta, Scalar threshold) -> Tensor",
702  "aten::softshrink(Tensor self, Scalar lambd) -> Tensor",
703  "aten::sqrt(Tensor self) -> Tensor",
704  "aten::tan(Tensor self) -> Tensor",
705  "aten::tanh(Tensor self) -> Tensor",
706  "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor",
707  "aten::transpose(Tensor self, int dim0, int dim1) -> Tensor",
708  "aten::tril(Tensor self, int diagonal) -> Tensor",
709  "aten::triu(Tensor self, int diagonal) -> Tensor",
710  "aten::trunc(Tensor self) -> Tensor",
711  "aten::rot90(Tensor self, int k, int[] dims) -> Tensor",
712  "aten::narrow(Tensor self, int dim, int start, int length) -> Tensor",
713  "aten::slice(Tensor self, int dim, int start, int end, int step) -> Tensor",
714  "aten::alias(Tensor self) -> Tensor",
715  "aten::detach(Tensor self) -> Tensor",
716  "aten::cumprod(Tensor self, int dim) -> Tensor",
717  "aten::cumsum(Tensor self, int dim) -> Tensor",
718 
719  "aten::empty_like(Tensor self) -> Tensor",
720  "aten::full_like(Tensor self, Scalar fill_value) -> Tensor",
721  "aten::ones_like(Tensor self) -> Tensor",
722  "aten::rand_like(Tensor self) -> Tensor",
723  "aten::randint_like(Tensor self, int high) -> Tensor",
724  "aten::randint_like(Tensor self, int low, int high) -> Tensor",
725  "aten::randn_like(Tensor self) -> Tensor",
726  "aten::zeros_like(Tensor self) -> Tensor",
727  },
728  [](Node* node) -> type_vec_t {
729  auto input_type =
730  node->input(0)->type()->cast<DimensionedTensorType>();
731  return input_type ? type_vec_t{input_type} : type_vec_t{};
732  }};
733 
734  // Requirements:
735  // dims : broadcast all tensor args
736  // scalar type : always matching and preserved
737  // device : always matching and preserved
738  // tensor inputs : *
739  // tensor outputs : 1
740  static const register_formula_for broadcasting_ops{
741  {
742  // Tensor-Tensor operators
743  "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
744  "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
745  "aten::mul(Tensor self, Tensor other) -> Tensor",
746  "aten::div(Tensor self, Tensor other) -> Tensor",
747  "aten::pow(Tensor self, Tensor exponent) -> Tensor",
748  "aten::fmod(Tensor self, Tensor other) -> Tensor",
749  "aten::remainder(Tensor self, Tensor other) -> Tensor",
750  "aten::lerp(Tensor self, Tensor end, Scalar weight) -> Tensor",
751  "aten::lerp(Tensor self, Tensor end, Tensor weight) -> Tensor",
752  "aten::max(Tensor self, Tensor other) -> Tensor",
753  "aten::min(Tensor self, Tensor other) -> Tensor",
754  "aten::__and__(Tensor self, Tensor other) -> Tensor",
755  "aten::__or__(Tensor self, Tensor other) -> Tensor",
756  "aten::__xor__(Tensor self, Tensor other) -> Tensor",
757  "aten::__lshift__(Tensor self, Tensor other) -> Tensor",
758  "aten::__rshift__(Tensor self, Tensor other) -> Tensor",
759  "aten::__iand__(Tensor self, Tensor other) -> Tensor",
760  "aten::__ior__(Tensor self, Tensor other) -> Tensor",
761  "aten::__ixor__(Tensor self, Tensor other) -> Tensor",
762  "aten::__ilshift__(Tensor self, Tensor other) -> Tensor",
763  "aten::__irshift__(Tensor self, Tensor other) -> Tensor",
764 
765  // Tensor-Scalar operators
766  "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
767  "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor",
768  "aten::mul(Tensor self, Scalar other) -> Tensor",
769  "aten::div(Tensor self, Scalar other) -> Tensor",
770  "aten::pow(Tensor self, Scalar exponent) -> Tensor",
771  "aten::fmod(Tensor self, Scalar other) -> Tensor",
772  "aten::remainder(Tensor self, Scalar other) -> Tensor",
773  "aten::pow(Scalar self, Tensor exponent) -> Tensor",
774  "aten::__and__(Tensor self, Scalar other) -> Tensor",
775  "aten::__or__(Tensor self, Scalar other) -> Tensor",
776  "aten::__xor__(Tensor self, Scalar other) -> Tensor",
777  "aten::__lshift__(Tensor self, Scalar other) -> Tensor",
778  "aten::__rshift__(Tensor self, Scalar other) -> Tensor",
779  "aten::__iand__(Tensor self, Scalar other) -> Tensor",
780  "aten::__ior__(Tensor self, Scalar other) -> Tensor",
781  "aten::__ixor__(Tensor self, Scalar other) -> Tensor",
782  "aten::__ilshift__(Tensor self, Scalar other) -> Tensor",
783  "aten::__irshift__(Tensor self, Scalar other) -> Tensor",
784 
785  // Ops with Tensor-Tensor overloads only
786  "aten::atan2(Tensor self, Tensor other) -> Tensor",
787 
788  // Non-binary ops
789  "aten::addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value) -> Tensor",
790  "aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value) -> Tensor",
791  },
792  [this](Node* node) -> type_vec_t {
793  if (auto maybe_tensor_types =
794  gatherTensorTypes<DimensionedTensorType>(node)) {
795  return {broadcast(*maybe_tensor_types, 0)};
796  }
797  return {};
798  }};
799 
800  // aten::where is special in that its return type is the second argument's
801  // (self) type rather than the that of condition
802  static const register_formula_for where_op{
803  {
804  "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor",
805  },
806  [this](Node* node) -> type_vec_t {
807  if (auto maybe_tensor_types =
808  gatherTensorTypes<DimensionedTensorType>(node)) {
809  return {broadcast(*maybe_tensor_types, 1)};
810  }
811  return {};
812  }};
813 
814  static const auto any_tensor_type =
815  [](Node* node) -> DimensionedTensorTypePtr {
816  for (Value* input : node->inputs()) {
817  if (auto type = input->type()->cast<DimensionedTensorType>()) {
818  return type;
819  }
820  }
821  return nullptr;
822  };
823 
824  // Requirements:
825  // dims : always matching and preserved
826  // scalar type : always matching and preserved
827  // device : always matching and preserved
828  // tensor inputs : 2
829  // tensor outputs : 1
830  static const register_formula_for binary_ops_strict_match{
831  {
832  "aten::normal(Tensor mean, Tensor std, *, Generator? generator) -> Tensor",
833  "aten::mm(Tensor self, Tensor mat2) -> Tensor",
834  "aten::bmm(Tensor self, Tensor mat2) -> Tensor",
835  },
836  [](Node* node) -> type_vec_t {
837  if (auto type = any_tensor_type(node)) {
838  return {type};
839  }
840  return {};
841  }};
842 
843  // Requirements:
844  // dims : all tensor args are broadcast
845  // scalar type : byte/uint8
846  // device : always matching and preserved
847  // tensor inputs : *
848  // tensor outputs : 1
849  static const register_formula_for comparison_ops{
850  {
851  "aten::lt(Tensor self, Tensor other) -> Tensor",
852  "aten::le(Tensor self, Tensor other) -> Tensor",
853  "aten::gt(Tensor self, Tensor other) -> Tensor",
854  "aten::ge(Tensor self, Tensor other) -> Tensor",
855  "aten::eq(Tensor self, Tensor other) -> Tensor",
856  "aten::ne(Tensor self, Tensor other) -> Tensor",
857  "aten::lt(Tensor self, Scalar other) -> Tensor",
858  "aten::le(Tensor self, Scalar other) -> Tensor",
859  "aten::gt(Tensor self, Scalar other) -> Tensor",
860  "aten::ge(Tensor self, Scalar other) -> Tensor",
861  "aten::eq(Tensor self, Scalar other) -> Tensor",
862  "aten::ne(Tensor self, Scalar other) -> Tensor",
863  },
864  [this](Node* node) -> type_vec_t {
865  if (auto maybe_tensor_types =
866  gatherTensorTypes<DimensionedTensorType>(node)) {
867  return {broadcast(*maybe_tensor_types, 0)->toScalarType(at::kByte)};
868  }
869  return {};
870  }};
871 
872  // Requirements:
873  // dims : preserved from the first argument
874  // scalar type : preserved from the first argument (doesn't have to
875  // match other arguments) device : always matching and preserved
876  // tensor inputs : *
877  // tensor outputs : 1
878  // NB: those ops (with slight adjustments) are good candidates for restarts.
879  // Knowing the type and device of weights or biases is usually enough to
880  // infer the output type.
881  static const register_formula_for nn_ops_first_input_preserving{
882  {
883  "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor",
884  "aten::conv1d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor",
885  "aten::conv2d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor",
886  "aten::conv3d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor",
887  "aten::conv_tbc(Tensor self, Tensor weight, Tensor bias, int pad) -> Tensor",
888  "aten::conv_transpose1d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int groups, int[] dilation) -> Tensor",
889  "aten::conv_transpose2d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int groups, int[] dilation) -> Tensor",
890  "aten::conv_transpose3d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int groups, int[] dilation) -> Tensor",
891  "aten::convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor",
892  "aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> Tensor",
893  "aten::adaptive_avg_pool1d(Tensor self, int[] output_size) -> Tensor",
894  "aten::adaptive_avg_pool2d(Tensor self, int[] output_size) -> Tensor",
895  "aten::adaptive_avg_pool3d(Tensor self, int[] output_size) -> Tensor",
896  "aten::avg_pool1d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor",
897  "aten::avg_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor",
898  "aten::avg_pool3d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor",
899  "aten::max_pool1d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor",
900  "aten::max_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor",
901  "aten::max_pool3d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor",
902  "aten::max_unpool2d(Tensor self, Tensor indices, int[] output_size) -> Tensor",
903  "aten::max_unpool3d(Tensor self, Tensor indices, int[] output_size, int[] stride, int[] padding) -> Tensor",
904  "aten::reflection_pad1d(Tensor self, int[] padding) -> Tensor",
905  "aten::reflection_pad2d(Tensor self, int[] padding) -> Tensor",
906  "aten::replication_pad1d(Tensor self, int[] padding) -> Tensor",
907  "aten::replication_pad2d(Tensor self, int[] padding) -> Tensor",
908  "aten::replication_pad3d(Tensor self, int[] padding) -> Tensor",
909  "aten::upsample_bilinear2d(Tensor self, int[] output_size, bool align_corners) -> Tensor",
910  "aten::upsample_linear1d(Tensor self, int[] output_size, bool align_corners) -> Tensor",
911  "aten::upsample_nearest1d(Tensor self, int[] output_size) -> Tensor",
912  "aten::upsample_nearest2d(Tensor self, int[] output_size) -> Tensor",
913  "aten::upsample_nearest3d(Tensor self, int[] output_size) -> Tensor",
914  "aten::upsample_trilinear3d(Tensor self, int[] output_size, bool align_corners) -> Tensor",
915  "aten::prelu(Tensor self, Tensor weight) -> Tensor",
916  },
917  [](Node* node) -> type_vec_t {
918  if (auto type =
919  node->input(0)->type()->cast<DimensionedTensorType>()) {
920  return {type};
921  }
922  return {};
923  }};
924 
925  // Requirements:
926  // dims : 0
927  // scalar type : preserved
928  // device : preserved
929  // tensor inputs : 1
930  // tensor outputs : 1
931  // Additionally:
932  // - First input should be the only tensor input
933  static const register_formula_for all_reduce_ops{
934  {
935  "aten::det(Tensor self) -> Tensor",
936  "aten::logdet(Tensor self) -> Tensor",
937  "aten::max(Tensor self) -> Tensor",
938  "aten::min(Tensor self) -> Tensor",
939  "aten::mean(Tensor self) -> Tensor",
940  "aten::median(Tensor self) -> Tensor",
941  "aten::norm(Tensor self, Scalar p) -> Tensor",
942  "aten::std(Tensor self, bool unbiased) -> Tensor",
943  "aten::sum(Tensor self) -> Tensor",
944  "aten::trace(Tensor self) -> Tensor",
945  "aten::var(Tensor self, bool unbiased) -> Tensor",
946  "aten::all(Tensor self) -> Tensor",
947  "aten::any(Tensor self) -> Tensor",
948  },
949  [](Node* node) -> type_vec_t {
950  if (auto type =
951  node->input(0)->type()->cast<DimensionedTensorType>()) {
952  return {type->withDim(0)};
953  }
954  return {};
955  }};
956 
957  // Requirements:
958  // dims : 0
959  // scalar type : preserved if floating point, otherwise long/int64
960  // device : preserved
961  // tensor inputs : 1
962  // tensor outputs : 1
963  // Additionally:
964  // - First input should be the only tensor input
965  static const register_formula_for all_reduce_ops_with_integer_upcast{
966  {
967  "aten::sum(Tensor self) -> Tensor",
968  "aten::prod(Tensor self) -> Tensor",
969  },
970  [](Node* node) -> type_vec_t {
971  if (auto type =
972  node->input(0)->type()->cast<DimensionedTensorType>()) {
973  return {at::isFloatingType(type->scalarType())
974  ? type->withDim(0)
975  : type->withDim(0)->toScalarType(at::kLong)};
976  }
977  return {};
978  }};
979 
980  static const auto multidim_reduce_with_postprocess =
981  [](Node* node,
982  int64_t num_reduced_dim,
983  bool upcast_integer) -> type_vec_t {
984  auto maybe_keepdim = node->get<bool>(attr::keepdim);
985  if (!maybe_keepdim)
986  return {};
987  if (auto type = node->input(0)->type()->cast<DimensionedTensorType>()) {
988  if (upcast_integer && !at::isFloatingType(type->scalarType())) {
989  type = type->toScalarType(at::kLong);
990  }
991  if (*maybe_keepdim) {
992  return {type};
993  } else if (type->dim() > num_reduced_dim) {
994  return {type->withDim(type->dim() - num_reduced_dim)};
995  }
996  }
997  return {};
998  };
999 
1000  // Requirements:
1001  // dims : 0 if dim is None, otherwise preserved if keepdim == false or 1 smaller otherwise
1002  // scalar type : preserved
1003  // device : preserved
1004  // tensor inputs : 1
1005  // tensor outputs : 1
1006  // Additionally:
1007  // - First input should be the only tensor input
1008  // - Has a bool keepdim argument
1009  static const register_formula_for argminmax{
1010  {
1011  "aten::argmax(Tensor self, int? dim, bool keepdim) -> Tensor",
1012  "aten::argmin(Tensor self, int? dim, bool keepdim) -> Tensor",
1013  },
1014  [](Node* node) -> type_vec_t {
1015  if (auto type =
1016  node->input(0)->type()->cast<DimensionedTensorType>()) {
1017  if (node->input(1)->type()->kind() == c10::TypeKind::NoneType) {
1018  return {type->withDim(0)};
1019  } else {
1020  return multidim_reduce_with_postprocess(
1021  node, /*num_reduced_dim=*/1, /*upcast_integer=*/false);
1022  }
1023  }
1024  return {};
1025  }};
1026 
1027  // Requirements:
1028  // dims : preserved if keepdim == false, 1 smaller otherwise
1029  // scalar type : preserved for first output, byte/uint8 for second
1030  // output if exists device : preserved tensor inputs : 1 tensor
1031  // outputs : 1 or 2
1032  // Additionally:
1033  // - First input should be the only tensor input
1034  // - Has a bool keepdim argument
1035  static const register_formula_for dim_reduce_ops{
1036  {
1037  "aten::all(Tensor self, int dim, bool keepdim) -> Tensor",
1038  "aten::any(Tensor self, int dim, bool keepdim) -> Tensor",
1039 
1040  // Ops returning indices as second output
1041  "aten::kthvalue(Tensor self, int k, int dim, bool keepdim) -> (Tensor, Tensor)",
1042  "aten::max(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
1043  "aten::min(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
1044  "aten::median(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
1045  "aten::mode(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
1046  },
1047  [](Node* node) -> type_vec_t {
1048  // NB: Note that while this function is generally meant to be used
1049  // with ops that have a single output, we will fix up its return right
1050  // below.
1051  auto output_types = multidim_reduce_with_postprocess(
1052  node, /*num_reduced_dim=*/1, /*upcast_integer=*/false);
1053  if (!output_types.empty() && node->outputs().size() == 2) {
1054  output_types.push_back(
1055  output_types.back()->toScalarType(at::kLong));
1056  }
1057  return output_types;
1058  }};
1059 
1060  // Requirements:
1061  // dims : preserved if keepdim == false, 1 smaller otherwise
1062  // scalar type : preserved if floating point, otherwise long/int64
1063  // device : preserved
1064  // tensor inputs : 1
1065  // tensor outputs : 1
1066  // Additionally:
1067  // - First input should be the only tensor input
1068  // - has a bool keepdim argument
1069  static const register_formula_for dim_reduce_ops_with_integer_upcast{
1070  {
1071  "aten::prod(Tensor self, int dim, bool keepdim) -> Tensor",
1072  },
1073  [](Node* node) -> type_vec_t {
1074  return multidim_reduce_with_postprocess(
1075  node, /*num_reduce_dim=*/1, /*integer_upcast=*/true);
1076  }};
1077 
1078  // Requirements:
1079  // dims : preserved if keepdim == false, dim->size() smaller
1080  // otherwise scalar type : preserved device : preserved tensor
1081  // inputs : 1 tensor outputs : 1
1082  // Additionally:
1083  // - First input should be the only tensor input
1084  // - has a bool keepdim argument
1085  static const register_formula_for multidim_reduce_ops{
1086  {
1087  "aten::logsumexp(Tensor self, int[] dim, bool keepdim) -> Tensor",
1088  "aten::mean(Tensor self, int[] dim, bool keepdim) -> Tensor",
1089  "aten::norm(Tensor self, Scalar? p, int[] dim, bool keepdim) -> Tensor",
1090  "aten::std(Tensor self, int[] dim, bool unbiased, bool keepdim) -> Tensor",
1091  "aten::var(Tensor self, int[] dim, bool unbiased, bool keepdim) -> Tensor",
1092  "aten::max_values(Tensor self, int[] dim, bool keepdim) -> Tensor",
1093  "aten::min_values(Tensor self, int[] dim, bool keepdim) -> Tensor",
1094  },
1095  [](Node* node) -> type_vec_t {
1096  if (auto dim = node->get<std::vector<int64_t>>(attr::dim)) {
1097  return multidim_reduce_with_postprocess(
1098  node, /*num_reduced_dim=*/dim->size(), /*upcast_integer=*/false);
1099  }
1100  return {};
1101  }};
1102 
1103  // Requirements:
1104  // dims : preserved if keepdim == false, 1 smaller otherwise
1105  // scalar type : preserved if floating point, otherwise long/int64
1106  // device : preserved
1107  // tensor inputs : 1
1108  // tensor outputs : 1
1109  // Additionally:
1110  // - has bool keepdim and int[] dim arguments
1111  static const register_formula_for multidim_reduce_ops_with_integer_upcast{
1112  {
1113  "aten::sum(Tensor self, int[] dim, bool keepdim) -> Tensor",
1114  },
1115  [](Node* node) -> type_vec_t {
1116  if (auto dim = node->get<std::vector<int64_t>>(attr::dim)) {
1117  // TODO: can dim contain duplicates?
1118  return multidim_reduce_with_postprocess(
1119  node, /*num_reduced_dim=*/dim->size(), /*upcast_integer=*/true);
1120  }
1121  return {};
1122  }};
1123 
1124  static const auto factory_with_ndim = [](Node* node,
1125  int dim) -> type_vec_t {
1126  at::optional<IValue> maybe_layout_option = node->get(attr::layout);
1127  if (!maybe_layout_option)
1128  return {};
1129  auto layout =
1130  (maybe_layout_option->isNone() ? at::kStrided
1131  : maybe_layout_option->toLayout());
1132 
1133  at::optional<IValue> maybe_device_option = node->get(attr::device);
1134  if (!maybe_device_option)
1135  return {};
1136  auto device =
1137  (maybe_device_option->isNone() ? at::kCPU
1138  : maybe_device_option->toDevice());
1139 
1140  at::optional<IValue> maybe_dtype_option = node->get(attr::dtype);
1141  if (!maybe_dtype_option)
1142  return {};
1143  auto dtype =
1144  (maybe_dtype_option->isNone() ? at::kFloat
1145  : maybe_dtype_option->toScalarType());
1146 
1147  return {DimensionedTensorType::create(dtype, device, dim)};
1148  };
1149 
1150  // Requirements:
1151  // dims : preserved
1152  // scalar type : equal to value of dtype
1153  // device : equal to value of device
1154  // tensor inputs : 1
1155  // tensor outputs : 1
1156  // Additionally:
1157  // - has ScalarType dtype, Layeout layout and Device device arguments
1158  static const register_formula_for like_factories_with_options{
1159  {
1160  "aten::empty_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor",
1161  "aten::full_like(Tensor self, Scalar fill_value, *, int dtype, int layout, Device device) -> Tensor",
1162  "aten::ones_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor",
1163  "aten::rand_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor",
1164  "aten::randint_like(Tensor self, int high, *, int dtype, int layout, Device device) -> Tensor",
1165  "aten::randint_like(Tensor self, int low, int high, *, int dtype, int layout, Device device) -> Tensor",
1166  "aten::randn_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor",
1167  "aten::zeros_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor",
1168  },
1169  [](Node* node) -> type_vec_t {
1170  if (auto type = node->namedInput(attr::self)
1171  ->type()
1172  ->cast<DimensionedTensorType>()) {
1173  return factory_with_ndim(node, type->dim());
1174  }
1175  return {};
1176  }};
1177 
1178  // Requirements:
1179  // dims : equal to number of elements in size
1180  // scalar type : equal to value of dtype
1181  // device : equal to value of device
1182  // tensor inputs : 1
1183  // tensor outputs : 1
1184  // Additionally:
1185  // - has int[] size, ScalarType dtype, Layeout layout and Device device
1186  // arguments
1187  static const register_formula_for size_factories_with_options{
1188  {
1189  "aten::empty(int[] size, *, int? dtype, int? layout, Device? device) -> Tensor",
1190  "aten::full(int[] size, Scalar fill_value, *, int? dtype, int? layout, Device? device) -> Tensor",
1191  "aten::ones(int[] size, *, int? dtype, int? layout, Device? device) -> Tensor",
1192  "aten::rand(int[] size, *, int? dtype, int? layout, Device? device) -> Tensor",
1193  "aten::randn(int[] size, *, int? dtype, int? layout, Device? device) -> Tensor",
1194  "aten::zeros(int[] size, *, int? dtype, int? layout, Device? device) -> Tensor",
1195  "aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device) -> Tensor",
1196  "aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device) -> Tensor",
1197  },
1198  [](Node* node) -> type_vec_t {
1199  if (auto maybe_size = node->get<std::vector<int64_t>>(attr::size)) {
1200  return factory_with_ndim(node, maybe_size->size());
1201  }
1202  return {};
1203  }};
1204 
1205  static const auto get_cast_scalar_type = [](Node* node) -> at::ScalarType {
1206  switch (node->kind()) {
1207  case aten::_cast_Byte:
1208  return at::kByte;
1209  case aten::_cast_Char:
1210  return at::kChar;
1211  case aten::_cast_Double:
1212  return at::kDouble;
1213  case aten::_cast_Float:
1214  return at::kFloat;
1215  case aten::_cast_Half:
1216  return at::kHalf;
1217  case aten::_cast_Int:
1218  return at::kInt;
1219  case aten::_cast_Long:
1220  return at::kLong;
1221  case aten::_cast_Short:
1222  return at::kShort;
1223  default:
1224  AT_ASSERTM(
1225  false,
1226  "unknown node kind in get_cast_scalar_type: ",
1227  node->kind().toQualString());
1228  }
1229  };
1230  static const register_formula_for cast_ops{
1231  {
1232  "aten::_cast_Byte(Tensor self, bool non_blocking) -> Tensor",
1233  "aten::_cast_Char(Tensor self, bool non_blocking) -> Tensor",
1234  "aten::_cast_Double(Tensor self, bool non_blocking) -> Tensor",
1235  "aten::_cast_Float(Tensor self, bool non_blocking) -> Tensor",
1236  "aten::_cast_Half(Tensor self, bool non_blocking) -> Tensor",
1237  "aten::_cast_Int(Tensor self, bool non_blocking) -> Tensor",
1238  "aten::_cast_Long(Tensor self, bool non_blocking) -> Tensor",
1239  "aten::_cast_Short(Tensor self, bool non_blocking) -> Tensor",
1240  },
1241  [](Node* node) -> type_vec_t {
1242  if (auto type = node->namedInput(attr::self)
1243  ->type()
1244  ->cast<DimensionedTensorType>()) {
1245  return {type->toScalarType(get_cast_scalar_type(node))};
1246  }
1247  return {};
1248  }};
1249 
1250  // First, try to match one of the registered formulas to their operator
1251  // sets.
1252  for (auto& entry : shape_formulas) {
1253  if (entry.first.find(node)) {
1254  auto types = entry.second(node);
1255  if (types.empty()) {
1256  return false;
1257  } else {
1258  auto outputs = node->outputs();
1259  AT_ASSERT(types.size() == outputs.size());
1260  for (size_t i = 0; i < types.size(); ++i) {
1261  AT_ASSERT(outputs[i]->type()->isSubtypeOf(TensorType::get()));
1262  outputs[i]->setType(types[i]);
1263  }
1264  return true;
1265  }
1266  }
1267  }
1268 
1269  // This section implements shape prop for an assorted set of nodes that only
1270  // need partial information about their input types.
1271  const auto input_type = [node](size_t index) {
1272  return node->input(index)->type()->cast<DimensionedTensorType>();
1273  };
1274  if (node->matches(
1275  "aten::masked_select(Tensor self, Tensor mask) -> Tensor")) {
1276  auto type = input_type(0);
1277  auto mask_type = input_type(1);
1278  if (type && mask_type) {
1279  if (type->dim() == 0 && mask_type->dim() == 0) {
1280  node->output()->setType(type->withDim(0));
1281  } else {
1282  node->output()->setType(type->withDim(1));
1283  }
1284  return true;
1285  }
1286  if (auto type = input_type(0)) {
1287  node->output()->setType(type->withDim(1));
1288  return true;
1289  }
1290  } else if (node->matches(
1291  "aten::dot(Tensor self, Tensor tensor) -> Tensor")) {
1292  if (auto type = any_tensor_type(node)) {
1293  node->output()->setType(type->withDim(0));
1294  return true;
1295  }
1296  } else if (
1297  node->matches("aten::mv(Tensor self, Tensor vec) -> Tensor") ||
1298  node->matches(
1299  "aten::addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta, Scalar alpha) -> Tensor")) {
1300  if (auto type = any_tensor_type(node)) {
1301  node->output()->setType(type->withDim(1));
1302  return true;
1303  }
1304  } else if (
1305  node->matches(
1306  "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor") ||
1307  node->matches(
1308  "aten::addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta, Scalar alpha) -> Tensor") ||
1309  node->matches(
1310  "aten::addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta, Scalar alpha) -> Tensor")) {
1311  if (auto type = any_tensor_type(node)) {
1312  node->output()->setType(type->withDim(2));
1313  return true;
1314  }
1315  } else if (
1316  node->matches(
1317  "aten::baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta, Scalar alpha) -> Tensor")) {
1318  if (auto type = any_tensor_type(node)) {
1319  node->output()->setType(type->withDim(3));
1320  return true;
1321  }
1322  } else if (
1323  node->matches(
1324  "aten::index_select(Tensor self, int dim, Tensor index) -> Tensor")) {
1325  auto type = input_type(0);
1326  auto index_type = input_type(1);
1327  // index_select behaves very weirdly when self.dim() == 0. It allows both
1328  // 0D and 1D indices, and returns a value that has as many dimensions as
1329  // index.
1330  if (type && index_type) {
1331  if (type->dim() == 0) {
1332  node->output()->setType(type->withDim(index_type->dim()));
1333  } else {
1334  node->output()->setType(type);
1335  }
1336  return true;
1337  }
1338  } else if (
1339  node->matches(
1340  "aten::gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor")) {
1341  auto type = input_type(0);
1342  auto index_type = input_type(1);
1343  // Gather has this annoying edge case where index always needs to match
1344  // the number of dims of self, **except** when self is 1D and index is 0D
1345  // in which case we return a 0D output.
1346  if (type && index_type) {
1347  if (index_type->dim() == 0) {
1348  node->output()->setType(type->withDim(0));
1349  } else {
1350  node->output()->setType(type);
1351  }
1352  return true;
1353  }
1354  } else if (
1355  node->matches(
1356  "aten::embedding(Tensor weight, Tensor indices, int padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor")) {
1357  auto weight_type = input_type(0);
1358  auto indices_type = input_type(1);
1359  if (weight_type && indices_type) {
1360  node->output()->setType(weight_type->withDim(indices_type->dim() + 1));
1361  return true;
1362  }
1363  } else if (
1364  node->matches(
1365  "aten::bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor? bias) -> Tensor")) {
1366  if (auto type = input_type(0)) {
1367  node->output()->setType(type);
1368  return true;
1369  }
1370  if (auto type = input_type(1)) {
1371  node->output()->setType(type);
1372  return true;
1373  }
1374  } else if (
1375  node->matches(
1376  "aten::dist(Tensor self, Tensor other, Scalar p) -> Tensor")) {
1377  if (auto type = any_tensor_type(node)) {
1378  node->output()->setType(type->withDim(0));
1379  return true;
1380  }
1381  }
1382 
1383  // The code below implements formulas that need type information for all
1384  // their tensor inputs, and have exactly one output.
1385  std::vector<DimensionedTensorTypePtr> tensor_types;
1386  static const auto reshape_prop =
1387  [](Node* node,
1388  Symbol shape_input,
1389  const std::vector<DimensionedTensorTypePtr>& tensor_types)
1390  -> DimensionedTensorTypePtr {
1391  if (auto list_size = determineListSize(node->namedInput(shape_input))) {
1392  return tensor_types.at(0)->withDim(*list_size);
1393  }
1394  return nullptr;
1395  };
1396  const auto getSingleOutputType = [&]() -> TypePtr {
1397  if (node->matches("aten::type_as(Tensor self, Tensor other) -> Tensor")) {
1398  return tensor_types.at(0)->toScalarType(
1399  tensor_types.at(1)->scalarType());
1400  } else if (
1401  node->matches("aten::view_as(Tensor self, Tensor other) -> Tensor") ||
1402  node->matches(
1403  "aten::expand_as(Tensor self, Tensor other) -> Tensor") ||
1404  node->matches(
1405  "aten::reshape_as(Tensor self, Tensor other) -> Tensor")) {
1406  return tensor_types.at(0)->withDim(tensor_types.at(1)->dim());
1407  } else if (
1408  node->matches("aten::view(Tensor self, int[] size) -> Tensor") ||
1409  node->matches(
1410  "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor") ||
1411  node->matches(
1412  "aten::as_strided(Tensor self, int[] size, int[] stride, int? storage_offset) -> Tensor")) {
1413  return reshape_prop(node, attr::size, tensor_types);
1414  } else if (node->matches(
1415  "aten::reshape(Tensor self, int[] shape) -> Tensor")) {
1416  return reshape_prop(node, attr::shape, tensor_types);
1417  } else if (node->matches(
1418  "aten::repeat(Tensor self, int[] repeats) -> Tensor")) {
1419  return reshape_prop(node, attr::repeats, tensor_types);
1420  } else if (node->matches(
1421  "aten::unsqueeze(Tensor self, int dim) -> Tensor")) {
1422  auto& t = tensor_types.at(0);
1423  return t->withDim(t->dim() + 1);
1424  } else if (
1425  node->matches(
1426  "aten::select(Tensor self, int dim, int index) -> Tensor") ||
1427  node->matches(
1428  "aten::diagonal(Tensor self, int offset, int dim1, int dim2) -> Tensor")) {
1429  auto& t = tensor_types.at(0);
1430  return t->dim() > 0 ? t->withDim(t->dim() - 1) : nullptr;
1431  } else if (node->matches(
1432  "aten::matmul(Tensor self, Tensor other) -> Tensor")) {
1433  int dim1 = tensor_types.at(0)->dim();
1434  int dim2 = tensor_types.at(1)->dim();
1435  if (dim1 == 1 && dim2 == 1) {
1436  // Dot product
1437  return tensor_types.at(0)->withDim(0);
1438  } else if (dim1 == 2 && dim2 == 2) {
1439  // Matrix multiply
1440  return tensor_types.at(0);
1441  } else if (dim1 == 1 && dim2 == 2) {
1442  // Unsqueeze + matrix multiply + squeeze
1443  return tensor_types.at(0);
1444  } else if (dim1 == 2 && dim2 == 1) {
1445  // Matrix vector multiply
1446  return tensor_types.at(1);
1447  } else {
1448  // Batched matrix multiply (possibly with squeeze + unsqueeze if one
1449  // argument is 1D)
1450  auto type = broadcast(tensor_types, 0);
1451  if (tensor_types.at(0)->dim() == 1 ||
1452  tensor_types.at(1)->dim() == 1) {
1453  type = type->withDim(type->dim() - 1);
1454  }
1455  return type;
1456  }
1457  } else if (node->matches("aten::nonzero(Tensor self) -> Tensor")) {
1458  return tensor_types.at(0)->toScalarType(at::kLong);
1459  } else if (node->matches(
1460  "aten::take(Tensor self, Tensor index) -> Tensor")) {
1461  return tensor_types.at(1)->toScalarType(
1462  tensor_types.at(0)->scalarType());
1463  } else if (node->matches(
1464  "aten::diagflat(Tensor self, int offset) -> Tensor")) {
1465  return tensor_types.at(0)->withDim(2);
1466  } else if (node->matches(
1467  "aten::diag(Tensor self, int diagonal) -> Tensor")) {
1468  auto& t = tensor_types.at(0);
1469  if (t->dim() == 1) {
1470  return t->withDim(2);
1471  } else if (t->dim() == 2) {
1472  return t->withDim(1);
1473  } else {
1474  return nullptr;
1475  }
1476  } else if (
1477  node->matches(
1478  "aten::unfold(Tensor self, int dimension, int size, int step) -> Tensor")) {
1479  auto& t = tensor_types.at(0);
1480  return t->dim() == 0 ? t : t->withDim(t->dim() + 1);
1481  } else if (node->matches(
1482  "aten::polygamma(int n, Tensor self) -> Tensor")) {
1483  return tensor_types.at(0);
1484  }
1485  return nullptr;
1486  };
1487  if (auto maybe_tensor_types =
1488  gatherTensorTypes<DimensionedTensorType>(node)) {
1489  tensor_types = std::move(*maybe_tensor_types);
1490  } else {
1491  return false;
1492  }
1493  if (node->outputs().size() == 1) {
1494  if (auto type = getSingleOutputType()) {
1495  node->output()->setType(type);
1496  return true;
1497  }
1498  }
1499  return false;
1500  }
1501 
1502  bool PropagateCompleteShapeOnNode(
1503  Node* node,
1504  bool insert_expands,
1505  std::vector<CompleteTensorTypePtr> tensor_types) {
1506  // For expensive ops we can directly encode their shape propagation
1507  // here, otherwise we fallback to running a fake version of the op
1508  // to get a quick and dirty propagation.
1509  if (node->matches(
1510  "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") ||
1511  node->matches(
1512  "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") ||
1513  node->matches("aten::mul(Tensor self, Tensor other) -> Tensor")) {
1514  // These nodes and "div" handle tensors of different shapes internally,
1515  // so there's no need to insert explicit expand nodes. Note that "div" is
1516  // handled by the fallthrough because it's not always safe to run it due
1517  // to integer divide-by-zero.
1518  return PropagateShapeOnNodeByRunningIt(node);
1519  } else if (
1520  node->matches(
1521  "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor") ||
1522  node->matches(
1523  "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor") ||
1524  node->matches("aten::mul(Tensor self, Scalar other) -> Tensor") ||
1525  node->matches("aten::pow(Tensor self, Scalar exponent) -> Tensor")) {
1526  node->output()->setType(tensor_types.at(0));
1527  return true;
1528  } else if (
1529  insert_expands &&
1530  (node->matches("aten::pow(Tensor self, Tensor exponent) -> Tensor") ||
1531  node->matches("aten::min(Tensor self, Tensor other) -> Tensor") ||
1532  node->matches("aten::max(Tensor self, Tensor other) -> Tensor") ||
1533  node->matches("aten::lt(Tensor self, Tensor other) -> Tensor") ||
1534  node->matches("aten::le(Tensor self, Tensor other) -> Tensor") ||
1535  node->matches("aten::gt(Tensor self, Tensor other) -> Tensor") ||
1536  node->matches("aten::ge(Tensor self, Tensor other) -> Tensor") ||
1537  node->matches("aten::eq(Tensor self, Tensor other) -> Tensor") ||
1538  node->matches("aten::ne(Tensor self, Tensor other) -> Tensor"))) {
1539  // Binary broadcasting ops
1540  // NB: we don't handle the nodes in any other way (note the lack of
1541  // return!), because the type casting logic in scalar cases is
1542  // non-trivial. It's better to just run them.
1543  broadcastBinary(node, tensor_types, 0, 1);
1544  return PropagateShapeOnNodeByRunningIt(node);
1545  } else if (
1546  node->matches("aten::neg(Tensor self) -> Tensor") ||
1547  node->matches("aten::sigmoid(Tensor self) -> Tensor") ||
1548  node->matches("aten::tanh(Tensor self) -> Tensor")) {
1549  node->output()->setType(tensor_types.at(0)->contiguous());
1550  return true;
1551  } else if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
1552  auto lhs_type = tensor_types.at(0);
1553  auto rhs_type = tensor_types.at(1);
1554  SHAPE_ASSERT(
1555  lhs_type->sizes().size() == 2 && rhs_type->sizes().size() == 2);
1556  node->output()->setType(CompleteTensorType::create(
1557  lhs_type->scalarType(),
1558  lhs_type->device(),
1559  at::IntArrayRef{lhs_type->sizes().at(0), rhs_type->sizes().at(1)}));
1560  return true;
1561  } else if (node->matches("aten::t(Tensor self) -> Tensor")) {
1562  auto tp = tensor_types.at(0);
1563  auto sizes = tp->sizes();
1564  auto strides = tp->strides();
1565  SHAPE_ASSERT(sizes.size() == 2);
1566  std::swap(sizes.at(0), sizes.at(1));
1567  std::swap(strides.at(0), strides.at(1));
1568  node->output()->setType(tp->withSizesStrides(sizes, strides));
1569  return true;
1570  } else if (
1571  node->matches(
1572  "aten::narrow(Tensor self, int dim, int start, int length) -> Tensor",
1573  /*const_inputs=*/{attr::dim, attr::length})) {
1574  auto tp = tensor_types.at(0);
1575  auto sizes = tp->sizes();
1576  int64_t dim = node->get<int64_t>(attr::dim).value();
1577  int64_t length = node->get<int64_t>(attr::length).value();
1578  SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < sizes.size());
1579  sizes.at(dim) = length;
1580  node->output()->setType(tp->withSizesStrides(sizes, tp->strides()));
1581  return true;
1582  } else if (node->matches("aten::sum(Tensor self) -> Tensor")) {
1583  node->output()->setType(tensor_types.at(0)->withSizes({}));
1584  return true;
1585  } else if (node->matches(
1586  "aten::sum(Tensor self, int[] dim, bool keepdim) -> Tensor",
1587  /*const_inputs=*/{attr::dim, attr::keepdim})) {
1588  auto& tp = tensor_types.at(0);
1589  auto sizes = tp->sizes();
1590  auto dims = node->get<std::vector<int64_t>>(attr::dim).value();
1591  bool keepdim = node->get<bool>(attr::keepdim).value();
1592  std::reverse(dims.begin(), dims.end());
1593  for (int64_t dim : dims) {
1594  SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < sizes.size());
1595  if (keepdim) {
1596  sizes.at(dim) = 1;
1597  } else {
1598  sizes.erase(sizes.begin() + dim);
1599  }
1600  }
1601  node->output()->setType(tp->withSizes(sizes));
1602  return true;
1603  } else if (node->matches(
1604  "aten::squeeze(Tensor self, int dim) -> Tensor",
1605  /*const_inputs=*/attr::dim)) {
1606  auto& tp = tensor_types.at(0);
1607  auto sizes = tp->sizes();
1608  auto strides = tp->strides();
1609  int64_t dim = wrapDim(node->get<int64_t>(attr::dim).value(), sizes);
1610  SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < sizes.size());
1611  if (sizes.at(dim) == 1) {
1612  sizes.erase(sizes.begin() + dim);
1613  strides.erase(strides.begin() + dim);
1614  }
1615  node->output()->setType(tp->withSizesStrides(sizes, strides));
1616  return true;
1617  } else if (node->matches(
1618  "aten::unsqueeze(Tensor self, int dim) -> Tensor",
1619  /*const_inputs=*/attr::dim)) {
1620  auto& tp = tensor_types.at(0);
1621  auto sizes = tp->sizes();
1622  auto strides = tp->strides();
1623  int64_t dim = wrapDim(node->get<int64_t>(attr::dim).value(), sizes);
1624  SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) <= sizes.size());
1625  int64_t new_stride = dim >= static_cast<int64_t>(sizes.size())
1626  ? 1
1627  : sizes.at(dim) * strides.at(dim);
1628  sizes.insert(sizes.begin() + dim, 1);
1629  strides.insert(strides.begin() + dim, new_stride);
1630  node->output()->setType(tp->withSizesStrides(sizes, strides));
1631  return true;
1632  } else if (node->matches(
1633  "aten::view(Tensor self, int[] size) -> Tensor",
1634  /*const_inputs=*/attr::size)) {
1635  auto sizes = node->get<std::vector<int64_t>>(attr::size).value();
1636  bool inferred = false;
1637  size_t inferred_idx;
1638  int64_t size_product = 1;
1639  for (size_t i = 0; i < sizes.size(); ++i) {
1640  if (sizes[i] == -1) {
1641  if (inferred)
1642  throw propagation_error();
1643  inferred = true;
1644  inferred_idx = i;
1645  } else {
1646  size_product *= sizes[i];
1647  }
1648  }
1649 
1650  if (inferred) {
1651  SHAPE_ASSERT(size_product != 0);
1652  size_t numel = 1;
1653  for (int64_t s : tensor_types.at(0)->sizes())
1654  numel *= s;
1655  int64_t inferred_size = numel / size_product;
1656  sizes[inferred_idx] = inferred_size;
1657  }
1658  node->output()->setType(tensor_types.at(0)->withSizes(sizes));
1659  return true;
1660  } else if (node->matches(
1661  "aten::type_as(Tensor self, Tensor other) -> Tensor")) {
1662  if (tensor_types.at(0)->scalarType() ==
1663  tensor_types.at(1)->scalarType()) {
1664  node->output()->setType(node->namedInput(attr::self)->type());
1665  } else {
1666  // This will be a copy, so the result will be contiguous
1667  node->output()->setType(
1668  tensor_types.at(1)->withSizes(tensor_types.at(0)->sizes()));
1669  }
1670  return true;
1671  } else if (
1672  node->matches(
1673  "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor",
1674  /*const_inputs=*/attr::size)) {
1675  auto tp = tensor_types.at(0);
1676  std::vector<int64_t> sizes, strides;
1677  std::tie(sizes, strides) = at::inferExpandGeometry(
1678  tp->sizes(),
1679  tp->strides(),
1680  node->get<std::vector<int64_t>>(attr::size).value());
1681  node->output()->setType(tp->withSizesStrides(sizes, strides));
1682  return true;
1683  } else if (
1684  node->matches(
1685  "aten::index_select(Tensor self, int dim, Tensor index) -> Tensor",
1686  /*const_inputs=*/attr::dim)) {
1687  auto ten = tensor_types.at(0);
1688  auto index = tensor_types.at(1);
1689  int64_t dim = node->get<int64_t>(attr::dim).value();
1690  SHAPE_ASSERT(index->sizes().size() == 1);
1691  SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < ten->sizes().size());
1692  std::vector<int64_t> sizes = ten->sizes();
1693  sizes[dim] = index->sizes()[0];
1694  node->output()->setType(ten->withSizes(sizes));
1695  return true;
1696  } else if (node->matches(
1697  "aten::chunk(Tensor self, int chunks, int dim) -> Tensor[]",
1698  /*const_inputs=*/{attr::chunks, attr::dim})) {
1699  auto input_type = tensor_types.at(0);
1700  auto sizes = input_type->sizes();
1701  const auto& strides = input_type->strides();
1702  int64_t dim = node->get<int64_t>(attr::dim).value();
1703  int64_t chunks = node->get<int64_t>(attr::chunks).value();
1704  sizes[dim] /= chunks;
1705  for (Value* output : node->outputs()) {
1706  output->setType(input_type->withSizesStrides(sizes, strides));
1707  }
1708  if (input_type->sizes().at(dim) % chunks != 0) {
1709  sizes[dim] = input_type->sizes().at(dim) % chunks;
1710  node->outputs().back()->setType(
1711  input_type->withSizesStrides(sizes, strides));
1712  }
1713  return true;
1714  } else if (node->kind() == ::c10::onnx::Shape) {
1715  SHAPE_ASSERT(node->inputs().size() == 1 && node->outputs().size() == 1);
1716  std::vector<int64_t> dim_vec = {
1717  (int64_t)tensor_types.at(0)->sizes().size()};
1718  at::IntArrayRef dims(dim_vec);
1719  node->output()->setType(
1720  CompleteTensorType::create(at::kLong, at::kCPU, dims));
1721  return true;
1722  } else if (node->kind() == ::c10::onnx::Reshape) {
1723  setUnshapedType(node);
1724  return true;
1725  }
1726  setUnshapedType(node);
1727  return false;
1728  }
1729 };
1730 } // anonymous namespace
1731 
1732 void PropagateInputShapes(const std::shared_ptr<Graph>& graph) {
1733  ShapePropagator(graph).PropagateShapeOnBlock(graph->block());
1734 }
1735 
1736 namespace {
1737 
1738 void EraseShapeInformation(at::ArrayRef<Value*> vals) {
1739  for (Value* v : vals) {
1740  v->setType(unshapedType(v->type()));
1741  }
1742 }
1743 
1744 void EraseShapeInformation(Block* b) {
1745  EraseShapeInformation(b->inputs());
1746  EraseShapeInformation(b->outputs());
1747  for (Node* n : b->nodes()) {
1748  EraseShapeInformation(n->outputs());
1749  for (Block* sb : n->blocks()) {
1750  EraseShapeInformation(sb);
1751  }
1752  if (n->hasAttribute(attr::Subgraph)) {
1753  EraseShapeInformation(n->g(attr::Subgraph));
1754  }
1755  }
1756 }
1757 } // anonymous namespace
1758 
1759 void EraseShapeInformation(const std::shared_ptr<Graph>& graph) {
1760  EraseShapeInformation(graph->block());
1761 }
1762 } // namespace jit
1763 } // namespace torch
Alias analysis pass.
Represents a a compute device on which a tensor is located.
Definition: Device.h:30
constexpr size_t size() const
size - Get the array size.
Definition: ArrayRef.h:138
Definition: jit_type.h:17
An utility class for setting temporary insertion points.
Definition: ir.h:1174
RAII guard that sets a certain default device in its constructor, and changes it back to the device t...
Definition: DeviceGuard.h:19
AT_CPP14_CONSTEXPR const T & at(size_t Index) const
Vector compatibility.
Definition: ArrayRef.h:186