Caffe2 - C++ API
A deep learning, cross platform ML framework
erase_number_types.cpp
1 #include <torch/csrc/jit/passes/erase_number_types.h>
2 #include <torch/csrc/jit/constants.h>
3 
4 namespace torch {
5 namespace jit {
6 
7 static void EraseNumberTypesOnBlock(Block* block) {
8  for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end;
9  ++it) {
10  for (auto inp : it->inputs()) {
11  if (inp->type()->isSubtypeOf(NumberType::get())) {
12  inp->setType(TensorType::get());
13  }
14  }
15  for (auto sub : it->blocks()) {
16  EraseNumberTypesOnBlock(sub);
17  }
18  switch (it->kind()) {
19  case prim::Constant: {
20  // remove primitive constants, replacing with tensor equivalent
21  // ONNX does not support non-tensor constants
22  if (it->output()->type()->isSubtypeOf(NumberType::get()) ||
23  it->output()->type()->isSubtypeOf(BoolType::get())) {
24  at::Scalar s;
25  if (it->output()->type()->isSubtypeOf(BoolType::get())) {
26  s = static_cast<int64_t>(*constant_as<bool>(it->output()));
27  } else {
28  s = *constant_as<at::Scalar>(it->output());
29  }
30 
31  WithInsertPoint guard(*it);
32  Value* r = block->owningGraph()->insertConstant(
33  scalar_to_tensor(s), nullptr, c10::nullopt, it->scope());
34  it->output()->replaceAllUsesWith(r);
35  }
36  } break;
37  case prim::Bool:
38  case prim::Float:
39  case prim::Int:
40  case prim::ImplicitTensorToNum:
41  case prim::NumToTensor: {
42  it->output()->replaceAllUsesWith(it->inputs()[0]);
43  // Let DCE cleanup
44  } break;
45  default: {
46  for (auto o : it->outputs()) {
47  if (o->type()->isSubtypeOf(NumberType::get())) {
48  o->setType(CompleteTensorType::fromNumberType(o->type()));
49  } else if (o->type()->isSubtypeOf(BoolType::get())) {
50  o->setType(CompleteTensorType::fromBoolType());
51  }
52  }
53  } break;
54  }
55  }
56 }
57 
58 void EraseNumberTypes(const std::shared_ptr<Graph>& graph) {
59  EraseNumberTypesOnBlock(graph->block());
60 }
61 } // namespace jit
62 } // namespace torch
Scalar represents a 0-dimensional tensor which contains a single element.
Definition: Scalar.h:22
Definition: jit_type.h:17