1 #include <gtest/gtest.h> 4 #include <torch/types.h> 8 TEST(TorchScriptTest, CanCompileMultipleFunctions) {
9 auto module = torch::jit::compile(R
"JIT( 13 return torch.relu(a + b) 19 def test_len(a : List[int]): 22 auto a = torch::ones(1);
23 auto b = torch::ones(1);
25 ASSERT_EQ(1, module->run_method(
"test_mul", a, b).toTensor().item<int64_t>());
27 ASSERT_EQ(2, module->run_method(
"test_relu", a, b).toTensor().item<int64_t>());
30 0x200 == module->run_method(
"test_while", a, b).toTensor().item<int64_t>());
32 at::IValue list = std::vector<int64_t>({3, 4});
33 ASSERT_EQ(2, module->run_method(
"test_len", list).toInt());
38 TEST(TorchScriptTest, TestNestedIValueModuleArgMatching) {
39 auto module = torch::jit::compile(R
"JIT( 40 def nested_loop(a: List[List[Tensor]], b: int): 41 return torch.tensor(1.0) + b 46 std::vector<torch::Tensor> list = {torch::rand({4, 4})};
48 std::vector<torch::jit::IValue> list_of_lists;
49 list_of_lists.push_back(list);
50 module->run_method(
"nested_loop", list_of_lists, b);
52 std::vector<torch::jit::IValue> generic_list;
53 std::vector<torch::jit::IValue> empty_generic_list;
54 empty_generic_list.push_back(generic_list);
55 module->run_method(
"nested_loop", empty_generic_list, b);
57 std::vector<torch::jit::IValue> too_many_lists;
58 too_many_lists.push_back(empty_generic_list);
60 module->run_method(
"nested_loop", too_many_lists, b);
65 .find(
"Expected value of type Tensor[][] for argument 'a' in " 66 "position 0, but instead got value of type t[][][]") == 0);
70 std::vector<torch::jit::IValue> gen_list;
71 std::vector<int64_t> int_list = {1, 2, 3};
73 gen_list.emplace_back(list);
74 gen_list.emplace_back(int_list);
77 module->run_method(
"nested_loop", gen_list, b);
84 .find(
"Expected value of type Tensor[][] for argument 'a' in " 85 "position 0, but instead got value of type Tensor[][]") == 0);
91 TEST(TorchScriptTest, TestDictArgMatching) {
92 auto module = torch::jit::compile(R
"JIT( 93 def dict_op(a: Dict[str, Tensor], b: str): 96 c10::ivalue::UnorderedMap dict; 97 dict[std::string("hello")] = torch::ones({2});
98 auto output = module->run_method(
"dict_op", dict, std::string(
"hello"));
99 ASSERT_EQ(1, output.toTensor()[0].item<int64_t>());
102 TEST(TorchScriptTest, TestTupleArgMatching) {
103 auto module = torch::jit::compile(R
"JIT( 104 def tuple_op(a: Tuple[List[int]]): 108 std::vector<int64_t> int_list = {1}; 109 auto tuple_generic_list = torch::jit::Tuple::create({ int_list });
112 module->run_method(
"tuple_op", tuple_generic_list);
The primary ATen error class.
const char * what_without_backtrace() const noexcept
Returns only the error message string, without source location.