blob: 1f13b49420b6063d2a0394f76edea964451bae84 [file] [log] [blame]
Alex Ford7a1b6682018-10-12 13:33:43 -07001import unittest
Alex Ford7a1b6682018-10-12 13:33:43 -07002
Pritam Damaniaf050b162020-01-22 21:05:28 -08003import torch.testing._internal.common_utils as common
4from torch.testing._internal.common_utils import TEST_NUMBA, TEST_NUMPY
5from torch.testing._internal.common_cuda import TEST_NUMBA_CUDA, TEST_CUDA, TEST_MULTIGPU
Alex Ford7a1b6682018-10-12 13:33:43 -07006
7import torch
8
9if TEST_NUMPY:
10 import numpy
11
12if TEST_NUMBA:
13 import numba
14
15if TEST_NUMBA_CUDA:
16 import numba.cuda
17
18
19class TestNumbaIntegration(common.TestCase):
20 @unittest.skipIf(not TEST_NUMPY, "No numpy")
21 @unittest.skipIf(not TEST_CUDA, "No cuda")
22 def test_cuda_array_interface(self):
23 """torch.Tensor exposes __cuda_array_interface__ for cuda tensors.
24
25 An object t is considered a cuda-tensor if:
26 hasattr(t, '__cuda_array_interface__')
27
28 A cuda-tensor provides a tensor description dict:
29 shape: (integer, ...) Tensor shape.
30 strides: (integer, ...) Tensor strides, in bytes.
31 typestr: (str) A numpy-style typestr.
32 data: (int, boolean) A (data_ptr, read-only) tuple.
33 version: (int) Version 0
34
35 See:
36 https://numba.pydata.org/numba-doc/latest/cuda/cuda_array_interface.html
37 """
38
39 types = [
40 torch.DoubleTensor,
41 torch.FloatTensor,
42 torch.HalfTensor,
43 torch.LongTensor,
44 torch.IntTensor,
45 torch.ShortTensor,
46 torch.CharTensor,
47 torch.ByteTensor,
48 ]
49 dtypes = [
50 numpy.float64,
51 numpy.float32,
52 numpy.float16,
53 numpy.int64,
54 numpy.int32,
55 numpy.int16,
56 numpy.int8,
57 numpy.uint8,
58 ]
59 for tp, npt in zip(types, dtypes):
60
61 # CPU tensors do not implement the interface.
62 cput = tp(10)
63
64 self.assertFalse(hasattr(cput, "__cuda_array_interface__"))
65 self.assertRaises(AttributeError, lambda: cput.__cuda_array_interface__)
66
67 # Sparse CPU/CUDA tensors do not implement the interface
68 if tp not in (torch.HalfTensor,):
Tongzhou Wang46162cc2018-10-24 09:58:36 -070069 indices_t = torch.empty(1, cput.size(0), dtype=torch.long).clamp_(min=0)
70 sparse_t = torch.sparse_coo_tensor(indices_t, cput)
Alex Ford7a1b6682018-10-12 13:33:43 -070071
72 self.assertFalse(hasattr(sparse_t, "__cuda_array_interface__"))
73 self.assertRaises(
74 AttributeError, lambda: sparse_t.__cuda_array_interface__
75 )
76
Tongzhou Wang46162cc2018-10-24 09:58:36 -070077 sparse_cuda_t = torch.sparse_coo_tensor(indices_t, cput).cuda()
Alex Ford7a1b6682018-10-12 13:33:43 -070078
79 self.assertFalse(hasattr(sparse_cuda_t, "__cuda_array_interface__"))
80 self.assertRaises(
81 AttributeError, lambda: sparse_cuda_t.__cuda_array_interface__
82 )
83
Seiya Tokui1d7b40f2019-12-06 07:33:41 -080084 # CUDA tensors have the attribute and v2 interface
Alex Ford7a1b6682018-10-12 13:33:43 -070085 cudat = tp(10).cuda()
86
87 self.assertTrue(hasattr(cudat, "__cuda_array_interface__"))
88
89 ar_dict = cudat.__cuda_array_interface__
90
91 self.assertEqual(
92 set(ar_dict.keys()), {"shape", "strides", "typestr", "data", "version"}
93 )
94
95 self.assertEqual(ar_dict["shape"], (10,))
Seiya Tokui1d7b40f2019-12-06 07:33:41 -080096 self.assertIs(ar_dict["strides"], None)
Alex Ford7a1b6682018-10-12 13:33:43 -070097 # typestr from numpy, cuda-native little-endian
98 self.assertEqual(ar_dict["typestr"], numpy.dtype(npt).newbyteorder("<").str)
99 self.assertEqual(ar_dict["data"], (cudat.data_ptr(), False))
Seiya Tokui1d7b40f2019-12-06 07:33:41 -0800100 self.assertEqual(ar_dict["version"], 2)
Alex Ford7a1b6682018-10-12 13:33:43 -0700101
102 @unittest.skipIf(not TEST_CUDA, "No cuda")
103 @unittest.skipIf(not TEST_NUMBA_CUDA, "No numba.cuda")
104 def test_array_adaptor(self):
105 """Torch __cuda_array_adaptor__ exposes tensor data to numba.cuda."""
106
107 torch_dtypes = [
Adam Thompson1c616c52020-08-14 10:24:22 -0700108 torch.complex64,
109 torch.complex128,
Alex Ford7a1b6682018-10-12 13:33:43 -0700110 torch.float16,
111 torch.float32,
112 torch.float64,
113 torch.uint8,
114 torch.int8,
115 torch.int16,
116 torch.int32,
117 torch.int64,
118 ]
119
120 for dt in torch_dtypes:
Alex Ford7a1b6682018-10-12 13:33:43 -0700121
122 # CPU tensors of all types do not register as cuda arrays,
123 # attempts to convert raise a type error.
124 cput = torch.arange(10).to(dt)
125 npt = cput.numpy()
126
127 self.assertTrue(not numba.cuda.is_cuda_array(cput))
128 with self.assertRaises(TypeError):
129 numba.cuda.as_cuda_array(cput)
130
131 # Any cuda tensor is a cuda array.
132 cudat = cput.to(device="cuda")
133 self.assertTrue(numba.cuda.is_cuda_array(cudat))
134
135 numba_view = numba.cuda.as_cuda_array(cudat)
136 self.assertIsInstance(numba_view, numba.cuda.devicearray.DeviceNDArray)
137
138 # The reported type of the cuda array matches the numpy type of the cpu tensor.
139 self.assertEqual(numba_view.dtype, npt.dtype)
140 self.assertEqual(numba_view.strides, npt.strides)
141 self.assertEqual(numba_view.shape, cudat.shape)
142
143 # Pass back to cuda from host for all equality checks below, needed for
144 # float16 comparisons, which aren't supported cpu-side.
145
146 # The data is identical in the view.
147 self.assertEqual(cudat, torch.tensor(numba_view.copy_to_host()).to("cuda"))
148
149 # Writes to the torch.Tensor are reflected in the numba array.
150 cudat[:5] = 11
151 self.assertEqual(cudat, torch.tensor(numba_view.copy_to_host()).to("cuda"))
152
153 # Strided tensors are supported.
154 strided_cudat = cudat[::2]
155 strided_npt = cput[::2].numpy()
156 strided_numba_view = numba.cuda.as_cuda_array(strided_cudat)
157
158 self.assertEqual(strided_numba_view.dtype, strided_npt.dtype)
159 self.assertEqual(strided_numba_view.strides, strided_npt.strides)
160 self.assertEqual(strided_numba_view.shape, strided_cudat.shape)
161
162 # As of numba 0.40.0 support for strided views is ...limited...
163 # Cannot verify correctness of strided view operations.
164
165 @unittest.skipIf(not TEST_CUDA, "No cuda")
166 @unittest.skipIf(not TEST_NUMBA_CUDA, "No numba.cuda")
167 def test_conversion_errors(self):
168 """Numba properly detects array interface for tensor.Tensor variants."""
169
170 # CPU tensors are not cuda arrays.
171 cput = torch.arange(100)
172
173 self.assertFalse(numba.cuda.is_cuda_array(cput))
174 with self.assertRaises(TypeError):
175 numba.cuda.as_cuda_array(cput)
176
177 # Sparse tensors are not cuda arrays, regardless of device.
178 sparset = torch.sparse_coo_tensor(cput[None, :], cput)
179
180 self.assertFalse(numba.cuda.is_cuda_array(sparset))
181 with self.assertRaises(TypeError):
182 numba.cuda.as_cuda_array(sparset)
183
184 sparse_cuda_t = sparset.cuda()
185
186 self.assertFalse(numba.cuda.is_cuda_array(sparset))
187 with self.assertRaises(TypeError):
188 numba.cuda.as_cuda_array(sparset)
189
190 # Device-status overrides gradient status.
191 # CPU+gradient isn't a cuda array.
192 cpu_gradt = torch.zeros(100).requires_grad_(True)
193
194 self.assertFalse(numba.cuda.is_cuda_array(cpu_gradt))
195 with self.assertRaises(TypeError):
196 numba.cuda.as_cuda_array(cpu_gradt)
197
198 # CUDA+gradient raises a RuntimeError on check or conversion.
199 #
200 # Use of hasattr for interface detection causes interface change in
201 # python2; it swallows all exceptions not just AttributeError.
202 cuda_gradt = torch.zeros(100).requires_grad_(True).cuda()
203
David Reisse75fb432020-04-22 09:20:13 -0700204 # conversion raises RuntimeError
205 with self.assertRaises(RuntimeError):
206 numba.cuda.is_cuda_array(cuda_gradt)
207 with self.assertRaises(RuntimeError):
208 numba.cuda.as_cuda_array(cuda_gradt)
Alex Ford7a1b6682018-10-12 13:33:43 -0700209
210 @unittest.skipIf(not TEST_CUDA, "No cuda")
211 @unittest.skipIf(not TEST_NUMBA_CUDA, "No numba.cuda")
212 @unittest.skipIf(not TEST_MULTIGPU, "No multigpu")
213 def test_active_device(self):
214 """'as_cuda_array' tensor device must match active numba context."""
215
216 # Both torch/numba default to device 0 and can interop freely
217 cudat = torch.arange(10, device="cuda")
218 self.assertEqual(cudat.device.index, 0)
219 self.assertIsInstance(
220 numba.cuda.as_cuda_array(cudat), numba.cuda.devicearray.DeviceNDArray
221 )
222
223 # Tensors on non-default device raise api error if converted
224 cudat = torch.arange(10, device=torch.device("cuda", 1))
225
226 with self.assertRaises(numba.cuda.driver.CudaAPIError):
227 numba.cuda.as_cuda_array(cudat)
228
229 # but can be converted when switching to the device's context
230 with numba.cuda.devices.gpus[cudat.device.index]:
231 self.assertIsInstance(
232 numba.cuda.as_cuda_array(cudat), numba.cuda.devicearray.DeviceNDArray
233 )
234
Mads R. B. Kristensen5d8879c2019-05-21 12:29:23 -0700235 @unittest.skipIf(not TEST_NUMPY, "No numpy")
236 @unittest.skipIf(not TEST_CUDA, "No cuda")
237 @unittest.skipIf(not TEST_NUMBA_CUDA, "No numba.cuda")
238 def test_from_cuda_array_interface(self):
239 """torch.as_tensor() and torch.tensor() supports the __cuda_array_interface__ protocol.
240
241 If an object exposes the __cuda_array_interface__, .as_tensor() and .tensor()
242 will use the exposed device memory.
243
244 See:
245 https://numba.pydata.org/numba-doc/latest/cuda/cuda_array_interface.html
246 """
247
248 dtypes = [
Adam Thompson1c616c52020-08-14 10:24:22 -0700249 numpy.complex64,
250 numpy.complex128,
Mads R. B. Kristensen5d8879c2019-05-21 12:29:23 -0700251 numpy.float64,
252 numpy.float32,
253 numpy.int64,
254 numpy.int32,
255 numpy.int16,
256 numpy.int8,
257 numpy.uint8,
258 ]
259 for dtype in dtypes:
260 numpy_arys = [
261 numpy.arange(6).reshape(2, 3).astype(dtype),
262 numpy.arange(6).reshape(2, 3).astype(dtype)[1:], # View offset should be ignored
263 numpy.arange(6).reshape(2, 3).astype(dtype)[:, None], # change the strides but still contiguous
264 ]
265 # Zero-copy when using `torch.as_tensor()`
266 for numpy_ary in numpy_arys:
267 numba_ary = numba.cuda.to_device(numpy_ary)
268 torch_ary = torch.as_tensor(numba_ary, device="cuda")
269 self.assertEqual(numba_ary.__cuda_array_interface__, torch_ary.__cuda_array_interface__)
Adam Thompson1c616c52020-08-14 10:24:22 -0700270 self.assertEqual(torch_ary.cpu().data.numpy(), numpy.asarray(numba_ary, dtype=dtype))
Mads R. B. Kristensen5d8879c2019-05-21 12:29:23 -0700271
272 # Check that `torch_ary` and `numba_ary` points to the same device memory
273 torch_ary += 42
Adam Thompson1c616c52020-08-14 10:24:22 -0700274 self.assertEqual(torch_ary.cpu().data.numpy(), numpy.asarray(numba_ary, dtype=dtype))
Mads R. B. Kristensen5d8879c2019-05-21 12:29:23 -0700275
276 # Implicit-copy because `torch_ary` is a CPU array
277 for numpy_ary in numpy_arys:
278 numba_ary = numba.cuda.to_device(numpy_ary)
279 torch_ary = torch.as_tensor(numba_ary, device="cpu")
Adam Thompson1c616c52020-08-14 10:24:22 -0700280 self.assertEqual(torch_ary.data.numpy(), numpy.asarray(numba_ary, dtype=dtype))
Mads R. B. Kristensen5d8879c2019-05-21 12:29:23 -0700281
282 # Check that `torch_ary` and `numba_ary` points to different memory
283 torch_ary += 42
Adam Thompson1c616c52020-08-14 10:24:22 -0700284 self.assertEqual(torch_ary.data.numpy(), numpy.asarray(numba_ary, dtype=dtype) + 42)
Mads R. B. Kristensen5d8879c2019-05-21 12:29:23 -0700285
Brian Wignallf3260452020-01-17 16:01:29 -0800286 # Explicit-copy when using `torch.tensor()`
Mads R. B. Kristensen5d8879c2019-05-21 12:29:23 -0700287 for numpy_ary in numpy_arys:
288 numba_ary = numba.cuda.to_device(numpy_ary)
289 torch_ary = torch.tensor(numba_ary, device="cuda")
Adam Thompson1c616c52020-08-14 10:24:22 -0700290 self.assertEqual(torch_ary.cpu().data.numpy(), numpy.asarray(numba_ary, dtype=dtype))
Mads R. B. Kristensen5d8879c2019-05-21 12:29:23 -0700291
292 # Check that `torch_ary` and `numba_ary` points to different memory
293 torch_ary += 42
Adam Thompson1c616c52020-08-14 10:24:22 -0700294 self.assertEqual(torch_ary.cpu().data.numpy(), numpy.asarray(numba_ary, dtype=dtype) + 42)
Mads R. B. Kristensen5d8879c2019-05-21 12:29:23 -0700295
296 @unittest.skipIf(not TEST_NUMPY, "No numpy")
297 @unittest.skipIf(not TEST_CUDA, "No cuda")
298 @unittest.skipIf(not TEST_NUMBA_CUDA, "No numba.cuda")
Gregory Chananb38901a2019-12-20 08:19:49 -0800299 def test_from_cuda_array_interface_inferred_strides(self):
300 """torch.as_tensor(numba_ary) should have correct inferred (contiguous) strides"""
301 # This could, in theory, be combined with test_from_cuda_array_interface but that test
302 # is overly strict: it checks that the exported protocols are exactly the same, which
303 # cannot handle differing exported protocol versions.
304 dtypes = [
305 numpy.float64,
306 numpy.float32,
307 numpy.int64,
308 numpy.int32,
309 numpy.int16,
310 numpy.int8,
311 numpy.uint8,
312 ]
313 for dtype in dtypes:
314 numpy_ary = numpy.arange(6).reshape(2, 3).astype(dtype),
315 numba_ary = numba.cuda.to_device(numpy_ary)
316 self.assertTrue(numba_ary.is_c_contiguous())
317 torch_ary = torch.as_tensor(numba_ary, device="cuda")
318 self.assertTrue(torch_ary.is_contiguous())
319
320 @unittest.skipIf(not TEST_NUMPY, "No numpy")
321 @unittest.skipIf(not TEST_CUDA, "No cuda")
322 @unittest.skipIf(not TEST_NUMBA_CUDA, "No numba.cuda")
Mads R. B. Kristensen5d8879c2019-05-21 12:29:23 -0700323 def test_from_cuda_array_interface_lifetime(self):
324 """torch.as_tensor(obj) tensor grabs a reference to obj so that the lifetime of obj exceeds the tensor"""
325 numba_ary = numba.cuda.to_device(numpy.arange(6))
326 torch_ary = torch.as_tensor(numba_ary, device="cuda")
327 self.assertEqual(torch_ary.__cuda_array_interface__, numba_ary.__cuda_array_interface__) # No copy
328 del numba_ary
329 self.assertEqual(torch_ary.cpu().data.numpy(), numpy.arange(6)) # `torch_ary` is still alive
330
331 @unittest.skipIf(not TEST_NUMPY, "No numpy")
332 @unittest.skipIf(not TEST_CUDA, "No cuda")
333 @unittest.skipIf(not TEST_NUMBA_CUDA, "No numba.cuda")
334 @unittest.skipIf(not TEST_MULTIGPU, "No multigpu")
335 def test_from_cuda_array_interface_active_device(self):
336 """torch.as_tensor() tensor device must match active numba context."""
337
Mads R. B. Kristensenf583f2e2019-08-23 08:54:16 -0700338 # Zero-copy: both torch/numba default to device 0 and can interop freely
Mads R. B. Kristensen5d8879c2019-05-21 12:29:23 -0700339 numba_ary = numba.cuda.to_device(numpy.arange(6))
340 torch_ary = torch.as_tensor(numba_ary, device="cuda")
341 self.assertEqual(torch_ary.cpu().data.numpy(), numpy.asarray(numba_ary))
342 self.assertEqual(torch_ary.__cuda_array_interface__, numba_ary.__cuda_array_interface__)
343
Mads R. B. Kristensenf583f2e2019-08-23 08:54:16 -0700344 # Implicit-copy: when the Numba and Torch device differ
Mads R. B. Kristensen5d8879c2019-05-21 12:29:23 -0700345 numba_ary = numba.cuda.to_device(numpy.arange(6))
Mads R. B. Kristensenf583f2e2019-08-23 08:54:16 -0700346 torch_ary = torch.as_tensor(numba_ary, device=torch.device("cuda", 1))
347 self.assertEqual(torch_ary.get_device(), 1)
348 self.assertEqual(torch_ary.cpu().data.numpy(), numpy.asarray(numba_ary))
349 if1 = torch_ary.__cuda_array_interface__
350 if2 = numba_ary.__cuda_array_interface__
351 self.assertNotEqual(if1["data"], if2["data"])
352 del if1["data"]
353 del if2["data"]
354 self.assertEqual(if1, if2)
Mads R. B. Kristensen5d8879c2019-05-21 12:29:23 -0700355
Alex Ford7a1b6682018-10-12 13:33:43 -0700356
357if __name__ == "__main__":
358 common.run_tests()