Caffe2 - C++ API
A deep learning, cross platform ML framework
peephole.cpp
1 #include <c10/util/Exception.h>
2 #include <torch/csrc/jit/passes/onnx/peephole.h>
3 
4 #include <c10/util/Optional.h>
5 
6 #if defined(_MSC_VER)
7 #include <BaseTsd.h>
8 typedef SSIZE_T ssize_t;
9 #endif
10 
11 namespace torch {
12 namespace jit {
13 
14 namespace onnx {
15 using namespace ::c10::onnx;
16 }
17 
18 bool isRNN(const Node* node) {
19  auto k = node->kind();
20  return k == onnx::RNN || k == onnx::LSTM || k == onnx::GRU;
21 }
22 
23 bool isNopTranspose(const std::vector<int64_t>& perm) {
24  for (int64_t i = 0, perm_size = perm.size(); i < perm_size; i++)
25  if (perm[i] != i)
26  return false;
27  return true;
28 }
29 
30 // returns a vector `ret` such that transposing by `ret` is equivalent
31 // to transposing by `t1` and then by `t2`
32 //
33 // This fires in the case that we have transpose ops T1 -> T2. We are
34 // fusing the transpose op T1 into T2 and discarding T1. We assume the elements
35 // of the permutation in `t1` are raw indices into its input, since a previous
36 // iteration would have folded all the transposes up to that point. Thus,
37 // `ret[i] = t1[t2[i]]` says "the output of t2 at position i takes the value of
38 // the input tensor index contained in t1 at position `t2[i]``".
39 std::vector<int64_t> composeTransposes(
40  const std::vector<int64_t>& t1,
41  const std::vector<int64_t>& t2) {
42  AT_ASSERT(t1.size() == t2.size());
43  std::vector<int64_t> ret;
44  ret.reserve(t1.size());
45  for (const auto& i : t2) {
46  AT_ASSERT(i < int64_t(t1.size()));
47  ret.push_back(t1[i]);
48  }
49  return ret;
50 }
51 
52 const std::vector<size_t>& getBroadcastPositions(Node* node) {
53  // Most of the element-wise ops in ONNX supports numpy broadcasting.
54  // Only GEMM supports one-directional broadcasting, which broadcasts the bias
55  // to the product.
56  static std::unordered_map<NodeKind, std::vector<size_t>> broadcast_positions =
57  {
58  {onnx::Add, {0, 1}},
59  {onnx::Div, {0, 1}},
60  {onnx::Mul, {0, 1}},
61  {onnx::Pow, {0, 1}},
62  {onnx::Sub, {0, 1}},
63  {onnx::Gemm, {2}},
64  {onnx::Equal, {0, 1}},
65  {onnx::Greater, {0, 1}},
66  {onnx::Less, {0, 1}},
67  };
68  static std::vector<size_t> no_positions;
69 
70  auto iter = broadcast_positions.find(node->kind());
71  if (iter != broadcast_positions.end()) {
72  return iter->second;
73  }
74  return no_positions;
75 }
76 
77 // Determine whether `from` can broadcast to `to`, and if so at which
78 // position. `from` must be a suffix of `to`, except that any
79 // occurences of 1 in `from` are treated as wildcards.
80 c10::optional<size_t> fusibleExpandTo(at::IntArrayRef from, at::IntArrayRef to) {
81  if (from.size() > to.size()) {
82  return c10::nullopt;
83  }
84 
85  for (size_t i = 0; i < from.size(); i++) {
86  auto fdim = from[from.size() - 1 - i];
87  auto tdim = to[to.size() - 1 - i];
88  if (fdim != 1 && fdim != tdim) {
89  return c10::nullopt;
90  }
91  }
92 
93  return to.size() - from.size();
94 }
95 
96 void fuseBroadcast(Block* b) {
97  for (auto n : b->nodes()) {
98  for (auto* child_block : n->blocks()) {
99  fuseBroadcast(child_block);
100  }
101 
102  auto& broadcast_positions = getBroadcastPositions(n);
103  if (!broadcast_positions.empty()) {
104  AT_ASSERT(!n->hasAttribute(attr::axis));
105  }
106 
107  for (size_t position : broadcast_positions) {
108  auto* expand_node = n->input(position)->node();
109 
110  // Confirm it is expand node.
111  if (expand_node->kind() != aten::expand ||
112  expand_node->input(1)->node()->kind() != onnx::Constant ||
113  expand_node->input(2)->node()->kind() != onnx::Constant) {
114  continue;
115  }
116 
117  auto* unexpanded_input = expand_node->input(0);
118 
119  // We need to know what the type pre-expand is. We should basically
120  // always have this information (because expands are only ever traced,
121  // not generated from symbolic), but if for some reason we don't
122  // have it, we need to skip.
123  if (!unexpanded_input->isTensor() || !n->output()->isTensor())
124  continue;
125 
126  // Not all broadcasts are supported by ONNX broadcast.
127  c10::optional<size_t> axis = fusibleExpandTo(
128  unexpanded_input->type()
129  ->expect<CompleteTensorType>()
130  ->sizes(), // from
131  n->output()->type()->expect<CompleteTensorType>()->sizes()); // to
132  if (axis == c10::nullopt)
133  continue;
134 
135  n->replaceInput(position, unexpanded_input);
136  if (!expand_node->hasUses()) {
137  expand_node->destroy();
138  }
139  }
140  }
141 }
142 
143 void fuseConsecutiveTransposes(Block* b) {
144  for (auto n : b->nodes()) {
145  for (auto* child_block : n->blocks()) {
146  fuseConsecutiveTransposes(child_block);
147  }
148  if (n->kind() == onnx::Transpose &&
149  n->input()->node()->kind() == onnx::Transpose) {
150  auto origInput = n->input();
151  n->is_(
152  attr::perm,
153  composeTransposes(
154  origInput->node()->is(attr::perm), n->is(attr::perm)));
155  n->replaceInput(0, origInput->node()->input());
156  if (origInput->uses().size() == 0) {
157  origInput->node()->destroy();
158  }
159  continue;
160  }
161  }
162 }
163 
164 void eliminateNopTranspose(Block* b) {
165  for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
166  auto n = *it;
167  for (auto* child_block : n->blocks()) {
168  eliminateNopTranspose(child_block);
169  }
170  if (n->kind() == onnx::Transpose) {
171  if (isNopTranspose(n->is(attr::perm))) {
172  n->output()->replaceAllUsesWith(n->input());
173  it.destroyCurrent();
174  continue;
175  }
176  }
177  }
178 }
179 
180 void fuseTransposeIntoGemm(Block* b) {
181  static const std::vector<int64_t> simpleTransPerm({1, 0});
182 
183  for (auto n : b->nodes()) {
184  for (auto* child_block : n->blocks()) {
185  fuseTransposeIntoGemm(child_block);
186  }
187  if (n->kind() == onnx::Gemm) {
188  for (size_t i : {0, 1}) {
189  auto inp = n->inputs()[i];
190  auto trans = i == 0 ? attr::transA : attr::transB;
191  if (inp->node()->kind() == onnx::Transpose &&
192  inp->node()->is(attr::perm) == simpleTransPerm) {
193  n->replaceInput(i, inp->node()->input());
194  n->i_(trans, n->hasAttribute(trans) ? !n->i(trans) : 1);
195  if (inp->uses().size() == 0) {
196  inp->node()->destroy();
197  }
198  }
199  }
200  }
201  }
202 }
203 
204 // Why this is here:
205 //
206 // Pytorch has a "packed" representation of sequences, as well as a
207 // "padded" representation. ONNX has only one representation,
208 // corresponding to pytorch's "padded". Therefore, we need to remove
209 // any use of packed sequences before exporting.
210 //
211 // What this does:
212 //
213 // This code uses the observation that
214 // RNN(PackPadded(x)) == PackPadded(RNN(x))
215 // and converts the first form to the second whenever possible,
216 // "pushing" the packing operation past the RNN operation. Then,
217 // the removeNopPacking pass removes the packing operations
218 // entirely by pairing them with their inverse PadPacked. If the
219 // input graph does not pair the operations, export will fail.
220 
221 void pushPackingPastRnn(Block* b) {
222  for (auto it = b->nodes().begin(); it != b->nodes().end(); ++it) {
223  auto* n = *it;
224  for (auto* child_block : n->blocks()) {
225  pushPackingPastRnn(child_block);
226  }
227 
228  if (n->kind() != prim::PackPadded) {
229  continue;
230  }
231  if (n->outputs().at(0)->uses().size() != 1) {
232  // For now, only handle the case where there is one consumer.
233  continue;
234  }
235  Node* rnn = n->outputs()[0]->uses()[0].user;
236  if (!isRNN(rnn)) {
237  continue;
238  }
239 
240  if (rnn->owningBlock() != n->owningBlock())
241  continue;
242 
243  // Packing only has an effect on a network when its outputs are actually
244  // used, so we can remove it here.
245  if (rnn->outputs().at(0)->uses().empty() &&
246  n->outputs().at(1)->uses().size() == 1) {
247  n->outputs().at(0)->replaceAllUsesWith(n->inputs().at(0));
248  n->outputs().at(1)->replaceFirstUseWith(n->inputs().at(1));
249  it.destroyCurrent();
250  continue;
251  }
252 
253  // The rnn is followed by a transpose and a reshape (if
254  // bidirectional), or by a squeeze (if unidirectional).
255  Node* next = rnn->outputs().at(0)->uses().at(0).user;
256  if (next->kind() == onnx::Transpose) {
257  next = next->outputs().at(0)->uses().at(0).user;
258  if (next->kind() != onnx::Reshape) {
259  continue;
260  }
261  } else if (next->kind() != onnx::Squeeze) {
262  continue;
263  }
264 
265  // remove PackPadded from in front of the RNN
266  n->outputs().at(0)->replaceAllUsesWith(n->inputs().at(0));
267 
268  // note there can be multiple uses of the length blob. If we are
269  // translating a multi-level RNN it will be an input to each level.
270  n->outputs().at(1)->replaceFirstUseWith(n->inputs().at(1));
271 
272  // and insert new PackPadded after the RNN
273  Node* newPackPadded = b->owningGraph()->create(prim::PackPadded, 2);
274  newPackPadded->insertAfter(next);
275 
276  // make things consume from the new PackPadded
277  next->outputs().at(0)->replaceAllUsesWith(newPackPadded->outputs().at(0));
278  n->outputs().at(1)->replaceAllUsesWith(newPackPadded->outputs().at(1));
279 
280  // setup the new PackPadded's inputs
281  newPackPadded->addInput(next->outputs().at(0));
282  newPackPadded->addInput(n->inputs().at(1));
283 
284  // See https://github.com/pytorch/pytorch/issues/9043 for a full
285  // description. Since PackPadded is for now treated in an
286  // unhygenic way, Pytorch ends up propagating an incorrect type.
287  // Until a long-term cleanup comes around, we can fix this by
288  // resetting the size to the correct value.
289  CompleteTensorTypePtr oldType =
290  rnn->inputs().at(0)->type()->cast<CompleteTensorType>();
291  if (oldType) {
292  std::vector<int64_t> new_sizes;
293  new_sizes.push_back(oldType->sizes().at(0));
294  new_sizes.push_back(oldType->sizes().at(1));
295  new_sizes.push_back(rnn->i(attr::hidden_size));
296  CompleteTensorTypePtr newType = CompleteTensorType::create(
297  oldType->scalarType(), oldType->device(), new_sizes);
298  next->outputs().at(0)->setType(newType);
299  }
300 
301  it.destroyCurrent();
302  }
303 }
304 
305 void removeNopPacking(Block* graph) {
306  for (auto it = graph->nodes().begin(); it != graph->nodes().end(); ++it) {
307  auto* n = *it;
308  for (auto* child_block : n->blocks()) {
309  removeNopPacking(child_block);
310  }
311 
312  if (n->kind() != prim::PadPacked) {
313  continue;
314  }
315  Node* input = n->inputs()[0]->node();
316  if (input->kind() != prim::PackPadded) {
317  continue;
318  }
319  if (input->outputs()[0] != n->inputs()[0]) {
320  continue;
321  }
322  if (input->outputs()[1] != n->inputs()[1]) {
323  continue;
324  }
325  n->outputs()[0]->replaceAllUsesWith(input->inputs()[0]);
326  n->outputs()[1]->replaceAllUsesWith(input->inputs()[1]);
327 
328  n->removeAllInputs();
329  it.destroyCurrent();
330  }
331 }
332 
333 void hackFixupPadPackedShapes(Block* graph) {
334  // FIXME: the shape of the input to the fictional PadPacked node has
335  // incorrect shape. For now, just copy the shape of PadPacked to the shape
336  // of its input.
337  for (auto it = graph->nodes().begin(); it != graph->nodes().end(); ++it) {
338  auto* n = *it;
339  for (auto* child_block : n->blocks()) {
340  removeNopPacking(child_block);
341  }
342 
343  if (n->kind() != prim::PadPacked) {
344  continue;
345  }
346  Node* input = n->inputs()[0]->node();
347  input->outputs()[0]->setType(n->outputs()[0]->type());
348  }
349 }
350 
351 void fixDefaultRNNState(Graph* graph, Node* n, int input_index) {
352  auto initial_state = n->inputs()[input_index];
353 
354  // The RNN code in pytorch accepts an optional hidden state. When it
355  // is provided, everything works great. When it is not provided, it
356  // is default-initialized by constructing a new Variable, which gets
357  // traced as a Constant. Recognize that pattern here and replace it
358  // with something that doesn't fix the batch size. Note that for
359  // multi-layer RNNs there will be a Slice operation between the
360  // Constant and the RNN.
361  bool needsFixing = initial_state->node()->kind() == onnx::Constant ||
362  (initial_state->node()->kind() == onnx::Slice &&
363  initial_state->node()->inputs()[0]->node()->kind() == onnx::Constant);
364 
365  if (!needsFixing) {
366  return;
367  }
368 
369  Node* shape_of_input = graph->create(onnx::Shape, 1);
370  shape_of_input->insertBefore(n);
371  shape_of_input->addInput(n->inputs()[0]);
372 
373  Node* gather_indices = graph->create(onnx::Constant, 1);
374  gather_indices->insertBefore(n);
375  gather_indices->t_(attr::value, autograd::make_variable(at::scalar_to_tensor(at::Scalar(1))));
376 
377  Node* batch_size = graph->create(onnx::Gather, 1);
378  batch_size->insertBefore(n);
379  batch_size->addInput(shape_of_input->outputs()[0]);
380  batch_size->addInput(gather_indices->outputs()[0]);
381 
382  Node* unsqueezed_batch_size = graph->create(onnx::Unsqueeze, 1);
383  unsqueezed_batch_size->insertBefore(n);
384  unsqueezed_batch_size->addInput(batch_size->outputs()[0]);
385  unsqueezed_batch_size->is_(attr::axes, {0});
386 
387  Node* hidden_size = graph->create(onnx::Constant, 1);
388  hidden_size->insertBefore(n);
389  hidden_size->t_(
390  attr::value,
391  autograd::make_variable(at::full(
392  {1},
393  n->i(attr::hidden_size),
394  at::kLong))); // at::Scalar(n->i(attr::hidden_size)).toTensor());
395 
396  Node* num_directions = graph->create(onnx::Constant, 1);
397  num_directions->insertBefore(n);
398  num_directions->t_(
399  attr::value,
400  autograd::make_variable(scalar_to_tensor(at::Scalar(
401  n->hasAttribute(attr::direction) &&
402  n->s(attr::direction) == "bidirectional"
403  ? 2
404  : 1))));
405 
406  Node* unsqueezed_num_directions = graph->create(onnx::Unsqueeze, 1);
407  unsqueezed_num_directions->insertBefore(n);
408  unsqueezed_num_directions->addInput(num_directions->outputs()[0]);
409  unsqueezed_num_directions->is_(attr::axes, {0});
410 
411  Node* concated_dims = graph->create(onnx::Concat, 1);
412  concated_dims->insertBefore(n);
413  concated_dims->i_(attr::axis, 0);
414  concated_dims->addInput(unsqueezed_num_directions->outputs()[0]);
415  concated_dims->addInput(unsqueezed_batch_size->outputs()[0]);
416  concated_dims->addInput(hidden_size->outputs()[0]);
417 
418  Node* constant_of_shape = graph->create(onnx::ConstantOfShape, 1);
419  constant_of_shape->insertBefore(n);
420  constant_of_shape->addInput(concated_dims->outputs()[0]);
421  n->replaceInput(input_index, constant_of_shape->outputs()[0]);
422 
423  if (initial_state->uses().size() == 0) {
424  initial_state->node()->destroy();
425  }
426 }
427 
428 void fixDefaultRnnHiddenState(Block* b) {
429  for (auto it = b->nodes().begin(); it != b->nodes().end(); ++it) {
430  auto* n = *it;
431  for (auto* child_block : n->blocks()) {
432  fixDefaultRnnHiddenState(child_block);
433  }
434 
435  if (!isRNN(n)) {
436  continue;
437  }
438  // Hidden state is the sixth input for RNN, LSTM, GRU.
439  // See https://pytorch.org/docs/master/nn.html#torch.nn.RNN
440  if (n->inputs().size() < 6) {
441  continue;
442  }
443  fixDefaultRNNState(b->owningGraph(), n, 5);
444  }
445 }
446 
447 void fixDefaultLstmCellState(Block* b) {
448  for (auto it = b->nodes().begin(); it != b->nodes().end(); ++it) {
449  auto* n = *it;
450  for (auto* child_block : n->blocks()) {
451  fixDefaultLstmCellState(child_block);
452  }
453 
454  if (n->kind() != onnx::LSTM) {
455  continue;
456  }
457  // Cell state is the seventh input for LSTM.
458  // See https://pytorch.org/docs/master/nn.html#torch.nn.LSTM
459  if (n->inputs().size() < 7) {
460  continue;
461  }
462  fixDefaultRNNState(b->owningGraph(), n, 6);
463  }
464 }
465 
466 static bool isSafeToSpeculate(Node* n) {
467  return n->kind() == onnx::Transpose;
468 }
469 
470 static void speculateOps(Block* block) {
471  for (auto it = block->nodes().begin(), end = block->nodes().end();
472  it != end;) {
473  Node* n = *it;
474  ++it; // note: increment first so that it is safe to move the node if needed
475 
476  for (auto b : n->blocks()) {
477  speculateOps(b);
478  }
479  if (!isSafeToSpeculate(n))
480  continue;
481  // XXX - only works for nodes with a single input
482  // move node n outside of the control flow it is nested in
483  auto node_input = n->input()->node();
484  if (node_input->owningBlock() == n->owningBlock())
485  continue;
486  // find the control flow node in the same block as node_input that contains
487  // Node n
488  auto control_flow_node = n->owningBlock()->owningNode();
489  while (control_flow_node->owningBlock() != node_input->owningBlock())
490  control_flow_node = control_flow_node->owningBlock()->owningNode();
491  // put the node right before this flow node
492  n->moveBefore(control_flow_node);
493  }
494 }
495 
496 static void replaceInputWithList(Node* node, size_t i, ArrayRef<Value*> to) {
497  node->removeInput(i);
498  for (auto* to_val : to) {
499  AT_ASSERT(to_val->owningGraph() == node->owningGraph());
500  node->insertInput(i++, to_val);
501  }
502 }
503 
504 static void eraseListConstruct(Block* block) {
505  for (auto it = block->nodes().begin(), end = block->nodes().end();
506  it != end;) {
507  Node* n = *it;
508  ++it;
509 
510  for (auto b : n->blocks()) {
511  eraseListConstruct(b);
512  }
513  std::vector<std::tuple<size_t, std::vector<Value*>>> replacements;
514 
515  size_t i = 0;
516  for (auto* input : n->inputs()) {
517  if (input->node()->kind() == prim::ListConstruct) {
518  auto* lc_node = input->node();
519  TypePtr elem =
520  lc_node->output()->type()->cast<ListType>()->getElementType();
521  if (elem->cast<IntType>()) {
522  // ListConstruct Int[] output case, we need to transfrom to ONNX
523  // Concat to ensure the output is a single tensor(dynamic) type in
524  // order to be consumed as inputs
525  std::vector<Value*> unsqueezed;
526  Graph* g = block->owningGraph();
527  for (auto* input : lc_node->inputs()) {
528  Node* unsqueezed_node = g->create(onnx::Unsqueeze, 1);
529  unsqueezed_node->insertBefore(lc_node);
530  unsqueezed_node->addInput(input);
531  unsqueezed_node->is_(attr::axes, {0});
532  unsqueezed.emplace_back(unsqueezed_node->output());
533  }
534  Node* concat_node = g->create(onnx::Concat, 1);
535  concat_node->i_(attr::axis, 0);
536  for (auto v : unsqueezed) {
537  concat_node->addInput(v);
538  }
539  concat_node->insertBefore(lc_node);
540 
541  // make concat node output as new input, then ListConstruct should
542  // become dead
543  replacements.emplace_back(
544  i, std::vector<Value*>({concat_node->output()}));
545 
546  } else {
547  // Tensor lists are used mostly for inputs to cat/stack. They are
548  // already handled in those symbolics, and should become dead
549  // afterwards.
550  replacements.emplace_back(
551  i,
552  std::vector<Value*>(
553  lc_node->inputs().begin(), lc_node->inputs().end()));
554  }
555  }
556  i++;
557  }
558 
559  for (auto ritr = replacements.rbegin(); ritr != replacements.rend();
560  ++ritr) {
561  replaceInputWithList(n, std::get<0>(*ritr), std::get<1>(*ritr));
562  }
563  }
564 }
565 
566 static void fuseSplitListUnpack(Block *b) {
567  for(auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
568  for (auto *child_block : it->blocks()) {
569  fuseSplitListUnpack(child_block);
570  }
571  if (it->kind() == prim::ListUnpack && it->input()->node()->kind() == onnx::Split) {
572  auto origSplitNode = it->input()->node();
573 
574  Node * splitNode = b->owningGraph()->create(onnx::Split, it->outputs().size());
575  for (size_t i=0; i<splitNode->outputs().size(); ++i) {
576  splitNode->outputs()[i]->copyMetadata(it->outputs()[i]);
577  }
578  splitNode->copyAttributes(*origSplitNode);
579  splitNode->insertBefore(origSplitNode);
580  splitNode->addInput(origSplitNode->input());
581  it->replaceAllUsesWith(splitNode);
582  it->removeAllInputs();
583  origSplitNode->destroy();
584  it.destroyCurrent();
585  continue;
586  }
587  }
588 }
589 
590 void removeMaxPoolUnusedOutput(Block* b) {
591  for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
592  auto n = *it;
593  for (auto* child_block : n->blocks()) {
594  removeMaxPoolUnusedOutput(child_block);
595  }
596  if (strcmp(n->kind().toQualString(), "onnx::MaxPool") == 0) {
597  if (n->outputs().size() == 2 && n->outputs().at(1)->uses().empty()) {
598  it->eraseOutput(1);
599  }
600  }
601  }
602 }
603 
604 // This optimization does ONNX-specific peephole optimizations.
605 //
606 // At the moment, here are the optimizations it does:
607 // - This optimization fuses expand calls into ONNX operators, because it is
608 // easier for non-strided backends to more efficiently do broadcasts if this
609 // is local information. This optimization is not useful for PyTorch as
610 // 'expand' is free.
611 // - Fusing of consecutive transposes
612 // - Elimination of NOP transposes
613 // - Fusing of transposes into Gemm
614 // - Elimination of PaddedSequences
615 //
616 // Before you write an optimization here, ask yourself, "Could I do this
617 // optimization on ATen operators"? If so, you should seriously consider
618 // writing your optimization in jit/passes/peephole.cpp rather than
619 // here, as it will be generally applicable to the JIT as well. The
620 // optimizations here are ONLY applied on ONNX update
621 void PeepholeOptimizeONNX(std::shared_ptr<Graph>& graph) {
622  // TODO: decide on fixpoint strategy
623  // TODO: make it easier not to do O(k) iterations over the graph, where
624  // k is the number of distinct peephole optimizations
625  hackFixupPadPackedShapes(graph->block());
626  pushPackingPastRnn(graph->block());
627  removeNopPacking(graph->block());
628  fixDefaultRnnHiddenState(graph->block());
629  fixDefaultLstmCellState(graph->block());
630  fuseBroadcast(graph->block());
631  fuseConsecutiveTransposes(graph->block());
632  eliminateNopTranspose(graph->block());
633  fuseTransposeIntoGemm(graph->block());
634  speculateOps(graph->block());
635  eraseListConstruct(graph->block());
636  fuseSplitListUnpack(graph->block());
637  removeMaxPoolUnusedOutput(graph->block());
638 }
639 
640 } // namespace jit
641 } // namespace torch
Scalar represents a 0-dimensional tensor which contains a single element.
Definition: Scalar.h:22
constexpr size_t size() const
size - Get the array size.
Definition: ArrayRef.h:138
Definition: jit_type.h:17