Caffe2 - C++ API
A deep learning, cross platform ML framework
test_misc.h
1 #pragma once
2 
3 #include "test/cpp/jit/test_base.h"
4 #include "test/cpp/jit/test_utils.h"
5 
6 #include <torch/csrc/jit/passes/canonicalize.h>
7 #include "ATen/core/interned_strings.h"
8 #include "torch/csrc/autograd/generated/variable_factories.h"
9 #include "torch/csrc/autograd/variable.h"
10 #include "torch/csrc/jit/argument_spec.h"
11 #include "torch/csrc/jit/attributes.h"
12 #include "torch/csrc/jit/autodiff.h"
13 #include "torch/csrc/jit/code_template.h"
14 #include "torch/csrc/jit/custom_operator.h"
15 #include "torch/csrc/jit/dynamic_dag.h"
16 #include "torch/csrc/jit/fuser/interface.h"
17 #include "torch/csrc/jit/import.h"
18 #include "torch/csrc/jit/interpreter.h"
19 #include "torch/csrc/jit/passes/alias_analysis.h"
20 #include "torch/csrc/jit/passes/common_subexpression_elimination.h"
21 #include "torch/csrc/jit/passes/constant_propagation.h"
22 #include "torch/csrc/jit/passes/create_autodiff_subgraphs.h"
23 #include "torch/csrc/jit/passes/dead_code_elimination.h"
24 #include "torch/csrc/jit/passes/graph_fuser.h"
25 #include "torch/csrc/jit/passes/lower_grad_of.h"
26 #include "torch/csrc/jit/passes/lower_tuples.h"
27 #include "torch/csrc/jit/passes/requires_grad_analysis.h"
28 #include "torch/csrc/jit/passes/shape_analysis.h"
29 #include "torch/csrc/jit/passes/utils/subgraph_utils.h"
30 #include "torch/csrc/jit/symbolic_script.h"
31 #include "torch/csrc/jit/symbolic_variable.h"
32 #include "torch/csrc/jit/tracer.h"
33 #include "torch/csrc/utils/hash.h"
34 #include "torch/csrc/utils/memory.h"
35 
36 #include "torch/csrc/autograd/engine.h"
37 #include "torch/csrc/autograd/variable.h"
38 
39 #include <torch/csrc/jit/testing/file_check.h>
40 #include "ATen/core/ivalue.h"
41 #include "torch/csrc/jit/script/compiler.h"
42 #include "torch/csrc/jit/script/module.h"
43 
44 #include "onnx/onnx_pb.h"
45 
46 #include <ATen/ATen.h>
47 
48 #include <c10/util/Exception.h>
49 
50 #include <algorithm>
51 #include <cstddef>
52 #include <functional>
53 #include <iostream>
54 #include <memory>
55 #include <stdexcept>
56 #include <string>
57 #include <tuple>
58 #include <unordered_set>
59 #include <utility>
60 #include <vector>
61 
62 namespace torch {
63 namespace jit {
64 namespace test {
65 
66 using Var = SymbolicVariable;
67 
68 using namespace torch::autograd;
69 
70 template <typename T>
71 std::ostream& operator<<(std::ostream& out, const std::vector<T>& list) {
72  size_t i = 0;
73  out << "{";
74  for (auto&& e : list) {
75  if (i++ > 0)
76  out << ", ";
77  out << e;
78  }
79  out << "}";
80  return out;
81 }
82 
83 void testInternedStrings() {
84  ASSERT_EQ(prim::Param, Symbol::prim("Param"));
85  ASSERT_EQ(prim::Return, Symbol::prim("Return"));
86  ASSERT_EQ(prim::Return.toUnqualString(), std::string("Return"));
87  ASSERT_EQ(prim::Return.toQualString(), std::string("prim::Return"));
88  Symbol newsym = Symbol::aten("__NEW_SYMBOL");
89  size_t symstart = newsym;
90  ASSERT_EQ(newsym.toQualString(), std::string("aten::__NEW_SYMBOL"));
91  // TODO: This test is a bit too close to the implementation details.
92  ASSERT_EQ(Symbol::aten("What"), symstart + 1);
93  ASSERT_EQ(Symbol::aten("What2"), symstart + 2);
94  ASSERT_EQ(Symbol::aten("What"), symstart + 1);
95  ASSERT_EQ(Symbol::aten("What2"), symstart + 2);
96  ASSERT_EQ(Symbol(symstart + 2).toUnqualString(), std::string("What2"));
97 }
98 
99 void testFromQualString() {
100  ASSERT_EQ(Symbol::fromQualString("prim::Param"), Symbol::prim("Param"));
101  ASSERT_EQ(Symbol::fromQualString("aten::mm"), Symbol::aten("mm"));
102  ASSERT_EQ(Symbol::fromQualString("onnx::LSTM"), Symbol::onnx("LSTM"));
103  ASSERT_EQ(Symbol::fromQualString("attr::value"), Symbol::attr("value"));
104  ASSERT_EQ(Symbol::fromQualString("scope::"), Symbol::scope(""));
105  ASSERT_EQ(Symbol::fromQualString("::").toUnqualString(), std::string(""));
106  ASSERT_EQ(
107  Symbol::fromQualString("::").ns().toQualString(),
108  std::string("namespaces::"));
109  ASSERT_EQ(
110  Symbol::fromQualString("new_ns::param").toUnqualString(),
111  std::string("param"));
112  ASSERT_EQ(
113  Symbol::fromQualString("new_ns::param").ns().toUnqualString(),
114  std::string("new_ns"));
115  ASSERT_EQ(
116  Symbol::fromQualString("new_ns::param").ns(),
117  Symbol::fromQualString("namespaces::new_ns"));
118 
119  auto bad_inputs = {"scope", ":", ""};
120  for (auto input : bad_inputs) {
121  try {
122  Symbol::fromQualString(input);
123  ASSERT_TRUE(0);
124  } catch (const std::exception& c) {
125  }
126  }
127 }
128 
129 void testTHNNConv() {
130  std::vector<int64_t> input_size = {4, 3, 15, 17}; // B x C x H x W
131  std::vector<int64_t> kernel_size = {3, 5};
132  std::vector<int64_t> stride = {1, 2};
133  std::vector<int64_t> padding = {2, 1};
134  constexpr int out_channels = 5;
135 
136  // make inputs
137  at::Tensor input = torch::randn(input_size);
138  at::Tensor weight = torch::randn(
139  {out_channels, input_size[1], kernel_size[0], kernel_size[1]});
140  at::Tensor bias = torch::randn({out_channels});
141 
142  // run forward eagerly
143  at::Tensor output, finput, fgradinput;
144  std::tie(output, finput, fgradinput) = at::thnn_conv2d_forward(
145  input, weight, kernel_size, bias, stride, padding);
146 
147  // make grad_outputs
148  at::Tensor grad_output = torch::randn_like(output);
149  at::Tensor grad_finput = torch::zeros_like(finput);
150  at::Tensor grad_fgradinput = torch::zeros_like(fgradinput);
151 
152  // run backward eagerly
153  at::Tensor grad_input, grad_weight, grad_bias;
154  std::tie(grad_input, grad_weight, grad_bias) = at::thnn_conv2d_backward(
155  grad_output,
156  input,
157  weight,
158  kernel_size,
159  stride,
160  padding,
161  finput,
162  fgradinput,
163  {true, true, true});
164 
165  // make JIT graph
166  auto graph = std::make_shared<Graph>();
167  auto ksz_val = graph->insertConstant(IValue(kernel_size));
168  auto kst_val = graph->insertConstant(IValue(stride));
169  auto pad_val = graph->insertConstant(IValue(padding));
170 
171  auto inputg = graph->addInput("self");
172  auto weightg = graph->addInput("weight");
173  auto biasg = graph->addInput("bias");
174 
175  Value* conv = graph->insert(
176  aten::thnn_conv2d_forward,
177  {inputg, weightg, ksz_val, biasg, kst_val, pad_val});
178  auto outputs = conv->node()->outputs();
179  for (auto output : outputs) {
180  graph->registerOutput(output);
181  }
182  LowerAllTuples(graph);
183  graph->lint();
184 
185  // differentiate JIT graph
186  EliminateDeadCode(graph); // Tracing of some ops depends on the DCE trick
187  ConstantPropagation(graph);
188  auto grad_spec = differentiate(graph);
189  LowerGradOf(*grad_spec.df);
190 
191  // prepare JIT inputs / gradients
192  tensor_list tensors_in;
193  tensors_in.push_back(input);
194  tensors_in.push_back(weight);
195  tensors_in.push_back(bias);
196 
197  tensor_list tensor_grads_in;
198  tensor_grads_in.push_back(grad_output);
199  tensor_grads_in.push_back(grad_finput);
200  tensor_grads_in.push_back(grad_fgradinput);
201 
202  // Get outputs from the interpreter
203  tensor_list tensors_out, tensor_grads_out;
204  std::tie(tensors_out, tensor_grads_out) =
205  runGradient(grad_spec, tensors_in, tensor_grads_in);
206 
207  // prepare expected structs
208  tensor_list expected_tensors_out, expected_tensor_grads_out;
209  expected_tensors_out.push_back(output);
210  expected_tensors_out.push_back(finput);
211  expected_tensors_out.push_back(fgradinput);
212  expected_tensor_grads_out.push_back(grad_input);
213  expected_tensor_grads_out.push_back(grad_weight);
214  expected_tensor_grads_out.push_back(grad_bias);
215 
216  // Compare results
217  assertAllClose(tensors_out, expected_tensors_out);
218  assertAllClose(tensor_grads_out, expected_tensor_grads_out);
219 }
220 
221 void testATenNativeBatchNorm() {
222  // aten::native_batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor
223  // running_mean, Tensor running_var, bool training, float momentum, float eps)
224  // -> (Tensor, Tensor, Tensor)
225  std::vector<int64_t> input_size = {4, 3, 15, 17}; // B x C x H x W
226  bool training = true;
227  float momentum = 0.9;
228  float eps = 1e-5;
229 
230  // make inputs
231  at::Tensor input = torch::randn(input_size);
232  at::Tensor weight = torch::randn({input_size[1]});
233  at::Tensor bias = torch::randn({input_size[1]});
234  at::Tensor running_mean = torch::randn({input_size[1]});
235  at::Tensor running_var = torch::randn({input_size[1]});
236 
237  // running_mean and running_var are changed in-place, so clone and send them
238  at::Tensor running_mean_eager = running_mean.clone();
239  at::Tensor running_var_eager = running_var.clone();
240  at::Tensor running_mean_jit = running_mean.clone();
241  at::Tensor running_var_jit = running_var.clone();
242 
243  // run forward eagerly
244  at::Tensor output, savemean, saveinvstd;
245  std::tie(output, savemean, saveinvstd) = at::native_batch_norm(
246  input,
247  weight,
248  bias,
249  running_mean_eager,
250  running_var_eager,
251  training,
252  momentum,
253  eps);
254 
255  // make grad_outputs
256  at::Tensor grad_output = torch::randn_like(output);
257  at::Tensor grad_savemean = torch::zeros_like(savemean);
258  at::Tensor grad_saveinvstd = torch::zeros_like(saveinvstd);
259 
260  // run backward eagerly
261  at::Tensor grad_input, grad_weight, grad_bias;
262  // aten::native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor
263  // weight, Tensor running_mean, Tensor running_var, Tensor save_mean, Tensor
264  // save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor,
265  // Tensor, Tensor)
266  std::tie(grad_input, grad_weight, grad_bias) = at::native_batch_norm_backward(
267  grad_output,
268  input,
269  weight,
270  running_mean_eager,
271  running_var_eager,
272  savemean,
273  saveinvstd,
274  training,
275  eps,
276  {true, true, true});
277 
278  // make JIT graph
279  auto graph = std::make_shared<Graph>();
280  auto training_val = graph->insertConstant(IValue(training));
281  auto momentum_val = graph->insertConstant(IValue(momentum));
282  auto eps_val = graph->insertConstant(IValue(eps));
283 
284  auto inputg = graph->addInput("self");
285  auto weightg = graph->addInput("weight");
286  auto biasg = graph->addInput("bias");
287  auto running_meang = graph->addInput("running_mean");
288  auto running_varg = graph->addInput("running_var");
289 
290  Value* bn = graph->insert(
291  aten::native_batch_norm,
292  {inputg,
293  weightg,
294  biasg,
295  running_meang,
296  running_varg,
297  training_val,
298  momentum_val,
299  eps_val});
300  auto outputs = bn->node()->outputs();
301  for (auto output : outputs) {
302  graph->registerOutput(output);
303  }
304  LowerAllTuples(graph);
305  graph->lint();
306 
307  // differentiate JIT graph
308  EliminateDeadCode(graph); // Tracing of some ops depends on the DCE trick
309  ConstantPropagation(graph);
310  auto grad_spec = differentiate(graph);
311  LowerGradOf(*grad_spec.df);
312 
313  // prepare JIT inputs / gradients
314  tensor_list tensors_in;
315  tensors_in.push_back(input);
316  tensors_in.push_back(weight);
317  tensors_in.push_back(bias);
318  tensors_in.push_back(running_mean_jit);
319  tensors_in.push_back(running_var_jit);
320 
321  tensor_list tensor_grads_in;
322  tensor_grads_in.push_back(grad_output);
323  tensor_grads_in.push_back(grad_savemean);
324  tensor_grads_in.push_back(grad_saveinvstd);
325 
326  // Get outputs from the interpreter
327  tensor_list tensors_out, tensor_grads_out;
328  std::tie(tensors_out, tensor_grads_out) =
329  runGradient(grad_spec, tensors_in, tensor_grads_in);
330 
331  // prepare expected structs
332  tensor_list expected_tensors_out, expected_tensor_grads_out;
333  expected_tensors_out.push_back(output);
334  expected_tensors_out.push_back(savemean);
335  expected_tensors_out.push_back(saveinvstd);
336  expected_tensors_out.push_back(running_mean_eager);
337  expected_tensors_out.push_back(running_var_eager);
338  expected_tensor_grads_out.push_back(grad_input);
339  expected_tensor_grads_out.push_back(grad_weight);
340  expected_tensor_grads_out.push_back(grad_bias);
341 
342  tensors_out.push_back(running_mean_jit);
343  tensors_out.push_back(running_var_jit);
344 
345  // Compare results
346  assertAllClose(tensors_out, expected_tensors_out);
347  assertAllClose(tensor_grads_out, expected_tensor_grads_out);
348 }
349 
350 static const auto cf_examples = R"JIT(
351  def if_test(a, b):
352  # FIXME: use 0 instead of a.
353  # c = 0
354  c = a
355  if bool(a < b):
356  c = b
357  else:
358  c = a
359  return c
360  def if_one(a, b):
361  c = b
362  if bool(a < b):
363  c = a
364  return c
365  def while_test(a, i):
366  while bool(i < 3):
367  a *= a
368  i += 1
369  return a
370 )JIT";
371 void testControlFlow() {
372  auto cu = std::make_shared<script::Module>();
373  script::defineMethodsInModule(
374  cu, cf_examples, script::nativeResolver, c10::nullopt);
375  auto run = [&](const std::string& name, std::vector<IValue> stack) {
376  auto graph = cu->get_method(name).graph();
377  Code code(graph);
378  InterpreterState interp(code);
379  interp.run(stack);
380  return stack;
381  };
382 
383  auto L = [](int64_t l) {
384  return IValue(autograd::make_variable(scalar_to_tensor(at::Scalar(l))));
385  };
386  auto V = [](IValue t) { return std::move(t).toTensor().item<int64_t>(); };
387  auto run_binary = [&](const std::string& name, int64_t a, int64_t b) {
388  return V(run(name, {L(a), L(b)})[0]);
389  };
390  ASSERT_EQ(2, run_binary("if_test", 1, 2));
391  ASSERT_EQ(3, run_binary("if_test", 3, 2));
392  ASSERT_EQ(2, run_binary("if_one", 2, 3));
393  ASSERT_EQ(2, run_binary("if_one", 3, 2));
394  ASSERT_EQ(256, run_binary("while_test", 2, 0));
395 }
396 
397 void testProto() {
398  ::ONNX_NAMESPACE::ModelProto proto;
399  proto.set_producer_name("foo");
400 }
401 
402 void testEvalModeForLoadedModule() {
403  if (isSandcastle())
404  return; // The module file to load is not generated in Sandcastle
405  std::string module_path = "dropout_model.pt";
406  std::shared_ptr<torch::jit::script::Module> module =
407  torch::jit::load(module_path);
408  AT_ASSERT(module->get_module("dropout")->is_training());
409  module->eval();
410  AT_ASSERT(!module->get_module("dropout")->is_training());
411  module->train();
412  AT_ASSERT(module->get_module("dropout")->is_training());
413 }
414 
415 // test a few features that are not directly used in schemas yet
416 void testSchemaParser() {
417  // nested arrays
418  auto s = parseSchema("at::what(int[][4] foo) -> ()");
419  ASSERT_TRUE(s.arguments().at(0).N() == 4);
420  ASSERT_TRUE(IntType::get()->isSubtypeOf(s.arguments()
421  .at(0)
422  .type()
423  ->expect<ListType>()
424  ->getElementType()
425  ->expect<ListType>()
426  ->getElementType()));
427  auto s2 = parseSchema("at::what(int[][] foo) -> ()");
428  ASSERT_TRUE(IntType::get()->isSubtypeOf(s2.arguments()
429  .at(0)
430  .type()
431  ->expect<ListType>()
432  ->getElementType()
433  ->expect<ListType>()
434  ->getElementType()));
435 
436  // named returns
437  parseSchema("at::what(Tensor! i_will_be_written_to) -> ()");
438  auto s3 =
439  parseSchema("at::what() -> (Tensor the_return, Tensor the_return2)");
440  ASSERT_TRUE(s3.returns().at(0).name() == "the_return");
441  ASSERT_TRUE(s3.returns().at(1).name() == "the_return2");
442 
443  // futures
444  auto s4 = parseSchema("at::what(Future(int) foo) -> ()");
445  ASSERT_TRUE(IntType::get()->isSubtypeOf(
446  s4.arguments().at(0).type()->expect<FutureType>()->getElementType()));
447 
448  // test tensor with annotated alias sets
449  parseSchema("at::what(Tensor(a) foo) -> (Tensor(a))");
450 
451  {
452  const auto s = parseSchema(
453  "at::what(Tensor(b|c)[](a!) list, Tensor(c) element)"
454  " -> (Tensor(b|c)[](a!))");
455 
456  // The list itself is annotated with `a`
457  const auto& aliasInfo = *s.arguments().at(0).alias_info();
458  ASSERT_TRUE(
459  aliasInfo.beforeSets() ==
460  std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
461  ASSERT_TRUE(aliasInfo.isWrite());
462 
463  // Check the contained types
464  ASSERT_TRUE(!aliasInfo.containedTypes().empty());
465  const auto& containedAliasInfo = aliasInfo.containedTypes()[0];
466  const auto expected = std::unordered_set<Symbol>{
467  Symbol::fromQualString("alias::b"),
468  Symbol::fromQualString("alias::c"),
469  };
470  ASSERT_TRUE(containedAliasInfo.beforeSets() == expected);
471  ASSERT_TRUE(containedAliasInfo.afterSets() == expected);
472  ASSERT_FALSE(containedAliasInfo.isWrite());
473  }
474  {
475  const auto s = parseSchema(
476  "at::what(Tensor(b -> b|c)[](a!) list, Tensor(c) element)"
477  " -> (Tensor(b|c)[](a!))");
478 
479  // The list itself is annotated with `a`
480  const auto& aliasInfo = *s.arguments().at(0).alias_info();
481  ASSERT_EQ(
482  aliasInfo.beforeSets(),
483  std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
484  ASSERT_EQ(
485  aliasInfo.afterSets(),
486  std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
487  ASSERT_TRUE(aliasInfo.isWrite());
488  ASSERT_EQ(aliasInfo.containedTypes().size(), 1);
489 
490  // Check the contained types
491  ASSERT_TRUE(!aliasInfo.containedTypes().empty());
492  const auto& containedAliasInfo = aliasInfo.containedTypes()[0];
493  const auto expectedBefore = std::unordered_set<Symbol>{
494  Symbol::fromQualString("alias::b"),
495  };
496  const auto expectedAfter = std::unordered_set<Symbol>{
497  Symbol::fromQualString("alias::b"), Symbol::fromQualString("alias::c")};
498  ASSERT_TRUE(containedAliasInfo.beforeSets() == expectedBefore);
499  ASSERT_TRUE(containedAliasInfo.afterSets() == expectedAfter);
500  ASSERT_FALSE(containedAliasInfo.isWrite());
501  }
502 }
503 
504 void testTopologicalIndex() {
505  {
506  Graph graph;
507  auto node1 = graph.create(prim::AutogradZero);
508  auto node2 = graph.create(prim::AutogradZero);
509  auto node3 = graph.create(prim::AutogradZero);
510  auto node4 = graph.create(prim::AutogradZero);
511 
512  graph.appendNode(node4);
513  graph.prependNode(node1);
514  node2->insertAfter(node1);
515  node3->insertBefore(node4);
516 
517  // nodes should be in numerical order
518  ASSERT_TRUE(node1->isBefore(node2));
519  ASSERT_TRUE(node1->isBefore(node3));
520  ASSERT_TRUE(node1->isBefore(node4));
521  ASSERT_TRUE(node2->isAfter(node1));
522  ASSERT_TRUE(node2->isBefore(node3));
523  ASSERT_TRUE(node2->isBefore(node4));
524  ASSERT_FALSE(node3->isBefore(node1));
525  ASSERT_FALSE(node3->isBefore(node2));
526  ASSERT_FALSE(node3->isAfter(node4));
527 
528  // Built up a block structure
529  // node3
530  // /\ ...
531  // A B block1
532  // \ ...
533  // C block2
534  auto block1 = node3->addBlock();
535  auto A = graph.create(prim::AutogradZero);
536  block1->appendNode(A);
537  auto B = graph.create(prim::AutogradZero);
538  block1->appendNode(B);
539  auto block2 = B->addBlock();
540  auto C = graph.create(prim::AutogradZero);
541  block2->appendNode(C);
542 
543  // Check isAfter on different block levels
544  ASSERT_TRUE(node1->isBefore(A));
545  ASSERT_TRUE(A->isBefore(B));
546  ASSERT_TRUE(A->isBefore(C));
547 
548  // make sure things don't blow up on deletions
549  node2->destroy();
550  auto node2p = graph.create(prim::AutogradZero);
551  node2p->insertAfter(node1);
552  ASSERT_TRUE(node1->isBefore(node2p));
553  ASSERT_TRUE(node2p->isBefore(node3));
554  }
555  {
556  // Induce reindexing to test that path
557  Graph graph;
558  std::map<size_t, Node*> nodes;
559 
560  auto anchor = graph.create(prim::AutogradZero);
561  graph.appendNode(anchor);
562  // Inserting to the same place a lot will trigger reindexing
563  for (auto i = 0; i < 100; ++i) {
564  auto n = graph.create(prim::AutogradZero);
565  n->insertAfter(anchor);
566  nodes[i] = n;
567  }
568 
569  // Nodes should be in reverse order
570  for (auto i = 0; i < 100; ++i) {
571  for (auto j = i + 1; j < 100; ++j) {
572  ASSERT_TRUE(nodes[i]->isAfter(nodes[j]));
573  }
574  }
575  }
576 }
577 
578 void testAutogradProfiler() {
579  constexpr int batch_size = 4;
580  constexpr int input_size = 256;
581  constexpr int seq_len = 32;
582 
583  int hidden_size = 2 * input_size;
584  auto input = torch::randn({seq_len, batch_size, input_size}, at::kCPU);
585  auto hx = torch::randn({batch_size, hidden_size}, at::kCPU);
586  auto cx = torch::randn({batch_size, hidden_size}, at::kCPU);
587  auto w_ih = t_def(torch::randn({4 * hidden_size, input_size}, at::kCPU));
588  auto w_hh = t_def(torch::randn({4 * hidden_size, hidden_size}, at::kCPU));
589 
590  std::stringstream ss;
591  {
592  autograd::profiler::RecordProfile guard(ss);
593  for (size_t i = 0; i < 100; ++i) {
594  std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh);
595  }
596  }
597 
598  std::string result = ss.str();
599  size_t count = 0;
600  for (size_t pos = 0; (pos = result.find("tanh", pos)) != std::string::npos;
601  count++, pos++) {
602  }
603  AT_CHECK(count == 200);
604 }
605 
606 void testNoneSchemaMatch() {
607  RegisterOperators reg({
608  Operator(
609  "test::test_none() -> int?",
610  [](const Node* node) {
611  return [](Stack& stack) {
612  push(stack, IValue());
613  return 0;
614  };
615  }),
616  Operator(
617  "test::is_none(int? a) -> bool",
618  [](const Node* node) {
619  return [](Stack& stack) {
620  IValue a = pop(stack);
621  if (a.isNone()) {
622  push(stack, true);
623  } else {
624  push(stack, false);
625  }
626  return 0;
627  };
628  }),
629  });
630 
631  // Constant propagation will run test_none and produce a None,
632  // testing that its type is set appropriately and schema matching doesn't
633  // fail when running is_none
634 
635  auto r = std::make_shared<Graph>();
636  auto& g = *r;
637  auto opt_int = g.insert(Symbol::fromQualString("test::test_none"), {});
638  auto out_bool = g.insert(Symbol::fromQualString("test::is_none"), {opt_int});
639  g.registerOutput(out_bool);
640  ConstantPropagation(r);
641 
642  auto nodes = r->block()->nodes();
643  // checking that constant propagation ran wo/failure
644  AT_ASSERT(std::distance(nodes.begin(), nodes.end()) == 1);
645 }
646 } // namespace test
647 } // namespace jit
648 } // namespace torch
Scalar represents a 0-dimensional tensor which contains a single element.
Definition: Scalar.h:22
Definition: module.cpp:17
does bound shape inference given a C2 net.
Definition: static.cpp:64
Definition: jit_type.h:17
Definition: static.cpp:58