1 #include <torch/script.h> 2 #include <torch/cuda.h> 13 template <
typename Predicate>
14 void check_all_parameters(
16 Predicate predicate) {
17 for (
const auto& parameter : module.get_parameters()) {
18 AT_ASSERT(predicate(parameter->slot()->toTensor()));
20 for (
const auto& child : module.get_modules()) {
21 check_all_parameters(*child->module, predicate);
26 void get_operator_from_registry_and_execute() {
27 auto& ops = torch::jit::getAllOperatorsFor(
28 torch::jit::Symbol::fromQualString(
"custom::op"));
29 AT_ASSERT(ops.size() == 1);
31 auto& op = ops.front();
32 AT_ASSERT(op->schema().name() ==
"custom::op");
34 torch::jit::Stack stack;
35 torch::jit::push(stack, torch::ones(5), 2.0, 3);
36 op->getOperation()(stack);
37 std::vector<torch::Tensor> output;
38 torch::jit::pop(stack, output);
40 const auto manual = custom_op(torch::ones(5), 2.0, 3);
42 AT_ASSERT(output.size() == 3);
43 for (
size_t i = 0; i < output.size(); ++i) {
44 AT_ASSERT(output[i].allclose(torch::ones(5) * 2));
45 AT_ASSERT(output[i].allclose(manual[i]));
49 void load_serialized_module_with_custom_op_and_execute(
50 const std::string& path_to_exported_script_module) {
51 std::shared_ptr<torch::jit::script::Module> module =
52 torch::jit::load(path_to_exported_script_module);
53 AT_ASSERT(module !=
nullptr);
55 std::vector<torch::jit::IValue> inputs;
56 inputs.push_back(torch::ones(5));
57 auto output = module->forward(inputs).toTensor();
59 AT_ASSERT(output.allclose(torch::ones(5) + 1));
62 void test_argument_checking_for_serialized_modules(
63 const std::string& path_to_exported_script_module) {
64 std::shared_ptr<torch::jit::script::Module> module =
65 torch::jit::load(path_to_exported_script_module);
66 AT_ASSERT(module !=
nullptr);
74 .find(
"Expected at most 1 argument(s) for operator 'forward', " 75 "but received 2 argument(s)") == 0);
84 .find(
"Expected value of type Tensor for argument 'input' in " 85 "position 0, but instead got value of type int") == 0);
94 .find(
"forward() is missing value for argument 'input'") == 0);
98 void test_move_to_device(
const std::string& path_to_exported_script_module) {
99 std::shared_ptr<torch::jit::script::Module> module =
100 torch::jit::load(path_to_exported_script_module);
101 AT_ASSERT(module !=
nullptr);
103 helpers::check_all_parameters(*module, [](
const torch::Tensor& tensor) {
107 module->to(torch::kCUDA);
109 helpers::check_all_parameters(*module, [](
const torch::Tensor& tensor) {
113 module->to(torch::kCPU);
115 helpers::check_all_parameters(*module, [](
const torch::Tensor& tensor) {
120 void test_move_to_dtype(
const std::string& path_to_exported_script_module) {
121 std::shared_ptr<torch::jit::script::Module> module =
122 torch::jit::load(path_to_exported_script_module);
123 AT_ASSERT(module !=
nullptr);
125 module->to(torch::kInt);
127 helpers::check_all_parameters(*module, [](
const torch::Tensor& tensor) {
128 return tensor.
dtype() == torch::kInt;
131 module->to(torch::kDouble);
133 helpers::check_all_parameters(*module, [](
const torch::Tensor& tensor) {
134 return tensor.
dtype() == torch::kDouble;
138 int main(
int argc,
const char* argv[]) {
140 std::cerr <<
"usage: test_custom_ops <path-to-exported-script-module>\n";
143 const std::string path_to_exported_script_module = argv[1];
145 get_operator_from_registry_and_execute();
146 load_serialized_module_with_custom_op_and_execute(
147 path_to_exported_script_module);
148 test_argument_checking_for_serialized_modules(path_to_exported_script_module);
149 test_move_to_dtype(path_to_exported_script_module);
151 if (torch::cuda::device_count() > 0) {
152 test_move_to_device(path_to_exported_script_module);
bool is_cuda() const noexcept
Return true if the device is of CUDA type.
caffe2::TypeMeta dtype() const noexcept
Returns a Tensor's dtype (TypeMeta). Defined in TensorMethods.h.
The primary ATen error class.
Device device() const
Returns a Tensor's device.
bool is_cpu() const noexcept
Return true if the device is of CPU type.
const char * what_without_backtrace() const noexcept
Returns only the error message string, without source location.