Caffe2 - C++ API
A deep learning, cross platform ML framework
jit.cpp
1 #include <gtest/gtest.h>
2 
3 #include <torch/jit.h>
4 #include <torch/types.h>
5 
6 #include <string>
7 
8 TEST(TorchScriptTest, CanCompileMultipleFunctions) {
9  auto module = torch::jit::compile(R"JIT(
10  def test_mul(a, b):
11  return a * b
12  def test_relu(a, b):
13  return torch.relu(a + b)
14  def test_while(a, i):
15  while bool(i < 10):
16  a += a
17  i += 1
18  return a
19  def test_len(a : List[int]):
20  return len(a)
21  )JIT");
22  auto a = torch::ones(1);
23  auto b = torch::ones(1);
24 
25  ASSERT_EQ(1, module->run_method("test_mul", a, b).toTensor().item<int64_t>());
26 
27  ASSERT_EQ(2, module->run_method("test_relu", a, b).toTensor().item<int64_t>());
28 
29  ASSERT_TRUE(
30  0x200 == module->run_method("test_while", a, b).toTensor().item<int64_t>());
31 
32  at::IValue list = std::vector<int64_t>({3, 4});
33  ASSERT_EQ(2, module->run_method("test_len", list).toInt());
34 
35 }
36 
37 
38 TEST(TorchScriptTest, TestNestedIValueModuleArgMatching) {
39  auto module = torch::jit::compile(R"JIT(
40  def nested_loop(a: List[List[Tensor]], b: int):
41  return torch.tensor(1.0) + b
42  )JIT");
43 
44  auto b = 3;
45 
46  std::vector<torch::Tensor> list = {torch::rand({4, 4})};
47 
48  std::vector<torch::jit::IValue> list_of_lists;
49  list_of_lists.push_back(list);
50  module->run_method("nested_loop", list_of_lists, b);
51 
52  std::vector<torch::jit::IValue> generic_list;
53  std::vector<torch::jit::IValue> empty_generic_list;
54  empty_generic_list.push_back(generic_list);
55  module->run_method("nested_loop", empty_generic_list, b);
56 
57  std::vector<torch::jit::IValue> too_many_lists;
58  too_many_lists.push_back(empty_generic_list);
59  try {
60  module->run_method("nested_loop", too_many_lists, b);
61  AT_ASSERT(false);
62  } catch (const c10::Error& error) {
63  AT_ASSERT(
64  std::string(error.what_without_backtrace())
65  .find("Expected value of type Tensor[][] for argument 'a' in "
66  "position 0, but instead got value of type t[][][]") == 0);
67 
68  };
69 
70  std::vector<torch::jit::IValue> gen_list;
71  std::vector<int64_t> int_list = {1, 2, 3};
72 
73  gen_list.emplace_back(list);
74  gen_list.emplace_back(int_list);
75 
76  try {
77  module->run_method("nested_loop", gen_list, b);
78  AT_ASSERT(false);
79  } catch (const c10::Error& error) {
80  //TODO: currently does not unify types across encounted generic lists,
81  //so the error message is not helpful here.
82  AT_ASSERT(
83  std::string(error.what_without_backtrace())
84  .find("Expected value of type Tensor[][] for argument 'a' in "
85  "position 0, but instead got value of type Tensor[][]") == 0);
86 
87  };
88 }
89 
90 
91 TEST(TorchScriptTest, TestDictArgMatching) {
92  auto module = torch::jit::compile(R"JIT(
93  def dict_op(a: Dict[str, Tensor], b: str):
94  return a[b]
95  )JIT");
96  c10::ivalue::UnorderedMap dict;
97  dict[std::string("hello")] = torch::ones({2});
98  auto output = module->run_method("dict_op", dict, std::string("hello"));
99  ASSERT_EQ(1, output.toTensor()[0].item<int64_t>());
100 }
101 
102 TEST(TorchScriptTest, TestTupleArgMatching) {
103  auto module = torch::jit::compile(R"JIT(
104  def tuple_op(a: Tuple[List[int]]):
105  return a
106  )JIT");
107 
108  std::vector<int64_t> int_list = {1};
109  auto tuple_generic_list = torch::jit::Tuple::create({ int_list });
110 
111  // doesn't fail on arg matching
112  module->run_method("tuple_op", tuple_generic_list);
113 
114 }
The primary ATen error class.
Definition: Exception.h:27
const char * what_without_backtrace() const noexcept
Returns only the error message string, without source location.
Definition: Exception.h:79