| import unittest |
| |
| import torch |
| import torch.utils.cpp_extension |
| import torch_test_cpp_extension as cpp_extension |
| |
| import common |
| |
| TEST_CUDA = torch.cuda.is_available() |
| |
| |
| class TestCppExtension(common.TestCase): |
| def test_extension_function(self): |
| x = torch.randn(4, 4) |
| y = torch.randn(4, 4) |
| z = cpp_extension.sigmoid_add(x, y) |
| self.assertEqual(z, x.sigmoid() + y.sigmoid()) |
| |
| def test_extension_module(self): |
| mm = cpp_extension.MatrixMultiplier(4, 8) |
| weights = torch.rand(8, 4) |
| expected = mm.get().mm(weights) |
| result = mm.forward(weights) |
| self.assertEqual(expected, result) |
| |
| def test_backward(self): |
| mm = cpp_extension.MatrixMultiplier(4, 8) |
| weights = torch.rand(8, 4, requires_grad=True) |
| result = mm.forward(weights) |
| result.sum().backward() |
| tensor = mm.get() |
| |
| expected_weights_grad = tensor.t().mm(torch.ones([4, 4])) |
| self.assertEqual(weights.grad, expected_weights_grad) |
| |
| expected_tensor_grad = torch.ones([4, 4]).mm(weights.t()) |
| self.assertEqual(tensor.grad, expected_tensor_grad) |
| |
| def test_jit_compile_extension(self): |
| module = torch.utils.cpp_extension.load( |
| name='jit_extension', |
| sources=[ |
| 'cpp_extensions/jit_extension.cpp', |
| 'cpp_extensions/jit_extension2.cpp' |
| ], |
| extra_include_paths=['cpp_extensions'], |
| extra_cflags=['-g'], |
| verbose=True) |
| x = torch.randn(4, 4) |
| y = torch.randn(4, 4) |
| |
| z = module.tanh_add(x, y) |
| self.assertEqual(z, x.tanh() + y.tanh()) |
| |
| # Checking we can call a method defined not in the main C++ file. |
| z = module.exp_add(x, y) |
| self.assertEqual(z, x.exp() + y.exp()) |
| |
| # Checking we can use this JIT-compiled class. |
| doubler = module.Doubler(2, 2) |
| self.assertIsNone(doubler.get().grad) |
| self.assertEqual(doubler.get().sum(), 4) |
| self.assertEqual(doubler.forward().sum(), 8) |
| |
| @unittest.skipIf(not TEST_CUDA, "CUDA not found") |
| def test_cuda_extension(self): |
| import torch_test_cuda_extension as cuda_extension |
| |
| x = torch.FloatTensor(100).zero_().cuda() |
| y = torch.FloatTensor(100).zero_().cuda() |
| |
| z = cuda_extension.sigmoid_add(x, y).cpu() |
| |
| # 2 * sigmoid(0) = 2 * 0.5 = 1 |
| self.assertEqual(z, torch.ones_like(z)) |
| |
| @unittest.skipIf(not TEST_CUDA, "CUDA not found") |
| def test_jit_cuda_extension(self): |
| # NOTE: The name of the extension must equal the name of the module. |
| module = torch.utils.cpp_extension.load( |
| name='torch_test_cuda_extension', |
| sources=[ |
| 'cpp_extensions/cuda_extension.cpp', |
| 'cpp_extensions/cuda_extension.cu' |
| ], |
| extra_cuda_cflags=['-O2'], |
| verbose=True) |
| |
| x = torch.FloatTensor(100).zero_().cuda() |
| y = torch.FloatTensor(100).zero_().cuda() |
| |
| z = module.sigmoid_add(x, y).cpu() |
| |
| # 2 * sigmoid(0) = 2 * 0.5 = 1 |
| self.assertEqual(z, torch.ones_like(z)) |
| |
| |
| if __name__ == '__main__': |
| common.run_tests() |