1 #include <torch/csrc/jit/passes/erase_number_types.h> 2 #include <torch/csrc/jit/constants.h> 7 static void EraseNumberTypesOnBlock(Block* block) {
8 for (
auto it = block->nodes().begin(), end = block->nodes().end(); it != end;
10 for (
auto inp : it->inputs()) {
11 if (inp->type()->isSubtypeOf(NumberType::get())) {
12 inp->setType(TensorType::get());
15 for (
auto sub : it->blocks()) {
16 EraseNumberTypesOnBlock(sub);
19 case prim::Constant: {
22 if (it->output()->type()->isSubtypeOf(NumberType::get()) ||
23 it->output()->type()->isSubtypeOf(BoolType::get())) {
25 if (it->output()->type()->isSubtypeOf(BoolType::get())) {
26 s =
static_cast<int64_t
>(*constant_as<bool>(it->output()));
28 s = *constant_as<at::Scalar>(it->output());
31 WithInsertPoint guard(*it);
32 Value* r = block->owningGraph()->insertConstant(
33 scalar_to_tensor(s),
nullptr, c10::nullopt, it->scope());
34 it->output()->replaceAllUsesWith(r);
40 case prim::ImplicitTensorToNum:
41 case prim::NumToTensor: {
42 it->output()->replaceAllUsesWith(it->inputs()[0]);
46 for (
auto o : it->outputs()) {
47 if (o->type()->isSubtypeOf(NumberType::get())) {
48 o->setType(CompleteTensorType::fromNumberType(o->type()));
49 }
else if (o->type()->isSubtypeOf(BoolType::get())) {
50 o->setType(CompleteTensorType::fromBoolType());
58 void EraseNumberTypes(
const std::shared_ptr<Graph>& graph) {
59 EraseNumberTypesOnBlock(graph->block());
Scalar represents a 0-dimensional tensor which contains a single element.