Caffe2 - Python API
A deep learning, cross platform ML framework
test_cpp_extensions.py
1 import os
2 import shutil
3 import sys
4 import unittest
5 import warnings
6 
7 import common_utils as common
8 import torch
11 from torch.utils.cpp_extension import CUDA_HOME
12 
13 
14 try:
15  import torch_test_cpp_extension.cpp as cpp_extension
16  import torch_test_cpp_extension.msnpu as msnpu_extension
17 except ImportError:
18  warnings.warn(
19  "test_cpp_extensions.py cannot be invoked directly. Run "
20  "`python run_test.py -i cpp_extensions` instead."
21  )
22 
23 
24 TEST_CUDA = torch.cuda.is_available() and CUDA_HOME is not None
25 TEST_CUDNN = False
26 if TEST_CUDA:
27  CUDNN_HEADER_EXISTS = os.path.isfile(os.path.join(CUDA_HOME, "include/cudnn.h"))
28  TEST_CUDNN = (
29  TEST_CUDA and CUDNN_HEADER_EXISTS and torch.backends.cudnn.is_available()
30  )
31 IS_WINDOWS = sys.platform == "win32"
32 
33 
34 # This effectively allows re-using the same extension (compiled once) in
35 # multiple tests, just to split up the tested properties.
36 def dont_wipe_extensions_build_folder(func):
37  func.dont_wipe = True
38  return func
39 
40 
41 class TestCppExtension(common.TestCase):
42  def setUp(self):
43  test_name = self.id().split(".")[-1]
44  dont_wipe = hasattr(getattr(self, test_name), "dont_wipe")
45  if dont_wipe:
46  print(
47  "Test case {} has 'dont_wipe' attribute set, ".format(test_name)
48  + "therefore not wiping extensions build folder before running the test"
49  )
50  return
51  if sys.platform == "win32":
52  print("Not wiping extensions build folder because Windows")
53  return
55  if os.path.exists(default_build_root):
56  shutil.rmtree(default_build_root)
57 
58  def test_extension_function(self):
59  x = torch.randn(4, 4)
60  y = torch.randn(4, 4)
61  z = cpp_extension.sigmoid_add(x, y)
62  self.assertEqual(z, x.sigmoid() + y.sigmoid())
63 
64  def test_extension_module(self):
65  mm = cpp_extension.MatrixMultiplier(4, 8)
66  weights = torch.rand(8, 4)
67  expected = mm.get().mm(weights)
68  result = mm.forward(weights)
69  self.assertEqual(expected, result)
70 
71  def test_backward(self):
72  mm = cpp_extension.MatrixMultiplier(4, 8)
73  weights = torch.rand(8, 4, requires_grad=True)
74  result = mm.forward(weights)
75  result.sum().backward()
76  tensor = mm.get()
77 
78  expected_weights_grad = tensor.t().mm(torch.ones([4, 4]))
79  self.assertEqual(weights.grad, expected_weights_grad)
80 
81  expected_tensor_grad = torch.ones([4, 4]).mm(weights.t())
82  self.assertEqual(tensor.grad, expected_tensor_grad)
83 
84  def test_jit_compile_extension(self):
86  name="jit_extension",
87  sources=[
88  "cpp_extensions/jit_extension.cpp",
89  "cpp_extensions/jit_extension2.cpp",
90  ],
91  extra_include_paths=["cpp_extensions"],
92  extra_cflags=["-g"],
93  verbose=True,
94  )
95  x = torch.randn(4, 4)
96  y = torch.randn(4, 4)
97 
98  z = module.tanh_add(x, y)
99  self.assertEqual(z, x.tanh() + y.tanh())
100 
101  # Checking we can call a method defined not in the main C++ file.
102  z = module.exp_add(x, y)
103  self.assertEqual(z, x.exp() + y.exp())
104 
105  # Checking we can use this JIT-compiled class.
106  doubler = module.Doubler(2, 2)
107  self.assertIsNone(doubler.get().grad)
108  self.assertEqual(doubler.get().sum(), 4)
109  self.assertEqual(doubler.forward().sum(), 8)
110 
111  @unittest.skipIf(not TEST_CUDA, "CUDA not found")
112  def test_cuda_extension(self):
113  import torch_test_cpp_extension.cuda as cuda_extension
114 
115  x = torch.zeros(100, device="cuda", dtype=torch.float32)
116  y = torch.zeros(100, device="cuda", dtype=torch.float32)
117 
118  z = cuda_extension.sigmoid_add(x, y).cpu()
119 
120  # 2 * sigmoid(0) = 2 * 0.5 = 1
121  self.assertEqual(z, torch.ones_like(z))
122 
123  @unittest.skipIf(not TEST_CUDA, "CUDA not found")
124  def test_jit_cuda_extension(self):
125  # NOTE: The name of the extension must equal the name of the module.
127  name="torch_test_cuda_extension",
128  sources=[
129  "cpp_extensions/cuda_extension.cpp",
130  "cpp_extensions/cuda_extension.cu",
131  ],
132  extra_cuda_cflags=["-O2"],
133  verbose=True,
134  )
135 
136  x = torch.zeros(100, device="cuda", dtype=torch.float32)
137  y = torch.zeros(100, device="cuda", dtype=torch.float32)
138 
139  z = module.sigmoid_add(x, y).cpu()
140 
141  # 2 * sigmoid(0) = 2 * 0.5 = 1
142  self.assertEqual(z, torch.ones_like(z))
143 
144  @unittest.skipIf(not TEST_CUDNN, "CuDNN not found")
145  def test_jit_cudnn_extension(self):
146  # implementation of CuDNN ReLU
147  if IS_WINDOWS:
148  extra_ldflags = ["cudnn.lib"]
149  else:
150  extra_ldflags = ["-lcudnn"]
152  name="torch_test_cudnn_extension",
153  sources=["cpp_extensions/cudnn_extension.cpp"],
154  extra_ldflags=extra_ldflags,
155  verbose=True,
156  with_cuda=True,
157  )
158 
159  x = torch.randn(100, device="cuda", dtype=torch.float32)
160  y = torch.zeros(100, device="cuda", dtype=torch.float32)
161  module.cudnn_relu(x, y) # y=relu(x)
162  self.assertEqual(torch.nn.functional.relu(x), y)
163  with self.assertRaisesRegex(RuntimeError, "same size"):
164  y_incorrect = torch.zeros(20, device="cuda", dtype=torch.float32)
165  module.cudnn_relu(x, y_incorrect)
166 
167  def test_optional(self):
168  has_value = cpp_extension.function_taking_optional(torch.ones(5))
169  self.assertTrue(has_value)
170  has_value = cpp_extension.function_taking_optional(None)
171  self.assertFalse(has_value)
172 
173  def test_inline_jit_compile_extension_with_functions_as_list(self):
174  cpp_source = """
175  torch::Tensor tanh_add(torch::Tensor x, torch::Tensor y) {
176  return x.tanh() + y.tanh();
177  }
178  """
179 
181  name="inline_jit_extension_with_functions_list",
182  cpp_sources=cpp_source,
183  functions="tanh_add",
184  verbose=True,
185  )
186 
187  self.assertEqual(module.tanh_add.__doc__.split("\n")[2], "tanh_add")
188 
189  x = torch.randn(4, 4)
190  y = torch.randn(4, 4)
191 
192  z = module.tanh_add(x, y)
193  self.assertEqual(z, x.tanh() + y.tanh())
194 
195  def test_inline_jit_compile_extension_with_functions_as_dict(self):
196  cpp_source = """
197  torch::Tensor tanh_add(torch::Tensor x, torch::Tensor y) {
198  return x.tanh() + y.tanh();
199  }
200  """
201 
203  name="inline_jit_extension_with_functions_dict",
204  cpp_sources=cpp_source,
205  functions={"tanh_add": "Tanh and then sum :D"},
206  verbose=True,
207  )
208 
209  self.assertEqual(module.tanh_add.__doc__.split("\n")[2], "Tanh and then sum :D")
210 
211  def test_inline_jit_compile_extension_multiple_sources_and_no_functions(self):
212  cpp_source1 = """
213  torch::Tensor sin_add(torch::Tensor x, torch::Tensor y) {
214  return x.sin() + y.sin();
215  }
216  """
217 
218  cpp_source2 = """
219  #include <torch/extension.h>
220  torch::Tensor sin_add(torch::Tensor x, torch::Tensor y);
221  PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
222  m.def("sin_add", &sin_add, "sin(x) + sin(y)");
223  }
224  """
225 
227  name="inline_jit_extension",
228  cpp_sources=[cpp_source1, cpp_source2],
229  verbose=True,
230  )
231 
232  x = torch.randn(4, 4)
233  y = torch.randn(4, 4)
234 
235  z = module.sin_add(x, y)
236  self.assertEqual(z, x.sin() + y.sin())
237 
238  @unittest.skipIf(not TEST_CUDA, "CUDA not found")
239  def test_inline_jit_compile_extension_cuda(self):
240  cuda_source = """
241  __global__ void cos_add_kernel(
242  const float* __restrict__ x,
243  const float* __restrict__ y,
244  float* __restrict__ output,
245  const int size) {
246  const auto index = blockIdx.x * blockDim.x + threadIdx.x;
247  if (index < size) {
248  output[index] = __cosf(x[index]) + __cosf(y[index]);
249  }
250  }
251 
252  torch::Tensor cos_add(torch::Tensor x, torch::Tensor y) {
253  auto output = torch::zeros_like(x);
254  const int threads = 1024;
255  const int blocks = (output.numel() + threads - 1) / threads;
256  cos_add_kernel<<<blocks, threads>>>(x.data<float>(), y.data<float>(), output.data<float>(), output.numel());
257  return output;
258  }
259  """
260 
261  # Here, the C++ source need only declare the function signature.
262  cpp_source = "torch::Tensor cos_add(torch::Tensor x, torch::Tensor y);"
263 
265  name="inline_jit_extension_cuda",
266  cpp_sources=cpp_source,
267  cuda_sources=cuda_source,
268  functions=["cos_add"],
269  verbose=True,
270  )
271 
272  self.assertEqual(module.cos_add.__doc__.split("\n")[2], "cos_add")
273 
274  x = torch.randn(4, 4, device="cuda", dtype=torch.float32)
275  y = torch.randn(4, 4, device="cuda", dtype=torch.float32)
276 
277  z = module.cos_add(x, y)
278  self.assertEqual(z, x.cos() + y.cos())
279 
280  def test_inline_jit_compile_extension_throws_when_functions_is_bad(self):
281  with self.assertRaises(ValueError):
283  name="invalid_jit_extension", cpp_sources="", functions=5
284  )
285 
286  def test_lenient_flag_handling_in_jit_extensions(self):
287  cpp_source = """
288  torch::Tensor tanh_add(torch::Tensor x, torch::Tensor y) {
289  return x.tanh() + y.tanh();
290  }
291  """
292 
294  name="lenient_flag_handling_extension",
295  cpp_sources=cpp_source,
296  functions="tanh_add",
297  extra_cflags=["-g\n\n", "-O0 -Wall"],
298  extra_include_paths=[" cpp_extensions\n"],
299  verbose=True,
300  )
301 
302  x = torch.zeros(100, dtype=torch.float32)
303  y = torch.zeros(100, dtype=torch.float32)
304  z = module.tanh_add(x, y).cpu()
305  self.assertEqual(z, x.tanh() + y.tanh())
306 
307  def test_complex_registration(self):
309  name="complex_registration_extension",
310  sources="cpp_extensions/complex_registration_extension.cpp",
311  verbose=True,
312  )
313 
314  # Make sure that the empty tensor is of the desired shape and type
315  # Refer to https://github.com/pytorch/pytorch/issues/14829
316  t = torch.empty(2, 2, dtype=torch.complex64)
317  self.assertEqual(t.size(), torch.Size([2, 2]))
318  self.assertEqual(t.type(), 'torch.ComplexFloatTensor')
319 
320  @unittest.skipIf(not TEST_CUDA, "CUDA not found")
321  def test_half_support(self):
322  """
323  Checks for an issue with operator< ambiguity for half when certain
324  THC headers are included.
325 
326  See https://github.com/pytorch/pytorch/pull/10301#issuecomment-416773333
327  for the corresponding issue.
328  """
329  cuda_source = """
330  #include <THC/THCNumerics.cuh>
331 
332  template<typename T, typename U>
333  __global__ void half_test_kernel(const T* input, U* output) {
334  if (input[0] < input[1] || input[0] >= input[1]) {
335  output[0] = 123;
336  }
337  }
338 
339  torch::Tensor half_test(torch::Tensor input) {
340  auto output = torch::empty(1, input.options().dtype(torch::kFloat));
341  AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "half_test", [&] {
342  half_test_kernel<scalar_t><<<1, 1>>>(
343  input.data<scalar_t>(),
344  output.data<float>());
345  });
346  return output;
347  }
348  """
349 
351  name="half_test_extension",
352  cpp_sources="torch::Tensor half_test(torch::Tensor input);",
353  cuda_sources=cuda_source,
354  functions=["half_test"],
355  verbose=True,
356  )
357 
358  x = torch.randn(3, device="cuda", dtype=torch.half)
359  result = module.half_test(x)
360  self.assertEqual(result[0], 123)
361 
362  def test_reload_jit_extension(self):
363  def compile(code):
365  name="reloaded_jit_extension",
366  cpp_sources=code,
367  functions="f",
368  verbose=True,
369  )
370 
371  module = compile("int f() { return 123; }")
372  self.assertEqual(module.f(), 123)
373 
374  module = compile("int f() { return 456; }")
375  self.assertEqual(module.f(), 456)
376  module = compile("int f() { return 456; }")
377  self.assertEqual(module.f(), 456)
378 
379  module = compile("int f() { return 789; }")
380  self.assertEqual(module.f(), 789)
381 
382  @dont_wipe_extensions_build_folder
383  @common.skipIfRocm
384  def test_cpp_frontend_module_has_same_output_as_python(self):
385  extension = torch.utils.cpp_extension.load(
386  name="cpp_frontend_extension",
387  sources="cpp_extensions/cpp_frontend_extension.cpp",
388  verbose=True,
389  )
390 
391  input = torch.randn(2, 5)
392  cpp_linear = extension.Net(5, 2)
393  cpp_linear.to(torch.float64)
394  python_linear = torch.nn.Linear(5, 2)
395 
396  # First make sure they have the same parameters
397  cpp_parameters = dict(cpp_linear.named_parameters())
398  with torch.no_grad():
399  python_linear.weight.copy_(cpp_parameters["fc.weight"])
400  python_linear.bias.copy_(cpp_parameters["fc.bias"])
401 
402  cpp_output = cpp_linear.forward(input)
403  python_output = python_linear(input)
404  self.assertEqual(cpp_output, python_output)
405 
406  cpp_output.sum().backward()
407  python_output.sum().backward()
408 
409  for p in cpp_linear.parameters():
410  self.assertFalse(p.grad is None)
411 
412  self.assertEqual(cpp_parameters["fc.weight"].grad, python_linear.weight.grad)
413  self.assertEqual(cpp_parameters["fc.bias"].grad, python_linear.bias.grad)
414 
415  @dont_wipe_extensions_build_folder
416  @common.skipIfRocm
417  def test_cpp_frontend_module_python_inter_op(self):
418  extension = torch.utils.cpp_extension.load(
419  name="cpp_frontend_extension",
420  sources="cpp_extensions/cpp_frontend_extension.cpp",
421  verbose=True,
422  )
423 
424  # Create a torch.nn.Module which uses the C++ module as a submodule.
425  class M(torch.nn.Module):
426  def __init__(self):
427  super(M, self).__init__()
428  self.x = torch.nn.Parameter(torch.tensor(1.0))
429  self.net = extension.Net(3, 5)
430 
431  def forward(self, input):
432  return self.net.forward(input) + self.x
433 
434  net = extension.Net(5, 2)
435  net.double()
436  net.to(torch.get_default_dtype())
437  self.assertEqual(str(net), "Net")
438 
439  # Further embed the torch.nn.Module into a Sequential, and also add the
440  # C++ module as an element of the Sequential.
441  sequential = torch.nn.Sequential(M(), torch.nn.Tanh(), net, torch.nn.Sigmoid())
442 
443  input = torch.randn(2, 3)
444  # Try calling the module!
445  output = sequential.forward(input)
446  # The call operator is bound to forward too.
447  self.assertEqual(output, sequential(input))
448  self.assertEqual(list(output.shape), [2, 2])
449 
450  # Do changes on the module hierarchy.
451  old_dtype = torch.get_default_dtype()
452  sequential.to(torch.float64)
453  sequential.to(torch.float32)
454  sequential.to(old_dtype)
455  self.assertEqual(sequential[2].parameters()[0].dtype, old_dtype)
456 
457  # Make sure we can access these methods recursively.
458  self.assertEqual(len(list(sequential.parameters())), len(net.parameters()) * 2 + 1)
459  self.assertEqual(len(list(sequential.named_parameters())), len(net.named_parameters()) * 2 + 1)
460  self.assertEqual(len(list(sequential.buffers())), len(net.buffers()) * 2)
461  self.assertEqual(len(list(sequential.modules())), 8)
462 
463  # Test clone()
464  net2 = net.clone()
465  self.assertEqual(len(net.parameters()), len(net2.parameters()))
466  self.assertEqual(len(net.buffers()), len(net2.buffers()))
467  self.assertEqual(len(net.modules()), len(net2.modules()))
468 
469  # Try differentiating through the whole module.
470  for parameter in net.parameters():
471  self.assertIsNone(parameter.grad)
472  output.sum().backward()
473  for parameter in net.parameters():
474  self.assertFalse(parameter.grad is None)
475  self.assertGreater(parameter.grad.sum(), 0)
476 
477  # Try calling zero_grad()
478  net.zero_grad()
479  for p in net.parameters():
480  self.assertEqual(p.grad, torch.zeros_like(p))
481 
482  # Test train(), eval(), training (a property)
483  self.assertTrue(net.training)
484  net.eval()
485  self.assertFalse(net.training)
486  net.train()
487  self.assertTrue(net.training)
488  net.eval()
489 
490  # Try calling the additional methods we registered.
491  biased_input = torch.randn(4, 5)
492  output_before = net.forward(biased_input)
493  bias = net.get_bias().clone()
494  self.assertEqual(list(bias.shape), [2])
495  net.set_bias(bias + 1)
496  self.assertEqual(net.get_bias(), bias + 1)
497  output_after = net.forward(biased_input)
498 
499  self.assertNotEqual(output_before, output_after)
500 
501  # Try accessing parameters
502  self.assertEqual(len(net.parameters()), 2)
503  np = net.named_parameters()
504  self.assertEqual(len(np), 2)
505  self.assertIn("fc.weight", np)
506  self.assertIn("fc.bias", np)
507 
508  self.assertEqual(len(net.buffers()), 1)
509  nb = net.named_buffers()
510  self.assertEqual(len(nb), 1)
511  self.assertIn("buf", nb)
512  self.assertEqual(nb[0][1], torch.eye(5))
513 
514  @dont_wipe_extensions_build_folder
515  @common.skipIfRocm
516  def test_cpp_frontend_module_has_up_to_date_attributes(self):
517  extension = torch.utils.cpp_extension.load(
518  name="cpp_frontend_extension",
519  sources="cpp_extensions/cpp_frontend_extension.cpp",
520  verbose=True,
521  )
522 
523  net = extension.Net(5, 2)
524 
525  self.assertEqual(len(net._parameters), 0)
526  net.add_new_parameter("foo", torch.eye(5))
527  self.assertEqual(len(net._parameters), 1)
528 
529  self.assertEqual(len(net._buffers), 1)
530  net.add_new_buffer("bar", torch.eye(5))
531  self.assertEqual(len(net._buffers), 2)
532 
533  self.assertEqual(len(net._modules), 1)
534  net.add_new_submodule("fc2")
535  self.assertEqual(len(net._modules), 2)
536 
537  @dont_wipe_extensions_build_folder
538  @unittest.skipIf(not TEST_CUDA, "CUDA not found")
539  @common.skipIfRocm
540  def test_cpp_frontend_module_python_inter_op_with_cuda(self):
541  extension = torch.utils.cpp_extension.load(
542  name="cpp_frontend_extension",
543  sources="cpp_extensions/cpp_frontend_extension.cpp",
544  verbose=True,
545  )
546 
547  net = extension.Net(5, 2)
548  for p in net.parameters():
549  self.assertTrue(p.device.type == "cpu")
550  cpu_parameters = [p.clone() for p in net.parameters()]
551 
552  device = torch.device("cuda", 0)
553  net.to(device)
554 
555  for i, p in enumerate(net.parameters()):
556  self.assertTrue(p.device.type == "cuda")
557  self.assertTrue(p.device.index == 0)
558  self.assertEqual(cpu_parameters[i], p)
559 
560  net.cpu()
561  net.add_new_parameter("a", torch.eye(5))
562  net.add_new_parameter("b", torch.eye(5))
563  net.add_new_buffer("c", torch.eye(5))
564  net.add_new_buffer("d", torch.eye(5))
565  net.add_new_submodule("fc2")
566  net.add_new_submodule("fc3")
567 
568  for p in net.parameters():
569  self.assertTrue(p.device.type == "cpu")
570 
571  net.cuda()
572 
573  for p in net.parameters():
574  self.assertTrue(p.device.type == "cuda")
575 
576  def test_returns_shared_library_path_when_is_python_module_is_true(self):
577  source = """
578  #include <torch/script.h>
579  torch::Tensor func(torch::Tensor x) { return x; }
580  static torch::jit::RegisterOperators r("test::func", &func);
581  """
583  name="is_python_module",
584  cpp_sources=source,
585  functions="func",
586  verbose=True,
587  is_python_module=False,
588  )
589  self.assertEqual(torch.ops.test.func(torch.eye(5)), torch.eye(5))
590 
591  @unittest.skipIf(IS_WINDOWS, "Not available on Windows")
592  def test_no_python_abi_suffix_sets_the_correct_library_name(self):
593  # For this test, run_test.py will call `python setup.py install` in the
594  # cpp_extensions/no_python_abi_suffix_test folder, where the
595  # `BuildExtension` class has a `no_python_abi_suffix` option set to
596  # `True`. This *should* mean that on Python 3, the produced shared
597  # library does not have an ABI suffix like
598  # "cpython-37m-x86_64-linux-gnu" before the library suffix, e.g. "so".
599  # On Python 2 there is no ABI suffix anyway.
600  root = os.path.join("cpp_extensions", "no_python_abi_suffix_test", "build")
601  matches = [f for _, _, fs in os.walk(root) for f in fs if f.endswith("so")]
602  self.assertEqual(len(matches), 1, str(matches))
603  self.assertEqual(matches[0], "no_python_abi_suffix_test.so", str(matches))
604 
605  def test_set_default_type_also_changes_aten_default_type(self):
607  name="test_set_default_type",
608  cpp_sources="torch::Tensor get() { return torch::empty({}); }",
609  functions="get",
610  verbose=True,
611  )
612 
613  initial_default = torch.get_default_dtype()
614  try:
615  self.assertEqual(module.get().dtype, initial_default)
616  torch.set_default_dtype(torch.float64)
617  self.assertEqual(module.get().dtype, torch.float64)
618  torch.set_default_dtype(torch.float32)
619  self.assertEqual(module.get().dtype, torch.float32)
620  torch.set_default_dtype(torch.float16)
621  self.assertEqual(module.get().dtype, torch.float16)
622  finally:
623  torch.set_default_dtype(initial_default)
624 
625 
626 class TestMSNPUTensor(common.TestCase):
627  @classmethod
628  def setUpClass(cls):
629  msnpu_extension.init_msnpu_extension()
630 
631  def test_unregistered(self):
632  a = torch.empty(5, 5, device='cpu')
633  with self.assertRaisesRegex(RuntimeError, "No function registered"):
634  b = torch.empty(5, 5, device='msnpu')
635 
636  def test_zeros(self):
637  a = torch.zeros(5, 5, device='cpu')
638  self.assertEqual(a.device, torch.device('cpu'))
639  self.assertEqual(a.sum(), 0)
640 
641  b = torch.zeros(5, 5, device='msnpu')
642  self.assertEqual(b.device, torch.device('msnpu', 0))
643  self.assertEqual(msnpu_extension.get_test_int(), 0)
644  self.assertEqual(torch.get_default_dtype(), b.dtype)
645 
646  c = torch.zeros((5, 5), dtype=torch.int64, device='msnpu')
647  self.assertEqual(msnpu_extension.get_test_int(), 0)
648  self.assertEqual(torch.int64, c.dtype)
649 
650  def test_add(self):
651  a = torch.zeros(5, 5, device='msnpu')
652  self.assertEqual(msnpu_extension.get_test_int(), 0)
653 
654  b = torch.zeros(5, 5, device='msnpu')
655  self.assertEqual(msnpu_extension.get_test_int(), 0)
656 
657  c = torch.add(a, b)
658  self.assertEqual(msnpu_extension.get_test_int(), 1)
659 
660  def test_backwards(self):
661  a = torch.zeros(5, 5, device='msnpu', requires_grad=True)
662  self.assertEqual(msnpu_extension.get_test_int(), 0)
663 
664  b = torch.zeros(5, 5, device='msnpu')
665  self.assertEqual(msnpu_extension.get_test_int(), 0)
666 
667  c = torch.kl_div(a, b)
668  self.assertEqual(msnpu_extension.get_test_int(), 3)
669 
670  d = c.sum()
671  self.assertEqual(msnpu_extension.get_test_int(), 2)
672 
673  d.backward()
674  self.assertEqual(msnpu_extension.get_test_int(), 4)
675 
676 
677 if __name__ == "__main__":
678  common.run_tests()
def relu(input, inplace=False)
Definition: functional.py:929
Module caffe2.python.layers.split.
def is_available()
Definition: __init__.py:45
def set_default_dtype(d)
Definition: __init__.py:156
def load_inline(name, cpp_sources, cuda_sources=None, functions=None, extra_cflags=None, extra_cuda_cflags=None, extra_ldflags=None, extra_include_paths=None, build_directory=None, verbose=False, with_cuda=None, is_python_module=True)
def load(name, sources, extra_cflags=None, extra_cuda_cflags=None, extra_ldflags=None, extra_include_paths=None, build_directory=None, verbose=False, with_cuda=None, is_python_module=True)