Caffe2 - C++ API
A deep learning, cross platform ML framework
test_custom_ops.cpp
1 #include <torch/script.h>
2 #include <torch/cuda.h>
3 
4 #include "op.h"
5 
6 #include <memory>
7 #include <string>
8 #include <vector>
9 
10 #include <iostream>
11 
12 namespace helpers {
13 template <typename Predicate>
14 void check_all_parameters(
15  const torch::jit::script::Module& module,
16  Predicate predicate) {
17  for (const auto& parameter : module.get_parameters()) {
18  AT_ASSERT(predicate(parameter->slot()->toTensor()));
19  }
20  for (const auto& child : module.get_modules()) {
21  check_all_parameters(*child->module, predicate);
22  }
23 }
24 } // namespace helpers
25 
26 void get_operator_from_registry_and_execute() {
27  auto& ops = torch::jit::getAllOperatorsFor(
28  torch::jit::Symbol::fromQualString("custom::op"));
29  AT_ASSERT(ops.size() == 1);
30 
31  auto& op = ops.front();
32  AT_ASSERT(op->schema().name() == "custom::op");
33 
34  torch::jit::Stack stack;
35  torch::jit::push(stack, torch::ones(5), 2.0, 3);
36  op->getOperation()(stack);
37  std::vector<torch::Tensor> output;
38  torch::jit::pop(stack, output);
39 
40  const auto manual = custom_op(torch::ones(5), 2.0, 3);
41 
42  AT_ASSERT(output.size() == 3);
43  for (size_t i = 0; i < output.size(); ++i) {
44  AT_ASSERT(output[i].allclose(torch::ones(5) * 2));
45  AT_ASSERT(output[i].allclose(manual[i]));
46  }
47 }
48 
49 void load_serialized_module_with_custom_op_and_execute(
50  const std::string& path_to_exported_script_module) {
51  std::shared_ptr<torch::jit::script::Module> module =
52  torch::jit::load(path_to_exported_script_module);
53  AT_ASSERT(module != nullptr);
54 
55  std::vector<torch::jit::IValue> inputs;
56  inputs.push_back(torch::ones(5));
57  auto output = module->forward(inputs).toTensor();
58 
59  AT_ASSERT(output.allclose(torch::ones(5) + 1));
60 }
61 
62 void test_argument_checking_for_serialized_modules(
63  const std::string& path_to_exported_script_module) {
64  std::shared_ptr<torch::jit::script::Module> module =
65  torch::jit::load(path_to_exported_script_module);
66  AT_ASSERT(module != nullptr);
67 
68  try {
69  module->forward({torch::jit::IValue(1), torch::jit::IValue(2)});
70  AT_ASSERT(false);
71  } catch (const c10::Error& error) {
72  AT_ASSERT(
73  std::string(error.what_without_backtrace())
74  .find("Expected at most 1 argument(s) for operator 'forward', "
75  "but received 2 argument(s)") == 0);
76  }
77 
78  try {
79  module->forward({torch::jit::IValue(5)});
80  AT_ASSERT(false);
81  } catch (const c10::Error& error) {
82  AT_ASSERT(
83  std::string(error.what_without_backtrace())
84  .find("Expected value of type Tensor for argument 'input' in "
85  "position 0, but instead got value of type int") == 0);
86  }
87 
88  try {
89  module->forward({});
90  AT_ASSERT(false);
91  } catch (const c10::Error& error) {
92  AT_ASSERT(
93  std::string(error.what_without_backtrace())
94  .find("forward() is missing value for argument 'input'") == 0);
95  }
96 }
97 
98 void test_move_to_device(const std::string& path_to_exported_script_module) {
99  std::shared_ptr<torch::jit::script::Module> module =
100  torch::jit::load(path_to_exported_script_module);
101  AT_ASSERT(module != nullptr);
102 
103  helpers::check_all_parameters(*module, [](const torch::Tensor& tensor) {
104  return tensor.device().is_cpu();
105  });
106 
107  module->to(torch::kCUDA);
108 
109  helpers::check_all_parameters(*module, [](const torch::Tensor& tensor) {
110  return tensor.device().is_cuda();
111  });
112 
113  module->to(torch::kCPU);
114 
115  helpers::check_all_parameters(*module, [](const torch::Tensor& tensor) {
116  return tensor.device().is_cpu();
117  });
118 }
119 
120 void test_move_to_dtype(const std::string& path_to_exported_script_module) {
121  std::shared_ptr<torch::jit::script::Module> module =
122  torch::jit::load(path_to_exported_script_module);
123  AT_ASSERT(module != nullptr);
124 
125  module->to(torch::kInt);
126 
127  helpers::check_all_parameters(*module, [](const torch::Tensor& tensor) {
128  return tensor.dtype() == torch::kInt;
129  });
130 
131  module->to(torch::kDouble);
132 
133  helpers::check_all_parameters(*module, [](const torch::Tensor& tensor) {
134  return tensor.dtype() == torch::kDouble;
135  });
136 }
137 
138 int main(int argc, const char* argv[]) {
139  if (argc != 2) {
140  std::cerr << "usage: test_custom_ops <path-to-exported-script-module>\n";
141  return -1;
142  }
143  const std::string path_to_exported_script_module = argv[1];
144 
145  get_operator_from_registry_and_execute();
146  load_serialized_module_with_custom_op_and_execute(
147  path_to_exported_script_module);
148  test_argument_checking_for_serialized_modules(path_to_exported_script_module);
149  test_move_to_dtype(path_to_exported_script_module);
150 
151  if (torch::cuda::device_count() > 0) {
152  test_move_to_device(path_to_exported_script_module);
153  }
154 
155  std::cout << "ok\n";
156 }
bool is_cuda() const noexcept
Return true if the device is of CUDA type.
Definition: Device.h:80
caffe2::TypeMeta dtype() const noexcept
Returns a Tensor&#39;s dtype (TypeMeta). Defined in TensorMethods.h.
The primary ATen error class.
Definition: Exception.h:27
Device device() const
Returns a Tensor&#39;s device.
bool is_cpu() const noexcept
Return true if the device is of CPU type.
Definition: Device.h:85
const char * what_without_backtrace() const noexcept
Returns only the error message string, without source location.
Definition: Exception.h:79