Caffe2 - C++ API
A deep learning, cross platform ML framework
to_batch.cpp
1 #include <torch/csrc/jit/passes/to_batch.h>
2 #include <torch/csrc/jit/passes/dead_code_elimination.h>
3 #include <torch/csrc/jit/script/compiler.h>
4 
5 namespace torch {
6 namespace jit {
7 
8 std::unordered_map<std::string, std::vector<std::shared_ptr<Graph>>>
9  ToBatch::batch_operator_table;
10 
11 std::shared_ptr<Graph> ToBatch::getBatchOperator(
12  const std::string& name,
13  int64_t num_inputs) {
14  if (batch_operator_table.find(name) == batch_operator_table.end()) {
15  throw std::runtime_error(
16  "function " + name + " is not supported in batched tensor yet");
17  }
18  auto ops = batch_operator_table.at(name);
19  if (num_inputs == -1) // default function
20  return ops[0];
21  for (auto op : ops) {
22  if (size_t(num_inputs) == op->inputs().size())
23  return op;
24  }
25  throw std::runtime_error(
26  "function " + name + " with " + std::to_string(num_inputs) +
27  " inputs is not supported in batched tensor yet");
28 }
29 
30 std::vector<Value*> inlineUnpackedCallTo(
31  Graph& g,
32  Graph& callee,
33  ArrayRef<Value*> inputs) {
34  return inlineCallTo(g, callee, inputs, /*unpack_outputs=*/true);
35 }
36 
37 // replace aten operator node with BatchTensor operator graph
38 void ToBatch::visitAten(Node* n, Block* block, Block* res_block) {
39  auto res_graph = res_block->owningGraph();
40  auto func_name = std::string(n->kind().toUnqualString());
41  std::vector<Value*> new_inputs;
42  for (Value* input : n->inputs()) {
43  if (rn_env.find(input) == rn_env.end()) { // non-tensor input
44  auto new_input = batch_map.at(input);
45  new_inputs.insert(new_inputs.end(), new_input.begin(), new_input.end());
46  } else { // batched tensor input
47  new_inputs.push_back(rn_env.at(input));
48  }
49  }
50 
51  // transform scalar to tensor before pass to batch operator script
52  for (auto& input : new_inputs) {
53  if (input->type() == IntType::get() || input->type() == FloatType::get() ||
54  input->type() == BoolType::get()) {
55  auto to_tensor_node = res_graph->createNumToTensor(input);
56  res_graph->insertNode(to_tensor_node);
57  input = to_tensor_node->output();
58  }
59  }
60 
61  auto batch_graph = getBatchOperator(func_name, new_inputs.size());
62  auto outputs =
63  inlineUnpackedCallTo(*res_block->owningGraph(), *batch_graph, new_inputs);
64 
65  // Assume all outputs from inlined operator implementation are in the triple
66  // form batched tensor or just a single non-tensor.
67  if (outputs.size() == 1) {
68  // if previous output is scalar, transform new output back to scalar from
69  // dynamic
70  TypePtr orig_type = n->outputs()[0]->type();
71  if (!orig_type->isSubtypeOf(outputs[0]->type())) {
72  Symbol op;
73  if (orig_type == IntType::get()) {
74  op = prim::Int;
75  } else if (orig_type == FloatType::get()) {
76  op = prim::Float;
77  } else if (orig_type == BoolType::get()) {
78  op = prim::Bool;
79  } else {
80  throw std::runtime_error(
81  "NYI: scalar types other than int, float, and bool are not supported yet");
82  }
83  rn_env[n->outputs()[0]] = res_graph->insert(op, {outputs[0]});
84  } else {
85  rn_env[n->outputs()[0]] = outputs[0];
86  }
87  } else {
88  for (size_t i = 0; i < n->outputs().size(); i++) {
89  auto output = n->outputs()[i];
90  batch_map[output] = std::vector<Value*>(
91  outputs.begin() + i * EXP_BTENSOR_SIZE,
92  outputs.begin() + i * EXP_BTENSOR_SIZE + EXP_BTENSOR_SIZE);
93  }
94  }
95 }
96 
97 // clone prim::Constant to new graph
98 // batching transformation is applied to the output of prim::NumToTensor.
99 // If there is a prim::NumToTensor following prim::Constant, it will be finally
100 // transformed to BatchTensor.
101 void ToBatch::visitConstant(Node* n, Block* block, Block* res_block) {
102  auto res_graph = res_block->owningGraph();
103  auto* r_node = res_graph->createClone(n, rn_fn);
104  res_block->appendNode(r_node);
105  rn_env[n->output()] = r_node->output();
106 }
107 
108 // change return tensor to expanded batched tensor, eg: {data, mask, dims}
109 void ToBatch::visitNumToTensor(Node* n, Block* block, Block* res_block) {
110  auto res_graph = res_block->owningGraph();
111  auto* r_node = res_graph->createClone(n, rn_fn);
112  res_block->appendNode(r_node);
113  auto outputs = inlineUnpackedCallTo(
114  *res_block->owningGraph(),
115  *getBatchOperator("batch_from_scalar_tensor"),
116  r_node->outputs());
117  batch_map[n->output()] = outputs;
118 }
119 
120 // clone prim::TensorToNum to new graph
121 void ToBatch::visitTensorToNum(Node* n, Block* block, Block* res_block) {
122  auto res_graph = res_block->owningGraph();
123  if (rn_env.find(n->input()) == rn_env.end()) {
124  rn_env[n->input()] = batch_map.at(n->input())[0];
125  }
126  auto* r_node = res_graph->createClone(n, rn_fn);
127  res_block->appendNode(r_node);
128  rn_env[n->output()] = r_node->output();
129  batch_map[n->output()] = batch_map.at(n->input());
130 }
131 
132 // clone prim::ListConstruct to new graph
133 void ToBatch::visitListConstruct(Node* n, Block* block, Block* res_block) {
134  auto res_graph = res_block->owningGraph();
135  if (n->inputs()[0]->type() ==
136  TensorType::get()) { // TensorList: expand directly
137  std::vector<Value*> inputs;
138  for (Value* input : n->inputs()) {
139  auto res = batch_map.at(input);
140  inputs.insert(inputs.end(), res.begin(), res.end());
141  }
142  batch_map[n->output()] = inputs;
143  } else { // ScalarList: transform to tensor, then transform back
144  for (Value* input : n->inputs()) {
145  if (rn_env.find(input) == rn_env.end()) {
146  rn_env[input] = batch_map.at(input)[0];
147  }
148  }
149  auto* r_node = res_graph->createClone(n, rn_fn);
150  res_block->appendNode(r_node);
151  // transform int[] to tensor
152  auto to_tensor_node =
153  res_graph->create(Symbol::fromQualString("aten::_list_to_tensor"));
154  to_tensor_node->addInput(r_node->output());
155  res_block->appendNode(to_tensor_node);
156  rn_env[n->output()] = to_tensor_node->output();
157  }
158 }
159 
160 // clang-format off
161 // prim::If transformation:
162 // elif is not supported
163 //
164 // transformation example:
165 // @torch.jit.batch(batch_size=4)
166 // def batch_if(a, b):
167 // if a > b:
168 // a += b
169 // else:
170 // a -= b
171 // return a
172 //
173 // original graph:
174 // graph(%a.1 : Dynamic
175 // %b : Dynamic) {
176 // %2 : Dynamic = aten::gt(%a.1, %b)
177 // %a : Dynamic = prim::If(%2)
178 // block0() {
179 // %a.2 : Dynamic = aten::add[alpha={1}](%a.1, %b)
180 // -> (%a.2)
181 // }
182 // block1() {
183 // %a.3 : Dynamic = aten::sub[alpha={1}](%a.1, %b)
184 // -> (%a.3)
185 // }
186 // return (%a);
187 // }
188 //
189 // transformed graph:
190 // graph(%a.1_data : Dynamic
191 // %a.1_mask : Dynamic
192 // %a.1_dims : Dynamic
193 // %b_data : Dynamic
194 // %b_mask : Dynamic
195 // %b_dims : Dynamic) {
196 // %6 : Dynamic = aten::gt(%a.1_data, %b_data) // calculate condition
197 // %7 : Dynamic = aten::mul(%a.1_mask, %b_mask)
198 // %8 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
199 // %9 : int = prim::TensorToNum(%6)
200 // %10 : Long() = prim::Constant[value={1}]() // if_block
201 // %alpha.1 : float = prim::TensorToNum(%10)
202 // %data.1 : Dynamic = aten::add(%a.1_data, %b_data, %alpha.1)
203 // %mask.1 : Dynamic = aten::mul(%a.1_mask, %b_mask)
204 // %dims.1 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
205 // %15 : Long() = prim::Constant[value={1}]() // else_block
206 // %alpha : float = prim::TensorToNum(%15)
207 // %data.4 : Dynamic = aten::sub(%a.1_data, %b_data, %alpha)
208 // %mask : Dynamic = aten::mul(%a.1_mask, %b_mask)
209 // %dims : Dynamic = aten::__or__(%a.1_dims, %b_dims)
210 // %20 : Dynamic = aten::type_as(%7, %6) // combine two outputs (batch_where)
211 // %cond_mask.1 : Dynamic = aten::mul(%6, %20)
212 // %22 : int = aten::dim(%cond_mask.1)
213 // %23 : int = prim::Constant[value=1]()
214 // %24 : int = aten::eq(%22, %23)
215 // %cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%24)
216 // block0() {
217 // %28 : int = aten::dim(%data.1)
218 // %29 : int = prim::Constant[value=1]()
219 // %30 : int = aten::sub(%28, %29)
220 // %31 : int = prim::Constant[value=1]()
221 // %data.3 : Dynamic = prim::Loop(%30, %31, %cond_mask.1)
222 // block0(%_ : int, %34 : Dynamic) {
223 // %35 : int = prim::Constant[value=1]()
224 // %36 : int = aten::neg(%35)
225 // %data.2 : Dynamic = aten::unsqueeze(%34, %36)
226 // %38 : int = prim::Constant[value=1]()
227 // -> (%38, %data.2)
228 // }
229 // %cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)
230 // %cond_mask.2 : Dynamic = aten::expand_as(%data.3, %mask.1)
231 // -> (%cond_data.1, %cond_mask.2, %data.3)
232 // }
233 // block1() {
234 // -> (%cond_mask.1, %cond_mask.1, %cond_mask.1)
235 // }
236 // %res_data : Dynamic = aten::where(%cond_data, %data.1, %data.4)
237 // %res_mask : Dynamic = aten::where(%cond_mask, %mask.1, %mask)
238 // %res_dims : Dynamic = aten::__or__(%dims.1, %dims)
239 // return (%res_data, %res_mask, %res_dims);
240 // }
241 // clang-format on
242 void ToBatch::visitIf(Node* n, Block* block, Block* res_block) {
243  toBatch(n->blocks()[0], res_block);
244  toBatch(n->blocks()[1], res_block);
245 
246  // combine results from two if paths
247  for (size_t i = 0; i < n->outputs().size(); i++) {
248  std::vector<Value*> inputs;
249  if (batch_map.find(n->input()) == batch_map.end()) { // cond is scalar
250  inputs.push_back(rn_env.at(n->input()));
251  } else { // cond is tensor
252  auto cond = batch_map.at(n->input());
253  inputs.insert(inputs.end(), cond.begin(), cond.end());
254  }
255  auto if_output = batch_map.at(n->blocks()[0]->outputs()[i]);
256  inputs.insert(inputs.end(), if_output.begin(), if_output.end());
257  auto else_output = batch_map.at(n->blocks()[1]->outputs()[i]);
258  inputs.insert(inputs.end(), else_output.begin(), else_output.end());
259  auto outputs = inlineUnpackedCallTo(
260  *res_block->owningGraph(),
261  *getBatchOperator("where", inputs.size()),
262  inputs);
263  batch_map[n->outputs()[i]] = outputs;
264  }
265 }
266 
267 // clang-format off
268 // prim::Loop transformation:
269 //
270 // transformation example:
271 // @torch.jit.batch(batch_size=4)
272 // def batch_while(a, b):
273 // while a > b:
274 // a -= b
275 // return a
276 //
277 // original graph:
278 // graph(%a.1 : Dynamic
279 // %b : Dynamic) {
280 // %2 : int = prim::Constant[value={2147483647}]()
281 // %3 : Dynamic = aten::gt(%a.1, %b)
282 // %a : Dynamic = prim::Loop(%2, %3, %a.1)
283 // block0(%4 : Dynamic, %5 : Dynamic) {
284 // %a.2 : Dynamic = aten::sub[alpha={1}](%5, %b)
285 // %9 : Dynamic = aten::gt(%a.2, %b)
286 // -> (%9, %a.2)
287 // }
288 // return (%a);
289 // }
290 //
291 // transformed graph:
292 // graph(%a.1_data : Dynamic
293 // %a.1_mask : Dynamic
294 // %a.1_dims : Dynamic
295 // %b_data : Dynamic
296 // %b_mask : Dynamic
297 // %b_dims : Dynamic) {
298 // %6 : int = prim::Constant[value=2147483647]()
299 // %7 : Dynamic = aten::gt(%a.1_data, %b_data)
300 // %8 : Dynamic = aten::mul(%a.1_mask, %b_mask)
301 // %9 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
302 // %10 : int = prim::TensorToNum(%7)
303 // %11 : Dynamic = aten::mul(%7, %8)
304 // %12 : Dynamic = aten::sum(%11)
305 // %13 : Dynamic = aten::gt[other={0}](%12) // cond_any
306 // %14 : int = prim::TensorToNum(%13)
307 // %62 : Dynamic, %63 : Dynamic, %64 : Dynamic, %a : Dynamic, %60 : Dynamic, %61 : Dynamic = prim::Loop(%6, %14, %7, %8, %9, %a.1_data, %a.1_mask, %a.1_dims)
308 // block0(%loop_num : int, %cond_data.2 : Dynamic, %cond_mask.3 : Dynamic, %cond_dims : Dynamic, %6_data : Dynamic, %6_mask : Dynamic, %6_dims : Dynamic) {
309 // %23 : Long() = prim::Constant[value={1}]()
310 // %alpha : float = prim::TensorToNum(%23)
311 // %data.1 : Dynamic = aten::sub(%6_data, %b_data, %alpha)
312 // %mask : Dynamic = aten::mul(%6_mask, %b_mask)
313 // %dims : Dynamic = aten::__or__(%6_dims, %b_dims)
314 // %28 : Dynamic = aten::gt(%data.1, %b_data)
315 // %29 : Dynamic = aten::mul(%mask, %b_mask)
316 // %30 : Dynamic = aten::__or__(%dims, %b_dims)
317 // %31 : int = prim::TensorToNum(%28)
318 // %32 : Dynamic = aten::type_as(%cond_mask.3, %cond_data.2) // update outputs (batch_where)
319 // %cond_mask.1 : Dynamic = aten::mul(%cond_data.2, %32)
320 // %34 : int = aten::dim(%cond_mask.1)
321 // %35 : int = prim::Constant[value=1]()
322 // %36 : int = aten::eq(%34, %35)
323 // %cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%36)
324 // block0() {
325 // %40 : int = aten::dim(%data.1)
326 // %41 : int = prim::Constant[value=1]()
327 // %42 : int = aten::sub(%40, %41)
328 // %43 : int = prim::Constant[value=1]()
329 // %data.3 : Dynamic = prim::Loop(%42, %43, %cond_mask.1)
330 // block0(%_ : int, %46 : Dynamic) {
331 // %47 : int = prim::Constant[value=1]()
332 // %48 : int = aten::neg(%47)
333 // %data.2 : Dynamic = aten::unsqueeze(%46, %48)
334 // %50 : int = prim::Constant[value=1]()
335 // -> (%50, %data.2)
336 // }
337 // %cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)
338 // %cond_mask.2 : Dynamic = aten::expand_as(%data.3, %mask)
339 // -> (%cond_data.1, %cond_mask.2, %data.3)
340 // }
341 // block1() {
342 // -> (%cond_mask.1, %cond_mask.1, %cond_mask.1)
343 // }
344 // %res_data : Dynamic = aten::where(%cond_data, %data.1, %6_data)
345 // %res_mask : Dynamic = aten::where(%cond_mask, %mask, %6_mask)
346 // %res_dims : Dynamic = aten::__or__(%dims, %6_dims)
347 // %56 : Dynamic = aten::mul(%28, %29)
348 // %57 : Dynamic = aten::sum(%56)
349 // %58 : Dynamic = aten::gt[other={0}](%57)
350 // %59 : int = prim::TensorToNum(%58)
351 // -> (%59, %28, %29, %30, %res_data, %res_mask, %res_dims)
352 // }
353 // return (%a, %60, %61);
354 // }
355 // clang-format on
356 void ToBatch::visitLoop(Node* n, Block* block, Block* res_block) {
357  auto res_graph = res_block->owningGraph();
358  // bool cond_is_tensor indicates whether cond is tensor
359  // cond_is_tensor = false, eg: for loop, n->inputs()[1] = byte()
360  // cond_is_tensor = true, eg: in some while loop, cond is a batched tensor,
361  // we need to add expanded cond to the inputs of
362  // loop node and block, and compute cond_any as
363  // cond for while loop
364  bool cond_is_tensor = (batch_map.find(n->inputs()[1]) != batch_map.end());
365 
366  // create prim::Loop node for res_block
367 
368  // type of cond in loop should be int type
369  if (rn_env.at(n->inputs()[0])->type() != IntType::get()) {
370  rn_env[n->inputs()[0]] =
371  res_graph->insert(prim::Int, {rn_env.at(n->inputs()[0])});
372  }
373  if (cond_is_tensor) {
374  auto cond = batch_map.at(n->inputs()[1]);
375  auto cond_any = inlineUnpackedCallTo(
376  *res_block->owningGraph(), *getBatchOperator("any"), cond);
377  rn_env[n->inputs()[1]] = res_graph->insert(prim::Bool, {cond_any[0]});
378  }
379  for (size_t i = 2; i < n->inputs().size(); i++) {
380  auto input = n->inputs()[i];
381  rn_env[input] = batch_map.at(input)[0];
382  }
383  auto* r_node = res_graph->createClone(n, rn_fn, /*copy_blocks=*/false);
384 
385  // change inputs of prim::Loop
386  if (cond_is_tensor) {
387  for (size_t i = 0; i < EXP_BTENSOR_SIZE; i++) {
388  auto cond = batch_map.at(n->inputs()[1]);
389  r_node->insertInput(i + 2, cond[i]);
390  }
391  }
392  for (size_t i = 2; i < n->inputs().size(); i++) {
393  for (size_t j = 1; j < EXP_BTENSOR_SIZE; j++) {
394  r_node->insertInput(
395  (i - 2) * EXP_BTENSOR_SIZE + EXP_BTENSOR_SIZE * cond_is_tensor + 2 +
396  j,
397  batch_map.at(n->inputs()[i])[j]);
398  }
399  }
400  res_block->appendNode(r_node);
401 
402  // create block for Loop node in res_block
403  // if cond is tensor: first 4 inputs of block: cond_any, cond_data,
404  // cond_mask, cond_dims
405  // if cond is not tensor: first 1 input of block: cond
406  auto loop_block = r_node->addBlock();
407 
408  // add inputs
409  loop_block->addInput("loop_num");
410  loop_block->inputs()[0]->setType(IntType::get());
411  rn_env[n->blocks()[0]->inputs()[0]] = loop_block->inputs()[0];
412  if (cond_is_tensor) {
413  for (size_t i = 0; i < EXP_BTENSOR_SIZE; i++) {
414  loop_block->addInput("cond_" + EXP_BTENSOR_NAME[i]);
415  }
416  }
417  for (size_t i = 1; i < n->blocks()[0]->inputs().size(); i++) {
418  auto input = n->blocks()[0]->inputs()[i];
419  auto name = input->uniqueName();
420  for (size_t j = 0; j < EXP_BTENSOR_SIZE; j++) {
421  loop_block->addInput(name + "_" + EXP_BTENSOR_NAME[j]);
422  }
423  batch_map[input] =
424  std::vector<Value*>(loop_block->inputs()
425  .slice(
426  (i - 1) * EXP_BTENSOR_SIZE + 1 +
427  EXP_BTENSOR_SIZE * cond_is_tensor,
428  EXP_BTENSOR_SIZE)
429  .vec());
430  }
431 
432  toBatch(n->blocks()[0], loop_block);
433 
434  WithInsertPoint guard(loop_block);
435 
436  // use where operator to update variables and add to outputs
437  for (size_t i = 0; i < n->outputs().size(); i++) {
438  std::vector<Value*> inputs, outputs;
439  if (cond_is_tensor) {
440  for (size_t j = 0; j < EXP_BTENSOR_SIZE; j++) {
441  inputs.push_back(loop_block->inputs()[j + 1]);
442  }
443  auto data = batch_map.at(n->blocks()[0]->outputs()[i + 1]);
444  inputs.insert(inputs.end(), data.begin(), data.end());
445  for (size_t j = 0; j < EXP_BTENSOR_SIZE; j++) {
446  inputs.push_back(
447  loop_block
448  ->inputs()[i * EXP_BTENSOR_SIZE + j + EXP_BTENSOR_SIZE + 1]);
449  }
450  outputs = inlineUnpackedCallTo(
451  *res_block->owningGraph(), *getBatchOperator("where"), inputs);
452  } else {
453  for (size_t j = 0; j < EXP_BTENSOR_SIZE; j++) {
454  inputs.push_back(loop_block->inputs()[i * EXP_BTENSOR_SIZE + j + 1]);
455  }
456  auto data = batch_map.at(n->blocks()[0]->outputs()[i + 1]);
457  inputs.insert(inputs.end(), data.begin(), data.end());
458  outputs = inlineUnpackedCallTo(
459  *res_block->owningGraph(), *getBatchOperator("update"), inputs);
460  }
461  batch_map[n->outputs()[i]] = outputs;
462  for (size_t j = 0; j < EXP_BTENSOR_SIZE; j++) {
463  loop_block->registerOutput(outputs[j]);
464  }
465  }
466 
467  // update loop conditions
468  if (cond_is_tensor) {
469  auto cond = batch_map.at(n->blocks()[0]->outputs()[0]);
470  auto cond_any = inlineUnpackedCallTo(
471  *res_block->owningGraph(), *getBatchOperator("any"), cond);
472  auto to_bool_output = res_graph->insert(prim::Bool, {cond_any[0]});
473  loop_block->insertOutput(0, to_bool_output);
474  for (size_t i = 0; i < EXP_BTENSOR_SIZE; i++) {
475  loop_block->insertOutput(i + 1, cond[i]);
476  }
477  } else {
478  auto cond = rn_env.at(n->blocks()[0]->outputs()[0]);
479  loop_block->insertOutput(0, cond);
480  }
481 
482  // change outputs of prim::Loop
483  auto size = r_node->outputs().size();
484  for (size_t i = 0; i < size; i++) {
485  for (size_t j = 1; j < EXP_BTENSOR_SIZE; j++) {
486  r_node->insertOutput(i * EXP_BTENSOR_SIZE + j);
487  }
488  batch_map[n->outputs()[i]] =
489  r_node->outputs().slice(i * EXP_BTENSOR_SIZE, EXP_BTENSOR_SIZE).vec();
490  }
491  // add cond to outputs of loop node
492  if (cond_is_tensor) {
493  for (size_t i = 0; i < EXP_BTENSOR_SIZE; i++) {
494  r_node->insertOutput(i);
495  }
496  }
497 }
498 
499 void ToBatch::toBatch(Block* block, Block* res_block) {
500  WithInsertPoint guard(res_block);
501 
502  // change inputs of block-expand tensor to batchtensor eg: (data, mask, dims)
503  // eg: a -> a_data, a_mask, a_dims for block in prim::Loop, register inputs
504  // separately to deal with cond
505  if (!block->owningNode() || block->owningNode()->kind() != prim::Loop) {
506  auto size = block->inputs().size();
507  for (size_t i = 0; i < size; i++) {
508  auto input = block->inputs()[i];
509  auto name = input->uniqueName();
510  for (size_t j = 0; j < EXP_BTENSOR_SIZE; j++) {
511  res_block->addInput(name + "_" + EXP_BTENSOR_NAME[j]);
512  }
513  batch_map[input] =
514  std::vector<Value*>(res_block->inputs()
515  .slice(i * EXP_BTENSOR_SIZE, EXP_BTENSOR_SIZE)
516  .vec());
517  }
518  }
519 
520  for (auto it = block->nodes().begin(); it != block->nodes().end(); it++) {
521  auto n = *it;
522  if (n->kind().is_aten()) {
523  visitAten(n, block, res_block);
524  } else if (n->kind().is_prim()) {
525  switch (n->kind()) {
526  case prim::Constant:
527  visitConstant(n, block, res_block);
528  break;
529  case prim::NumToTensor:
530  visitNumToTensor(n, block, res_block);
531  break;
532  case prim::Bool:
533  case prim::Float:
534  case prim::Int:
535  visitTensorToNum(n, block, res_block);
536  break;
537  case prim::ListConstruct:
538  visitListConstruct(n, block, res_block);
539  break;
540  case prim::If:
541  visitIf(n, block, res_block);
542  break;
543  case prim::Loop:
544  visitLoop(n, block, res_block);
545  break;
546  default:
547  throw std::runtime_error(
548  "NYI: node of prim kind other than [Constant, NumToTensor, TensorToNum, If, Loop] is not supported yet");
549  }
550  } else {
551  throw std::runtime_error(
552  "NYI: node that is not aten or prim kind is not supported yet");
553  }
554  }
555  // change outputs of block - expand tensor to batchtensor(data, mask, dims)
556  // for block in prim::Loop, register outputs separately to deal with cond and
557  // cond_any
558  //
559  // for block in prim::If, register outputs separately by combining
560  // outputs from two paths and return
561  if (!block->owningNode() ||
562  (block->owningNode()->kind() != prim::Loop &&
563  block->owningNode()->kind() != prim::If)) {
564  for (Value* output : block->outputs()) {
565  auto r_output = batch_map.at(output);
566  for (size_t i = 0; i < EXP_BTENSOR_SIZE; i++) {
567  res_block->registerOutput(r_output[i]);
568  }
569  }
570  }
571 }
572 
573 std::shared_ptr<Graph> to_batch_graph(std::shared_ptr<Graph> graph) {
574  // lower the tuple before the pass
575  if (graph->outputs().at(0)->type()->kind() == TupleType::Kind) {
576  graph = graph->copy();
577  auto outs = createTupleUnpack(graph->outputs().at(0));
578  graph->eraseOutput(0);
579  for (auto o : outs)
580  graph->registerOutput(o);
581  EliminateDeadCode(graph->block());
582  }
583  std::shared_ptr<Graph> res_graph = std::make_shared<Graph>();
584  ToBatch to_batch;
585  to_batch.toBatch(graph->block(), res_graph->block());
586 
587  // methods should only have a single output, so we pack everything into a
588  // tuple
589  auto tup =
590  res_graph->insertNode(res_graph->createTuple(res_graph->outputs()));
591  while (res_graph->outputs().size() > 0)
592  res_graph->eraseOutput(res_graph->outputs().size() - 1);
593  res_graph->registerOutput(tup->output());
594  EliminateDeadCode(res_graph->block());
595 
596  return res_graph;
597 }
598 
599 void initRegisterBatchOpsBindings(PyObject* module) {
600  auto m = py::handle(module).cast<py::module>();
601  m.def("to_batch_graph", to_batch_graph);
602  m.def(
603  "register_batch_operator",
604  [](std::string name, std::shared_ptr<Graph> graph) {
605  ToBatch::batch_operator_table[name].push_back(graph);
606  });
607 }
608 
609 } // namespace jit
610 } // namespace torch
Definition: jit_type.h:17