1 #include <torch/csrc/jit/passes/to_batch.h> 2 #include <torch/csrc/jit/passes/dead_code_elimination.h> 3 #include <torch/csrc/jit/script/compiler.h> 8 std::unordered_map<std::string, std::vector<std::shared_ptr<Graph>>>
9 ToBatch::batch_operator_table;
11 std::shared_ptr<Graph> ToBatch::getBatchOperator(
12 const std::string& name,
14 if (batch_operator_table.find(name) == batch_operator_table.end()) {
15 throw std::runtime_error(
16 "function " + name +
" is not supported in batched tensor yet");
18 auto ops = batch_operator_table.at(name);
22 if (
size_t(num_inputs) == op->inputs().size())
25 throw std::runtime_error(
26 "function " + name +
" with " + std::to_string(num_inputs) +
27 " inputs is not supported in batched tensor yet");
30 std::vector<Value*> inlineUnpackedCallTo(
33 ArrayRef<Value*> inputs) {
34 return inlineCallTo(g, callee, inputs,
true);
38 void ToBatch::visitAten(Node* n, Block* block, Block* res_block) {
39 auto res_graph = res_block->owningGraph();
40 auto func_name = std::string(n->kind().toUnqualString());
41 std::vector<Value*> new_inputs;
42 for (Value* input : n->inputs()) {
43 if (rn_env.find(input) == rn_env.end()) {
44 auto new_input = batch_map.at(input);
45 new_inputs.insert(new_inputs.end(), new_input.begin(), new_input.end());
47 new_inputs.push_back(rn_env.at(input));
52 for (
auto& input : new_inputs) {
53 if (input->type() == IntType::get() || input->type() == FloatType::get() ||
54 input->type() == BoolType::get()) {
55 auto to_tensor_node = res_graph->createNumToTensor(input);
56 res_graph->insertNode(to_tensor_node);
57 input = to_tensor_node->output();
61 auto batch_graph = getBatchOperator(func_name, new_inputs.size());
63 inlineUnpackedCallTo(*res_block->owningGraph(), *batch_graph, new_inputs);
67 if (outputs.size() == 1) {
70 TypePtr orig_type = n->outputs()[0]->type();
71 if (!orig_type->isSubtypeOf(outputs[0]->type())) {
73 if (orig_type == IntType::get()) {
75 }
else if (orig_type == FloatType::get()) {
77 }
else if (orig_type == BoolType::get()) {
80 throw std::runtime_error(
81 "NYI: scalar types other than int, float, and bool are not supported yet");
83 rn_env[n->outputs()[0]] = res_graph->insert(op, {outputs[0]});
85 rn_env[n->outputs()[0]] = outputs[0];
88 for (
size_t i = 0; i < n->outputs().size(); i++) {
89 auto output = n->outputs()[i];
90 batch_map[output] = std::vector<Value*>(
91 outputs.begin() + i * EXP_BTENSOR_SIZE,
92 outputs.begin() + i * EXP_BTENSOR_SIZE + EXP_BTENSOR_SIZE);
101 void ToBatch::visitConstant(Node* n, Block* block, Block* res_block) {
102 auto res_graph = res_block->owningGraph();
103 auto* r_node = res_graph->createClone(n, rn_fn);
104 res_block->appendNode(r_node);
105 rn_env[n->output()] = r_node->output();
109 void ToBatch::visitNumToTensor(Node* n, Block* block, Block* res_block) {
110 auto res_graph = res_block->owningGraph();
111 auto* r_node = res_graph->createClone(n, rn_fn);
112 res_block->appendNode(r_node);
113 auto outputs = inlineUnpackedCallTo(
114 *res_block->owningGraph(),
115 *getBatchOperator(
"batch_from_scalar_tensor"),
117 batch_map[n->output()] = outputs;
121 void ToBatch::visitTensorToNum(Node* n, Block* block, Block* res_block) {
122 auto res_graph = res_block->owningGraph();
123 if (rn_env.find(n->input()) == rn_env.end()) {
124 rn_env[n->input()] = batch_map.at(n->input())[0];
126 auto* r_node = res_graph->createClone(n, rn_fn);
127 res_block->appendNode(r_node);
128 rn_env[n->output()] = r_node->output();
129 batch_map[n->output()] = batch_map.at(n->input());
133 void ToBatch::visitListConstruct(Node* n, Block* block, Block* res_block) {
134 auto res_graph = res_block->owningGraph();
135 if (n->inputs()[0]->type() ==
137 std::vector<Value*> inputs;
138 for (Value* input : n->inputs()) {
139 auto res = batch_map.at(input);
140 inputs.insert(inputs.end(), res.begin(), res.end());
142 batch_map[n->output()] = inputs;
144 for (Value* input : n->inputs()) {
145 if (rn_env.find(input) == rn_env.end()) {
146 rn_env[input] = batch_map.at(input)[0];
149 auto* r_node = res_graph->createClone(n, rn_fn);
150 res_block->appendNode(r_node);
152 auto to_tensor_node =
153 res_graph->create(Symbol::fromQualString(
"aten::_list_to_tensor"));
154 to_tensor_node->addInput(r_node->output());
155 res_block->appendNode(to_tensor_node);
156 rn_env[n->output()] = to_tensor_node->output();
242 void ToBatch::visitIf(Node* n, Block* block, Block* res_block) {
243 toBatch(n->blocks()[0], res_block);
244 toBatch(n->blocks()[1], res_block);
247 for (
size_t i = 0; i < n->outputs().size(); i++) {
248 std::vector<Value*> inputs;
249 if (batch_map.find(n->input()) == batch_map.end()) {
250 inputs.push_back(rn_env.at(n->input()));
252 auto cond = batch_map.at(n->input());
253 inputs.insert(inputs.end(), cond.begin(), cond.end());
255 auto if_output = batch_map.at(n->blocks()[0]->outputs()[i]);
256 inputs.insert(inputs.end(), if_output.begin(), if_output.end());
257 auto else_output = batch_map.at(n->blocks()[1]->outputs()[i]);
258 inputs.insert(inputs.end(), else_output.begin(), else_output.end());
259 auto outputs = inlineUnpackedCallTo(
260 *res_block->owningGraph(),
261 *getBatchOperator(
"where", inputs.size()),
263 batch_map[n->outputs()[i]] = outputs;
356 void ToBatch::visitLoop(Node* n, Block* block, Block* res_block) {
357 auto res_graph = res_block->owningGraph();
364 bool cond_is_tensor = (batch_map.find(n->inputs()[1]) != batch_map.end());
369 if (rn_env.at(n->inputs()[0])->type() != IntType::get()) {
370 rn_env[n->inputs()[0]] =
371 res_graph->insert(prim::Int, {rn_env.at(n->inputs()[0])});
373 if (cond_is_tensor) {
374 auto cond = batch_map.at(n->inputs()[1]);
375 auto cond_any = inlineUnpackedCallTo(
376 *res_block->owningGraph(), *getBatchOperator(
"any"), cond);
377 rn_env[n->inputs()[1]] = res_graph->insert(prim::Bool, {cond_any[0]});
379 for (
size_t i = 2; i < n->inputs().size(); i++) {
380 auto input = n->inputs()[i];
381 rn_env[input] = batch_map.at(input)[0];
383 auto* r_node = res_graph->createClone(n, rn_fn,
false);
386 if (cond_is_tensor) {
387 for (
size_t i = 0; i < EXP_BTENSOR_SIZE; i++) {
388 auto cond = batch_map.at(n->inputs()[1]);
389 r_node->insertInput(i + 2, cond[i]);
392 for (
size_t i = 2; i < n->inputs().size(); i++) {
393 for (
size_t j = 1; j < EXP_BTENSOR_SIZE; j++) {
395 (i - 2) * EXP_BTENSOR_SIZE + EXP_BTENSOR_SIZE * cond_is_tensor + 2 +
397 batch_map.at(n->inputs()[i])[j]);
400 res_block->appendNode(r_node);
406 auto loop_block = r_node->addBlock();
409 loop_block->addInput(
"loop_num");
410 loop_block->inputs()[0]->setType(IntType::get());
411 rn_env[n->blocks()[0]->inputs()[0]] = loop_block->inputs()[0];
412 if (cond_is_tensor) {
413 for (
size_t i = 0; i < EXP_BTENSOR_SIZE; i++) {
414 loop_block->addInput(
"cond_" + EXP_BTENSOR_NAME[i]);
417 for (
size_t i = 1; i < n->blocks()[0]->inputs().size(); i++) {
418 auto input = n->blocks()[0]->inputs()[i];
419 auto name = input->uniqueName();
420 for (
size_t j = 0; j < EXP_BTENSOR_SIZE; j++) {
421 loop_block->addInput(name +
"_" + EXP_BTENSOR_NAME[j]);
424 std::vector<Value*>(loop_block->inputs()
426 (i - 1) * EXP_BTENSOR_SIZE + 1 +
427 EXP_BTENSOR_SIZE * cond_is_tensor,
432 toBatch(n->blocks()[0], loop_block);
434 WithInsertPoint guard(loop_block);
437 for (
size_t i = 0; i < n->outputs().size(); i++) {
438 std::vector<Value*> inputs, outputs;
439 if (cond_is_tensor) {
440 for (
size_t j = 0; j < EXP_BTENSOR_SIZE; j++) {
441 inputs.push_back(loop_block->inputs()[j + 1]);
443 auto data = batch_map.at(n->blocks()[0]->outputs()[i + 1]);
444 inputs.insert(inputs.end(), data.begin(), data.end());
445 for (
size_t j = 0; j < EXP_BTENSOR_SIZE; j++) {
448 ->inputs()[i * EXP_BTENSOR_SIZE + j + EXP_BTENSOR_SIZE + 1]);
450 outputs = inlineUnpackedCallTo(
451 *res_block->owningGraph(), *getBatchOperator(
"where"), inputs);
453 for (
size_t j = 0; j < EXP_BTENSOR_SIZE; j++) {
454 inputs.push_back(loop_block->inputs()[i * EXP_BTENSOR_SIZE + j + 1]);
456 auto data = batch_map.at(n->blocks()[0]->outputs()[i + 1]);
457 inputs.insert(inputs.end(), data.begin(), data.end());
458 outputs = inlineUnpackedCallTo(
459 *res_block->owningGraph(), *getBatchOperator(
"update"), inputs);
461 batch_map[n->outputs()[i]] = outputs;
462 for (
size_t j = 0; j < EXP_BTENSOR_SIZE; j++) {
463 loop_block->registerOutput(outputs[j]);
468 if (cond_is_tensor) {
469 auto cond = batch_map.at(n->blocks()[0]->outputs()[0]);
470 auto cond_any = inlineUnpackedCallTo(
471 *res_block->owningGraph(), *getBatchOperator(
"any"), cond);
472 auto to_bool_output = res_graph->insert(prim::Bool, {cond_any[0]});
473 loop_block->insertOutput(0, to_bool_output);
474 for (
size_t i = 0; i < EXP_BTENSOR_SIZE; i++) {
475 loop_block->insertOutput(i + 1, cond[i]);
478 auto cond = rn_env.at(n->blocks()[0]->outputs()[0]);
479 loop_block->insertOutput(0, cond);
483 auto size = r_node->outputs().size();
484 for (
size_t i = 0; i < size; i++) {
485 for (
size_t j = 1; j < EXP_BTENSOR_SIZE; j++) {
486 r_node->insertOutput(i * EXP_BTENSOR_SIZE + j);
488 batch_map[n->outputs()[i]] =
489 r_node->outputs().slice(i * EXP_BTENSOR_SIZE, EXP_BTENSOR_SIZE).vec();
492 if (cond_is_tensor) {
493 for (
size_t i = 0; i < EXP_BTENSOR_SIZE; i++) {
494 r_node->insertOutput(i);
499 void ToBatch::toBatch(Block* block, Block* res_block) {
500 WithInsertPoint guard(res_block);
505 if (!block->owningNode() || block->owningNode()->kind() != prim::Loop) {
506 auto size = block->inputs().size();
507 for (
size_t i = 0; i < size; i++) {
508 auto input = block->inputs()[i];
509 auto name = input->uniqueName();
510 for (
size_t j = 0; j < EXP_BTENSOR_SIZE; j++) {
511 res_block->addInput(name +
"_" + EXP_BTENSOR_NAME[j]);
514 std::vector<Value*>(res_block->inputs()
515 .slice(i * EXP_BTENSOR_SIZE, EXP_BTENSOR_SIZE)
520 for (
auto it = block->nodes().begin(); it != block->nodes().end(); it++) {
522 if (n->kind().is_aten()) {
523 visitAten(n, block, res_block);
524 }
else if (n->kind().is_prim()) {
527 visitConstant(n, block, res_block);
529 case prim::NumToTensor:
530 visitNumToTensor(n, block, res_block);
535 visitTensorToNum(n, block, res_block);
537 case prim::ListConstruct:
538 visitListConstruct(n, block, res_block);
541 visitIf(n, block, res_block);
544 visitLoop(n, block, res_block);
547 throw std::runtime_error(
548 "NYI: node of prim kind other than [Constant, NumToTensor, TensorToNum, If, Loop] is not supported yet");
551 throw std::runtime_error(
552 "NYI: node that is not aten or prim kind is not supported yet");
561 if (!block->owningNode() ||
562 (block->owningNode()->kind() != prim::Loop &&
563 block->owningNode()->kind() != prim::If)) {
564 for (Value* output : block->outputs()) {
565 auto r_output = batch_map.at(output);
566 for (
size_t i = 0; i < EXP_BTENSOR_SIZE; i++) {
567 res_block->registerOutput(r_output[i]);
573 std::shared_ptr<Graph> to_batch_graph(std::shared_ptr<Graph> graph) {
575 if (graph->outputs().at(0)->type()->kind() == TupleType::Kind) {
576 graph = graph->copy();
577 auto outs = createTupleUnpack(graph->outputs().at(0));
578 graph->eraseOutput(0);
580 graph->registerOutput(o);
581 EliminateDeadCode(graph->block());
583 std::shared_ptr<Graph> res_graph = std::make_shared<Graph>();
585 to_batch.toBatch(graph->block(), res_graph->block());
590 res_graph->insertNode(res_graph->createTuple(res_graph->outputs()));
591 while (res_graph->outputs().size() > 0)
592 res_graph->eraseOutput(res_graph->outputs().size() - 1);
593 res_graph->registerOutput(tup->output());
594 EliminateDeadCode(res_graph->block());
599 void initRegisterBatchOpsBindings(PyObject* module) {
600 auto m = py::handle(module).cast<py::module>();
601 m.def(
"to_batch_graph", to_batch_graph);
603 "register_batch_operator",
604 [](std::string name, std::shared_ptr<Graph> graph) {
605 ToBatch::batch_operator_table[name].push_back(graph);