Caffe2 - C++ API
A deep learning, cross platform ML framework
prepare_division_for_onnx.cpp
1 #include <torch/csrc/jit/constants.h>
2 #include <torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h>
3 
4 namespace torch {
5 namespace jit {
6 
7 static void PrepareDivisionForONNXOnBlock(Block* block) {
8  for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
9  for (auto sub : it->blocks()) {
10  PrepareDivisionForONNXOnBlock(sub);
11  }
12  WithInsertPoint guard(*it);
13  auto* subgraph = it->owningGraph();
14 
15  if (it->matches("aten::div(int a, int b) -> float")) {
16  // Cast to Float before dividing
17  std::vector<Value*> floattensor_inputs =
18  fmap(it->inputs(), [&](Value* input) {
19  auto* longtensor =
20  subgraph->insertNode(subgraph->createNumToTensor(input))
21  ->output();
22  auto* nonblocking = subgraph->insertConstant(0);
23  auto* cast =
24  subgraph->create(aten::_cast_Float, {longtensor, nonblocking});
25  return subgraph->insertNode(cast)->output();
26  });
27 
28  it->replaceInput(0, floattensor_inputs[0]);
29  it->replaceInput(1, floattensor_inputs[1]);
30  it->output()->setType(
31  CompleteTensorType::fromNumberType(FloatType::get()));
32  }
33  }
34 }
35 
36 void PrepareDivisionForONNX(const std::shared_ptr<Graph>& graph) {
37  PrepareDivisionForONNXOnBlock(graph->block());
38 }
39 
40 } // namespace jit
41 } // namespace torch
Definition: jit_type.h:17