1 #include <torch/csrc/jit/passes/shape_analysis.h> 3 #include <c10/util/Exception.h> 4 #include <torch/csrc/jit/argument_spec.h> 5 #include <torch/csrc/jit/constants.h> 6 #include <torch/csrc/jit/ir.h> 7 #include <torch/csrc/jit/operator.h> 8 #include <torch/csrc/jit/passes/alias_analysis.h> 10 #include <torch/csrc/autograd/variable.h> 12 #include <ATen/DeviceGuard.h> 13 #include <ATen/ExpandUtils.h> 25 using namespace ::c10::prim;
30 #define SHAPE_ASSERT(cond) \ 32 throw propagation_error() 36 bool isValidArgumentForRunning(
Value* v) {
41 return !at::isIntegralType(tt->scalarType());
43 return v->type()->isSubtypeOf(FloatType::get());
46 bool isValidReturnForRunning(
Value* v) {
47 return v->type()->isSubtypeOf(TensorType::get()) ||
48 v->type()->isSubtypeOf(NumberType::get());
51 bool containsTensorType(
const TypePtr& t) {
52 auto n_contained = t->containedTypes().size();
53 if (n_contained == 1) {
54 return t->containedTypes().at(0)->isSubtypeOf(TensorType::get());
55 }
else if (n_contained > 1) {
57 t->containedTypes().begin(),
58 t->containedTypes().end(),
64 class ShapePropagator {
66 explicit ShapePropagator(std::shared_ptr<Graph> graph) : aliasDb_(graph) {
67 collectResizeSet(std::move(graph)->block());
70 void PropagateShapeOnBlock(
Block* block,
bool insert_expands =
true) {
71 for (
Node* node : block->nodes()) {
73 PropagateShapeOnNode(node, insert_expands);
75 setUnshapedType(node);
76 }
catch (std::exception& e) {
77 if (
auto sl = node->getSourceLocation()) {
78 sl->wrapAndRethrowException(e,
"operation failed shape propagation");
87 ValueSet resized_alias_set;
90 bool resizesInput(
Node* n) {
91 static std::unordered_set<Symbol> resize_ops{
96 if (resize_ops.count(n->kind()))
99 if (!n->maybeSchema())
103 if (
auto out_arg_index = n->schema().argumentIndexWithName(
"out")) {
104 auto arg = n->schema().arguments().at(*out_arg_index);
105 return arg.kwarg_only() && arg.type()->isSubtypeOf(TensorType::get());
110 void collectResizeSet(
Block* block) {
111 for (
Node* n : block->nodes()) {
112 for (
Block* b : n->blocks()) {
115 if (resizesInput(n)) {
116 for (
const auto input : n->inputs()) {
117 if (aliasDb_.writesToAlias(n, {input},
false)) {
118 resized_alias_set.insert(input);
125 void setUnshapedType(
Value* o) {
126 o->setType(unshapedType(o->type()));
129 void setUnshapedType(
Node* node) {
130 for (
auto o : node->outputs()) {
144 static at::Device jitDeviceIndexToDevice(
int device) {
145 return device == -1 ? at::kCPU :
at::Device(at::kCUDA, device);
149 TypePtr type_ = v->type();
151 if (
auto iv = toIValue(v)) {
156 type->device().is_cpu() ? at::Backend::CPU : at::Backend::CUDA;
158 auto& attype = at::getNonVariableType(backend, type->scalarType());
160 at::empty_strided(type->sizes(), type->strides(), attype.options())
162 return autograd::make_variable(t,
false);
163 }
else if (type_->isSubtypeOf(FloatType::get())) {
168 std::stringstream ss;
169 ss <<
"unable to create representative value for: " << type_->str()
170 <<
". File a bug report.";
171 throw std::runtime_error(ss.str());
177 template <
typename T>
179 std::vector<std::shared_ptr<T>> tensor_types;
181 auto& schema = node->schema();
182 auto& args = schema.arguments();
185 if (schema.is_vararg()) {
188 for (
size_t i = 0; i < args.size(); ++i) {
189 if (args[i].type()->isSubtypeOf(ListType::ofTensors())) {
191 }
else if (args[i].type()->isSubtypeOf(TensorType::get())) {
192 if (
auto type = node->input(i)->type()->cast<
T>()) {
193 tensor_types.push_back(type);
210 bool changed =
false;
211 for (
size_t i = 0; i < lhs.
size(); ++i) {
212 auto old_output_type = outputs[i]->type();
213 auto new_type = unifyTypes(lhs[i]->type(), rhs[i]->type());
215 outputs[i]->setType(*new_type);
216 if (*old_output_type != *outputs[i]->type())
222 void broadcastBinary(
224 std::vector<CompleteTensorTypePtr>& types,
228 at::infer_size(types[idx1]->sizes(), types[idx2]->sizes());
229 auto broadcast = [&](
size_t input_idx) {
230 CompleteTensorTypePtr input_type = types.at(input_idx);
231 if (input_type->sizes() == expected_size)
233 auto graph = node->owningGraph();
238 {node->inputs().at(input_idx),
239 graph->insertConstant(expected_size),
240 graph->insertConstant(
false)})
241 ->insertBefore(node);
242 PropagateShapeOnNode(expand);
243 node->replaceInput(input_idx, expand->output());
251 OperatorSet cannot_propagate_shape_by_running_it = {
252 "aten::solve(Tensor self, Tensor A) -> (Tensor, Tensor)",
253 "aten::inverse(Tensor self) -> Tensor",
259 std::unordered_map<Node*, bool> dependsOnMutationMemo_;
260 bool dependsOnMutation(
Node* node) {
261 if (dependsOnMutationMemo_.count(node) != 0) {
262 return dependsOnMutationMemo_[node];
265 if (aliasDb_.hasWriters(node)) {
268 dependsOnMutationMemo_[node] =
true;
279 auto depends =
false;
280 for (
auto input : node->inputs()) {
281 depends |= dependsOnMutation(input->node());
284 dependsOnMutationMemo_[node] = depends;
288 bool canPropagateShapeByRunningIt(
Node* node) {
289 if (cannot_propagate_shape_by_running_it.find(node)) {
293 if (dependsOnMutation(node)) {
297 bool valid_args = std::all_of(
298 node->inputs().begin(),
299 node->inputs().end(),
300 isValidArgumentForRunning);
304 bool valid_returns = std::all_of(
305 node->outputs().begin(),
306 node->outputs().end(),
307 isValidReturnForRunning);
316 bool DoesntRefineOutputs(
Node* node) {
317 auto outputs = node->outputs();
318 for (
auto& out : outputs) {
319 if (containsTensorType(out->type())) {
326 bool PropagateShapeOnNodeByRunningIt(
Node* node) {
327 if (!canPropagateShapeByRunningIt(node))
329 auto op = getOperation(node);
332 for (
auto input : node->inputs()) {
333 stack.push_back(representativeValue(input));
342 AT_ASSERT(stack.size() == node->outputs().size());
343 for (
size_t i = 0; i < stack.size(); ++i) {
347 if (stack[i].isTensor())
348 node->outputs()[i]->inferTypeFrom(stack[i].toTensor());
353 void PropagateCatShape(
Node* cat_node) {
354 static const auto propagate_complete =
356 auto input_types = fmap(tensors, [](
Value* v) {
362 [](
const CompleteTensorTypePtr& tp) {
return tp !=
nullptr; })) {
365 if (!node->is_constant(attr::dim))
367 std::vector<int64_t> sizes = input_types[0]->sizes();
368 const int64_t dim = wrapDim(node->get<int64_t>(attr::dim).value(), sizes);
369 const int64_t ndim = sizes.size();
371 if (dim < 0 || dim >= ndim)
375 for (
auto& tp : input_types) {
376 auto& tp_sizes = tp->sizes();
377 if (sizes.size() != tp_sizes.size())
379 for (int64_t i = 0; i < ndim; ++i) {
380 if (sizes[i] != tp_sizes[i] && i != dim) {
384 sizes[dim] += tp_sizes[dim];
386 node->output()->setType(input_types[0]->withSizes(sizes));
389 static const auto propagate = [](
Node* node,
391 for (
Value* v : tensors) {
393 node->output()->setType(type);
400 ((cat_node->kind() == prim::FusedConcat)
402 : cat_node->namedInput(attr::tensors)->node());
403 if (list_node->kind() == prim::ListConstruct ||
404 cat_node->kind() == prim::FusedConcat) {
405 auto tensors = list_node->inputs();
406 if (!tensors.empty()) {
407 if (propagate_complete(cat_node, tensors)) {
409 }
else if (propagate(cat_node, tensors)) {
414 setUnshapedType(cat_node);
418 bool in_resize =
false;
420 if (aliasDb_.mayAlias(ValueSet{v}, resized_alias_set)) {
428 void PropagateShapeOnNode(
Node* node,
bool insert_expands =
true) {
432 if (mayAliasResizedSet(node->inputs())) {
433 return setUnshapedType(node);
438 switch (node->kind()) {
440 auto then_block = node->blocks().at(0);
441 auto else_block = node->blocks().at(1);
442 PropagateShapeOnBlock(then_block);
443 PropagateShapeOnBlock(else_block);
445 then_block->outputs(), else_block->outputs(), node->outputs());
449 auto body_block = node->blocks().at(0);
451 body_block->inputs().at(0)->setType(node->inputs().at(0)->type());
453 auto loop_carried_inputs = node->inputs().slice(2);
454 auto loop_carried_block = body_block->inputs().slice(1);
455 for (
size_t i = 0; i < loop_carried_inputs.size(); ++i) {
456 loop_carried_block[i]->setType(loop_carried_inputs[i]->type());
458 auto loop_carried_outputs = body_block->outputs().slice(1);
461 PropagateShapeOnBlock(body_block,
false);
465 loop_carried_block, loop_carried_outputs, loop_carried_block));
468 PropagateShapeOnBlock(body_block,
true);
470 for (
size_t i = 0; i < loop_carried_inputs.size(); ++i) {
471 node->outputs()[i]->setType(loop_carried_block[i]->type());
475 case prim::ImplicitTensorToNum:
480 case prim::NumToTensor: {
481 TypePtr typ = node->input()->type();
482 if (typ->isSubtypeOf(IntType::get()) ||
483 typ->isSubtypeOf(BoolType::get())) {
484 node->output()->setType(
485 DimensionedTensorType::create(at::kLong, at::kCPU, 0));
486 }
else if (node->input()->type()->isSubtypeOf(FloatType::get())) {
487 node->output()->setType(
488 DimensionedTensorType::create(at::kDouble, at::kCPU, 0));
492 case prim::TupleConstruct: {
495 node->output()->setType(TupleType::create(
496 fmap(node->inputs(), [](
Value* v) {
return v->type(); })));
499 case prim::TupleUnpack: {
500 auto tuple_type = node->input()->type()->cast<
TupleType>();
503 tuple_type->elements().size() == node->outputs().size());
504 auto elems = tuple_type->elements();
505 for (
size_t i = 0; i < node->outputs().size(); ++i) {
506 node->output(i)->setType(elems[i]);
510 case prim::Constant: {
511 if (node->output()->type()->isSubtypeOf(TensorType::get())) {
512 node->output()->inferTypeFrom(node->t(attr::value));
516 case prim::ConstantChunk: {
517 Value* tensor = node->input();
519 for (
Value* output : node->outputs()) {
520 output->setType(type);
523 setUnshapedType(node);
527 case prim::AutogradZero: {
528 setUnshapedType(node);
531 case aten::_unwrap_optional: {
532 auto input_ivalue = toIValue(node->input());
533 if (input_ivalue && input_ivalue->isNone()) {
541 if (node->hasSideEffects()) {
545 if (node->matches(
"aten::cat(Tensor[] tensors, int dim) -> Tensor") ||
546 node->kind() == prim::FusedConcat) {
547 return PropagateCatShape(node);
550 if (
auto maybe_complete_types =
551 gatherTensorTypes<CompleteTensorType>(node)) {
552 if (PropagateCompleteShapeOnNode(
553 node, insert_expands, std::move(*maybe_complete_types))) {
558 if (PropagateTensorShapeOnNode(node, insert_expands)) {
562 if (DoesntRefineOutputs(node)) {
566 if (PropagateShapeOnNodeByRunningIt(node)) {
569 return setUnshapedType(node);
573 AT_ASSERT(list->type()->cast<
ListType>());
574 if (
auto shape = constant_as<std::vector<int64_t>>(list)) {
575 return shape->size();
577 auto input_node = list->node();
578 if (input_node->kind() == prim::ListConstruct) {
579 return input_node->inputs().size();
597 bool PropagateTensorShapeOnNode(
Node* node,
bool insert_expands) {
598 static const auto broadcast =
599 [](std::vector<DimensionedTensorTypePtr>& tensor_types,
600 size_t arg_for_type) -> DimensionedTensorTypePtr {
601 if (tensor_types.size() == 1) {
602 return tensor_types[0];
604 AT_ASSERT(!tensor_types.empty());
605 auto any_type = tensor_types[arg_for_type];
606 auto max_dims = any_type->dim();
607 for (
auto& type : tensor_types) {
608 max_dims = std::max(max_dims, type->dim());
610 return DimensionedTensorType::create(
611 any_type->scalarType(), any_type->device(), max_dims);
614 using type_vec_t = std::vector<DimensionedTensorTypePtr>;
618 using formula_t = std::function<type_vec_t(Node*)>;
619 static std::mutex shape_formulas_mutex;
620 static std::vector<std::pair<OperatorSet, formula_t>> shape_formulas;
621 struct register_formula_for {
622 register_formula_for(
OperatorSet operators, formula_t formula) {
623 std::unique_lock<std::mutex> lock{shape_formulas_mutex};
624 shape_formulas.emplace_back(std::move(operators), std::move(formula));
636 static const register_formula_for simple_unary_ops{
638 "aten::abs(Tensor self) -> Tensor",
639 "aten::acos(Tensor self) -> Tensor",
640 "aten::neg(Tensor self) -> Tensor",
641 "aten::t(Tensor self) -> Tensor",
642 "aten::sigmoid(Tensor self) -> Tensor",
643 "aten::tanh(Tensor self) -> Tensor",
644 "aten::relu(Tensor self) -> Tensor",
645 "aten::asin(Tensor self) -> Tensor",
646 "aten::atan(Tensor self) -> Tensor",
647 "aten::ceil(Tensor self) -> Tensor",
648 "aten::clone(Tensor self) -> Tensor",
649 "aten::contiguous(Tensor self) -> Tensor",
650 "aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor",
651 "aten::celu(Tensor self, Scalar alpha) -> Tensor",
652 "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor",
653 "aten::clamp_max(Tensor self, Scalar max) -> Tensor",
654 "aten::clamp_min(Tensor self, Scalar min) -> Tensor",
655 "aten::alpha_dropout(Tensor input, float p, bool train) -> Tensor",
656 "aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor",
657 "aten::cos(Tensor self) -> Tensor",
658 "aten::cosh(Tensor self) -> Tensor",
659 "aten::digamma(Tensor self) -> Tensor",
660 "aten::dropout(Tensor input, float p, bool train) -> Tensor",
661 "aten::elu(Tensor self, Scalar alpha, Scalar scale, Scalar input_scale) -> Tensor",
662 "aten::erf(Tensor self) -> Tensor",
663 "aten::erfc(Tensor self) -> Tensor",
664 "aten::erfinv(Tensor self) -> Tensor",
665 "aten::exp(Tensor self) -> Tensor",
666 "aten::expm1(Tensor self) -> Tensor",
667 "aten::log(Tensor self) -> Tensor",
668 "aten::log10(Tensor self) -> Tensor",
669 "aten::log1p(Tensor self) -> Tensor",
670 "aten::log2(Tensor self) -> Tensor",
671 "aten::log_sigmoid(Tensor self) -> Tensor",
672 "aten::log_softmax(Tensor self, int dim) -> Tensor",
673 "aten::floor(Tensor self) -> Tensor",
674 "aten::frac(Tensor self) -> Tensor",
675 "aten::flip(Tensor self, int[] dims) -> Tensor",
676 "aten::feature_alpha_dropout(Tensor input, float p, bool train) -> Tensor",
677 "aten::feature_dropout(Tensor input, float p, bool train) -> Tensor",
678 "aten::hardshrink(Tensor self, Scalar lambd) -> Tensor",
679 "aten::hardtanh(Tensor self, Scalar min_val, Scalar max_val) -> Tensor",
680 "aten::glu(Tensor self, int dim) -> Tensor",
681 "aten::inverse(Tensor self) -> Tensor",
682 "aten::leaky_relu(Tensor self, Scalar negative_slope) -> Tensor",
683 "aten::lgamma(Tensor self) -> Tensor",
684 "aten::mvlgamma(Tensor self, int p) -> Tensor",
685 "aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor",
686 "aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor",
687 "aten::permute(Tensor self, int[] dims) -> Tensor",
688 "aten::pin_memory(Tensor self) -> Tensor",
689 "aten::pinverse(Tensor self, float rcond) -> Tensor",
690 "aten::reciprocal(Tensor self) -> Tensor",
691 "aten::relu(Tensor self) -> Tensor",
692 "aten::round(Tensor self) -> Tensor",
693 "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
694 "aten::rsqrt(Tensor self) -> Tensor",
695 "aten::selu(Tensor self) -> Tensor",
696 "aten::sigmoid(Tensor self) -> Tensor",
697 "aten::sign(Tensor self) -> Tensor",
698 "aten::sin(Tensor self) -> Tensor",
699 "aten::sinh(Tensor self) -> Tensor",
700 "aten::softmax(Tensor self, int dim) -> Tensor",
701 "aten::softplus(Tensor self, Scalar beta, Scalar threshold) -> Tensor",
702 "aten::softshrink(Tensor self, Scalar lambd) -> Tensor",
703 "aten::sqrt(Tensor self) -> Tensor",
704 "aten::tan(Tensor self) -> Tensor",
705 "aten::tanh(Tensor self) -> Tensor",
706 "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor",
707 "aten::transpose(Tensor self, int dim0, int dim1) -> Tensor",
708 "aten::tril(Tensor self, int diagonal) -> Tensor",
709 "aten::triu(Tensor self, int diagonal) -> Tensor",
710 "aten::trunc(Tensor self) -> Tensor",
711 "aten::rot90(Tensor self, int k, int[] dims) -> Tensor",
712 "aten::narrow(Tensor self, int dim, int start, int length) -> Tensor",
713 "aten::slice(Tensor self, int dim, int start, int end, int step) -> Tensor",
714 "aten::alias(Tensor self) -> Tensor",
715 "aten::detach(Tensor self) -> Tensor",
716 "aten::cumprod(Tensor self, int dim) -> Tensor",
717 "aten::cumsum(Tensor self, int dim) -> Tensor",
719 "aten::empty_like(Tensor self) -> Tensor",
720 "aten::full_like(Tensor self, Scalar fill_value) -> Tensor",
721 "aten::ones_like(Tensor self) -> Tensor",
722 "aten::rand_like(Tensor self) -> Tensor",
723 "aten::randint_like(Tensor self, int high) -> Tensor",
724 "aten::randint_like(Tensor self, int low, int high) -> Tensor",
725 "aten::randn_like(Tensor self) -> Tensor",
726 "aten::zeros_like(Tensor self) -> Tensor",
728 [](
Node* node) -> type_vec_t {
731 return input_type ? type_vec_t{input_type} : type_vec_t{};
740 static const register_formula_for broadcasting_ops{
743 "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
744 "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
745 "aten::mul(Tensor self, Tensor other) -> Tensor",
746 "aten::div(Tensor self, Tensor other) -> Tensor",
747 "aten::pow(Tensor self, Tensor exponent) -> Tensor",
748 "aten::fmod(Tensor self, Tensor other) -> Tensor",
749 "aten::remainder(Tensor self, Tensor other) -> Tensor",
750 "aten::lerp(Tensor self, Tensor end, Scalar weight) -> Tensor",
751 "aten::lerp(Tensor self, Tensor end, Tensor weight) -> Tensor",
752 "aten::max(Tensor self, Tensor other) -> Tensor",
753 "aten::min(Tensor self, Tensor other) -> Tensor",
754 "aten::__and__(Tensor self, Tensor other) -> Tensor",
755 "aten::__or__(Tensor self, Tensor other) -> Tensor",
756 "aten::__xor__(Tensor self, Tensor other) -> Tensor",
757 "aten::__lshift__(Tensor self, Tensor other) -> Tensor",
758 "aten::__rshift__(Tensor self, Tensor other) -> Tensor",
759 "aten::__iand__(Tensor self, Tensor other) -> Tensor",
760 "aten::__ior__(Tensor self, Tensor other) -> Tensor",
761 "aten::__ixor__(Tensor self, Tensor other) -> Tensor",
762 "aten::__ilshift__(Tensor self, Tensor other) -> Tensor",
763 "aten::__irshift__(Tensor self, Tensor other) -> Tensor",
766 "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
767 "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor",
768 "aten::mul(Tensor self, Scalar other) -> Tensor",
769 "aten::div(Tensor self, Scalar other) -> Tensor",
770 "aten::pow(Tensor self, Scalar exponent) -> Tensor",
771 "aten::fmod(Tensor self, Scalar other) -> Tensor",
772 "aten::remainder(Tensor self, Scalar other) -> Tensor",
773 "aten::pow(Scalar self, Tensor exponent) -> Tensor",
774 "aten::__and__(Tensor self, Scalar other) -> Tensor",
775 "aten::__or__(Tensor self, Scalar other) -> Tensor",
776 "aten::__xor__(Tensor self, Scalar other) -> Tensor",
777 "aten::__lshift__(Tensor self, Scalar other) -> Tensor",
778 "aten::__rshift__(Tensor self, Scalar other) -> Tensor",
779 "aten::__iand__(Tensor self, Scalar other) -> Tensor",
780 "aten::__ior__(Tensor self, Scalar other) -> Tensor",
781 "aten::__ixor__(Tensor self, Scalar other) -> Tensor",
782 "aten::__ilshift__(Tensor self, Scalar other) -> Tensor",
783 "aten::__irshift__(Tensor self, Scalar other) -> Tensor",
786 "aten::atan2(Tensor self, Tensor other) -> Tensor",
789 "aten::addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value) -> Tensor",
790 "aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value) -> Tensor",
792 [
this](
Node* node) -> type_vec_t {
793 if (
auto maybe_tensor_types =
794 gatherTensorTypes<DimensionedTensorType>(node)) {
795 return {broadcast(*maybe_tensor_types, 0)};
802 static const register_formula_for where_op{
804 "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor",
806 [
this](
Node* node) -> type_vec_t {
807 if (
auto maybe_tensor_types =
808 gatherTensorTypes<DimensionedTensorType>(node)) {
809 return {broadcast(*maybe_tensor_types, 1)};
814 static const auto any_tensor_type =
815 [](
Node* node) -> DimensionedTensorTypePtr {
816 for (
Value* input : node->inputs()) {
830 static const register_formula_for binary_ops_strict_match{
832 "aten::normal(Tensor mean, Tensor std, *, Generator? generator) -> Tensor",
833 "aten::mm(Tensor self, Tensor mat2) -> Tensor",
834 "aten::bmm(Tensor self, Tensor mat2) -> Tensor",
836 [](
Node* node) -> type_vec_t {
837 if (
auto type = any_tensor_type(node)) {
849 static const register_formula_for comparison_ops{
851 "aten::lt(Tensor self, Tensor other) -> Tensor",
852 "aten::le(Tensor self, Tensor other) -> Tensor",
853 "aten::gt(Tensor self, Tensor other) -> Tensor",
854 "aten::ge(Tensor self, Tensor other) -> Tensor",
855 "aten::eq(Tensor self, Tensor other) -> Tensor",
856 "aten::ne(Tensor self, Tensor other) -> Tensor",
857 "aten::lt(Tensor self, Scalar other) -> Tensor",
858 "aten::le(Tensor self, Scalar other) -> Tensor",
859 "aten::gt(Tensor self, Scalar other) -> Tensor",
860 "aten::ge(Tensor self, Scalar other) -> Tensor",
861 "aten::eq(Tensor self, Scalar other) -> Tensor",
862 "aten::ne(Tensor self, Scalar other) -> Tensor",
864 [
this](
Node* node) -> type_vec_t {
865 if (
auto maybe_tensor_types =
866 gatherTensorTypes<DimensionedTensorType>(node)) {
867 return {broadcast(*maybe_tensor_types, 0)->toScalarType(at::kByte)};
881 static const register_formula_for nn_ops_first_input_preserving{
883 "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor",
884 "aten::conv1d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor",
885 "aten::conv2d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor",
886 "aten::conv3d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor",
887 "aten::conv_tbc(Tensor self, Tensor weight, Tensor bias, int pad) -> Tensor",
888 "aten::conv_transpose1d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int groups, int[] dilation) -> Tensor",
889 "aten::conv_transpose2d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int groups, int[] dilation) -> Tensor",
890 "aten::conv_transpose3d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int groups, int[] dilation) -> Tensor",
891 "aten::convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor",
892 "aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> Tensor",
893 "aten::adaptive_avg_pool1d(Tensor self, int[] output_size) -> Tensor",
894 "aten::adaptive_avg_pool2d(Tensor self, int[] output_size) -> Tensor",
895 "aten::adaptive_avg_pool3d(Tensor self, int[] output_size) -> Tensor",
896 "aten::avg_pool1d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor",
897 "aten::avg_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor",
898 "aten::avg_pool3d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor",
899 "aten::max_pool1d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor",
900 "aten::max_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor",
901 "aten::max_pool3d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor",
902 "aten::max_unpool2d(Tensor self, Tensor indices, int[] output_size) -> Tensor",
903 "aten::max_unpool3d(Tensor self, Tensor indices, int[] output_size, int[] stride, int[] padding) -> Tensor",
904 "aten::reflection_pad1d(Tensor self, int[] padding) -> Tensor",
905 "aten::reflection_pad2d(Tensor self, int[] padding) -> Tensor",
906 "aten::replication_pad1d(Tensor self, int[] padding) -> Tensor",
907 "aten::replication_pad2d(Tensor self, int[] padding) -> Tensor",
908 "aten::replication_pad3d(Tensor self, int[] padding) -> Tensor",
909 "aten::upsample_bilinear2d(Tensor self, int[] output_size, bool align_corners) -> Tensor",
910 "aten::upsample_linear1d(Tensor self, int[] output_size, bool align_corners) -> Tensor",
911 "aten::upsample_nearest1d(Tensor self, int[] output_size) -> Tensor",
912 "aten::upsample_nearest2d(Tensor self, int[] output_size) -> Tensor",
913 "aten::upsample_nearest3d(Tensor self, int[] output_size) -> Tensor",
914 "aten::upsample_trilinear3d(Tensor self, int[] output_size, bool align_corners) -> Tensor",
915 "aten::prelu(Tensor self, Tensor weight) -> Tensor",
917 [](
Node* node) -> type_vec_t {
933 static const register_formula_for all_reduce_ops{
935 "aten::det(Tensor self) -> Tensor",
936 "aten::logdet(Tensor self) -> Tensor",
937 "aten::max(Tensor self) -> Tensor",
938 "aten::min(Tensor self) -> Tensor",
939 "aten::mean(Tensor self) -> Tensor",
940 "aten::median(Tensor self) -> Tensor",
941 "aten::norm(Tensor self, Scalar p) -> Tensor",
942 "aten::std(Tensor self, bool unbiased) -> Tensor",
943 "aten::sum(Tensor self) -> Tensor",
944 "aten::trace(Tensor self) -> Tensor",
945 "aten::var(Tensor self, bool unbiased) -> Tensor",
946 "aten::all(Tensor self) -> Tensor",
947 "aten::any(Tensor self) -> Tensor",
949 [](
Node* node) -> type_vec_t {
952 return {type->withDim(0)};
965 static const register_formula_for all_reduce_ops_with_integer_upcast{
967 "aten::sum(Tensor self) -> Tensor",
968 "aten::prod(Tensor self) -> Tensor",
970 [](
Node* node) -> type_vec_t {
973 return {at::isFloatingType(type->scalarType())
975 : type->withDim(0)->toScalarType(at::kLong)};
980 static const auto multidim_reduce_with_postprocess =
982 int64_t num_reduced_dim,
983 bool upcast_integer) -> type_vec_t {
984 auto maybe_keepdim = node->get<
bool>(attr::keepdim);
988 if (upcast_integer && !at::isFloatingType(type->scalarType())) {
989 type = type->toScalarType(at::kLong);
991 if (*maybe_keepdim) {
993 }
else if (type->dim() > num_reduced_dim) {
994 return {type->withDim(type->dim() - num_reduced_dim)};
1009 static const register_formula_for argminmax{
1011 "aten::argmax(Tensor self, int? dim, bool keepdim) -> Tensor",
1012 "aten::argmin(Tensor self, int? dim, bool keepdim) -> Tensor",
1014 [](
Node* node) -> type_vec_t {
1017 if (node->input(1)->type()->kind() == c10::TypeKind::NoneType) {
1018 return {type->withDim(0)};
1020 return multidim_reduce_with_postprocess(
1035 static const register_formula_for dim_reduce_ops{
1037 "aten::all(Tensor self, int dim, bool keepdim) -> Tensor",
1038 "aten::any(Tensor self, int dim, bool keepdim) -> Tensor",
1041 "aten::kthvalue(Tensor self, int k, int dim, bool keepdim) -> (Tensor, Tensor)",
1042 "aten::max(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
1043 "aten::min(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
1044 "aten::median(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
1045 "aten::mode(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
1047 [](
Node* node) -> type_vec_t {
1051 auto output_types = multidim_reduce_with_postprocess(
1053 if (!output_types.empty() && node->outputs().size() == 2) {
1054 output_types.push_back(
1055 output_types.back()->toScalarType(at::kLong));
1057 return output_types;
1069 static const register_formula_for dim_reduce_ops_with_integer_upcast{
1071 "aten::prod(Tensor self, int dim, bool keepdim) -> Tensor",
1073 [](
Node* node) -> type_vec_t {
1074 return multidim_reduce_with_postprocess(
1085 static const register_formula_for multidim_reduce_ops{
1087 "aten::logsumexp(Tensor self, int[] dim, bool keepdim) -> Tensor",
1088 "aten::mean(Tensor self, int[] dim, bool keepdim) -> Tensor",
1089 "aten::norm(Tensor self, Scalar? p, int[] dim, bool keepdim) -> Tensor",
1090 "aten::std(Tensor self, int[] dim, bool unbiased, bool keepdim) -> Tensor",
1091 "aten::var(Tensor self, int[] dim, bool unbiased, bool keepdim) -> Tensor",
1092 "aten::max_values(Tensor self, int[] dim, bool keepdim) -> Tensor",
1093 "aten::min_values(Tensor self, int[] dim, bool keepdim) -> Tensor",
1095 [](
Node* node) -> type_vec_t {
1096 if (
auto dim = node->get<std::vector<int64_t>>(attr::dim)) {
1097 return multidim_reduce_with_postprocess(
1098 node, dim->size(),
false);
1111 static const register_formula_for multidim_reduce_ops_with_integer_upcast{
1113 "aten::sum(Tensor self, int[] dim, bool keepdim) -> Tensor",
1115 [](
Node* node) -> type_vec_t {
1116 if (
auto dim = node->get<std::vector<int64_t>>(attr::dim)) {
1118 return multidim_reduce_with_postprocess(
1119 node, dim->size(),
true);
1124 static const auto factory_with_ndim = [](
Node* node,
1125 int dim) -> type_vec_t {
1127 if (!maybe_layout_option)
1130 (maybe_layout_option->isNone() ? at::kStrided
1131 : maybe_layout_option->toLayout());
1134 if (!maybe_device_option)
1137 (maybe_device_option->isNone() ? at::kCPU
1138 : maybe_device_option->toDevice());
1141 if (!maybe_dtype_option)
1144 (maybe_dtype_option->isNone() ? at::kFloat
1145 : maybe_dtype_option->toScalarType());
1147 return {DimensionedTensorType::create(dtype, device, dim)};
1158 static const register_formula_for like_factories_with_options{
1160 "aten::empty_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor",
1161 "aten::full_like(Tensor self, Scalar fill_value, *, int dtype, int layout, Device device) -> Tensor",
1162 "aten::ones_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor",
1163 "aten::rand_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor",
1164 "aten::randint_like(Tensor self, int high, *, int dtype, int layout, Device device) -> Tensor",
1165 "aten::randint_like(Tensor self, int low, int high, *, int dtype, int layout, Device device) -> Tensor",
1166 "aten::randn_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor",
1167 "aten::zeros_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor",
1169 [](
Node* node) -> type_vec_t {
1170 if (
auto type = node->namedInput(attr::self)
1173 return factory_with_ndim(node, type->dim());
1187 static const register_formula_for size_factories_with_options{
1189 "aten::empty(int[] size, *, int? dtype, int? layout, Device? device) -> Tensor",
1190 "aten::full(int[] size, Scalar fill_value, *, int? dtype, int? layout, Device? device) -> Tensor",
1191 "aten::ones(int[] size, *, int? dtype, int? layout, Device? device) -> Tensor",
1192 "aten::rand(int[] size, *, int? dtype, int? layout, Device? device) -> Tensor",
1193 "aten::randn(int[] size, *, int? dtype, int? layout, Device? device) -> Tensor",
1194 "aten::zeros(int[] size, *, int? dtype, int? layout, Device? device) -> Tensor",
1195 "aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device) -> Tensor",
1196 "aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device) -> Tensor",
1198 [](
Node* node) -> type_vec_t {
1199 if (
auto maybe_size = node->get<std::vector<int64_t>>(attr::size)) {
1200 return factory_with_ndim(node, maybe_size->size());
1205 static const auto get_cast_scalar_type = [](
Node* node) -> at::ScalarType {
1206 switch (node->kind()) {
1207 case aten::_cast_Byte:
1209 case aten::_cast_Char:
1211 case aten::_cast_Double:
1213 case aten::_cast_Float:
1215 case aten::_cast_Half:
1217 case aten::_cast_Int:
1219 case aten::_cast_Long:
1221 case aten::_cast_Short:
1226 "unknown node kind in get_cast_scalar_type: ",
1227 node->kind().toQualString());
1230 static const register_formula_for cast_ops{
1232 "aten::_cast_Byte(Tensor self, bool non_blocking) -> Tensor",
1233 "aten::_cast_Char(Tensor self, bool non_blocking) -> Tensor",
1234 "aten::_cast_Double(Tensor self, bool non_blocking) -> Tensor",
1235 "aten::_cast_Float(Tensor self, bool non_blocking) -> Tensor",
1236 "aten::_cast_Half(Tensor self, bool non_blocking) -> Tensor",
1237 "aten::_cast_Int(Tensor self, bool non_blocking) -> Tensor",
1238 "aten::_cast_Long(Tensor self, bool non_blocking) -> Tensor",
1239 "aten::_cast_Short(Tensor self, bool non_blocking) -> Tensor",
1241 [](
Node* node) -> type_vec_t {
1242 if (
auto type = node->namedInput(attr::self)
1245 return {type->toScalarType(get_cast_scalar_type(node))};
1252 for (
auto& entry : shape_formulas) {
1253 if (entry.first.find(node)) {
1254 auto types = entry.second(node);
1255 if (types.empty()) {
1258 auto outputs = node->outputs();
1259 AT_ASSERT(types.size() == outputs.
size());
1260 for (
size_t i = 0; i < types.size(); ++i) {
1261 AT_ASSERT(outputs[i]->type()->isSubtypeOf(TensorType::get()));
1262 outputs[i]->setType(types[i]);
1271 const auto input_type = [node](
size_t index) {
1275 "aten::masked_select(Tensor self, Tensor mask) -> Tensor")) {
1276 auto type = input_type(0);
1277 auto mask_type = input_type(1);
1278 if (type && mask_type) {
1279 if (type->dim() == 0 && mask_type->dim() == 0) {
1280 node->output()->setType(type->withDim(0));
1282 node->output()->setType(type->withDim(1));
1286 if (
auto type = input_type(0)) {
1287 node->output()->setType(type->withDim(1));
1290 }
else if (node->matches(
1291 "aten::dot(Tensor self, Tensor tensor) -> Tensor")) {
1292 if (
auto type = any_tensor_type(node)) {
1293 node->output()->setType(type->withDim(0));
1297 node->matches(
"aten::mv(Tensor self, Tensor vec) -> Tensor") ||
1299 "aten::addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta, Scalar alpha) -> Tensor")) {
1300 if (
auto type = any_tensor_type(node)) {
1301 node->output()->setType(type->withDim(1));
1306 "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor") ||
1308 "aten::addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta, Scalar alpha) -> Tensor") ||
1310 "aten::addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta, Scalar alpha) -> Tensor")) {
1311 if (
auto type = any_tensor_type(node)) {
1312 node->output()->setType(type->withDim(2));
1317 "aten::baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta, Scalar alpha) -> Tensor")) {
1318 if (
auto type = any_tensor_type(node)) {
1319 node->output()->setType(type->withDim(3));
1324 "aten::index_select(Tensor self, int dim, Tensor index) -> Tensor")) {
1325 auto type = input_type(0);
1326 auto index_type = input_type(1);
1330 if (type && index_type) {
1331 if (type->dim() == 0) {
1332 node->output()->setType(type->withDim(index_type->dim()));
1334 node->output()->setType(type);
1340 "aten::gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor")) {
1341 auto type = input_type(0);
1342 auto index_type = input_type(1);
1346 if (type && index_type) {
1347 if (index_type->dim() == 0) {
1348 node->output()->setType(type->withDim(0));
1350 node->output()->setType(type);
1356 "aten::embedding(Tensor weight, Tensor indices, int padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor")) {
1357 auto weight_type = input_type(0);
1358 auto indices_type = input_type(1);
1359 if (weight_type && indices_type) {
1360 node->output()->setType(weight_type->withDim(indices_type->dim() + 1));
1365 "aten::bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor? bias) -> Tensor")) {
1366 if (
auto type = input_type(0)) {
1367 node->output()->setType(type);
1370 if (
auto type = input_type(1)) {
1371 node->output()->setType(type);
1376 "aten::dist(Tensor self, Tensor other, Scalar p) -> Tensor")) {
1377 if (
auto type = any_tensor_type(node)) {
1378 node->output()->setType(type->withDim(0));
1385 std::vector<DimensionedTensorTypePtr> tensor_types;
1386 static const auto reshape_prop =
1389 const std::vector<DimensionedTensorTypePtr>& tensor_types)
1390 -> DimensionedTensorTypePtr {
1391 if (
auto list_size = determineListSize(node->namedInput(shape_input))) {
1392 return tensor_types.at(0)->withDim(*list_size);
1396 const auto getSingleOutputType = [&]() -> TypePtr {
1397 if (node->matches(
"aten::type_as(Tensor self, Tensor other) -> Tensor")) {
1398 return tensor_types.at(0)->toScalarType(
1399 tensor_types.at(1)->scalarType());
1401 node->matches(
"aten::view_as(Tensor self, Tensor other) -> Tensor") ||
1403 "aten::expand_as(Tensor self, Tensor other) -> Tensor") ||
1405 "aten::reshape_as(Tensor self, Tensor other) -> Tensor")) {
1406 return tensor_types.at(0)->withDim(tensor_types.at(1)->dim());
1408 node->matches(
"aten::view(Tensor self, int[] size) -> Tensor") ||
1410 "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor") ||
1412 "aten::as_strided(Tensor self, int[] size, int[] stride, int? storage_offset) -> Tensor")) {
1413 return reshape_prop(node, attr::size, tensor_types);
1414 }
else if (node->matches(
1415 "aten::reshape(Tensor self, int[] shape) -> Tensor")) {
1416 return reshape_prop(node, attr::shape, tensor_types);
1417 }
else if (node->matches(
1418 "aten::repeat(Tensor self, int[] repeats) -> Tensor")) {
1419 return reshape_prop(node, attr::repeats, tensor_types);
1420 }
else if (node->matches(
1421 "aten::unsqueeze(Tensor self, int dim) -> Tensor")) {
1422 auto& t = tensor_types.at(0);
1423 return t->withDim(t->dim() + 1);
1426 "aten::select(Tensor self, int dim, int index) -> Tensor") ||
1428 "aten::diagonal(Tensor self, int offset, int dim1, int dim2) -> Tensor")) {
1429 auto& t = tensor_types.at(0);
1430 return t->dim() > 0 ? t->withDim(t->dim() - 1) :
nullptr;
1431 }
else if (node->matches(
1432 "aten::matmul(Tensor self, Tensor other) -> Tensor")) {
1433 int dim1 = tensor_types.at(0)->dim();
1434 int dim2 = tensor_types.at(1)->dim();
1435 if (dim1 == 1 && dim2 == 1) {
1437 return tensor_types.at(0)->withDim(0);
1438 }
else if (dim1 == 2 && dim2 == 2) {
1440 return tensor_types.at(0);
1441 }
else if (dim1 == 1 && dim2 == 2) {
1443 return tensor_types.at(0);
1444 }
else if (dim1 == 2 && dim2 == 1) {
1446 return tensor_types.at(1);
1450 auto type = broadcast(tensor_types, 0);
1451 if (tensor_types.at(0)->dim() == 1 ||
1452 tensor_types.at(1)->dim() == 1) {
1453 type = type->withDim(type->dim() - 1);
1457 }
else if (node->matches(
"aten::nonzero(Tensor self) -> Tensor")) {
1458 return tensor_types.at(0)->toScalarType(at::kLong);
1459 }
else if (node->matches(
1460 "aten::take(Tensor self, Tensor index) -> Tensor")) {
1461 return tensor_types.at(1)->toScalarType(
1462 tensor_types.at(0)->scalarType());
1463 }
else if (node->matches(
1464 "aten::diagflat(Tensor self, int offset) -> Tensor")) {
1465 return tensor_types.at(0)->withDim(2);
1466 }
else if (node->matches(
1467 "aten::diag(Tensor self, int diagonal) -> Tensor")) {
1468 auto& t = tensor_types.at(0);
1469 if (t->dim() == 1) {
1470 return t->withDim(2);
1471 }
else if (t->dim() == 2) {
1472 return t->withDim(1);
1478 "aten::unfold(Tensor self, int dimension, int size, int step) -> Tensor")) {
1479 auto& t = tensor_types.at(0);
1480 return t->dim() == 0 ? t : t->withDim(t->dim() + 1);
1481 }
else if (node->matches(
1482 "aten::polygamma(int n, Tensor self) -> Tensor")) {
1483 return tensor_types.at(0);
1487 if (
auto maybe_tensor_types =
1488 gatherTensorTypes<DimensionedTensorType>(node)) {
1489 tensor_types = std::move(*maybe_tensor_types);
1493 if (node->outputs().size() == 1) {
1494 if (
auto type = getSingleOutputType()) {
1495 node->output()->setType(type);
1502 bool PropagateCompleteShapeOnNode(
1504 bool insert_expands,
1505 std::vector<CompleteTensorTypePtr> tensor_types) {
1510 "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") ||
1512 "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") ||
1513 node->matches(
"aten::mul(Tensor self, Tensor other) -> Tensor")) {
1518 return PropagateShapeOnNodeByRunningIt(node);
1521 "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor") ||
1523 "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor") ||
1524 node->matches(
"aten::mul(Tensor self, Scalar other) -> Tensor") ||
1525 node->matches(
"aten::pow(Tensor self, Scalar exponent) -> Tensor")) {
1526 node->output()->setType(tensor_types.at(0));
1530 (node->matches(
"aten::pow(Tensor self, Tensor exponent) -> Tensor") ||
1531 node->matches(
"aten::min(Tensor self, Tensor other) -> Tensor") ||
1532 node->matches(
"aten::max(Tensor self, Tensor other) -> Tensor") ||
1533 node->matches(
"aten::lt(Tensor self, Tensor other) -> Tensor") ||
1534 node->matches(
"aten::le(Tensor self, Tensor other) -> Tensor") ||
1535 node->matches(
"aten::gt(Tensor self, Tensor other) -> Tensor") ||
1536 node->matches(
"aten::ge(Tensor self, Tensor other) -> Tensor") ||
1537 node->matches(
"aten::eq(Tensor self, Tensor other) -> Tensor") ||
1538 node->matches(
"aten::ne(Tensor self, Tensor other) -> Tensor"))) {
1543 broadcastBinary(node, tensor_types, 0, 1);
1544 return PropagateShapeOnNodeByRunningIt(node);
1546 node->matches(
"aten::neg(Tensor self) -> Tensor") ||
1547 node->matches(
"aten::sigmoid(Tensor self) -> Tensor") ||
1548 node->matches(
"aten::tanh(Tensor self) -> Tensor")) {
1549 node->output()->setType(tensor_types.at(0)->contiguous());
1551 }
else if (node->matches(
"aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
1552 auto lhs_type = tensor_types.at(0);
1553 auto rhs_type = tensor_types.at(1);
1555 lhs_type->sizes().size() == 2 && rhs_type->sizes().size() == 2);
1556 node->output()->setType(CompleteTensorType::create(
1557 lhs_type->scalarType(),
1561 }
else if (node->matches(
"aten::t(Tensor self) -> Tensor")) {
1562 auto tp = tensor_types.at(0);
1563 auto sizes = tp->sizes();
1564 auto strides = tp->strides();
1565 SHAPE_ASSERT(sizes.
size() == 2);
1566 std::swap(sizes.
at(0), sizes.
at(1));
1567 std::swap(strides.at(0), strides.at(1));
1568 node->output()->setType(tp->withSizesStrides(sizes, strides));
1572 "aten::narrow(Tensor self, int dim, int start, int length) -> Tensor",
1573 {attr::dim, attr::length})) {
1574 auto tp = tensor_types.at(0);
1575 auto sizes = tp->sizes();
1576 int64_t dim = node->get<int64_t>(attr::dim).value();
1577 int64_t length = node->get<int64_t>(attr::length).value();
1578 SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < sizes.
size());
1579 sizes.
at(dim) = length;
1580 node->output()->setType(tp->withSizesStrides(sizes, tp->strides()));
1582 }
else if (node->matches(
"aten::sum(Tensor self) -> Tensor")) {
1583 node->output()->setType(tensor_types.at(0)->withSizes({}));
1585 }
else if (node->matches(
1586 "aten::sum(Tensor self, int[] dim, bool keepdim) -> Tensor",
1587 {attr::dim, attr::keepdim})) {
1588 auto& tp = tensor_types.at(0);
1589 auto sizes = tp->sizes();
1590 auto dims = node->get<std::vector<int64_t>>(attr::dim).value();
1591 bool keepdim = node->get<
bool>(attr::keepdim).value();
1592 std::reverse(dims.begin(), dims.end());
1593 for (int64_t dim : dims) {
1594 SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < sizes.
size());
1598 sizes.erase(sizes.begin() + dim);
1601 node->output()->setType(tp->withSizes(sizes));
1603 }
else if (node->matches(
1604 "aten::squeeze(Tensor self, int dim) -> Tensor",
1606 auto& tp = tensor_types.at(0);
1607 auto sizes = tp->sizes();
1608 auto strides = tp->strides();
1609 int64_t dim = wrapDim(node->get<int64_t>(attr::dim).value(), sizes);
1610 SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < sizes.
size());
1611 if (sizes.
at(dim) == 1) {
1612 sizes.erase(sizes.begin() + dim);
1613 strides.erase(strides.begin() + dim);
1615 node->output()->setType(tp->withSizesStrides(sizes, strides));
1617 }
else if (node->matches(
1618 "aten::unsqueeze(Tensor self, int dim) -> Tensor",
1620 auto& tp = tensor_types.at(0);
1621 auto sizes = tp->sizes();
1622 auto strides = tp->strides();
1623 int64_t dim = wrapDim(node->get<int64_t>(attr::dim).value(), sizes);
1624 SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) <= sizes.
size());
1625 int64_t new_stride = dim >=
static_cast<int64_t
>(sizes.
size())
1627 : sizes.
at(dim) * strides.at(dim);
1628 sizes.insert(sizes.begin() + dim, 1);
1629 strides.insert(strides.begin() + dim, new_stride);
1630 node->output()->setType(tp->withSizesStrides(sizes, strides));
1632 }
else if (node->matches(
1633 "aten::view(Tensor self, int[] size) -> Tensor",
1635 auto sizes = node->get<std::vector<int64_t>>(attr::size).value();
1636 bool inferred =
false;
1637 size_t inferred_idx;
1638 int64_t size_product = 1;
1639 for (
size_t i = 0; i < sizes.
size(); ++i) {
1640 if (sizes[i] == -1) {
1646 size_product *= sizes[i];
1651 SHAPE_ASSERT(size_product != 0);
1653 for (int64_t s : tensor_types.at(0)->sizes())
1655 int64_t inferred_size = numel / size_product;
1656 sizes[inferred_idx] = inferred_size;
1658 node->output()->setType(tensor_types.at(0)->withSizes(sizes));
1660 }
else if (node->matches(
1661 "aten::type_as(Tensor self, Tensor other) -> Tensor")) {
1662 if (tensor_types.at(0)->scalarType() ==
1663 tensor_types.at(1)->scalarType()) {
1664 node->output()->setType(node->namedInput(attr::self)->type());
1667 node->output()->setType(
1668 tensor_types.at(1)->withSizes(tensor_types.at(0)->sizes()));
1673 "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor",
1675 auto tp = tensor_types.at(0);
1676 std::vector<int64_t> sizes, strides;
1677 std::tie(sizes, strides) = at::inferExpandGeometry(
1680 node->get<std::vector<int64_t>>(attr::size).value());
1681 node->output()->setType(tp->withSizesStrides(sizes, strides));
1685 "aten::index_select(Tensor self, int dim, Tensor index) -> Tensor",
1687 auto ten = tensor_types.at(0);
1688 auto index = tensor_types.at(1);
1689 int64_t dim = node->get<int64_t>(attr::dim).value();
1690 SHAPE_ASSERT(index->sizes().size() == 1);
1691 SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < ten->sizes().size());
1692 std::vector<int64_t> sizes = ten->sizes();
1693 sizes[dim] = index->sizes()[0];
1694 node->output()->setType(ten->withSizes(sizes));
1696 }
else if (node->matches(
1697 "aten::chunk(Tensor self, int chunks, int dim) -> Tensor[]",
1698 {attr::chunks, attr::dim})) {
1699 auto input_type = tensor_types.at(0);
1700 auto sizes = input_type->sizes();
1701 const auto& strides = input_type->strides();
1702 int64_t dim = node->get<int64_t>(attr::dim).value();
1703 int64_t chunks = node->get<int64_t>(attr::chunks).value();
1704 sizes[dim] /= chunks;
1705 for (
Value* output : node->outputs()) {
1706 output->setType(input_type->withSizesStrides(sizes, strides));
1708 if (input_type->sizes().at(dim) % chunks != 0) {
1709 sizes[dim] = input_type->sizes().
at(dim) % chunks;
1710 node->outputs().back()->setType(
1711 input_type->withSizesStrides(sizes, strides));
1714 }
else if (node->kind() == ::c10::onnx::Shape) {
1715 SHAPE_ASSERT(node->inputs().size() == 1 && node->outputs().size() == 1);
1716 std::vector<int64_t> dim_vec = {
1717 (int64_t)tensor_types.at(0)->sizes().size()};
1719 node->output()->setType(
1720 CompleteTensorType::create(at::kLong, at::kCPU, dims));
1722 }
else if (node->kind() == ::c10::onnx::Reshape) {
1723 setUnshapedType(node);
1726 setUnshapedType(node);
1732 void PropagateInputShapes(
const std::shared_ptr<Graph>& graph) {
1733 ShapePropagator(graph).PropagateShapeOnBlock(graph->block());
1739 for (
Value* v : vals) {
1740 v->setType(unshapedType(v->type()));
1744 void EraseShapeInformation(
Block* b) {
1745 EraseShapeInformation(b->inputs());
1746 EraseShapeInformation(b->outputs());
1747 for (
Node* n : b->nodes()) {
1748 EraseShapeInformation(n->outputs());
1749 for (
Block* sb : n->blocks()) {
1750 EraseShapeInformation(sb);
1752 if (n->hasAttribute(attr::Subgraph)) {
1753 EraseShapeInformation(n->g(attr::Subgraph));
1759 void EraseShapeInformation(
const std::shared_ptr<Graph>& graph) {
1760 EraseShapeInformation(graph->block());
Represents a a compute device on which a tensor is located.
constexpr size_t size() const
size - Get the array size.
An utility class for setting temporary insertion points.
RAII guard that sets a certain default device in its constructor, and changes it back to the device t...
AT_CPP14_CONSTEXPR const T & at(size_t Index) const
Vector compatibility.