1 #include <torch/csrc/jit/constants.h> 2 #include <torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h> 7 static void PrepareDivisionForONNXOnBlock(Block* block) {
8 for (
auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
9 for (
auto sub : it->blocks()) {
10 PrepareDivisionForONNXOnBlock(sub);
12 WithInsertPoint guard(*it);
13 auto* subgraph = it->owningGraph();
15 if (it->matches(
"aten::div(int a, int b) -> float")) {
17 std::vector<Value*> floattensor_inputs =
18 fmap(it->inputs(), [&](Value* input) {
20 subgraph->insertNode(subgraph->createNumToTensor(input))
22 auto* nonblocking = subgraph->insertConstant(0);
24 subgraph->create(aten::_cast_Float, {longtensor, nonblocking});
25 return subgraph->insertNode(cast)->output();
28 it->replaceInput(0, floattensor_inputs[0]);
29 it->replaceInput(1, floattensor_inputs[1]);
30 it->output()->setType(
31 CompleteTensorType::fromNumberType(FloatType::get()));
36 void PrepareDivisionForONNX(
const std::shared_ptr<Graph>& graph) {
37 PrepareDivisionForONNXOnBlock(graph->block());