blob: c5f733105f2ff6666bb0482dac0cad3aa546e7aa [file] [log] [blame]
Nikita Shulgaadd49e72021-04-06 18:29:56 -07001# -*- coding: utf-8 -*-
Jane Xu8a650472021-10-20 09:32:02 -07002# Owner(s): ["module: tests"]
3
Shen Li10224432021-08-12 11:39:31 -07004import torch
5import torch.utils.data
6import numpy as np
7
Victor Bittorf52f1a072021-04-15 15:56:00 -07008import contextlib
Edward Yangf05d5be2021-06-03 10:47:19 -07009import gc
Zsolt Dollensteinb0043072021-08-12 10:56:55 -070010import io
Shen Li10224432021-08-12 11:39:31 -070011import inspect
Kimish Patel4f792702021-06-10 08:23:10 -070012import itertools
Adam Paszkebc7bd7a2016-07-21 13:46:59 -040013import math
14import random
Vitaly Fedyunin516c7e42019-06-26 11:40:31 -070015import re
Shen Li10224432021-08-12 11:39:31 -070016import copy
17import os
Adam Paszke686e8d32016-08-22 22:11:50 -040018import tempfile
Adam Paszkebc7bd7a2016-07-21 13:46:59 -040019import unittest
Sam Grosse3e786e2016-11-03 16:29:14 -040020import warnings
Shen Li10224432021-08-12 11:39:31 -070021import types
22import pickle
23import textwrap
24import subprocess
Edward Yangf05d5be2021-06-03 10:47:19 -070025import weakref
Shen Li10224432021-08-12 11:39:31 -070026import sys
27from torch.utils.dlpack import from_dlpack, to_dlpack
Zsolt Dollensteinb0043072021-08-12 10:56:55 -070028from torch._six import inf, nan, string_classes
Shen Li10224432021-08-12 11:39:31 -070029from itertools import product, combinations, permutations
30from functools import partial
31from torch import multiprocessing as mp
Kushashwa Ravi Shrimalid3763692021-08-30 12:16:23 -070032from torch.testing import make_tensor
Shen Li10224432021-08-12 11:39:31 -070033from torch.testing._internal.common_utils import (
34 TestCase, TEST_WITH_ROCM, run_tests,
35 IS_WINDOWS, IS_FILESYSTEM_UTF8_ENCODING, NO_MULTIPROCESSING_SPAWN,
Mike Ruberrye0d829a2022-01-24 01:28:07 -080036 IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, load_tests, slowTest,
Animesh Jain1d90d6e2022-07-07 18:57:31 +000037 TEST_WITH_CROSSREF, skipIfTorchDynamo,
Edward Z. Yang30943d12022-04-18 11:52:01 -070038 skipCUDAMemoryLeakCheckIf, BytesIOContext,
Shen Li10224432021-08-12 11:39:31 -070039 skipIfRocm, skipIfNoSciPy, TemporaryFileName, TemporaryDirectoryName,
Kurt Mohler58835232021-10-05 13:48:45 -070040 wrapDeterministicFlagAPITest, DeterministicGuard, CudaSyncGuard,
Kulin Sethe011a8e2022-05-13 18:28:53 +000041 skipIfNotRegistered, bytes_to_scalar, parametrize, skipIfMps)
Shen Li10224432021-08-12 11:39:31 -070042from multiprocessing.reduction import ForkingPickler
Mike Ruberry36c87f12020-11-28 20:09:52 -080043from torch.testing._internal.common_device_type import (
kshitij12345885a8e52021-11-01 09:21:20 -070044 expectedFailureMeta,
kshitij12345c00806b2021-10-29 19:51:52 -070045 expectedFailureXLA,
Mike Ruberry36c87f12020-11-28 20:09:52 -080046 instantiate_device_type_tests,
Shen Li10224432021-08-12 11:39:31 -070047 onlyCUDA, onlyCPU,
48 dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast,
Ivan Yashchukeaf00bf2021-05-02 16:06:01 -070049 skipMeta,
kshitij12345885a8e52021-11-01 09:21:20 -070050 PYTORCH_CUDA_MEMCHECK, largeTensorTest, onlyNativeDeviceTypes,
51 expectedAlertNondeterministic, get_all_device_types, skipXLA)
Mike Ruberrye0d829a2022-01-24 01:28:07 -080052from typing import Tuple
Shen Li10224432021-08-12 11:39:31 -070053import torch.backends.quantized
54import torch.testing._internal.data
Nikita Karetnikov936a6502022-04-04 15:21:44 +000055from torch.testing._internal.common_cuda import (
56 tf32_on_and_off, tf32_is_not_fp32, TEST_CUDNN)
Philip Meier26b7ff52021-09-07 08:57:43 -070057from torch.testing._internal.common_dtype import (
Nikita Shulgabfac65d2022-03-30 14:13:21 -070058 floating_types_and, get_all_math_dtypes, all_types_and_complex_and, complex_types,
PyTorch MergeBotd7847ed2022-06-29 18:06:01 +000059 all_types_and, floating_types, floating_and_complex_types,
Philip Meier26b7ff52021-09-07 08:57:43 -070060)
Gao, Xiangdfb8f2d2020-09-21 14:24:29 -070061
Mike Ruberry36c87f12020-11-28 20:09:52 -080062# Protects against includes accidentally setting the default dtype
63assert torch.get_default_dtype() is torch.float32
Adam Paszkebc7bd7a2016-07-21 13:46:59 -040064
Pritam Damaniaf050b162020-01-22 21:05:28 -080065# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
Zachary DeVitodae76162018-10-26 18:18:20 -070066# sharding on sandcastle. This line silences flake warnings
67load_tests = load_tests
68
Xiang Gao36c3fbc2020-09-28 11:38:15 -070069AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32()
Gao, Xiangdfb8f2d2020-09-21 14:24:29 -070070
Victor Bittorf52f1a072021-04-15 15:56:00 -070071@contextlib.contextmanager
72def torch_vital_set(value):
73 stash = None
Shen Li10224432021-08-12 11:39:31 -070074 if 'TORCH_VITAL' in os.environ:
75 stash = os.environ['TORCH_VITAL']
76 os.environ['TORCH_VITAL'] = value
Victor Bittorf52f1a072021-04-15 15:56:00 -070077 try:
78 yield
79 finally:
80 if stash:
Shen Li10224432021-08-12 11:39:31 -070081 os.environ['TORCH_VITAL'] = stash
Victor Bittorf52f1a072021-04-15 15:56:00 -070082 else:
Shen Li10224432021-08-12 11:39:31 -070083 del os.environ['TORCH_VITAL']
Victor Bittorf52f1a072021-04-15 15:56:00 -070084
Victor Bittorf52f1a072021-04-15 15:56:00 -070085# Tests Vital Signs for Torch
Mike Ruberrye0d829a2022-01-24 01:28:07 -080086# FIXME: document or deprecate whatever this is
Victor Bittorf8b6487c2021-06-25 16:27:45 -070087class TestBasicVitalSigns(TestCase):
Victor Bittorf52f1a072021-04-15 15:56:00 -070088 def test_basic_vitals(self):
Shen Li10224432021-08-12 11:39:31 -070089 with torch_vital_set(''):
Victor Bittorf52f1a072021-04-15 15:56:00 -070090 self.assertFalse(torch.vitals_enabled())
Shen Li10224432021-08-12 11:39:31 -070091 with torch_vital_set('ON'):
Victor Bittorf52f1a072021-04-15 15:56:00 -070092 self.assertTrue(torch.vitals_enabled())
93
Victor Bittorf8b6487c2021-06-25 16:27:45 -070094 def test_basic_vitals_read_write(self):
Shen Li10224432021-08-12 11:39:31 -070095 with torch_vital_set('ON'):
Victor Bittorf52f1a072021-04-15 15:56:00 -070096 self.assertTrue(torch.vitals_enabled())
97 # This tests the code path of setting a vital
Shen Li10224432021-08-12 11:39:31 -070098 self.assertTrue(torch.set_vital('Dataloader', 'basic_unit_test', 'TEST_VALUE_STRING'))
99 self.assertIn('TEST_VALUE_STRING', torch.read_vitals())
100 self.assertIn('CUDA.used', torch.read_vitals())
Victor Bittorf8b6487c2021-06-25 16:27:45 -0700101
Victor Bittorf91c076e2021-06-29 14:05:56 -0700102 def test_dataloader_vitals(self):
Shen Li10224432021-08-12 11:39:31 -0700103 with torch_vital_set('ON'):
Victor Bittorf91c076e2021-06-29 14:05:56 -0700104 inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
105 tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
106 dataset = torch.utils.data.TensorDataset(inps, tgts)
107 loader = torch.utils.data.DataLoader(dataset, batch_size=2)
Shen Li10224432021-08-12 11:39:31 -0700108 self.assertIn('Dataloader.enabled\t\t True', torch.read_vitals())
Victor Bittorf91c076e2021-06-29 14:05:56 -0700109
Mike Ruberrye0d829a2022-01-24 01:28:07 -0800110# FIXME: document or deprecate whatever this is
Victor Bittorf8b6487c2021-06-25 16:27:45 -0700111class TestVitalSignsCuda(TestCase):
112 @onlyCUDA
113 def test_cuda_vitals_gpu_only(self, device):
Shen Li10224432021-08-12 11:39:31 -0700114 with torch_vital_set('ON'):
115 self.assertIn('CUDA.used\t\t true', torch.read_vitals())
Victor Bittorf52f1a072021-04-15 15:56:00 -0700116
117
Mike Ruberry248d5852019-09-18 15:31:08 -0700118class TestTorchDeviceType(TestCase):
Edward Yangba1bd412020-03-03 14:33:40 -0800119 exact_dtype = True
120
Mike Ruberry36c87f12020-11-28 20:09:52 -0800121 # TODO: move all tensor creation to common ops
122 def _rand_shape(self, dim, min_size, max_size):
123 shape = []
124 for i in range(dim):
125 shape.append(random.randint(min_size, max_size))
126 return tuple(shape)
127
praneeth9b30fb82021-06-21 20:43:58 -0700128 # Validates that mathematical constants are defined properly, as required by
129 # the Python Array API (https://data-apis.org/array-api/latest/API_specification/constants.html)
130 @onlyCPU
131 def test_constants(self, device):
132 self.assertIsInstance(torch.e, float)
133 self.assertEqual(torch.e, math.e, atol=0, rtol=0)
134
135 self.assertIsInstance(torch.pi, float)
136 self.assertEqual(torch.pi, math.pi, atol=0, rtol=0)
137
138 self.assertIsInstance(torch.nan, float)
139 self.assertEqual(torch.nan, math.nan, equal_nan=True)
140
141 self.assertIsInstance(torch.inf, float)
142 self.assertEqual(torch.inf, math.inf)
143
kshitij12345885a8e52021-11-01 09:21:20 -0700144 @onlyNativeDeviceTypes
Kurt Mohler58835232021-10-05 13:48:45 -0700145 @dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64,
146 torch.bool, torch.float32, torch.complex64, torch.float64,
147 torch.complex128)
148 def test_bytes_to_scalar(self, device, dtype):
149 def rand_byte():
150 if dtype == torch.bool:
151 return torch.randint(0, 2, ()).item()
152 else:
153 return torch.randint(0, 256, ()).item()
154
155 element_size = torch._utils._element_size(dtype)
156
157 for i in range(10):
158 bytes_list = [rand_byte() for _ in range(element_size)]
159 scalar = bytes_to_scalar(bytes_list, dtype, device)
160 self.assertEqual(scalar.storage()._untyped().tolist(), bytes_list)
161
162 @dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64,
163 torch.bool, torch.float32, torch.complex64, torch.float64,
164 torch.complex128)
anjali4111f09f7e2020-05-01 11:44:10 -0700165 def test_storage(self, device, dtype):
Philip Meier0973c5a2022-02-24 21:47:38 -0800166 v = make_tensor((3, 5), dtype=dtype, device=device, low=-9, high=9)
anjali4111f09f7e2020-05-01 11:44:10 -0700167 self.assertEqual(v.storage()[0], v[0][0])
168 self.assertEqual(v.storage()[14], v[2][4])
Kurt Mohler58835232021-10-05 13:48:45 -0700169 v_s = v.storage()
170
171 for el_num in range(v.numel()):
172 dim0 = el_num // v.size(1)
173 dim1 = el_num % v.size(1)
174 self.assertEqual(
175 v_s[el_num],
176 v[dim0][dim1])
177
178 v_s_byte = v.storage()._untyped()
179 el_size = v.element_size()
180
181 for el_num in range(v.numel()):
182 start = el_num * el_size
183 end = start + el_size
184 dim0 = el_num // v.size(1)
185 dim1 = el_num % v.size(1)
186 self.assertEqual(
187 bytes_to_scalar(v_s_byte[start:end], dtype, device),
188 v[dim0][dim1])
189
kshitij12345885a8e52021-11-01 09:21:20 -0700190 @onlyNativeDeviceTypes
Kurt Mohler58835232021-10-05 13:48:45 -0700191 @dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64,
192 torch.bool, torch.float32, torch.complex64, torch.float64,
193 torch.complex128, torch.quint8, torch.qint8, torch.qint32,
194 torch.quint4x2)
195 def test_storage_setitem(self, device, dtype):
196 # Skip quantized dtypes for CUDA, since they're not supported
197 if torch.device(device).type == 'cuda':
198 if dtype in [torch.quint8, torch.qint8, torch.qint32, torch.quint4x2]:
199 return
200
201 storage_type_name = torch.storage._dtype_to_storage_type_map()[dtype]
202 if torch.device(device).type == 'cuda':
203 storage_type = eval('torch.cuda.' + storage_type_name)
204 else:
205 storage_type = eval('torch.' + storage_type_name)
206
207 N = 10
208
209 s = storage_type(N)
210 s[:] = 0
211 l = [0] * N
212 self.assertEqual(s, storage_type(l))
213
214 for i in range(N):
215 s[i] = i
216 l[i] = i
217
218 self.assertEqual(s, storage_type(l))
219
220 l[2:7] = [1] * 5
221 s[2:7] = 1
222 self.assertEqual(s, storage_type(l))
223
kshitij12345885a8e52021-11-01 09:21:20 -0700224 @onlyNativeDeviceTypes
Nikita Shulgabfac65d2022-03-30 14:13:21 -0700225 @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
Kurt Mohler79ddc722022-03-22 16:35:42 -0700226 def test_tensor_storage_type(self, device, dtype):
227 a = make_tensor((10,), dtype=dtype, device=device, low=-9, high=9)
228
229 module = torch.cuda if (torch.device(device).type == 'cuda') else torch
230 expected_storage_type = getattr(module, torch.storage._dtype_to_storage_type_map()[dtype])
231
232 self.assertEqual(a.storage_type(), expected_storage_type)
233
234 @onlyNativeDeviceTypes
Nikita Shulgabfac65d2022-03-30 14:13:21 -0700235 @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
Kurt Mohler58835232021-10-05 13:48:45 -0700236 def test_tensor_from_storage(self, device, dtype):
Philip Meier0973c5a2022-02-24 21:47:38 -0800237 a = make_tensor((4, 5, 3), dtype=dtype, device=device, low=-9, high=9)
Kurt Mohler58835232021-10-05 13:48:45 -0700238 a_s = a.storage()
239 b = torch.tensor(a_s, device=device, dtype=dtype).reshape(a.size())
240 self.assertEqual(a, b)
241 c = torch.tensor(a_s._untyped(), device=device, dtype=dtype).reshape(a.size())
242 self.assertEqual(a, c)
243
Nikita Shulgabfac65d2022-03-30 14:13:21 -0700244 for error_dtype in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16):
Kurt Mohler58835232021-10-05 13:48:45 -0700245 if error_dtype == dtype:
246 continue
247 with self.assertRaisesRegex(RuntimeError, r'Expected a Storage of type'):
248 error_storage = a.to(error_dtype).storage()
249 torch.tensor(error_storage, device=device, dtype=dtype)
250
kshitij12345885a8e52021-11-01 09:21:20 -0700251 @onlyNativeDeviceTypes
Nikita Shulgabfac65d2022-03-30 14:13:21 -0700252 @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
Kurt Mohler58835232021-10-05 13:48:45 -0700253 def test_set_storage(self, device, dtype):
Philip Meier0973c5a2022-02-24 21:47:38 -0800254 a = make_tensor((4, 5, 3), dtype=dtype, device=device, low=-9, high=9)
Kurt Mohler58835232021-10-05 13:48:45 -0700255 a_s = a.storage()
256 b = torch.tensor([], device=device, dtype=dtype).set_(a_s).reshape(a.size())
257 self.assertEqual(a, b)
258 c = torch.tensor([], device=device, dtype=dtype).set_(a_s._untyped()).reshape(a.size())
259 self.assertEqual(a, c)
260
Nikita Shulgabfac65d2022-03-30 14:13:21 -0700261 for error_dtype in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16):
Kurt Mohler58835232021-10-05 13:48:45 -0700262 if error_dtype == dtype:
263 continue
264 with self.assertRaisesRegex(RuntimeError, r'Expected a Storage of type'):
265 error_storage = a.to(error_dtype).storage()
266 b = torch.tensor([], device=device, dtype=dtype).set_(error_storage)
anjali4111f09f7e2020-05-01 11:44:10 -0700267
Kurt Mohlere9afb432022-05-28 15:33:45 +0000268 def _check_storage_meta(self, s, s_check):
269 self.assertTrue(
270 isinstance(s, (torch._UntypedStorage, torch._TypedStorage)) and
271 isinstance(s_check, type(s)),
272 (
273 's and s_check must both be one of _UntypedStorage or '
274 '_TypedStorage, but got'
275 f' {type(s).__name__} and {type(s_check).__name__}'))
276
277 self.assertEqual(s.device.type, 'meta')
278 self.assertEqual(s.nbytes(), s_check.nbytes())
279 self.assertEqual(s.size(), s_check.size())
280 self.assertEqual(s.data_ptr(), 0)
281
282 with self.assertRaisesRegex(NotImplementedError, r'Not available'):
283 s[0]
284
285 if isinstance(s, torch._TypedStorage):
286 self.assertEqual(s.dtype, s_check.dtype)
287 self._check_storage_meta(s._untyped(), s_check._untyped())
288
289 @onlyNativeDeviceTypes
290 @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
291 def test_typed_storage_meta(self, device, dtype):
292 args_list = [
293 [],
294 [0],
295 [100],
296 [[1, 2, 3, 4, 5, 6]],
297 ]
298 for args in args_list:
299 s_check = torch._TypedStorage(*args, dtype=dtype, device=device)
300 s = torch._TypedStorage(*args, dtype=dtype, device='meta')
301 self._check_storage_meta(s, s_check)
302
303 @onlyNativeDeviceTypes
304 def test_untyped_storage_meta(self, device):
305 args_list = [
306 [],
307 [0],
308 [100],
309 [[1, 2, 3, 4, 5, 6]],
310 ]
311 for args in args_list:
312 s_check = torch._UntypedStorage(*args, device=device)
313 s = torch._UntypedStorage(*args, device='meta')
314 self._check_storage_meta(s, s_check)
315
316 @onlyNativeDeviceTypes
317 @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
318 def test_storage_meta_from_tensor(self, device, dtype):
319 t_check = make_tensor((4, 5, 3), dtype=dtype, device=device, low=-9, high=9)
320 t = t_check.to('meta')
321
322 s_check = t_check.storage()
323 s = t.storage()
324 self._check_storage_meta(s, s_check)
325
326 @onlyCPU
327 @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
328 def test_storage_meta_errors(self, device, dtype):
329 s0 = torch._TypedStorage([1, 2, 3, 4], device='meta', dtype=dtype)
330
331 with self.assertRaisesRegex(NotImplementedError, r'Cannot copy out'):
332 s0.cpu()
333
334 with self.assertRaisesRegex(RuntimeError, r'only available on CPU'):
335 s0._share_fd_cpu_()
336
337 with self.assertRaisesRegex(RuntimeError, r'only available on CPU'):
338 s0._share_filename_cpu_()
339
340 if torch.cuda.is_available():
341 with self.assertRaisesRegex(NotImplementedError, r'Cannot copy out'):
342 s0.cuda()
343
344 with self.assertRaisesRegex(RuntimeError, r'only available on CUDA'):
345 s0._share_cuda_()
346
347 with self.assertRaisesRegex(NotImplementedError, r'Cannot copy out'):
348 s0.pin_memory()
349
350 with self.assertRaisesRegex(RuntimeError, r'got unexpected device type'):
351 s0.resize_(10)
352
353 with self.assertRaisesRegex(RuntimeError, r'only available on CPU'):
354 s0.share_memory_()
355
356 with self.assertRaisesRegex(NotImplementedError, r'Not available'):
357 s0.tolist()
358
359 with tempfile.NamedTemporaryFile() as f:
360 with self.assertRaisesRegex(RuntimeError, r'Device not recognized'):
361 s0._write_file(f, True, True, s0.element_size())
362
363 for device in ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']:
364 s1 = torch._TypedStorage([1, 2, 3, 4], device=device, dtype=dtype)
365
366 with self.assertRaisesRegex(NotImplementedError, r'Cannot copy out'):
367 s1.copy_(s0)
368
Kurt Mohler4c279992022-07-05 15:17:36 +0000369 @onlyCUDA
370 def test_module_share_memory(self):
371 # Test fix for issue #80733
372 # See https://github.com/pytorch/pytorch/issues/80733
373 model = torch.nn.Linear(3, 1)
374 model_cuda = model.to('cuda')
375 model.share_memory()
376
anjali4111f09f7e2020-05-01 11:44:10 -0700377 @dtypes(torch.float32, torch.complex64)
378 def test_deepcopy(self, device, dtype):
379 from copy import deepcopy
380 a = torch.randn(5, 5, dtype=dtype, device=device)
381 b = torch.randn(5, 5, dtype=dtype, device=device)
382 c = a.view(25)
383 q = [a, [a.storage(), b.storage()], b, c]
384 w = deepcopy(q)
Mike Ruberry13120bf2020-05-27 06:28:05 -0700385 self.assertEqual(w[0], q[0], atol=0, rtol=0)
386 self.assertEqual(w[1][0], q[1][0], atol=0, rtol=0)
387 self.assertEqual(w[1][1], q[1][1], atol=0, rtol=0)
388 self.assertEqual(w[1], q[1], atol=0, rtol=0)
389 self.assertEqual(w[2], q[2], atol=0, rtol=0)
anjali4111f09f7e2020-05-01 11:44:10 -0700390
391 # Check that deepcopy preserves sharing
392 w[0].add_(1)
393 for i in range(a.numel()):
394 self.assertEqual(w[1][0][i], q[1][0][i] + 1)
395 self.assertEqual(w[3], c + 1)
396 w[2].sub_(1)
397 for i in range(a.numel()):
398 self.assertEqual(w[1][1][i], q[1][1][i] - 1)
399
Alban Desmaison7c62b6e2021-09-27 14:32:41 -0700400 # Check that deepcopy preserves attributes
401 a.foo = 3
402 self.assertEqual(deepcopy(a).foo, 3)
403
anjali4111f09f7e2020-05-01 11:44:10 -0700404 @dtypes(torch.float32, torch.complex64)
405 def test_deepcopy_scalar(self, device, dtype):
406 from copy import deepcopy
407 a = torch.tensor(5, dtype=dtype, device=device)
408 self.assertEqual(a.size(), deepcopy(a).size())
409 self.assertEqual(a, deepcopy(a))
410
Shen Li10224432021-08-12 11:39:31 -0700411 def check_internal_mem_overlap(self, inplace_op, num_inputs,
412 dtype, device,
413 expected_failure=False):
Mike Ruberryb4b8f532019-09-14 17:09:04 -0700414 if isinstance(inplace_op, str):
415 inplace_op = getattr(torch.Tensor, inplace_op)
Mike Ruberry7f183a92019-10-08 09:50:28 -0700416 input = torch.randn(1, dtype=dtype, device=device).expand(3, 3)
Shen Li10224432021-08-12 11:39:31 -0700417 inputs = [input] + [torch.randn_like(input)
418 for i in range(num_inputs - 1)]
Mike Ruberryb4b8f532019-09-14 17:09:04 -0700419 if not expected_failure:
Shen Li10224432021-08-12 11:39:31 -0700420 with self.assertRaisesRegex(RuntimeError, 'single memory location'):
Mike Ruberryb4b8f532019-09-14 17:09:04 -0700421 inplace_op(*inputs)
422 else:
423 with self.assertRaises(AssertionError):
Shen Li10224432021-08-12 11:39:31 -0700424 with self.assertRaisesRegex(RuntimeError, 'single memory location'):
Mike Ruberryb4b8f532019-09-14 17:09:04 -0700425 inplace_op(*inputs)
426
Shen Li10224432021-08-12 11:39:31 -0700427 def unary_check_input_output_mem_overlap(self, data, sz, op,
428 expected_failure=False):
429
Mike Ruberryb4b8f532019-09-14 17:09:04 -0700430 def _test(op, output, input):
431 output_exp = torch.empty_like(output)
432 op(input, out=output_exp)
Mike Ruberry13120bf2020-05-27 06:28:05 -0700433 self.assertEqual(op(input, out=output), output_exp, msg=op.__name__)
Mike Ruberryb4b8f532019-09-14 17:09:04 -0700434
435 # output is identical to input:
436 _test(op, output=data[0:sz], input=data[0:sz])
437 # output and input are independent:
Shen Li10224432021-08-12 11:39:31 -0700438 _test(op, output=data[0:sz], input=data[sz:2 * sz])
Mike Ruberryb4b8f532019-09-14 17:09:04 -0700439 # output partially overlaps with input:
440 if not expected_failure:
Shen Li10224432021-08-12 11:39:31 -0700441 with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
442 _test(op, data[0:sz], data[1:sz + 1])
Mike Ruberryb4b8f532019-09-14 17:09:04 -0700443 else:
444 with self.assertRaises(AssertionError):
Shen Li10224432021-08-12 11:39:31 -0700445 with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
446 _test(op, data[0:sz], data[1:sz + 1])
Peter Belle8e33942021-06-18 16:28:00 -0700447 # output is transpose of input:
448 length = int(math.sqrt(sz))
Shen Li10224432021-08-12 11:39:31 -0700449 input = data[:length**2].view([length, length])
Peter Belle8e33942021-06-18 16:28:00 -0700450 out = input.t()
451 if not expected_failure:
Shen Li10224432021-08-12 11:39:31 -0700452 with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
Peter Belle8e33942021-06-18 16:28:00 -0700453 _test(op, out, input)
454 else:
455 with self.assertRaises(AssertionError):
Shen Li10224432021-08-12 11:39:31 -0700456 with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
Peter Belle8e33942021-06-18 16:28:00 -0700457 _test(op, out, input)
Mike Ruberryb4b8f532019-09-14 17:09:04 -0700458
Shen Li10224432021-08-12 11:39:31 -0700459 def ternary_check_input_output_mem_overlap(self, op, device,
460 expected_failure=False):
Peter Belle8e33942021-06-18 16:28:00 -0700461 sz = 9
Mike Ruberryb4b8f532019-09-14 17:09:04 -0700462 data = torch.randn(2 * sz, device=device)
463 other1 = torch.randn(sz, device=device)
464 other2 = torch.randn(sz, device=device)
465
466 self.unary_check_input_output_mem_overlap(
Shen Li10224432021-08-12 11:39:31 -0700467 data, sz, lambda input, out:
468 op(input, other1.view(input.shape), other2.view(input.shape), out=out),
469 expected_failure=expected_failure)
Mike Ruberryb4b8f532019-09-14 17:09:04 -0700470
471 self.unary_check_input_output_mem_overlap(
Shen Li10224432021-08-12 11:39:31 -0700472 data, sz, lambda input, out:
473 op(other1.view(input.shape), input, other2.view(input.shape), out=out),
474 expected_failure=expected_failure)
Mike Ruberryb4b8f532019-09-14 17:09:04 -0700475
476 self.unary_check_input_output_mem_overlap(
Shen Li10224432021-08-12 11:39:31 -0700477 data, sz, lambda input, out:
478 op(other1.view(input.shape), other2.view(input.shape), input, out=out),
479 expected_failure=expected_failure)
480
Mike Ruberry248d5852019-09-18 15:31:08 -0700481 def _select_broadcastable_dims(self, dims_full=None):
482 # select full dimensionality
483 if dims_full is None:
484 dims_full = []
485 ndims = random.randint(1, 4)
486 dims_full = [random.randint(1, 8) for _ in range(ndims)]
487 else:
488 ndims = len(dims_full)
489
490 # select actual dimensions for ops:
491 # larger: full ndims, individual sizes may be reduced
492 # smaller: possibly reduced ndims, sizes may be reduced
493 smaller_ndims = random.randint(1, ndims)
494 dims_small = []
495 dims_large = []
496 for i in range(ndims - 1, -1, -1):
497 j = random.randint(1, 3)
498 if j == 1: # no reduced singleton dimension
499 ds = dims_full[i]
500 dl = dims_full[i]
501 elif j == 2: # larger may have reduced singleton dimension
502 ds = dims_full[i]
503 dl = 1 if len(dims_small) < smaller_ndims else dims_full[i]
504 elif j == 3: # smaller may have reduced singleton dimension
505 ds = 1
506 dl = dims_full[i]
507 dims_large = [dl] + dims_large
508 if len(dims_small) < smaller_ndims:
509 dims_small = [ds] + dims_small
510 return (dims_small, dims_large, dims_full)
511
Gregory Chanan61df98a2019-11-15 17:13:36 -0800512 # collected tests of ops that used scalar_check in Declarations.cwrap for
513 # correctness
514 def test_scalar_check(self, device):
515 zero_d = torch.randn((), device=device)
Gregory Chanan98362972019-11-15 17:13:36 -0800516 one_d = torch.randn((1,), device=device)
517
Gregory Chanan98362972019-11-15 17:13:36 -0800518 # remainder
519 self.assertEqual((), torch.remainder(zero_d, zero_d).shape)
520 self.assertEqual((), torch.remainder(zero_d, 2).shape)
521 self.assertEqual((1,), torch.remainder(zero_d, one_d).shape)
522 self.assertEqual((1,), torch.remainder(one_d, zero_d).shape)
523
Gregory Chananfe575b42019-11-15 17:13:36 -0800524 # fmod
525 self.assertEqual((), torch.fmod(zero_d, zero_d).shape)
526 self.assertEqual((), torch.fmod(zero_d, 2).shape)
527 self.assertEqual((1,), torch.fmod(zero_d, one_d).shape)
528 self.assertEqual((1,), torch.fmod(one_d, zero_d).shape)
529
Gregory Chanan8b297012019-12-03 07:02:45 -0800530 # exp, cos, cosh, tan, atan, tanh, erf, erfc, reciprocal
Gregory Chanan56972952019-12-02 16:08:03 -0800531 self.assertEqual((), torch.exp(zero_d).shape)
532 self.assertEqual((), torch.cos(zero_d).shape)
533 self.assertEqual((), torch.cosh(zero_d).shape)
534 self.assertEqual((), torch.tan(zero_d).shape)
535 self.assertEqual((), torch.atan(zero_d).shape)
krshrimali335e4a12020-06-04 11:38:14 -0700536 self.assertEqual((), torch.acosh(zero_d).shape)
537 self.assertEqual((), torch.asinh(zero_d).shape)
538 self.assertEqual((), torch.atanh(zero_d).shape)
Gregory Chanan56972952019-12-02 16:08:03 -0800539 self.assertEqual((), torch.tanh(zero_d).shape)
540 self.assertEqual((), torch.erf(zero_d).shape)
541 self.assertEqual((), torch.erfc(zero_d).shape)
Gregory Chanan8b297012019-12-03 07:02:45 -0800542 self.assertEqual((), torch.reciprocal(zero_d).shape)
Gregory Chanan56972952019-12-02 16:08:03 -0800543 self.assertEqual((1,), torch.exp(one_d).shape)
544 self.assertEqual((1,), torch.cos(one_d).shape)
545 self.assertEqual((1,), torch.cosh(one_d).shape)
546 self.assertEqual((1,), torch.tan(one_d).shape)
547 self.assertEqual((1,), torch.atan(one_d).shape)
krshrimali335e4a12020-06-04 11:38:14 -0700548 self.assertEqual((1,), torch.acosh(one_d).shape)
549 self.assertEqual((1,), torch.asinh(one_d).shape)
550 self.assertEqual((1,), torch.atanh(one_d).shape)
Gregory Chanan56972952019-12-02 16:08:03 -0800551 self.assertEqual((1,), torch.tanh(one_d).shape)
552 self.assertEqual((1,), torch.erf(one_d).shape)
553 self.assertEqual((1,), torch.erfc(one_d).shape)
Gregory Chanan8b297012019-12-03 07:02:45 -0800554 self.assertEqual((1,), torch.reciprocal(one_d).shape)
Gregory Chanan56972952019-12-02 16:08:03 -0800555
Gregory Chanan61798862019-12-03 07:02:45 -0800556 # clamp
557 self.assertEqual((), torch.clamp(zero_d, min=0, max=1).shape)
558 self.assertEqual((), torch.clamp(zero_d, min=0).shape)
559 self.assertEqual((), torch.clamp(zero_d, max=1).shape)
560 self.assertEqual((1,), torch.clamp(one_d, min=0, max=1).shape)
561 self.assertEqual((1,), torch.clamp(one_d, min=0).shape)
562 self.assertEqual((1,), torch.clamp(one_d, max=1).shape)
563
anjali4115b815d92020-01-17 10:45:36 -0800564 # cumsum, cumprod, cummax, cummin
kshitij1234534877442020-05-21 09:09:41 -0700565 self.assertEqual((), torch.logcumsumexp(zero_d, 0).shape)
Gregory Chanand87655f2019-11-15 17:13:36 -0800566 self.assertEqual((), torch.cumsum(zero_d, 0).shape)
567 self.assertEqual((), torch.cumprod(zero_d, 0).shape)
anjali4118dc67a02020-01-14 16:36:56 -0800568 self.assertEqual((), torch.cummax(zero_d, 0)[0].shape)
anjali4115b815d92020-01-17 10:45:36 -0800569 self.assertEqual((), torch.cummin(zero_d, 0)[0].shape)
Gregory Chanand87655f2019-11-15 17:13:36 -0800570
Gregory Chanane5d571a2019-12-06 07:47:33 -0800571 # sort, topk
Gregory Chanan79f06362019-11-15 17:13:36 -0800572 self.assertEqual([(), ()], [x.shape for x in torch.sort(zero_d, 0, False)])
573 self.assertEqual([(), ()], [x.shape for x in torch.sort(zero_d, 0, True)])
Gregory Chanane5d571a2019-12-06 07:47:33 -0800574 self.assertEqual([(), ()], [x.shape for x in torch.topk(zero_d, 1, 0, False)])
575 self.assertEqual([(), ()], [x.shape for x in torch.topk(zero_d, 1, 0, True)])
Gregory Chanan79f06362019-11-15 17:13:36 -0800576
Gregory Chanana6a31c62019-11-15 17:13:36 -0800577 # max, min
578 self.assertEqual((), torch.max(zero_d, zero_d).shape)
579 self.assertEqual((1,), torch.max(one_d, zero_d).shape)
580 self.assertEqual((1,), torch.max(zero_d, one_d).shape)
581 self.assertEqual((), torch.min(zero_d, zero_d).shape)
582 self.assertEqual((1,), torch.min(one_d, zero_d).shape)
583 self.assertEqual((1,), torch.min(zero_d, one_d).shape)
584
Gregory Chananf9943772019-11-25 08:53:57 -0800585 zero_d_int = torch.tensor(1, device=device)
586 one_d_int = torch.tensor([1], device=device)
587
588 # lshift, rshift
589 self.assertEqual((), (zero_d_int >> zero_d_int).shape)
590 self.assertEqual((), (zero_d_int >> 1).shape)
591 self.assertEqual((1,), (one_d_int >> zero_d_int).shape)
592 self.assertEqual((1,), (zero_d_int >> one_d_int).shape)
593 self.assertEqual((1,), (one_d_int >> 1).shape)
594
595 self.assertEqual((), (zero_d_int << zero_d_int).shape)
596 self.assertEqual((), (zero_d_int << 1).shape)
597 self.assertEqual((1,), (one_d_int << zero_d_int).shape)
598 self.assertEqual((1,), (zero_d_int << one_d_int).shape)
599 self.assertEqual((1,), (one_d_int << 1).shape)
600
Gregory Chanan94ad7542019-11-25 08:53:57 -0800601 # or
602 self.assertEqual((), (zero_d_int | zero_d_int).shape)
603 self.assertEqual((), (zero_d_int | 1).shape)
604 self.assertEqual((1,), (one_d_int | zero_d_int).shape)
605 self.assertEqual((1,), (zero_d_int | one_d_int).shape)
606 self.assertEqual((1,), (one_d_int | 1).shape)
607
Gregory Chanan0c9c62b2019-11-25 08:53:57 -0800608 # and
609 self.assertEqual((), (zero_d_int & zero_d_int).shape)
610 self.assertEqual((), (zero_d_int & 1).shape)
611 self.assertEqual((1,), (one_d_int & zero_d_int).shape)
612 self.assertEqual((1,), (zero_d_int & one_d_int).shape)
613 self.assertEqual((1,), (one_d_int & 1).shape)
614
615 # clone
616 self.assertEqual((), zero_d.clone().shape)
617
Gregory Chanance5f1a12019-11-25 08:53:57 -0800618 zero_d_bool = torch.tensor(True, device=device)
619 one_d_bool = torch.tensor([True], device=device)
620
621 # masked_select
622 self.assertEqual((1,), torch.masked_select(zero_d_bool, zero_d_bool).shape)
623 self.assertEqual((1,), torch.masked_select(zero_d_bool, one_d_bool).shape)
624 self.assertEqual((1,), torch.masked_select(one_d_bool, zero_d_bool).shape)
625
626 zero_d_uint8 = torch.tensor(1, dtype=torch.uint8, device=device)
627 one_d_uint8 = torch.tensor([1], dtype=torch.uint8, device=device)
628
629 with warnings.catch_warnings():
630 warnings.simplefilter("ignore")
Shen Li10224432021-08-12 11:39:31 -0700631 self.assertEqual((1,), torch.masked_select(zero_d_uint8, zero_d_uint8).shape)
Gregory Chanance5f1a12019-11-25 08:53:57 -0800632 self.assertEqual((1,), torch.masked_select(zero_d_uint8, one_d_uint8).shape)
633 self.assertEqual((1,), torch.masked_select(one_d_uint8, zero_d_uint8).shape)
634
Gregory Chanan16606e12019-11-25 08:53:57 -0800635 # mode
Shen Li10224432021-08-12 11:39:31 -0700636 self.assertEqual([(), ()], [x.shape for x in torch.mode(zero_d, dim=0, keepdim=True)])
637 self.assertEqual([(), ()], [x.shape for x in torch.mode(zero_d, dim=0, keepdim=False)])
638 self.assertEqual([(1,), (1,)], [x.shape for x in torch.mode(one_d, dim=0, keepdim=True)])
639 self.assertEqual([(), ()], [x.shape for x in torch.mode(one_d, dim=0, keepdim=False)])
Gregory Chanan16606e12019-11-25 08:53:57 -0800640
Gregory Chanan71603002019-11-25 08:53:57 -0800641 # max
Shen Li10224432021-08-12 11:39:31 -0700642 self.assertEqual([(), ()], [x.shape for x in torch.max(zero_d, dim=0, keepdim=True)])
643 self.assertEqual([(), ()], [x.shape for x in torch.max(zero_d, dim=0, keepdim=False)])
644 self.assertEqual([(1,), (1,)], [x.shape for x in torch.max(one_d, dim=0, keepdim=True)])
645 self.assertEqual([(), ()], [x.shape for x in torch.max(one_d, dim=0, keepdim=False)])
Gregory Chanan71603002019-11-25 08:53:57 -0800646
Xiang Gaoa860be82020-08-31 04:52:50 -0700647 # amax
648 self.assertEqual((), torch.amax(zero_d, dim=0, keepdim=True).shape)
649 self.assertEqual((), torch.amax(zero_d, dim=0, keepdim=False).shape)
650 self.assertEqual((1,), torch.amax(one_d, dim=0, keepdim=True).shape)
651 self.assertEqual((), torch.amax(one_d, dim=0, keepdim=False).shape)
652
Gregory Chanan71603002019-11-25 08:53:57 -0800653 # min
Shen Li10224432021-08-12 11:39:31 -0700654 self.assertEqual([(), ()], [x.shape for x in torch.min(zero_d, dim=0, keepdim=True)])
655 self.assertEqual([(), ()], [x.shape for x in torch.min(zero_d, dim=0, keepdim=False)])
656 self.assertEqual([(1,), (1,)], [x.shape for x in torch.min(one_d, dim=0, keepdim=True)])
657 self.assertEqual([(), ()], [x.shape for x in torch.min(one_d, dim=0, keepdim=False)])
Gregory Chanan71603002019-11-25 08:53:57 -0800658
Xiang Gaoa860be82020-08-31 04:52:50 -0700659 # amin
660 self.assertEqual((), torch.amin(zero_d, dim=0, keepdim=True).shape)
661 self.assertEqual((), torch.amin(zero_d, dim=0, keepdim=False).shape)
662 self.assertEqual((1,), torch.amin(one_d, dim=0, keepdim=True).shape)
663 self.assertEqual((), torch.amin(one_d, dim=0, keepdim=False).shape)
664
Gregory Chanan0c673112019-11-25 08:53:57 -0800665 # set_
666 zero_d_clone = zero_d.clone()
667 one_d_clone = one_d.clone()
668 self.assertEqual((), zero_d_clone.set_(one_d.storage(), 0, (), ()).shape)
669 self.assertEqual((1,), zero_d_clone.set_(one_d.storage(), 0, (1,), (1,)).shape)
670 self.assertEqual((), one_d_clone.set_(one_d.storage(), 0, (), ()).shape)
671 self.assertEqual((1,), one_d_clone.set_(one_d.storage(), 0, (1,), (1,)).shape)
672
Gregory Chanan79a830a2019-11-26 08:09:03 -0800673 self.assertEqual((), zero_d.clone().set_(zero_d).shape)
674 self.assertEqual((), one_d.clone().set_(zero_d).shape)
675 self.assertEqual((1,), zero_d.clone().set_(one_d).shape)
676 self.assertEqual((1,), one_d.clone().set_(one_d).shape)
677
Gregory Chanan72ac4562019-11-26 08:09:03 -0800678 # take
679 self.assertEqual((), torch.randn((2, 3), device=device).take(zero_d_int).shape)
680 self.assertEqual((1,), torch.randn((2, 3), device=device).take(one_d_int).shape)
681
Gregory Chanandbce53f2019-11-26 08:09:03 -0800682 # gather
Shen Li10224432021-08-12 11:39:31 -0700683 self.assertEqual((), torch.gather(zero_d, 0, torch.zeros((), dtype=torch.int64, device=device)).shape)
684 self.assertEqual((1,), torch.gather(zero_d, 0, torch.zeros((1,), dtype=torch.int64, device=device)).shape)
685 self.assertEqual((), torch.gather(one_d, 0, torch.zeros((), dtype=torch.int64, device=device)).shape)
686 self.assertEqual((1,), torch.gather(one_d, 0, torch.zeros((1,), dtype=torch.int64, device=device)).shape)
Gregory Chanandbce53f2019-11-26 08:09:03 -0800687
Gregory Chanan0b253712019-11-27 14:48:05 -0800688 # normal
Kurt Mohler6c235ef2021-03-31 21:04:52 -0700689 # std must be >= 0
690 zero_d_ge_0 = torch.rand((), device=device)
Gregory Chanan0b253712019-11-27 14:48:05 -0800691 # documentation says out shape matches shape of mean
Kurt Mohler6c235ef2021-03-31 21:04:52 -0700692 self.assertEqual((), torch.normal(zero_d, zero_d_ge_0).shape)
693 self.assertEqual((1,), torch.normal(one_d, zero_d_ge_0).shape)
694 self.assertEqual((), torch.normal(1, zero_d_ge_0).shape)
Gregory Chanan0b253712019-11-27 14:48:05 -0800695 self.assertEqual((), torch.normal(zero_d, 1).shape)
696 self.assertEqual((1,), torch.normal(one_d, 1).shape)
697 # TODO: this behavior differs on CPU and GPU, see https://github.com/pytorch/pytorch/issues/30480.
698 # self.assertEqual((), torch.normal(zero_d, one_d).shape)
699 # self.assertEqual((), torch.normal(1, one_d).shape)
700
Gregory Chanan26077722019-12-05 08:03:57 -0800701 # convolutions. Yes, we are testing nn.functional here; seems justified
702 # given its similar to the other tests
703 w = torch.randn(2, 1, 3, 3, device=device).div_(2).requires_grad_()
Shen Li10224432021-08-12 11:39:31 -0700704 self.assertRaises(RuntimeError, lambda: torch.nn.functional.conv2d(zero_d, w, groups=1))
705 self.assertRaises(RuntimeError, lambda: torch.nn.functional.conv2d(zero_d, w, groups=2))
Gregory Chanan26077722019-12-05 08:03:57 -0800706
Gregory Chananfa2aa242019-12-04 12:27:51 -0800707 # nll_loss -- verify input can't be 0-dimensional.
Shen Li10224432021-08-12 11:39:31 -0700708 self.assertRaises(ValueError, lambda: torch.nn.functional.nll_loss(zero_d, zero_d, reduction='none'))
709 self.assertRaises(ValueError, lambda: torch.nn.functional.nll_loss(zero_d, one_d, reduction='none'))
Gregory Chanan786de332019-12-04 12:27:51 -0800710 # verify output is 0-dimensional when reduction != 'none'
Shen Li10224432021-08-12 11:39:31 -0700711 for (input, target) in ((torch.randn(1, 1, device=device), torch.tensor([0], device=device)),
712 (torch.randn(1, 1, 1, 1, device=device), torch.tensor([[[0]]], device=device))):
713 self.assertEqual((), torch.nn.functional.nll_loss(input, target, reduction='mean').shape)
714 self.assertEqual((), torch.nn.functional.nll_loss(input, target, reduction='sum').shape)
Gregory Chananfa2aa242019-12-04 12:27:51 -0800715
Gregory Chanan473a0442019-12-05 08:03:57 -0800716 # multilabel_margin_loss
717 for input in (zero_d, one_d, torch.randn(1, 1, device=device)):
Shen Li10224432021-08-12 11:39:31 -0700718 for target in (torch.tensor(0, device=device), torch.tensor([0], device=device), torch.tensor([[0]], device=device)):
719 if (input.dim() <= 1 and target.dim() <= 1) or (input.dim() == 2 and target.dim() == 2):
Gregory Chanan50625792019-12-05 08:03:57 -0800720 output_shape = (target.shape[0],) if target.dim() == 2 else ()
Shen Li10224432021-08-12 11:39:31 -0700721 self.assertEqual(output_shape,
722 torch.nn.functional.multilabel_margin_loss(input, target, reduction='none').shape)
723 self.assertEqual((), torch.nn.functional.multilabel_margin_loss(input, target, reduction='mean').shape)
724 self.assertEqual((), torch.nn.functional.multilabel_margin_loss(input, target, reduction='sum').shape)
Gregory Chanan473a0442019-12-05 08:03:57 -0800725 else:
Shen Li10224432021-08-12 11:39:31 -0700726 self.assertRaises(RuntimeError,
727 lambda: torch.nn.functional.multilabel_margin_loss(input, target, reduction='none'))
728 self.assertRaises(RuntimeError,
729 lambda: torch.nn.functional.multilabel_margin_loss(input, target, reduction='mean'))
730 self.assertRaises(RuntimeError,
731 lambda: torch.nn.functional.multilabel_margin_loss(input, target, reduction='sum'))
Gregory Chanan473a0442019-12-05 08:03:57 -0800732
Gregory Chanan377131b2019-12-06 09:03:11 -0800733 # multi_margin_loss
734 for input in (zero_d, one_d, torch.randn(1, 1, device=device)):
Shen Li10224432021-08-12 11:39:31 -0700735 for target in (torch.tensor(0, device=device), torch.tensor([0], device=device)):
736 self.assertEqual(target.shape, torch.nn.functional.multi_margin_loss(input, target, reduction='none').shape)
737 self.assertEqual((), torch.nn.functional.multi_margin_loss(input, target, reduction='mean').shape)
738 self.assertEqual((), torch.nn.functional.multi_margin_loss(input, target, reduction='sum').shape)
Gregory Chanan377131b2019-12-06 09:03:11 -0800739
Mike Ruberryb64fc3c2020-04-25 21:16:50 -0700740 # Uses mismatched arange out size to trigger a warning
Animesh Jain1d90d6e2022-07-07 18:57:31 +0000741 @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
Edward Z. Yangee955b82022-04-19 19:56:43 -0700742 @unittest.skipIf(TEST_WITH_CROSSREF, "crossref perturbs line numbering")
Mike Ruberryb64fc3c2020-04-25 21:16:50 -0700743 def test_cpp_warnings_have_python_context(self, device):
Mike Ruberryb64fc3c2020-04-25 21:16:50 -0700744 # Creates long string in advance to avoid a too-long Python line
745 s = ".+Triggered internally at.+RangeFactories.+"
746
747 def cpp_warn_fn():
748 out = torch.empty((5,))
749 torch.arange(0, 3, out=out)
750 return out
751
752 # Checks eager-mode cpp warning
753 with warnings.catch_warnings(record=True) as w:
754 cpp_warn_fn()
755 frameinfo = inspect.getframeinfo(inspect.currentframe())
756 warning = w[0]
757
758 # Checks for cpp context in the warning message
Eric Sauser2d4291f2022-05-19 14:04:13 +0000759 escaped_warning_message = str(warning.message).encode('unicode_escape')
Alex Hedgescb2b7b12022-06-21 01:12:21 +0000760 self.assertTrue(re.search(s, repr(escaped_warning_message), re.IGNORECASE) is not None)
Mike Ruberryb64fc3c2020-04-25 21:16:50 -0700761
762 # Checks the Python features of the warning
763 # Note: the eager mode warning refers to the line in the function
764 # that throws the warning.
765 self.assertEqual(frameinfo.lineno - 6, warning.lineno)
766 self.assertEqual(len(w), 1)
767
768 # Checks jitted cpp warning
769 with warnings.catch_warnings(record=True) as w:
770 scripted_cpp_warn_fn = torch.jit.script(cpp_warn_fn)
771 scripted_cpp_warn_fn()
772 warning = w[0]
773
774 # Checks for cpp context in the warning message
Eric Sauser2d4291f2022-05-19 14:04:13 +0000775 escaped_warning_message = str(warning.message).encode('unicode_escape')
Alex Hedgescb2b7b12022-06-21 01:12:21 +0000776 self.assertTrue(re.search(s, repr(escaped_warning_message), re.IGNORECASE) is not None)
Mike Ruberryb64fc3c2020-04-25 21:16:50 -0700777
778 # Checks the Python features of the warning
779 # Note: the jitted warning's lineno refers to the call to the jitted
780 # function, which in our test suite has a layer of indirection
781 # that makes checking the Python lineno fragile
782 self.assertEqual(len(w), 1)
783
784 # Checks jitted Python warning
785 def warn_fn():
786 warnings.warn("Warning!")
787
788 # The jit mimics an eager-mode Python warning in this case
789 with warnings.catch_warnings(record=True) as w:
790 scripted_warn_fn = torch.jit.script(warn_fn)
791 scripted_warn_fn()
792 frameinfo = inspect.getframeinfo(inspect.currentframe())
793 warning = w[0]
794
Shen Li10224432021-08-12 11:39:31 -0700795 self.assertTrue(re.search('Warning!', str(warning.message)) is not None)
Mike Ruberryb64fc3c2020-04-25 21:16:50 -0700796
797 # Checks the Python features of the warning
798 self.assertEqual(frameinfo.lineno - 6, warning.lineno)
799 self.assertEqual(len(w), 1)
800
Mike Ruberrye0d829a2022-01-24 01:28:07 -0800801 # FIXME: move to test_testing
mattipb97a0402021-02-08 08:18:50 -0800802 @onlyCPU
803 def test_warn_always_caught(self, device):
804 # Check that we can catch a TORCH_WARN_ONCE warning twice
805 # since assertWarnsOnceRegex uses set_warn_always(True) which changes
806 # TORCH_WARN_ONCE to TORCH_WARN
807 a = np.arange(10)
808 a.flags.writeable = False
Nick Andersonf9ea41f2021-11-11 13:01:18 -0800809 with self.assertWarnsOnceRegex(UserWarning, '.*non-writable.*'):
mattipb97a0402021-02-08 08:18:50 -0800810 torch.from_numpy(a)
811
812 # OK, got it once, now try again
Nick Andersonf9ea41f2021-11-11 13:01:18 -0800813 with self.assertWarnsOnceRegex(UserWarning, '.*non-writable.*'):
mattipb97a0402021-02-08 08:18:50 -0800814 torch.from_numpy(a)
815
mattip54a24982021-03-08 03:30:11 -0800816 # Make sure emitting two warnings will pass the assertWarnsOnceRegex
817 # context manager
Nick Andersonf9ea41f2021-11-11 13:01:18 -0800818 with self.assertWarnsOnceRegex(UserWarning, '.*non-writable.*'):
mattip54a24982021-03-08 03:30:11 -0800819 torch.from_numpy(a)
820 torch.from_numpy(a)
mattipb97a0402021-02-08 08:18:50 -0800821
Kshiteej K497ae272022-06-01 18:38:29 +0000822 @onlyNativeDeviceTypes
823 def test_complex_half_experimental_warning(self, device):
824 msg = 'ComplexHalf support is experimental'
825 with self.assertWarnsOnceRegex(UserWarning, msg):
826 t = torch.randn(3, dtype=torch.chalf, device=device)
827
828 with self.assertWarnsOnceRegex(UserWarning, msg):
829 torch.rand(3, dtype=torch.chalf, device=device)
830
831 with self.assertWarnsOnceRegex(UserWarning, msg):
832 torch.empty(3, dtype=torch.chalf, device=device)
833
834 with self.assertWarnsOnceRegex(UserWarning, msg):
835 torch.ones(3, dtype=torch.chalf, device=device)
836
837 with self.assertWarnsOnceRegex(UserWarning, msg):
838 torch.zeros(3, dtype=torch.chalf, device=device)
839
840 with self.assertWarnsOnceRegex(UserWarning, msg):
841 torch.randn_like(t)
842
843 with self.assertWarnsOnceRegex(UserWarning, msg):
844 torch.rand_like(t)
845
846 with self.assertWarnsOnceRegex(UserWarning, msg):
847 torch.empty_like(t)
848
849 with self.assertWarnsOnceRegex(UserWarning, msg):
850 torch.ones_like(t)
851
852 with self.assertWarnsOnceRegex(UserWarning, msg):
853 torch.zeros_like(t)
854
855 with self.assertWarnsOnceRegex(UserWarning, msg):
856 # t + 1 allocates a new tensor for result using empty
857 t + 1
858
Mike Ruberry36c87f12020-11-28 20:09:52 -0800859 # TODO: this test should be in test_nn.py
Vitaly Fedyuninea514c82020-02-19 16:41:23 -0800860 def test_conv_transposed_backward_agnostic_to_memory_format(self, device):
861 in_channels = 64
862 out_channels = 128
863 scale_factor = 8
864 batch_size = 8
865 length = 16
866
867 conv = torch.nn.ConvTranspose1d(
Shen Li10224432021-08-12 11:39:31 -0700868 in_channels, out_channels, kernel_size=scale_factor * 2, stride=scale_factor).to(device)
Vitaly Fedyuninea514c82020-02-19 16:41:23 -0800869 layer_norm = torch.nn.LayerNorm(out_channels).to(device)
870
871 input_ = torch.randn(batch_size, in_channels, length).to(device).contiguous()
872 input_ = conv(input_).contiguous()
873 input_ = layer_norm(input_.transpose(1, 2).contiguous()).contiguous()
874 input_.sum().backward()
875
Thomas Viehmann33b77902021-11-05 08:26:36 -0700876 # 3d
877 conv = torch.nn.ConvTranspose3d(3, 3, kernel_size=3).to(device)
878 input = torch.randn(batch_size, 3, length, length, length, device=device)
879 out = conv(input)
880 out.backward(torch.ones_like(out).transpose(-2, -1))
881
Mike Ruberry36c87f12020-11-28 20:09:52 -0800882 # TODO: this test should be in test_nn.py
Jane Xud0b32152021-03-08 16:30:56 -0800883 @onlyCUDA
Shen Li10224432021-08-12 11:39:31 -0700884 @largeTensorTest('12GB')
Peter Bell7417b4c2020-06-03 06:40:02 -0700885 def test_conv_transposed_large(self, device):
886 # ConvTranspose3d works for large input tensors (gh-32866)
887 in_channels = 64
888 out_channels = 128
889 kernel_size = 5
890
891 conv = torch.nn.ConvTranspose3d(
Shen Li10224432021-08-12 11:39:31 -0700892 in_channels, out_channels, kernel_size=kernel_size,
893 stride=2, padding=2, output_padding=1).to(device)
Peter Bell7417b4c2020-06-03 06:40:02 -0700894
895 x = torch.rand([1, 64, 8, 128, 172]).to(device)
896 y = conv(x)
897
Gregory Chananbee4aca2019-10-23 13:45:19 -0700898 def test_is_set_to(self, device):
899 t1 = torch.empty(3, 4, 9, 10, device=device)
900 t2 = torch.empty(3, 4, 9, 10, device=device)
901 t3 = torch.tensor([], device=device).set_(t1)
902 t4 = t3.clone().resize_(12, 90)
903 self.assertFalse(t1.is_set_to(t2))
904 self.assertTrue(t1.is_set_to(t3))
905 self.assertTrue(t3.is_set_to(t1), "is_set_to should be symmetric")
906 self.assertFalse(t1.is_set_to(t4))
Shen Li10224432021-08-12 11:39:31 -0700907 self.assertFalse(torch.tensor([]).is_set_to(torch.tensor([])),
908 "Tensors with no storages should not appear to be set "
909 "to each other")
Gregory Chananbee4aca2019-10-23 13:45:19 -0700910
911 t1 = torch.tensor([True, True], dtype=torch.bool, device=device)
912 t2 = torch.tensor([0], dtype=torch.bool, device=device).set_(t1)
913 self.assertTrue(t1.is_set_to(t2))
914
Gregory Chanan4f0a3502019-10-24 09:13:30 -0700915 # test that sizes must match
916 t1 = torch.empty([2, 3, 4], device=device)
917 t2 = t1.view(4, 3, 2)
918 self.assertFalse(t1.is_set_to(t2))
919 self.assertFalse(t2.is_set_to(t1))
920
Gregory Chananbee4aca2019-10-23 13:45:19 -0700921 # test that legacy empty size behavior used to be respected (i.e. all
922 # empty tensors were logically collapsed to size [0]).
Gregory Chanan4f0a3502019-10-24 09:13:30 -0700923 t1 = torch.empty([2, 5, 0], device=device)
Gregory Chananbee4aca2019-10-23 13:45:19 -0700924 t2 = t1.view([0])
925 self.assertFalse(t1.is_set_to(t2))
Gregory Chanan4f0a3502019-10-24 09:13:30 -0700926 self.assertFalse(t2.is_set_to(t1))
Gregory Chananbee4aca2019-10-23 13:45:19 -0700927
Philip Meier1f74e082022-02-16 18:25:35 -0800928 # See https://github.com/pytorch/pytorch/issues/72650
Kulin Sethe011a8e2022-05-13 18:28:53 +0000929 @skipIfMps
Philip Meier1f74e082022-02-16 18:25:35 -0800930 @skipMeta
931 @parametrize(
932 "fn",
933 [
934 "dist", "atan2", "pow", "lerp", "add", "sub", "mul", "div", "fmod", "remainder", "eq", "ge", "gt", "le",
935 "lt", "max", "min", "ne", "addcdiv", "addcmul", "masked_scatter", "masked_select", "masked_fill", "map",
936 "map2", "copy",
937 ],
938 )
939 def test_broadcast(self, fn, device):
Mike Ruberryb4b8f532019-09-14 17:09:04 -0700940 # functions with three tensor arguments
Peter Bellb0ac4252020-01-14 11:32:04 -0800941 fns_3_args = {"map2"}
942 fns_value_kwarg = {"addcdiv", "addcmul"}
Mike Ruberryb4b8f532019-09-14 17:09:04 -0700943
Philip Meier1f74e082022-02-16 18:25:35 -0800944 (dims_small, dims_large, dims_full) = self._select_broadcastable_dims()
945 full1d = torch.randn(*dims_full, device=device).flatten().float()
946 small = torch.randn(*dims_small, device=device).float()
947 large = torch.randn(*dims_large, device=device).float()
948 small_expanded = small.expand(*dims_full)
949 large_expanded = large.expand(*dims_full)
950 small2 = None
951 small2_expanded = None
952 if fn in fns_3_args or fn in fns_value_kwarg:
953 # create another smaller tensor
954 (dims_small2, _, _) = self._select_broadcastable_dims(dims_full)
955 small2 = torch.randn(*dims_small2, device=device).float()
956 small2_expanded = small2.expand(*dims_full)
Mike Ruberryb4b8f532019-09-14 17:09:04 -0700957
Philip Meier1f74e082022-02-16 18:25:35 -0800958 if small.is_cuda and fn in ['map', 'map2']:
959 # map and map2 are not implementd on CUDA tensors
960 return
Mike Ruberryb4b8f532019-09-14 17:09:04 -0700961
Philip Meier1f74e082022-02-16 18:25:35 -0800962 if hasattr(large_expanded, fn):
963 # run through tensor versions of functions
964 # and verify fully expanded inputs give same results
965 expanded = {large: large_expanded, small: small_expanded, small2: small2_expanded}
Mike Ruberryb4b8f532019-09-14 17:09:04 -0700966
Philip Meier1f74e082022-02-16 18:25:35 -0800967 def tensorfn(myfn, t1, t2):
Mike Ruberryb4b8f532019-09-14 17:09:04 -0700968 if fn == "lerp":
Philip Meier1f74e082022-02-16 18:25:35 -0800969 return myfn(t1, 0.5)
970 elif fn == "masked_select":
971 return myfn(t1 < 0)
Mike Ruberryb4b8f532019-09-14 17:09:04 -0700972 elif fn == "masked_scatter":
Philip Meier1f74e082022-02-16 18:25:35 -0800973 return myfn(t1 < 0.5, full1d)
Mike Ruberryb4b8f532019-09-14 17:09:04 -0700974 elif fn == "masked_fill":
Philip Meier1f74e082022-02-16 18:25:35 -0800975 return myfn(t1 < 0.5, 1.0)
Mike Ruberryb4b8f532019-09-14 17:09:04 -0700976 elif fn in fns_3_args:
Philip Meier1f74e082022-02-16 18:25:35 -0800977 return myfn(1, t1, t2)
Peter Bellb0ac4252020-01-14 11:32:04 -0800978 elif fn in fns_value_kwarg:
Philip Meier1f74e082022-02-16 18:25:35 -0800979 return myfn(t1, t2, value=1)
Mike Ruberryb4b8f532019-09-14 17:09:04 -0700980 else:
Philip Meier1f74e082022-02-16 18:25:35 -0800981 return myfn(t1)
982
983 # test various orders
984 for first, second, third in [(large, small, small2), (small, large, small2),
985 (small2, small, large), (small2, large, small)]:
986 if first is None:
987 break # ignore last iter when small2 is None
988 method_expanded = getattr(expanded[first], fn)
989 method = getattr(first, fn)
990 r1 = tensorfn(method_expanded, expanded[second], expanded[third])
991 r2 = tensorfn(method, second, third)
Mike Ruberryb4b8f532019-09-14 17:09:04 -0700992 self.assertEqual(r1, r2)
993
Philip Meier1f74e082022-02-16 18:25:35 -0800994 # now for torch. versions of functions
995 if hasattr(torch, fn):
996 fntorch = getattr(torch, fn)
997 expanded = {large: large_expanded, small: small_expanded, small2: small2_expanded}
Mike Ruberryb4b8f532019-09-14 17:09:04 -0700998
Philip Meier1f74e082022-02-16 18:25:35 -0800999 def torchfn(t1, t2, t3):
1000 if fn == "lerp":
1001 return fntorch(t1, t2, 0.5)
1002 elif fn == "masked_select":
1003 return fntorch(t1, t2 < 0)
1004 elif fn == "masked_scatter":
1005 return fntorch(t1, t2 < 0.5, full1d)
1006 elif fn == "masked_fill":
1007 return fntorch(t1, t2 < 0.5, 1.0)
1008 elif fn in fns_3_args:
1009 return fntorch(t1, 1.0, t2, t3)
1010 elif fn in fns_value_kwarg:
1011 return fntorch(t1, t2, t3, value=1.0)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07001012 else:
Philip Meier1f74e082022-02-16 18:25:35 -08001013 return fntorch(t1, t2)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07001014
Philip Meier1f74e082022-02-16 18:25:35 -08001015 # test various orders
1016 for first, second, third in [(large, small, small2), (small, large, small2),
1017 (small2, small, large), (small2, large, small)]:
1018 if first is None:
1019 break # ignore last iter when small2 is None
1020 r1 = torchfn(expanded[first], expanded[second], expanded[third])
1021 r2 = torchfn(first, second, third)
1022 self.assertEqual(r1, r2)
1023
1024 # now for in place functions
1025 # in-place tensor is not broadcastable; test only guaranteed
1026 # to work by broadcasting other argument(s)
1027 if not hasattr(large_expanded, fn + "_"):
1028 return
1029
1030 # need to clone largeExpanded so we can reuse, since functions are in-place
1031 large_expanded_clone = large_expanded.clone()
1032
1033 def tensorfn_inplace(t0, t1, t2=None):
1034 t0_fn = getattr(t0, fn + "_")
1035 if fn == "lerp":
1036 return t0_fn(t1, 0.5)
1037 elif fn == "masked_scatter":
1038 return t0_fn(t1 < 0.5, full1d)
1039 elif fn == "masked_fill":
1040 return t0_fn(t1 < 0.5, 1.0)
1041 elif fn == "map":
1042 return t0_fn(t1, lambda x, y: x + y)
1043 elif fn == "map2":
1044 return t0_fn(t1, t2, lambda x, y, z: x + y + z)
1045 elif fn in fns_3_args:
1046 return t0_fn(1.0, t1, t2)
1047 elif fn in fns_value_kwarg:
1048 return t0_fn(t1, t2, value=1.0)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07001049 else:
Philip Meier1f74e082022-02-16 18:25:35 -08001050 return t0_fn(t1)
1051 # in-place pointwise operations don't actually work if the in-place
1052 # tensor is 0-strided (numpy has the same issue)
1053 if (0 not in large_expanded.stride() and 0 not in large_expanded_clone.stride()):
1054 r1 = tensorfn_inplace(large_expanded, small_expanded, small2_expanded)
1055 r2 = tensorfn_inplace(large_expanded_clone, small, small2)
1056 self.assertEqual(r1, r2)
1057
1058 def broadcastable(t0, t1, t2=None):
1059 try:
1060 t1.expand_as(t0)
1061 if t2 is not None:
1062 t2.expand_as(t0)
1063 except RuntimeError:
1064 return False
1065 return True
1066
1067 def _test_in_place_broadcastable(t0, t1, t2=None):
1068 if not broadcastable(t0, t1, t2):
1069 same_size = t0.numel() == t1.numel() and (t0.numel() == t2.numel() if t2 is not None else True)
1070 if not same_size:
1071 self.assertRaises(RuntimeError, lambda: tensorfn_inplace(t0, t1, t2))
1072 else:
1073 tensorfn_inplace(t0, t1, t2)
1074
1075 if fn not in fns_3_args and fn not in fns_value_kwarg:
1076 _test_in_place_broadcastable(small, large_expanded)
1077 _test_in_place_broadcastable(small, large)
1078 else:
1079 _test_in_place_broadcastable(small2, small_expanded, large_expanded)
1080 _test_in_place_broadcastable(small2, small, large)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07001081
Kurt Mohler1f044942021-04-22 23:33:03 -07001082 @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
1083 @onlyCUDA
1084 @wrapDeterministicFlagAPITest
1085 def test_cublas_config_nondeterministic_alert(self, device):
1086 test_cases = [
1087 # (function, (tensor sizes))
Shen Li10224432021-08-12 11:39:31 -07001088 ('mm', ((2, 2), (2, 2),)),
1089 ('mv', ((2, 2), (2,),)),
1090 ('bmm', ((1, 2, 2), (1, 2, 2),))]
Kurt Mohler1f044942021-04-22 23:33:03 -07001091
1092 test_configs = [
1093 # (CuBLAS workspace config, is deterministic)
Shen Li10224432021-08-12 11:39:31 -07001094 ('garbage', False),
Kurt Mohler1f044942021-04-22 23:33:03 -07001095 (None, False),
Shen Li10224432021-08-12 11:39:31 -07001096 (':4096:8', True),
1097 (':16:8', True)]
Kurt Mohler1f044942021-04-22 23:33:03 -07001098
Shen Li10224432021-08-12 11:39:31 -07001099 cublas_var_name = 'CUBLAS_WORKSPACE_CONFIG'
1100 is_cuda10_2_or_higher = (
1101 (torch.version.cuda is not None)
1102 and ([int(x) for x in torch.version.cuda.split(".")] >= [10, 2]))
Kurt Mohler1f044942021-04-22 23:33:03 -07001103
1104 def test_case_info(fn_name, config):
Shen Li10224432021-08-12 11:39:31 -07001105 return f'function "{fn_name}" with config "{"" if config is None else config}"'
Kurt Mohler1f044942021-04-22 23:33:03 -07001106
1107 # Create processes to test each combination of test cases and config settings
1108 processes = []
1109 for fn_name, arg_sizes in test_cases:
1110 for config, is_config_deterministic in test_configs:
1111 env = os.environ.copy()
1112 if config is None:
1113 if env.get(cublas_var_name) is not None:
1114 del env[cublas_var_name]
1115 else:
1116 env[cublas_var_name] = config
Shen Li10224432021-08-12 11:39:31 -07001117 should_throw_error = is_cuda10_2_or_higher and not is_config_deterministic
Kurt Mohler1f044942021-04-22 23:33:03 -07001118 script = f"""
1119import torch
1120torch.use_deterministic_algorithms(True)
1121fn = torch.{fn_name}
1122arg_sizes = {arg_sizes}
1123device = '{device}'
1124should_throw_error = {should_throw_error}
1125args = []
1126for arg_size in arg_sizes:
1127 args.append(torch.randn(*arg_size, device=device))
1128try:
1129 fn(*args)
1130except RuntimeError as e:
1131 if not should_throw_error:
1132 raise RuntimeError('Did not expect any error to be raised')
1133 elif 'Deterministic behavior was enabled with either' not in str(e):
1134 raise RuntimeError('Expected a CuBLAS nondeterministic error, but got a different error')
1135else:
1136 if should_throw_error:
1137 raise RuntimeError('Expected a CuBLAS nondeterministic error, but it was not raised')
1138
1139"""
1140 try:
1141 subprocess.check_output(
Shen Li10224432021-08-12 11:39:31 -07001142 [sys.executable, '-c', script],
Kurt Mohler1f044942021-04-22 23:33:03 -07001143 stderr=subprocess.STDOUT,
1144 # On Windows, opening the subprocess with the default CWD makes `import torch`
1145 # fail, so just set CWD to this script's directory
1146 cwd=os.path.dirname(os.path.realpath(__file__)),
Shen Li10224432021-08-12 11:39:31 -07001147 env=env)
Kurt Mohler1f044942021-04-22 23:33:03 -07001148 except subprocess.CalledProcessError as e:
Shen Li10224432021-08-12 11:39:31 -07001149 self.fail(msg=(
1150 f'Subprocess exception while attempting to run {test_case_info(fn_name, config)}:\n'
1151 + e.output.decode("utf-8")))
Kurt Mohler1f044942021-04-22 23:33:03 -07001152
Mike Ruberrye0d829a2022-01-24 01:28:07 -08001153 # FIXME: update OpInfos to support "nondeterministic samples" and port these tests
1154 # to that architecture
Kulin Sethe011a8e2022-05-13 18:28:53 +00001155 @skipIfMps
Kurt Mohler1f044942021-04-22 23:33:03 -07001156 def test_nondeterministic_alert_AvgPool3d(self, device):
1157 module = torch.nn.AvgPool3d(3)
1158 input = torch.randn(2, 3, 3, 3, requires_grad=True, device=device)
1159 res = module(input)
1160 grad = torch.ones_like(res)
1161
Kurt Mohler94f4e9a2021-10-21 12:50:21 -07001162 @expectedAlertNondeterministic('avg_pool3d_backward_cuda', ['cuda'])
Kurt Mohler1f044942021-04-22 23:33:03 -07001163 def backward_func(slf, device):
1164 res.backward(grad)
1165
1166 backward_func(self, device)
1167
Kulin Sethe011a8e2022-05-13 18:28:53 +00001168 @skipIfMps
Kurt Mohler1f044942021-04-22 23:33:03 -07001169 def test_nondeterministic_alert_AdaptiveAvgPool2d(self, device):
1170 module = torch.nn.AdaptiveAvgPool2d(3)
1171 input = torch.randn(2, 3, 3, requires_grad=True, device=device)
1172 res = module(input)
1173 grad = torch.ones_like(res)
1174
Kurt Mohler94f4e9a2021-10-21 12:50:21 -07001175 @expectedAlertNondeterministic('adaptive_avg_pool2d_backward_cuda', ['cuda'])
Kurt Mohler1f044942021-04-22 23:33:03 -07001176 def backward_func(slf, device):
1177 res.backward(grad)
1178
1179 backward_func(self, device)
1180
Kulin Sethe011a8e2022-05-13 18:28:53 +00001181 @skipIfMps
Kurt Mohler1f044942021-04-22 23:33:03 -07001182 def test_nondeterministic_alert_AdaptiveAvgPool3d(self, device):
1183 module = torch.nn.AdaptiveAvgPool3d(3)
1184 input = torch.randn(2, 3, 3, 3, requires_grad=True, device=device)
1185 res = module(input)
1186 grad = torch.ones_like(res)
1187
Kurt Mohler94f4e9a2021-10-21 12:50:21 -07001188 @expectedAlertNondeterministic('adaptive_avg_pool3d_backward_cuda', ['cuda'])
Kurt Mohler1f044942021-04-22 23:33:03 -07001189 def backward_func(slf, device):
1190 res.backward(grad)
1191
1192 backward_func(self, device)
1193
Kulin Sethe011a8e2022-05-13 18:28:53 +00001194 @skipIfMps
Kurt Mohler1f044942021-04-22 23:33:03 -07001195 def test_nondeterministic_alert_MaxPool3d(self, device):
1196 module = torch.nn.MaxPool3d(3)
1197 input = torch.randn(2, 3, 3, 3, requires_grad=True, device=device)
1198 res = module(input)
1199 grad = torch.ones_like(res)
1200
Kurt Mohler94f4e9a2021-10-21 12:50:21 -07001201 @expectedAlertNondeterministic('max_pool3d_with_indices_backward_cuda', ['cuda'])
Kurt Mohler1f044942021-04-22 23:33:03 -07001202 def backward_func(slf, device):
1203 res.backward(grad)
1204
1205 backward_func(self, device)
1206
Kulin Sethe011a8e2022-05-13 18:28:53 +00001207 @skipIfMps
Kurt Mohler1f044942021-04-22 23:33:03 -07001208 def test_nondeterministic_alert_AdaptiveMaxPool2d(self, device):
1209 module = torch.nn.AdaptiveMaxPool2d(3)
1210 input = torch.randn(2, 3, 3, requires_grad=True, device=device)
1211 res = module(input)
1212 grad = torch.ones_like(res)
1213
Kurt Mohler94f4e9a2021-10-21 12:50:21 -07001214 @expectedAlertNondeterministic('adaptive_max_pool2d_backward_cuda', ['cuda'])
Kurt Mohler1f044942021-04-22 23:33:03 -07001215 def backward_func(slf, device):
1216 res.backward(grad)
1217
1218 backward_func(self, device)
1219
Kulin Sethe011a8e2022-05-13 18:28:53 +00001220 @skipIfMps
Kurt Mohler1f044942021-04-22 23:33:03 -07001221 def test_nondeterministic_alert_FractionalMaxPool2d(self, device):
1222 module = torch.nn.FractionalMaxPool2d(2, output_ratio=0.5)
1223 input = torch.randn(2, 3, 3, 3, requires_grad=True, device=device)
1224 res = module(input)
1225 grad = torch.ones_like(res)
1226
Kurt Mohler94f4e9a2021-10-21 12:50:21 -07001227 @expectedAlertNondeterministic('fractional_max_pool2d_backward_cuda', ['cuda'])
Kurt Mohler1f044942021-04-22 23:33:03 -07001228 def backward_func(slf, device):
1229 res.backward(grad)
1230
1231 backward_func(self, device)
1232
Kulin Sethe011a8e2022-05-13 18:28:53 +00001233 @skipIfMps
Kurt Mohler1f044942021-04-22 23:33:03 -07001234 def test_nondeterministic_alert_FractionalMaxPool3d(self, device):
1235 module = torch.nn.FractionalMaxPool3d(2, output_ratio=0.5)
1236 input = torch.randn(2, 3, 3, 3, 3, requires_grad=True, device=device)
1237 res = module(input)
1238 grad = torch.ones_like(res)
1239
Kurt Mohler94f4e9a2021-10-21 12:50:21 -07001240 @expectedAlertNondeterministic('fractional_max_pool3d_backward_cuda', ['cuda'])
Kurt Mohler1f044942021-04-22 23:33:03 -07001241 def backward_func(slf, device):
1242 res.backward(grad)
1243
1244 backward_func(self, device)
1245
Kulin Sethe011a8e2022-05-13 18:28:53 +00001246 @skipIfMps
Kurt Mohler1f044942021-04-22 23:33:03 -07001247 def test_nondeterministic_alert_interpolate_linear(self, device):
1248 input = torch.randn(1, 2, 4, device=device, requires_grad=True)
1249 res = torch.nn.functional.interpolate(
Shen Li10224432021-08-12 11:39:31 -07001250 input,
1251 size=12,
1252 mode='linear',
1253 align_corners=False)
Kurt Mohler1f044942021-04-22 23:33:03 -07001254 grad = torch.ones_like(res)
1255
Kurt Mohler94f4e9a2021-10-21 12:50:21 -07001256 @expectedAlertNondeterministic('upsample_linear1d_backward_out_cuda', ['cuda'])
Kurt Mohler1f044942021-04-22 23:33:03 -07001257 def backward_func(slf, device):
1258 res.backward(grad)
1259
1260 backward_func(self, device)
1261
1262 def test_nondeterministic_alert_interpolate_bilinear(self, device):
1263 input = torch.randn(1, 2, 4, 4, device=device, requires_grad=True)
1264 res = torch.nn.functional.interpolate(
Shen Li10224432021-08-12 11:39:31 -07001265 input,
1266 size=12,
1267 mode='bilinear',
1268 align_corners=False)
Kurt Mohler1f044942021-04-22 23:33:03 -07001269 grad = torch.ones_like(res)
1270
Kurt Mohler94f4e9a2021-10-21 12:50:21 -07001271 @expectedAlertNondeterministic('upsample_bilinear2d_backward_out_cuda', ['cuda'])
Kurt Mohler1f044942021-04-22 23:33:03 -07001272 def backward_func(slf, device):
1273 res.backward(grad)
1274
1275 backward_func(self, device)
1276
Kulin Sethe011a8e2022-05-13 18:28:53 +00001277 @skipIfMps
Kurt Mohler1f044942021-04-22 23:33:03 -07001278 def test_nondeterministic_alert_interpolate_bicubic(self, device):
1279 input = torch.randn(1, 2, 4, 4, device=device, requires_grad=True)
1280 res = torch.nn.functional.interpolate(
Shen Li10224432021-08-12 11:39:31 -07001281 input,
1282 size=12,
1283 mode='bicubic',
1284 align_corners=False)
Kurt Mohler1f044942021-04-22 23:33:03 -07001285 grad = torch.ones_like(res)
1286
Kurt Mohler94f4e9a2021-10-21 12:50:21 -07001287 @expectedAlertNondeterministic('upsample_bicubic2d_backward_out_cuda', ['cuda'])
Kurt Mohler1f044942021-04-22 23:33:03 -07001288 def backward_func(slf, device):
1289 res.backward(grad)
1290
1291 backward_func(self, device)
1292
Kulin Sethe011a8e2022-05-13 18:28:53 +00001293 @skipIfMps
Kurt Mohler1f044942021-04-22 23:33:03 -07001294 def test_nondeterministic_alert_interpolate_trilinear(self, device):
1295 input = torch.randn(1, 2, 4, 4, 4, device=device, requires_grad=True)
1296 res = torch.nn.functional.interpolate(
Shen Li10224432021-08-12 11:39:31 -07001297 input,
1298 size=12,
1299 mode='trilinear',
1300 align_corners=False)
Kurt Mohler1f044942021-04-22 23:33:03 -07001301 grad = torch.ones_like(res)
1302
Kurt Mohler94f4e9a2021-10-21 12:50:21 -07001303 @expectedAlertNondeterministic('upsample_trilinear3d_backward_out_cuda', ['cuda'])
Kurt Mohler1f044942021-04-22 23:33:03 -07001304 def backward_func(slf, device):
1305 res.backward(grad)
1306
1307 backward_func(self, device)
1308
Kulin Sethe011a8e2022-05-13 18:28:53 +00001309 @skipIfMps
Kurt Mohler1f044942021-04-22 23:33:03 -07001310 def test_nondeterministic_alert_ReflectionPad1d(self, device):
1311 module = torch.nn.ReflectionPad1d((1, 2))
1312 input = torch.randn(2, 3, 8, device=device, requires_grad=True)
1313 res = module(input)
1314 grad = torch.ones_like(res)
1315
Kurt Mohler94f4e9a2021-10-21 12:50:21 -07001316 @expectedAlertNondeterministic('reflection_pad1d_backward_out_cuda', ['cuda'])
Kurt Mohler1f044942021-04-22 23:33:03 -07001317 def backward_func(slf, device):
1318 res.backward(grad)
1319
1320 backward_func(self, device)
1321
1322 def test_nondeterministic_alert_ReflectionPad2d(self, device):
1323 module = torch.nn.ReflectionPad2d((1, 2, 3, 4))
1324 input = torch.randn(2, 3, 8, 8, device=device, requires_grad=True)
1325 res = module(input)
1326 grad = torch.ones_like(res)
1327
Kurt Mohler94f4e9a2021-10-21 12:50:21 -07001328 @expectedAlertNondeterministic('reflection_pad2d_backward_cuda', ['cuda'])
Kurt Mohler1f044942021-04-22 23:33:03 -07001329 def backward_func(slf, device):
1330 res.backward(grad)
1331
1332 backward_func(self, device)
1333
Kulin Sethe011a8e2022-05-13 18:28:53 +00001334 @skipIfMps
Thomas J. Fanc16f8792021-06-21 10:51:49 -07001335 def test_nondeterministic_alert_ReflectionPad3d(self, device):
1336 module = torch.nn.ReflectionPad3d((1, 2, 3, 4, 5, 6))
1337 input = torch.randn(2, 3, 8, 8, 8, device=device, requires_grad=True)
1338 res = module(input)
1339 grad = torch.ones_like(res)
1340
Kurt Mohler94f4e9a2021-10-21 12:50:21 -07001341 @expectedAlertNondeterministic('reflection_pad3d_backward_out_cuda', ['cuda'])
Thomas J. Fanc16f8792021-06-21 10:51:49 -07001342 def backward_func(slf, device):
1343 res.backward(grad)
1344
1345 backward_func(self, device)
1346
Kulin Sethe011a8e2022-05-13 18:28:53 +00001347 @skipIfMps
Kurt Mohler1f044942021-04-22 23:33:03 -07001348 def test_nondeterministic_alert_ReplicationPad1d(self, device):
1349 module = torch.nn.ReplicationPad1d((1, 2))
1350 input = torch.randn(2, 3, 4, device=device, requires_grad=True)
1351 res = module(input)
1352 grad = torch.ones_like(res)
1353
Kurt Mohler94f4e9a2021-10-21 12:50:21 -07001354 @expectedAlertNondeterministic('replication_pad1d_backward_cuda', ['cuda'])
Kurt Mohler1f044942021-04-22 23:33:03 -07001355 def backward_func(slf, device):
1356 res.backward(grad)
1357
1358 backward_func(self, device)
1359
1360 def test_nondeterministic_alert_ReplicationPad2d(self, device):
1361 module = torch.nn.ReplicationPad2d((1, 2, 3, 4))
1362 input = torch.randn(2, 3, 4, 4, device=device, requires_grad=True)
1363 res = module(input)
1364 grad = torch.ones_like(res)
1365
Kurt Mohler94f4e9a2021-10-21 12:50:21 -07001366 @expectedAlertNondeterministic('replication_pad2d_backward_cuda', ['cuda'])
Kurt Mohler1f044942021-04-22 23:33:03 -07001367 def backward_func(slf, device):
1368 res.backward(grad)
1369
1370 backward_func(self, device)
1371
Kulin Sethe011a8e2022-05-13 18:28:53 +00001372 @skipIfMps
Kurt Mohler1f044942021-04-22 23:33:03 -07001373 def test_nondeterministic_alert_ReplicationPad3d(self, device):
1374 module = torch.nn.ReplicationPad3d((1, 2, 3, 4, 5, 6))
1375 input = torch.randn(2, 3, 4, 4, 4, device=device, requires_grad=True)
1376 res = module(input)
1377 grad = torch.ones_like(res)
1378
Kurt Mohler94f4e9a2021-10-21 12:50:21 -07001379 @expectedAlertNondeterministic('replication_pad3d_backward_cuda', ['cuda'])
Kurt Mohler1f044942021-04-22 23:33:03 -07001380 def backward_func(slf, device):
1381 res.backward(grad)
1382
1383 backward_func(self, device)
1384
1385 def test_nondeterministic_alert_NLLLoss(self, device):
1386 module = torch.nn.NLLLoss()
1387 input = torch.randn(2, 3, 5, 5, device=device)
1388 target = torch.rand(2, 5, 5, device=device).mul(3).floor().long()
1389
Kurt Mohlera2564892021-10-15 13:49:42 -07001390 @expectedAlertNondeterministic('nll_loss2d_forward_out_cuda_template', ['cuda'])
Kurt Mohler1f044942021-04-22 23:33:03 -07001391 def forward_func(slf, device):
1392 module(input, target)
1393
1394 forward_func(self, device)
1395
1396 def test_nondeterministic_alert_CTCLoss(self, device):
1397 module = torch.nn.CTCLoss()
1398 input = torch.randn(50, 3, 15, device=device, requires_grad=True)
1399 target = torch.randint(0, 14, (3, 30), device=device)
1400 input_lengths = [50, 50, 50]
1401 target_lengths = [30, 25, 20]
1402 res = module(input, target, input_lengths, target_lengths)
1403 grad = torch.ones_like(res)
1404
Kurt Mohler94f4e9a2021-10-21 12:50:21 -07001405 @expectedAlertNondeterministic('ctc_loss_backward_gpu', ['cuda'])
Kurt Mohler1f044942021-04-22 23:33:03 -07001406 def backward_func(slf, device):
Kurt Mohler94f4e9a2021-10-21 12:50:21 -07001407 res.backward(grad, retain_graph=True)
Kurt Mohler1f044942021-04-22 23:33:03 -07001408
1409 backward_func(self, device)
1410
1411 def test_nondeterministic_alert_EmbeddingBag_max(self, device):
1412 module = torch.nn.EmbeddingBag(
Shen Li10224432021-08-12 11:39:31 -07001413 4, 3, None, 2., False, 'max',
1414 _weight=torch.randn(4, 3, device=device, requires_grad=True))
Kurt Mohler1f044942021-04-22 23:33:03 -07001415 input = torch.randint(0, 3, (4, 3), device=device)
1416 res = module(input)
1417 grad = torch.ones_like(res)
1418
Kurt Mohler94f4e9a2021-10-21 12:50:21 -07001419 @expectedAlertNondeterministic('embedding_bag_backward_cuda_max', ['cuda'])
Kurt Mohler1f044942021-04-22 23:33:03 -07001420 def backward_func(slf, device):
1421 res.backward(grad)
1422
1423 backward_func(self, device)
1424
1425 def test_nondeterministic_alert_scatter_add(self, device):
1426 def test_func(op_call):
Yu Guo74c12da2021-05-23 21:34:55 -07001427 input = torch.randn(5, 4, device=device)
Kurt Mohler1f044942021-04-22 23:33:03 -07001428 dim = 0
Yu Guo74c12da2021-05-23 21:34:55 -07001429 index = torch.tensor([[3]], device=device)
1430 src = torch.tensor([[1.0]], device=device)
Kurt Mohler1f044942021-04-22 23:33:03 -07001431
Kurt Mohlera2564892021-10-15 13:49:42 -07001432 @expectedAlertNondeterministic('scatter_add_cuda_kernel', ['cuda'])
Kurt Mohler1f044942021-04-22 23:33:03 -07001433 def forward_func(slf, device):
1434 op_call(input, dim, index, src)
1435
1436 forward_func(self, device)
1437
1438 test_func(torch.Tensor.scatter_add_)
1439 test_func(torch.Tensor.scatter_add)
1440 test_func(torch.scatter_add)
1441
kshitij12345885a8e52021-11-01 09:21:20 -07001442 @expectedFailureMeta # expected a non-determinitic error, but it was not raised
1443 @onlyNativeDeviceTypes
Kurt Mohler1f044942021-04-22 23:33:03 -07001444 def test_nondeterministic_alert_put(self, device):
1445 def test_func(op_call):
1446 a = torch.randn(10, device=device)
1447 indices = torch.tensor([0, 0], device=device)
Shen Li10224432021-08-12 11:39:31 -07001448 values = torch.tensor([0., 1.], device=device)
Kurt Mohler1f044942021-04-22 23:33:03 -07001449
Shen Li10224432021-08-12 11:39:31 -07001450 @expectedAlertNondeterministic('put_')
Kurt Mohler1f044942021-04-22 23:33:03 -07001451 def forward_func(slf, device):
1452 op_call(a, indices, values, accumulate=False)
1453
1454 forward_func(self, device)
1455
1456 test_func(torch.Tensor.put)
1457 test_func(torch.Tensor.put_)
1458
1459 def test_nondeterministic_alert_put_accumulate(self, device):
1460 def test_func(op_call):
1461 a = torch.randn(10, device=device)
1462 indices = torch.tensor([0, 0], device=device)
Shen Li10224432021-08-12 11:39:31 -07001463 values = torch.tensor([0., 1.], device=device)
Kurt Mohler1f044942021-04-22 23:33:03 -07001464
Kurt Mohlera2564892021-10-15 13:49:42 -07001465 @expectedAlertNondeterministic('put_', ['cuda'])
Kurt Mohler1f044942021-04-22 23:33:03 -07001466 def forward_func(slf, device):
1467 op_call(a, indices, values, accumulate=True)
1468
1469 forward_func(self, device)
1470
1471 test_func(torch.Tensor.put)
1472 test_func(torch.Tensor.put_)
1473
Kulin Sethe011a8e2022-05-13 18:28:53 +00001474 @skipIfMps
Kurt Mohler1f044942021-04-22 23:33:03 -07001475 def test_nondeterministic_alert_histc(self, device):
1476 def test_func(op_call):
1477 a = torch.tensor([], device=device)
1478
Kurt Mohlera2564892021-10-15 13:49:42 -07001479 @expectedAlertNondeterministic('_histc_cuda', ['cuda'])
Kurt Mohler1f044942021-04-22 23:33:03 -07001480 def forward_func(slf, device):
1481 res = op_call(a, min=0, max=3)
1482
1483 forward_func(self, device)
1484
1485 test_func(torch.histc)
1486 test_func(torch.Tensor.histc)
1487
Kulin Sethe011a8e2022-05-13 18:28:53 +00001488 @skipIfMps
Kurt Mohler1f044942021-04-22 23:33:03 -07001489 def test_nondeterministic_alert_bincount(self, device):
1490 def test_func(op_call):
1491 a = torch.tensor([], device=device, dtype=torch.long)
1492
Kurt Mohlera2564892021-10-15 13:49:42 -07001493 @expectedAlertNondeterministic('_bincount_cuda', ['cuda'])
Kurt Mohler1f044942021-04-22 23:33:03 -07001494 def forward_func(slf, device):
1495 res = op_call(a)
1496
1497 forward_func(self, device)
1498
1499 test_func(torch.bincount)
1500 test_func(torch.Tensor.bincount)
Kurt Mohler5a45b1b2021-04-13 14:23:54 -07001501
Kurt Mohler2cb92042020-12-03 10:55:52 -08001502 # Ensures that kthvalue throws nondeterministic alerts in the correct cases
1503 @dtypes(torch.double)
Kurt Mohler1f044942021-04-22 23:33:03 -07001504 def test_nondeterministic_alert_kthvalue(self, device, dtype):
Kurt Mohlera2564892021-10-15 13:49:42 -07001505 @expectedAlertNondeterministic('kthvalue CUDA', ['cuda'])
Kurt Mohler2cb92042020-12-03 10:55:52 -08001506 def test_func(slf, device, call_type):
1507 S = 10
1508 k = 5
1509 a = torch.randn(S, device=device)
Shen Li10224432021-08-12 11:39:31 -07001510 if call_type == 'function':
Kurt Mohler2cb92042020-12-03 10:55:52 -08001511 torch.kthvalue(a, k)
Shen Li10224432021-08-12 11:39:31 -07001512 elif call_type == 'method':
Kurt Mohler2cb92042020-12-03 10:55:52 -08001513 a.kthvalue(k)
Shen Li10224432021-08-12 11:39:31 -07001514 elif call_type == 'out':
Kurt Mohler2cb92042020-12-03 10:55:52 -08001515 values = torch.empty_like(a)
1516 indices = torch.empty((), device=device, dtype=torch.long)
1517 torch.kthvalue(a, k, out=(values, indices))
1518 else:
1519 self.fail(f"'{call_type}' is not a valid call type")
1520
Shen Li10224432021-08-12 11:39:31 -07001521 test_func(self, device, 'function')
1522 test_func(self, device, 'method')
1523 test_func(self, device, 'out')
Kurt Mohler2cb92042020-12-03 10:55:52 -08001524
kshitij12345885a8e52021-11-01 09:21:20 -07001525 @onlyNativeDeviceTypes
Kurt Mohler1f044942021-04-22 23:33:03 -07001526 def test_nondeterministic_alert_gather(self, device):
1527 def test_func(op_call):
1528 a = torch.randn(3, 3, device=device, requires_grad=True)
1529 dim = 0
1530 index = torch.tensor([[0]], device=device)
1531 res = op_call(a, dim, index)
1532 grad = torch.ones_like(res)
1533
Kurt Mohler94f4e9a2021-10-21 12:50:21 -07001534 @expectedAlertNondeterministic('scatter_add_cuda_kernel', ['cuda'])
Kurt Mohler1f044942021-04-22 23:33:03 -07001535 def backward_func(slf, device):
1536 res.backward(grad)
1537
1538 backward_func(self, device)
1539
1540 test_func(torch.gather)
1541 test_func(torch.Tensor.gather)
1542
Kulin Sethe011a8e2022-05-13 18:28:53 +00001543 @skipIfMps
Kurt Mohler1f044942021-04-22 23:33:03 -07001544 def test_nondeterministic_alert_grid_sample_2d(self, device):
1545 input = torch.empty(1, 1, 2, 2, device=device, requires_grad=True)
1546 grid = torch.empty(1, 1, 1, 2, device=device)
1547 res = torch.nn.functional.grid_sample(input, grid, align_corners=False)
1548 grad = torch.ones_like(res)
1549
Kurt Mohler94f4e9a2021-10-21 12:50:21 -07001550 @expectedAlertNondeterministic('grid_sampler_2d_backward_cuda', ['cuda'])
Kurt Mohler1f044942021-04-22 23:33:03 -07001551 def backward_func(slf, device):
1552 res.backward(grad)
1553
1554 backward_func(self, device)
1555
Kulin Sethe011a8e2022-05-13 18:28:53 +00001556 @skipIfMps
Kurt Mohler1f044942021-04-22 23:33:03 -07001557 def test_nondeterministic_alert_grid_sample_3d(self, device):
1558 input = torch.empty(1, 1, 2, 2, 2, device=device, requires_grad=True)
1559 grid = torch.empty(1, 1, 1, 2, 3, device=device)
1560 res = torch.nn.functional.grid_sample(input, grid, align_corners=False)
1561 grad = torch.ones_like(res)
1562
Kurt Mohler94f4e9a2021-10-21 12:50:21 -07001563 @expectedAlertNondeterministic('grid_sampler_3d_backward_cuda', ['cuda'])
Kurt Mohler1f044942021-04-22 23:33:03 -07001564 def backward_func(slf, device):
1565 res.backward(grad)
1566
1567 backward_func(self, device)
1568
Nikita Karetnikov936a6502022-04-04 15:21:44 +00001569 def test_invalid_shapes_grid_sampler(self, device):
1570 make_arg = partial(
1571 make_tensor, device=device, dtype=torch.float64, requires_grad=True)
1572
1573 inputs = (
1574 # input, grid
1575 ((5, 5, 5, 5, 5,), (1, 1, 1, 4, 4,)), # 3d
1576 ((5, 5, 5, 5,), (1, 1, 4, 4,)), # 2d
1577 )
1578
1579 interpolation_mode = 0
1580 padding_mode = 0
1581 align_corners = True
1582
1583 err = "expected grid and input to have same batch size"
1584
1585 for input, grid in inputs:
1586 input = make_arg(input)
1587 grid = make_arg(grid, low=-1, high=1)
1588
1589 # Wrapper for the 2d, 3d, and cuDNN functions listed below.
1590 with self.assertRaisesRegex(RuntimeError, err):
1591 torch.grid_sampler(
1592 input, grid, interpolation_mode, padding_mode,
1593 align_corners)
1594
1595 # Expects 2d input.
1596 with self.assertRaisesRegex(RuntimeError, err):
1597 torch.grid_sampler_2d(
1598 input, grid, interpolation_mode, padding_mode,
1599 align_corners)
1600
1601 # Expects 3d input.
1602 with self.assertRaisesRegex(RuntimeError, err):
1603 torch.grid_sampler_3d(
1604 input, grid, interpolation_mode, padding_mode,
1605 align_corners)
1606
1607 # Expects 2d input.
1608 with self.assertRaisesRegex(RuntimeError, err):
1609 torch._grid_sampler_2d_cpu_fallback(
1610 input, grid, interpolation_mode, padding_mode,
1611 align_corners)
1612
1613 # Expects 2d input, on CUDA.
1614 # Doesn't work on CPU and ROCm.
1615 if device != 'cpu' and TEST_CUDNN and not TEST_WITH_ROCM:
1616 with self.assertRaisesRegex(RuntimeError, err):
1617 torch.cudnn_grid_sampler(input, grid)
1618
Mike Ruberryb4b8f532019-09-14 17:09:04 -07001619 def test_dist(self, device):
1620 def run_test(x, y):
1621 for p in [0, 1, 2, 3, 4, inf, -inf]:
1622 dist_xy = torch.dist(x, y, p)
1623 dist_xy_norm = torch.norm(x - y, p)
1624 self.assertEqual(dist_xy, dist_xy_norm)
1625
1626 run_test(torch.randn(5, device=device), torch.randn(5, device=device))
1627
1628 x = torch.zeros(3, device=device)
1629 y = torch.zeros(3, device=device)
Shen Li10224432021-08-12 11:39:31 -07001630 y[1] = 1.
Mike Ruberryb4b8f532019-09-14 17:09:04 -07001631 run_test(x, y)
1632
Kurt Mohler2cb92042020-12-03 10:55:52 -08001633 # Ensures that median throws nondeterministic alerts in the correct cases
1634 @dtypes(torch.double)
Kurt Mohler1f044942021-04-22 23:33:03 -07001635 def test_nondeterministic_alert_median(self, device, dtype):
Kurt Mohler2cb92042020-12-03 10:55:52 -08001636 def test_func(slf, device, call_type):
1637 S = 10
1638 a = torch.randn(S, device=device)
Shen Li10224432021-08-12 11:39:31 -07001639 if call_type == 'function':
Kurt Mohler2cb92042020-12-03 10:55:52 -08001640 torch.median(a)
Shen Li10224432021-08-12 11:39:31 -07001641 elif call_type == 'function with indices':
Kurt Mohler2cb92042020-12-03 10:55:52 -08001642 torch.median(a, 0)
Shen Li10224432021-08-12 11:39:31 -07001643 elif call_type == 'method':
Kurt Mohler2cb92042020-12-03 10:55:52 -08001644 a.median()
Shen Li10224432021-08-12 11:39:31 -07001645 elif call_type == 'method with indices':
Kurt Mohler2cb92042020-12-03 10:55:52 -08001646 a.median(0)
Shen Li10224432021-08-12 11:39:31 -07001647 elif call_type == 'out with indices':
Kurt Mohler2cb92042020-12-03 10:55:52 -08001648 result = torch.empty_like(a)
1649 indices = torch.empty((), dtype=torch.long, device=device)
1650 torch.median(a, 0, out=(result, indices))
1651 else:
1652 self.fail(f"'{call_type}' is not a valid call type")
1653
Kurt Mohlera2564892021-10-15 13:49:42 -07001654 @expectedAlertNondeterministic('median CUDA with indices output', ['cuda'])
Kurt Mohler2cb92042020-12-03 10:55:52 -08001655 def test_func_expect_error(slf, device, call_type):
1656 test_func(slf, device, call_type)
1657
Shen Li10224432021-08-12 11:39:31 -07001658 test_func(self, device, 'function')
1659 test_func_expect_error(self, device, 'function with indices')
1660 test_func(self, device, 'method')
1661 test_func_expect_error(self, device, 'method with indices')
1662 test_func_expect_error(self, device, 'out with indices')
Kurt Mohler2cb92042020-12-03 10:55:52 -08001663
Mike Ruberrye0d829a2022-01-24 01:28:07 -08001664 # FIXME: move to test_scatter_gather_ops
Shen Li10224432021-08-12 11:39:31 -07001665 def _test_gather_backward_one_dim(self, device, deterministic: bool = False) -> None:
Yu Guo8596ac12021-04-13 15:16:33 -07001666 with DeterministicGuard(deterministic):
1667 m = random.randint(2000, 3000)
1668 elems = random.randint(10 * m, 20 * m)
1669 dim = 0
1670 src = torch.randn(m, device=device, requires_grad=True)
1671 idx = torch.randint(m, (elems,), device=device)
1672 res = torch.gather(src, dim, idx)
1673 weight = torch.rand_like(res, device=device) * 10 ** 6
1674 res.backward(weight)
Peter Bell7843a5e2022-06-09 16:31:59 +01001675 assert src.grad is not None
Yu Guo8596ac12021-04-13 15:16:33 -07001676 grad = src.grad.detach().clone()
1677
Shen Li10224432021-08-12 11:39:31 -07001678 if torch.device(device).type == 'cuda':
Yu Guo8596ac12021-04-13 15:16:33 -07001679 for _ in range(2):
1680 src.grad.data.zero_()
1681 res = torch.gather(src, dim, idx)
1682 res.backward(weight)
1683 self.assertEqual(src.grad, grad, atol=0, rtol=0)
1684 else:
1685 expected = torch.zeros_like(src, device=device)
1686 for i in range(elems):
1687 expected[idx[i]] += weight[i]
1688 self.assertEqual(grad, expected, atol=0, rtol=0)
1689
Mike Ruberrye0d829a2022-01-24 01:28:07 -08001690 # FIXME: move to test_scatter_gather_ops
kshitij12345885a8e52021-11-01 09:21:20 -07001691 @onlyNativeDeviceTypes
Yu Guo8596ac12021-04-13 15:16:33 -07001692 def test_gather_backward_deterministic_path(self, device) -> None:
1693 self._test_gather_backward_one_dim(device, True)
1694
Mike Ruberrye0d829a2022-01-24 01:28:07 -08001695 # FIXME: move to test_scatter_gather_ops
Yu Guo8596ac12021-04-13 15:16:33 -07001696 @onlyCPU
1697 def test_gather_backward_one_dim(self, device) -> None:
1698 self._test_gather_backward_one_dim(device, False)
1699
Mike Ruberrye0d829a2022-01-24 01:28:07 -08001700 # FIXME: move to test_scatter_gather_ops
kshitij12345885a8e52021-11-01 09:21:20 -07001701 @onlyNativeDeviceTypes
Yu Guo74c12da2021-05-23 21:34:55 -07001702 def test_scatter_add_one_dim_deterministic(self, device) -> None:
1703 with DeterministicGuard(True):
1704 m = random.randint(20, 30)
1705 elems = random.randint(2000 * m, 3000 * m)
1706 dim = 0
1707 src = torch.randn(elems, device=device)
1708 idx = torch.randint(m, (elems,), device=device)
1709
1710 x = torch.zeros(m, device=device)
1711 res = x.scatter_add(dim, idx, src)
1712
1713 expected = torch.zeros(m, device=device)
1714 for i in range(elems):
1715 expected[idx[i]] += src[i]
1716
1717 self.assertEqual(res, expected, atol=0, rtol=0)
1718
Mike Ruberrye0d829a2022-01-24 01:28:07 -08001719 # FIXME: move to test_scatter_gather_ops
Emilio Castillo8dfff8b2022-01-07 09:16:42 -08001720 @onlyNativeDeviceTypes
1721 def test_scatter_zero_size_index(self, device) -> None:
1722 null_index = torch.zeros((0, 4), dtype=torch.int64)
1723 null_arr = torch.zeros((0, 4))
1724 original = torch.arange(4, dtype=torch.float32)
1725 result = original.scatter(0, null_index, null_arr)
1726 self.assertEqual(result, original, atol=0, rtol=0)
1727
Natalia Gimelsheind7836172021-07-30 09:10:47 -07001728 @onlyCUDA
1729 def test_sync_warning(self, device):
Shen Li10224432021-08-12 11:39:31 -07001730
Natalia Gimelsheind7836172021-07-30 09:10:47 -07001731 def _sync_raises_helper(f, level):
1732 with CudaSyncGuard(level):
1733 if level == 1:
1734 with self.assertWarnsRegex(UserWarning, "called a synchronizing "):
1735 f()
1736 elif level == 2:
Shen Li10224432021-08-12 11:39:31 -07001737 with self.assertRaisesRegex(RuntimeError, "called a synchronizing "):
Natalia Gimelsheind7836172021-07-30 09:10:47 -07001738 f()
1739
1740 def _no_sync_helper(f, level):
1741 with CudaSyncGuard(level):
1742 f()
1743
1744 def _ind_put_fn(x, ind, val):
1745 x[ind] = val
1746 return x
1747
1748 def _ind_get_fn(x, ind):
1749 return x[ind]
1750
1751 def _cond_fn(x):
1752 if x: # taking boolean value of a tensor synchronizes
1753 return x
1754 else:
1755 return 2 * x
1756
1757 # prepare inputs for subsequent ops
1758 size = 4
1759 x = torch.rand(size, device=device)
1760 y = torch.rand((), device=device)
1761 ind = torch.randint(size, (3,), device=device)
1762 ind_cpu = ind.cpu()
1763 repeats = torch.full((1,), 2, device=device)
1764 mask = torch.randint(2, (size,), device=device, dtype=bool)
Shen Li10224432021-08-12 11:39:31 -07001765 expect_no_sync = (lambda: _ind_put_fn(x, mask, 1.),
1766 lambda: _ind_put_fn(x, ind, y),
1767 lambda: _ind_get_fn(x, ind),
1768 lambda: torch.nn.functional.one_hot(ind, num_classes=size),
1769 lambda: torch.randperm(20000, device=device),
1770 lambda: torch.repeat_interleave(x, 2, output_size=2 * size),
1771 lambda: torch.repeat_interleave(x, repeats, output_size=2 * size))
1772 expect_sync = (lambda: _ind_put_fn(x, mask, y),
1773 lambda: _ind_put_fn(x, ind_cpu, y),
1774 lambda: _ind_get_fn(x, mask),
1775 lambda: _ind_get_fn(x, ind_cpu),
1776 lambda: x.nonzero(),
1777 lambda: _cond_fn(y),
1778 lambda: torch.nn.functional.one_hot(ind),
1779 lambda: torch.repeat_interleave(x, 2),
1780 lambda: torch.repeat_interleave(x, repeats))
Natalia Gimelsheind7836172021-07-30 09:10:47 -07001781 for f, level in product(expect_no_sync, (1, 2)):
1782 _no_sync_helper(f, level)
1783 for f, level in product(expect_sync, (1, 2)):
1784 _sync_raises_helper(f, level)
1785
Shen Li10224432021-08-12 11:39:31 -07001786
Nikita Shulgabfac65d2022-03-30 14:13:21 -07001787 @dtypes(*floating_types_and(torch.half, torch.bfloat16))
Kulin Sethe011a8e2022-05-13 18:28:53 +00001788 @skipIfMps
Pavel Belevichce6077d2020-04-29 08:03:04 -07001789 def test_log_normal(self, device, dtype):
1790 a = torch.tensor([10], dtype=dtype, device=device).log_normal_()
1791 self.assertEqual(a.dtype, dtype)
1792 self.assertEqual(a.size(), torch.Size([1]))
1793
Nikita Shulgabfac65d2022-03-30 14:13:21 -07001794 @dtypes(*all_types_and(torch.half, torch.bfloat16))
Kulin Sethe011a8e2022-05-13 18:28:53 +00001795 @skipIfMps
Pavel Belevich06168bf2020-04-29 08:03:04 -07001796 def test_geometric(self, device, dtype):
1797 a = torch.tensor([10], dtype=dtype, device=device).geometric_(0.5)
1798 self.assertEqual(a.dtype, dtype)
1799 self.assertEqual(a.size(), torch.Size([1]))
1800
Kulin Sethe011a8e2022-05-13 18:28:53 +00001801 @skipIfMps
Brandon Lind806b062021-04-19 09:05:56 -07001802 def test_repeat_interleave(self, device):
1803 y = torch.tensor([[1, 2], [3, 4]], device=device)
Serhat Yilmazb4f3a982021-05-25 00:30:01 -07001804 # exercise single argument function signature
1805 temp = y.repeat_interleave(2)
1806 self.assertEqual(torch.Size([8]), temp.size())
1807
Brandon Lind806b062021-04-19 09:05:56 -07001808 for dtype in [torch.int, torch.long]:
Serhat Yilmaz4ca46402021-05-22 20:52:26 -07001809 lengths = torch.tensor([1, 2], dtype=dtype, device=device)
1810 output_size = torch.sum(lengths)
Brandon Lind806b062021-04-19 09:05:56 -07001811 a = torch.repeat_interleave(
1812 y,
Serhat Yilmaz4ca46402021-05-22 20:52:26 -07001813 lengths,
Brandon Lind806b062021-04-19 09:05:56 -07001814 dim=0,
1815 )
1816 self.assertEqual(a.dtype, y.dtype)
1817 self.assertEqual(a.size(), torch.Size([3, 2]))
1818
Serhat Yilmaz4ca46402021-05-22 20:52:26 -07001819 a_with_output = torch.repeat_interleave(
1820 y,
1821 lengths,
1822 dim=0,
1823 output_size=output_size,
1824 )
1825 self.assertEqual(a_with_output.dtype, y.dtype)
1826 self.assertEqual(a_with_output.size(), torch.Size([3, 2]))
1827
Nikita Shulgabfac65d2022-03-30 14:13:21 -07001828 @dtypes(*floating_types())
1829 @dtypesIfCPU(*floating_types_and(torch.bfloat16))
1830 @dtypesIfCUDA(*floating_types_and(torch.half))
xueht-fnstfaf0a3b2020-06-07 07:16:27 -07001831 def test_bernoulli_p(self, device, dtype):
1832 for trivial_p in ([0, 1], [1, 0, 1, 1, 0, 1]):
1833 x = torch.tensor(trivial_p, dtype=dtype, device=device)
1834 self.assertEqual(x.bernoulli().tolist(), trivial_p)
1835
1836 def isBinary(t):
1837 return torch.ne(t, 0).mul_(torch.ne(t, 1)).sum().item() == 0
1838
1839 p = torch.rand(5, 5, dtype=dtype, device=device)
1840 self.assertTrue(isBinary(p.bernoulli()))
1841
1842 p = torch.rand(5, dtype=dtype, device=device).expand(5, 5)
1843 self.assertTrue(isBinary(p.bernoulli()))
1844
1845 p = torch.rand(5, 5, dtype=dtype, device=device)
1846 torch.bernoulli(torch.rand_like(p), out=p)
1847 self.assertTrue(isBinary(p))
1848
xueht-fnstfaf0a3b2020-06-07 07:16:27 -07001849 # RngUniform not implemented for Integral type in XLA test
Nikita Shulgabfac65d2022-03-30 14:13:21 -07001850 @dtypes(*floating_types())
1851 @dtypesIfCPU(*all_types_and(torch.bool))
1852 @dtypesIfCUDA(*all_types_and(torch.bool, torch.half))
xueht-fnstfaf0a3b2020-06-07 07:16:27 -07001853 def test_bernoulli_self(self, device, dtype):
Shen Li10224432021-08-12 11:39:31 -07001854
xueht-fnstfaf0a3b2020-06-07 07:16:27 -07001855 def isBinary(t):
1856 return torch.ne(t, 0).mul_(torch.ne(t, 1)).sum().item() == 0
1857
1858 t = torch.empty(10, 10, dtype=dtype, device=device)
1859
1860 t.fill_(2)
1861 t.bernoulli_(0.5)
1862 self.assertTrue(isBinary(t))
1863
Nikita Shulgabfac65d2022-03-30 14:13:21 -07001864 for p_dtype in floating_types_and(*[torch.half] if device.startswith('cuda') else []):
Xiong Wei51e341d2020-06-15 14:09:57 -07001865 p = torch.rand(10, dtype=p_dtype, device=device).expand(10, 10)
1866 t.fill_(2)
1867 t.bernoulli_(p)
1868 self.assertTrue(isBinary(t))
xueht-fnstfaf0a3b2020-06-07 07:16:27 -07001869
Xiong Wei51e341d2020-06-15 14:09:57 -07001870 t.fill_(2)
1871 torch.bernoulli(torch.rand_like(t, dtype=p_dtype), out=t)
1872 self.assertTrue(isBinary(t))
xueht-fnstfaf0a3b2020-06-07 07:16:27 -07001873
Xiong Wei51e341d2020-06-15 14:09:57 -07001874 t.fill_(2)
1875 t.bernoulli_(torch.rand_like(t, dtype=p_dtype))
1876 self.assertTrue(isBinary(t))
xueht-fnstfaf0a3b2020-06-07 07:16:27 -07001877
1878 @slowTest
Nikita Shulgabfac65d2022-03-30 14:13:21 -07001879 @dtypes(*floating_types())
1880 @dtypesIfCUDA(*floating_types_and(torch.half))
xueht-fnstfaf0a3b2020-06-07 07:16:27 -07001881 def test_bernoulli_edge_cases(self, device, dtype):
1882 # Need to draw a lot of samples to cover every random floating point number.
Shen Li10224432021-08-12 11:39:31 -07001883 a = torch.zeros(10000, 10000, dtype=dtype, device=device) # probability of drawing "1" is 0
xueht-fnstfaf0a3b2020-06-07 07:16:27 -07001884 num_ones = (torch.bernoulli(a) == 1).sum()
1885 self.assertEqual(num_ones, 0)
1886
Shen Li10224432021-08-12 11:39:31 -07001887 b = torch.ones(10000, 10000, dtype=dtype, device=device) # probability of drawing "1" is 1
xueht-fnstfaf0a3b2020-06-07 07:16:27 -07001888 num_zeros = (torch.bernoulli(b) == 0).sum()
1889 self.assertEqual(num_zeros, 0)
1890
Nikita Shulgabfac65d2022-03-30 14:13:21 -07001891 @dtypes(*floating_types_and(torch.half, torch.bfloat16))
Kulin Sethe011a8e2022-05-13 18:28:53 +00001892 @skipIfMps
Pavel Belevichec8517b2020-04-29 08:03:04 -07001893 def test_exponential(self, device, dtype):
1894 a = torch.tensor([10], dtype=dtype, device=device).exponential_(0.5)
1895 self.assertEqual(a.dtype, dtype)
1896 self.assertEqual(a.size(), torch.Size([1]))
1897
1898 # Tests extremal behavior
Shen Li10224432021-08-12 11:39:31 -07001899 tests = ((-0, float('inf')), (0, float('inf')), (float('inf'), 0))
Pavel Belevichec8517b2020-04-29 08:03:04 -07001900 for test in tests:
1901 t = torch.empty((1,), device=device, dtype=dtype).exponential_(test[0])
1902 self.assertTrue(t.item() == test[1])
1903
1904 # Tests that negative lambda fails
1905 with self.assertRaises(RuntimeError):
1906 torch.empty((1,), device=device, dtype=dtype).exponential_(-0.5)
1907
Natalia Gimelshein6aa51482021-03-10 00:33:06 -08001908 @onlyCUDA
kshitij12345d5d20962021-11-18 08:25:47 -08001909 @dtypes(torch.half, torch.float)
Natalia Gimelshein6aa51482021-03-10 00:33:06 -08001910 def test_exponential_no_zero(self, device, dtype):
1911 # naively, 0 in exponential can be generated with probability 2^-24
1912 # so we need more samples to check if it's not generated
1913 # instead of doing one
1914 # don't test CPU, that would be a long test
1915 x = torch.empty(50000000, device=device, dtype=dtype).exponential_()
1916 self.assertTrue(x.min() > 0)
1917
Heitor Schuerofff32f85e2021-06-30 12:29:55 -07001918 def _generate_correlation_tensors(self, device, dtype):
Philip Meier0973c5a2022-02-24 21:47:38 -08001919 yield make_tensor((0, 0), dtype=dtype, device=device)
1920 yield make_tensor((1, 0), dtype=dtype, device=device)
1921 yield make_tensor((0, 1), dtype=dtype, device=device)
1922 yield make_tensor((2,), dtype=dtype, device=device)
1923 yield make_tensor((2, 1), dtype=dtype, device=device)
1924 yield make_tensor((2, 2), dtype=dtype, device=device)
1925 yield make_tensor((2, 3), dtype=dtype, device=device)
1926 yield make_tensor((5, 10), dtype=dtype, device=device)
1927 yield make_tensor((5, 10), dtype=dtype, device=device, noncontiguous=True)
Heitor Schuerofff32f85e2021-06-30 12:29:55 -07001928 if dtype != torch.int:
1929 yield torch.tensor([0, -2, nan, 10.2, inf], dtype=dtype, device=device)
1930
kshitij12345885a8e52021-11-01 09:21:20 -07001931 @onlyNativeDeviceTypes
Heitor Schuerofff32f85e2021-06-30 12:29:55 -07001932 @dtypes(torch.int, torch.float, torch.cfloat)
1933 def test_corrcoef(self, device, dtype):
1934 for x in self._generate_correlation_tensors(device, dtype):
1935 res = torch.corrcoef(x)
1936 ref = np.corrcoef(x.cpu().numpy())
1937 self.assertEqual(res, ref, exact_dtype=False)
1938
1939 @dtypes(torch.int, torch.float, torch.cfloat)
Heitor Schueroffec9c03c2021-06-29 13:59:46 -07001940 def test_cov(self, device, dtype):
1941 def check(t, correction=1, fweights=None, aweights=None):
Shen Li10224432021-08-12 11:39:31 -07001942 res = torch.cov(t, correction=correction, fweights=fweights, aweights=aweights)
Heitor Schueroffec9c03c2021-06-29 13:59:46 -07001943 t = t.cpu().numpy()
1944 fweights = fweights.cpu().numpy() if fweights is not None else None
1945 aweights = aweights.cpu().numpy() if aweights is not None else None
Heitor Schuerofff32f85e2021-06-30 12:29:55 -07001946 ref = np.cov(t, ddof=correction, fweights=fweights, aweights=aweights)
1947 self.assertEqual(res, ref, atol=1e-05, rtol=1e-05, exact_dtype=False)
Heitor Schueroffec9c03c2021-06-29 13:59:46 -07001948
Heitor Schuerofff32f85e2021-06-30 12:29:55 -07001949 for x in self._generate_correlation_tensors(device, dtype):
1950 check(x)
1951 num_observations = x.numel() if x.ndim < 2 else x.size(1)
Heitor Schueroffec9c03c2021-06-29 13:59:46 -07001952 if num_observations > 0:
1953 fweights = torch.randint(1, 10, (num_observations,), device=device)
Philip Meier0973c5a2022-02-24 21:47:38 -08001954 aweights = make_tensor((num_observations,), dtype=torch.float, device=device, low=1)
Shen Li10224432021-08-12 11:39:31 -07001955 for correction, fw, aw in product([0, 1, 2], [None, fweights], [None, aweights]):
Heitor Schuerofff32f85e2021-06-30 12:29:55 -07001956 check(x, correction, fweights, aweights)
Heitor Schueroffec9c03c2021-06-29 13:59:46 -07001957
Pavel Belevich35beff02020-05-19 10:18:34 -07001958 @skipIfNoSciPy
Nikita Shulgabfac65d2022-03-30 14:13:21 -07001959 @dtypes(*floating_types_and(torch.half, torch.bfloat16))
Pavel Belevich35beff02020-05-19 10:18:34 -07001960 def test_uniform_kstest(self, device, dtype):
Pavel Belevich35beff02020-05-19 10:18:34 -07001961 from scipy import stats
1962 size = 1000
1963 for from_ in [-42, 0, 4.2]:
1964 for to_ in [-4.2, 0, 42]:
1965 if to_ > from_:
Shen Li10224432021-08-12 11:39:31 -07001966 t = torch.empty(size, dtype=dtype, device=device).uniform_(from_, to_)
1967 res = stats.kstest(t.cpu().to(torch.double), 'uniform', args=(from_, (to_ - from_)))
Pavel Belevich35beff02020-05-19 10:18:34 -07001968 self.assertTrue(res.statistic < 0.1)
1969
1970 @skipIfNoSciPy
Nikita Shulgabfac65d2022-03-30 14:13:21 -07001971 @dtypes(*floating_types_and(torch.half))
1972 @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
Pavel Belevich35beff02020-05-19 10:18:34 -07001973 def test_normal_kstest(self, device, dtype):
1974 from scipy import stats
1975 size = 1000
1976 for mean in [-10, 0, 50]:
1977 for std in [1, 5, 10]:
Shen Li10224432021-08-12 11:39:31 -07001978 t = torch.empty(size, dtype=dtype, device=device).normal_(mean=mean, std=std)
1979 res = stats.kstest(t.cpu().to(torch.double), 'norm', args=(mean, std))
Pavel Belevich35beff02020-05-19 10:18:34 -07001980 self.assertTrue(res.statistic < 0.1)
1981
Kulin Sethe011a8e2022-05-13 18:28:53 +00001982 @skipIfMps
Pavel Belevich35beff02020-05-19 10:18:34 -07001983 @skipIfNoSciPy
Nikita Shulgabfac65d2022-03-30 14:13:21 -07001984 @dtypes(*floating_types_and(torch.half, torch.bfloat16))
Pavel Belevich35beff02020-05-19 10:18:34 -07001985 def test_lognormal_kstest(self, device, dtype):
1986 from scipy import stats
1987 size = 1000
1988 for mean in [-3, 0, 7]:
1989 for std in [1, 5, 7]:
Shen Li10224432021-08-12 11:39:31 -07001990 t = torch.empty(size, dtype=dtype, device=device).log_normal_(mean=mean, std=std)
1991 res = stats.kstest(t.cpu().to(torch.double), 'lognorm', args=(std, 0, math.exp(mean)))
Pavel Belevich35beff02020-05-19 10:18:34 -07001992 if dtype == torch.half:
1993 self.assertTrue(res.statistic < 0.3)
1994 else:
1995 self.assertTrue(res.statistic < 0.1)
1996
Kulin Sethe011a8e2022-05-13 18:28:53 +00001997 @skipIfMps
Pavel Belevich35beff02020-05-19 10:18:34 -07001998 @skipIfNoSciPy
Nikita Shulgabfac65d2022-03-30 14:13:21 -07001999 @dtypes(*floating_types_and(torch.half, torch.bfloat16))
Pavel Belevich35beff02020-05-19 10:18:34 -07002000 def test_exponential_kstest(self, device, dtype):
2001 from scipy import stats
2002 size = 1000
2003 for lambd in [0.5, 1.0, 5.0]:
2004 t = torch.empty(size, dtype=dtype, device=device).exponential_(lambd=lambd)
Shen Li10224432021-08-12 11:39:31 -07002005 res = stats.kstest(t.cpu().to(torch.double), 'expon', args=(0, 1 / lambd,))
Pavel Belevich35beff02020-05-19 10:18:34 -07002006 self.assertTrue(res.statistic < 0.1)
2007
Kulin Sethe011a8e2022-05-13 18:28:53 +00002008 @skipIfMps
Pavel Belevich35beff02020-05-19 10:18:34 -07002009 @skipIfNoSciPy
Nikita Shulgabfac65d2022-03-30 14:13:21 -07002010 @dtypes(*floating_types_and(torch.half, torch.bfloat16))
Pavel Belevich35beff02020-05-19 10:18:34 -07002011 def test_cauchy_kstest(self, device, dtype):
2012 from scipy import stats
2013 size = 1000
2014 for median in [-10, 0, 50]:
2015 for sigma in [0.5, 1.0, 10.0]:
Shen Li10224432021-08-12 11:39:31 -07002016 t = torch.empty(size, dtype=dtype, device=device).cauchy_(median=median, sigma=sigma)
2017 res = stats.kstest(t.cpu().to(torch.double), 'cauchy', args=(median, sigma))
Pavel Belevich35beff02020-05-19 10:18:34 -07002018 self.assertTrue(res.statistic < 0.1)
2019
kshitij12345956faea2021-06-28 12:45:58 -07002020 @slowTest
2021 @onlyCUDA
2022 @dtypes(torch.bfloat16, torch.float32)
2023 def test_cauchy_no_inf(self, device, dtype):
2024 # torch.float16 will have `inf` because of its smaller range.
Shen Li10224432021-08-12 11:39:31 -07002025 for _ in range((2**16) * 2):
2026 x = torch.empty((2**16), dtype=dtype, device=device)
kshitij12345956faea2021-06-28 12:45:58 -07002027 x.cauchy_()
2028 self.assertFalse(x.isinf().sum())
2029
Kulin Sethe011a8e2022-05-13 18:28:53 +00002030 @skipIfMps
Pavel Belevich35beff02020-05-19 10:18:34 -07002031 @skipIfNoSciPy
Nikita Shulgabfac65d2022-03-30 14:13:21 -07002032 @dtypes(*all_types_and(torch.half, torch.bfloat16))
Pavel Belevich35beff02020-05-19 10:18:34 -07002033 def test_geometric_kstest(self, device, dtype):
2034 from scipy import stats
2035 size = 1000
2036 for p in [0.2, 0.5, 0.8]:
2037 t = torch.empty(size, dtype=dtype, device=device).geometric_(p=p)
2038 actual = np.histogram(t.cpu().to(torch.double), np.arange(1, 100))[0]
2039 expected = stats.geom(p).pmf(np.arange(1, 99)) * size
2040 res = stats.chisquare(actual, expected)
Mike Ruberry13120bf2020-05-27 06:28:05 -07002041 self.assertEqual(res.pvalue, 1.0, atol=0.1, rtol=0)
Pavel Belevich35beff02020-05-19 10:18:34 -07002042
Mike Ruberrye0d829a2022-01-24 01:28:07 -08002043 # FIXME: find test suite for pdist and cdist
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002044 def test_pairwise_distance_empty(self, device):
2045 shape = (2, 0)
2046 x = torch.randn(shape, device=device)
2047 y = torch.randn(shape, device=device)
2048
2049 self.assertEqual(torch.zeros(2, device=device), torch.pairwise_distance(x, y))
Shen Li10224432021-08-12 11:39:31 -07002050 self.assertEqual(torch.zeros((2, 1), device=device), torch.pairwise_distance(x, y, keepdim=True))
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002051
2052 shape = (0, 2)
2053 x = torch.randn(shape, device=device)
2054 y = torch.randn(shape, device=device)
2055 self.assertEqual(torch.zeros(0, device=device), torch.pairwise_distance(x, y))
Shen Li10224432021-08-12 11:39:31 -07002056 self.assertEqual(torch.zeros((0, 1), device=device), torch.pairwise_distance(x, y, keepdim=True))
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002057
2058 def test_pdist_empty(self, device):
2059 shape = (0, 2)
2060 x = torch.randn(shape, device=device)
2061 self.assertEqual(torch.empty(0, device=device), torch.pdist(x))
2062
2063 shape = (1, 2)
2064 x = torch.randn(shape, device=device)
2065 self.assertEqual(torch.empty(0, device=device), torch.pdist(x))
2066
2067 shape = (3, 0)
2068 x = torch.randn(shape, device=device)
2069 self.assertEqual(torch.zeros(3, device=device), torch.pdist(x))
2070
2071 def test_cdist_empty(self, device):
2072 x = torch.randn((0, 5), device=device)
2073 y = torch.randn((4, 5), device=device)
2074 self.assertEqual(torch.empty(0, 4, device=device), torch.cdist(x, y))
2075
2076 x = torch.randn((2, 5), device=device)
2077 y = torch.randn((0, 5), device=device)
2078 self.assertEqual(torch.empty(2, 0, device=device), torch.cdist(x, y))
2079
2080 x = torch.randn((2, 0), device=device)
2081 y = torch.randn((3, 0), device=device)
2082 self.assertEqual(torch.zeros(2, 3, device=device), torch.cdist(x, y))
2083
2084 x = torch.randn((2, 0), device=device)
2085 y = torch.randn((0, 0), device=device)
2086 self.assertEqual(torch.empty(2, 0, device=device), torch.cdist(x, y))
2087
ptrblck1e3664b2020-02-19 10:25:19 -08002088 def _brute_cdist(self, x, y, p=2):
2089 r1 = x.shape[-2]
2090 r2 = y.shape[-2]
2091 if r1 == 0 or r2 == 0:
2092 return torch.empty(r1, r2, device=x.device)
2093 return torch.norm(x[..., None, :] - y[..., None, :, :], p=p, dim=-1)
2094
Kulin Sethe011a8e2022-05-13 18:28:53 +00002095 @skipIfMps
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002096 def test_cdist_norm(self, device):
2097 for r1 in [3, 4, 5, 6]:
2098 for m in [2, 3, 4, 10]:
2099 for r2 in [4, 6, 7, 8]:
Shen Li10224432021-08-12 11:39:31 -07002100 for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002101 x = torch.randn(r1, m, device=device)
2102 y = torch.randn(r2, m, device=device)
Igor Fedan12dde7f2019-10-17 14:54:50 -07002103 if p == 2:
Shen Li10224432021-08-12 11:39:31 -07002104 for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
Igor Fedan12dde7f2019-10-17 14:54:50 -07002105 actual = torch.cdist(x, y, p=2, compute_mode=cm)
ptrblck1e3664b2020-02-19 10:25:19 -08002106 expected = self._brute_cdist(x, y, p=2)
JackCaoG46447042020-06-03 12:34:46 -07002107 self.assertEqual(expected, actual, rtol=0, atol=0.02)
Igor Fedan12dde7f2019-10-17 14:54:50 -07002108 else:
2109 actual = torch.cdist(x, y, p=p)
ptrblck1e3664b2020-02-19 10:25:19 -08002110 expected = self._brute_cdist(x, y, p=p)
JackCaoG46447042020-06-03 12:34:46 -07002111 self.assertEqual(expected, actual)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002112
Kulin Sethe011a8e2022-05-13 18:28:53 +00002113 @skipIfMps
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002114 def test_cdist_norm_batch(self, device):
2115 for r1 in [3, 4, 5, 6]:
2116 for m in [2, 3, 4, 10]:
2117 for r2 in [4, 6, 7, 8]:
Shen Li10224432021-08-12 11:39:31 -07002118 for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002119 x = torch.randn(2, 3, 6, r1, m, device=device)
2120 y = torch.randn(2, 3, 6, r2, m, device=device)
Igor Fedan12dde7f2019-10-17 14:54:50 -07002121 if p == 2:
Shen Li10224432021-08-12 11:39:31 -07002122 for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
Igor Fedan12dde7f2019-10-17 14:54:50 -07002123 actual = torch.cdist(x, y, p=2, compute_mode=cm)
ptrblck1e3664b2020-02-19 10:25:19 -08002124 expected = self._brute_cdist(x, y, p=2)
JackCaoG46447042020-06-03 12:34:46 -07002125 self.assertEqual(expected, actual, rtol=0, atol=0.02)
Igor Fedan12dde7f2019-10-17 14:54:50 -07002126 else:
2127 actual = torch.cdist(x, y, p=p)
ptrblck1e3664b2020-02-19 10:25:19 -08002128 expected = self._brute_cdist(x, y, p=p)
JackCaoG46447042020-06-03 12:34:46 -07002129 self.assertEqual(expected, actual)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002130
wanyu2018umac444203c2021-02-02 20:35:21 -08002131 @onlyCUDA
2132 def test_cdist_cuda_backward(self, device):
2133 for l1 in [1, 511, 513]:
2134 for l2 in [1, 511, 513]:
Shen Li10224432021-08-12 11:39:31 -07002135 for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
wanyu2018umac444203c2021-02-02 20:35:21 -08002136 x1 = torch.randn(4, l1, 32, device=device, requires_grad=True)
2137 x2 = x1.clone().detach_().requires_grad_()
2138 y1 = torch.randn(4, l2, 32, device=device, requires_grad=True)
2139 y2 = y1.clone().detach_().requires_grad_()
2140 if p == 2:
Shen Li10224432021-08-12 11:39:31 -07002141 for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
wanyu2018umac444203c2021-02-02 20:35:21 -08002142 z1 = torch.cdist(x1, y1, p=2, compute_mode=cm).mean()
2143 z2 = self._brute_cdist(x2, y2, p=2).mean()
2144 z1.backward()
2145 z2.backward()
2146 self.assertEqual(x1.grad, x2.grad, rtol=0, atol=0.001)
2147 self.assertEqual(y1.grad, y2.grad, rtol=0, atol=0.001)
2148 else:
2149 z1 = torch.cdist(x1, y1, p=p).mean()
2150 z2 = self._brute_cdist(x2, y2, p=p).mean()
2151 self.assertEqual(x1.grad, x2.grad, rtol=0, atol=0.001)
2152 self.assertEqual(y1.grad, y2.grad, rtol=0, atol=0.001)
2153
Xiang Gao23174ca2020-07-15 20:58:31 -07002154 @tf32_on_and_off(0.005)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002155 def test_cdist_large(self, device):
Shen Li10224432021-08-12 11:39:31 -07002156 for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
Igor Fedan12dde7f2019-10-17 14:54:50 -07002157 x = torch.randn(1000, 10, device=device)
2158 y = torch.randn(1000, 10, device=device)
2159 actual = torch.cdist(x, y, p=2, compute_mode=cm)
ptrblck1e3664b2020-02-19 10:25:19 -08002160 expected = self._brute_cdist(x, y, p=2)
JackCaoG46447042020-06-03 12:34:46 -07002161 self.assertEqual(expected, actual)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002162
Edward Yang74a06632020-03-03 11:44:37 -08002163 @slowTest
Xiang Gao23174ca2020-07-15 20:58:31 -07002164 @tf32_on_and_off(0.01)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002165 def test_cdist_large_batch(self, device):
Shen Li10224432021-08-12 11:39:31 -07002166 for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
Igor Fedan12dde7f2019-10-17 14:54:50 -07002167 x = torch.randn(4, 3, 1000, 10, device=device)
2168 y = torch.randn(4, 3, 1000, 10, device=device)
2169 actual = torch.cdist(x, y, p=2, compute_mode=cm)
ptrblck1e3664b2020-02-19 10:25:19 -08002170 expected = self._brute_cdist(x, y, p=2)
JackCaoG46447042020-06-03 12:34:46 -07002171 self.assertEqual(expected, actual)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002172
Xiang Gao23174ca2020-07-15 20:58:31 -07002173 @tf32_on_and_off(0.005)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002174 def test_cdist_non_contiguous(self, device):
Shen Li10224432021-08-12 11:39:31 -07002175 for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
lezcano09742152021-10-18 13:00:48 -07002176 x = torch.randn(5, 7, device=device).mT
2177 y = torch.randn(5, 3, device=device).mT
Igor Fedan12dde7f2019-10-17 14:54:50 -07002178 actual = torch.cdist(x, y, p=2, compute_mode=cm)
ptrblck1e3664b2020-02-19 10:25:19 -08002179 expected = self._brute_cdist(x, y, p=2)
Igor Fedan12dde7f2019-10-17 14:54:50 -07002180 self.assertFalse(x.is_contiguous())
2181 self.assertFalse(y.is_contiguous())
JackCaoG46447042020-06-03 12:34:46 -07002182 self.assertEqual(expected, actual)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002183
Igor Fedan12dde7f2019-10-17 14:54:50 -07002184 x = torch.randn(7, 5, device=device)
2185 y = torch.randn(5, 3, device=device).t()
2186 actual = torch.cdist(x, y, p=2, compute_mode=cm)
ptrblck1e3664b2020-02-19 10:25:19 -08002187 expected = self._brute_cdist(x, y, p=2)
Igor Fedan12dde7f2019-10-17 14:54:50 -07002188 self.assertTrue(x.is_contiguous())
2189 self.assertFalse(y.is_contiguous())
JackCaoG46447042020-06-03 12:34:46 -07002190 self.assertEqual(expected, actual)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002191
Igor Fedan12dde7f2019-10-17 14:54:50 -07002192 x = torch.randn(5, 7, device=device).t()
2193 y = torch.randn(3, 5, device=device)
2194 actual = torch.cdist(x, y, p=2, compute_mode=cm)
ptrblck1e3664b2020-02-19 10:25:19 -08002195 expected = self._brute_cdist(x, y, p=2)
Igor Fedan12dde7f2019-10-17 14:54:50 -07002196 self.assertFalse(x.is_contiguous())
2197 self.assertTrue(y.is_contiguous())
JackCaoG46447042020-06-03 12:34:46 -07002198 self.assertEqual(expected, actual)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002199
Xiang Gao23174ca2020-07-15 20:58:31 -07002200 @tf32_on_and_off()
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002201 def test_cdist_non_contiguous_batch(self, device):
Shen Li10224432021-08-12 11:39:31 -07002202 for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
lezcano09742152021-10-18 13:00:48 -07002203 x = torch.randn(4, 3, 2, 5, 7, device=device).mT
2204 y = torch.randn(4, 3, 2, 5, 3, device=device).mT
Igor Fedan12dde7f2019-10-17 14:54:50 -07002205 actual = torch.cdist(x, y, p=2, compute_mode=cm)
ptrblck1e3664b2020-02-19 10:25:19 -08002206 expected = self._brute_cdist(x, y, p=2)
Igor Fedan12dde7f2019-10-17 14:54:50 -07002207 self.assertFalse(x.is_contiguous())
2208 self.assertFalse(y.is_contiguous())
JackCaoG46447042020-06-03 12:34:46 -07002209 self.assertEqual(expected, actual)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002210
Igor Fedan12dde7f2019-10-17 14:54:50 -07002211 x = torch.randn(7, 2, 7, 5, device=device)
lezcano09742152021-10-18 13:00:48 -07002212 y = torch.randn(7, 2, 5, 3, device=device).mT
Igor Fedan12dde7f2019-10-17 14:54:50 -07002213 actual = torch.cdist(x, y, p=2, compute_mode=cm)
ptrblck1e3664b2020-02-19 10:25:19 -08002214 expected = self._brute_cdist(x, y, p=2)
Igor Fedan12dde7f2019-10-17 14:54:50 -07002215 self.assertTrue(x.is_contiguous())
2216 self.assertFalse(y.is_contiguous())
JackCaoG46447042020-06-03 12:34:46 -07002217 self.assertEqual(expected, actual)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002218
lezcano09742152021-10-18 13:00:48 -07002219 x = torch.randn(4, 5, 7, device=device).mT
Igor Fedan12dde7f2019-10-17 14:54:50 -07002220 y = torch.randn(4, 3, 5, device=device)
2221 actual = torch.cdist(x, y, p=2, compute_mode=cm)
ptrblck1e3664b2020-02-19 10:25:19 -08002222 expected = self._brute_cdist(x, y, p=2)
Igor Fedan12dde7f2019-10-17 14:54:50 -07002223 self.assertFalse(x.is_contiguous())
2224 self.assertTrue(y.is_contiguous())
JackCaoG46447042020-06-03 12:34:46 -07002225 self.assertEqual(expected, actual)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002226
soulitzer83e86122021-11-03 15:24:10 -07002227 # Maybe merge into OpInfo?
2228 def test_cdist_euclidean_large(self, device):
2229 def _test_euclidean_large_cdist(sizex, sizey=None):
2230 if sizey is None:
2231 sizey = sizex
2232 x = torch.randn(sizex, device=device, dtype=torch.float)
2233 y = torch.randn(sizey, device=device, dtype=torch.float)
2234 eps = 1e-6
2235 # to avoid extremum
2236 x = x - (((x - y) < eps).float() * 2 * eps)
2237 x.requires_grad = True
2238 y.requires_grad = True
2239 dist = torch.cdist(x, y, p=2)
2240 # Do a backward pass to check that it is valid for large
2241 # matrices
2242 loss = dist.sum()
2243 loss.backward()
2244
2245 _test_euclidean_large_cdist((2000, 5))
2246
2247 # Ensure that cdist backward with p<1 does not produce NaNs
Kulin Sethe011a8e2022-05-13 18:28:53 +00002248 @skipIfMps
soulitzer83e86122021-11-03 15:24:10 -07002249 def test_cdist_grad_p_lt_1_no_nan(self, device):
2250 for p in [0.99, 0.7, 0.5, 0.1, 0.01]:
2251 x = torch.randn(1, 2, device=device)
2252 y = x.clone().detach() + torch.tensor([[1., 0.]], device=device)
2253 x.requires_grad = True
2254 y.requires_grad = True
2255 result = torch.cdist(x, y, p=p)
2256 result.backward(torch.ones_like(result))
2257 self.assertFalse(torch.isnan(x.grad).any())
2258 self.assertFalse(torch.isnan(y.grad).any())
2259
2260 def test_cdist_same_inputs(self, device):
2261 # Test to detect issues in cdist gradient calculation
2262 # When the distances are 0
2263 sizex = (1, 27, 32)
2264 for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
2265 x = torch.randn(sizex, device=device, dtype=torch.float)
2266 dist_grad = torch.randn((1, 27, 27), device=device, dtype=torch.float)
2267 y = x.clone()
2268 eps = 1e-6
2269 x.requires_grad = True
2270 d = torch.cdist(x, y)
2271 d.backward(dist_grad)
2272 # Check that the backward passs does not contain invalid
2273 # values such as nan or inf
2274 assert torch.isfinite(x.grad).all()
2275
Kulin Sethe011a8e2022-05-13 18:28:53 +00002276 @skipIfMps
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002277 def test_cumsum(self, device):
2278 x = torch.rand(100, 100, device=device)
2279 res1 = torch.cumsum(x, 1)
Yukio Siraichi93bf0ae2021-04-11 15:43:54 -07002280 res2 = torch.tensor([]).to(device)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002281 torch.cumsum(x, 1, out=res2)
2282 self.assertEqual(res1, res2)
kiyosora008f8402020-11-19 11:17:25 -08002283 x.cumsum_(1)
2284 self.assertEqual(res1, x)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002285
Shen Li10224432021-08-12 11:39:31 -07002286 a = torch.tensor([[True, False, True],
2287 [False, False, False],
2288 [True, True, True]], device=device)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002289 b = a.byte()
2290 aRes = torch.cumsum(a, 0)
2291 bRes = torch.cumsum(b, 0)
2292 self.assertEqual(aRes, bRes)
Shen Li10224432021-08-12 11:39:31 -07002293 self.assertEqual(aRes, torch.tensor([[1, 0, 1],
2294 [1, 0, 1],
2295 [2, 1, 2]]))
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002296
2297 aRes = torch.cumsum(a, 1)
2298 bRes = torch.cumsum(b, 1)
2299 self.assertEqual(aRes, bRes)
Shen Li10224432021-08-12 11:39:31 -07002300 self.assertEqual(aRes, torch.tensor([[1, 1, 2],
2301 [0, 0, 0],
2302 [1, 2, 3]]))
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002303
leetanenbaum0b9cd412020-01-03 10:14:16 -08002304 # Check that cummulative sum over a zero length dimension doesn't crash on backprop.
2305 # Also check that cumsum over other dimensions in a tensor with a zero-length
2306 # dimensiuon also works
2307 # Also include a basic suite of similar tests for other bases cases.
2308 shapes = [[2, 0], [2, 1, 4], [0, 2, 3], [1], [5]]
2309 for shape in shapes:
2310 for dim in range(len(shape)):
2311 raw_tensor = torch.zeros(*shape, requires_grad=True)
2312 integrated = raw_tensor.cumsum(dim=dim)
2313 # Check that backward does not crash
2314 integrated.sum().backward()
2315 # Check that output maintained correct shape
2316 self.assertEqual(raw_tensor.shape, raw_tensor.grad.shape)
2317
2318 # Check a scalar example
Shen Li10224432021-08-12 11:39:31 -07002319 raw_tensor = torch.tensor(3., requires_grad=True)
leetanenbaum0b9cd412020-01-03 10:14:16 -08002320 integrated = raw_tensor.cumsum(dim=-1)
xiaobing.zhang4d203c62020-02-25 13:01:13 -08002321 self.assertEqual(raw_tensor, integrated)
leetanenbaum0b9cd412020-01-03 10:14:16 -08002322 # Check that backward does not crash
2323 integrated.sum().backward()
2324 # Check that output maintained correct shape
2325 self.assertEqual(raw_tensor.shape, raw_tensor.grad.shape)
2326
Kulin Sethe011a8e2022-05-13 18:28:53 +00002327 @skipIfMps
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002328 def test_cumprod(self, device):
2329 x = torch.rand(100, 100, device=device)
2330 res1 = torch.cumprod(x, 1)
Yukio Siraichi93bf0ae2021-04-11 15:43:54 -07002331 res2 = torch.tensor([]).to(device)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002332 torch.cumprod(x, 1, out=res2)
2333 self.assertEqual(res1, res2)
kiyosora008f8402020-11-19 11:17:25 -08002334 x.cumprod_(1)
2335 self.assertEqual(res1, x)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002336
Shen Li10224432021-08-12 11:39:31 -07002337 a = torch.tensor([[True, False, True],
2338 [False, False, False],
2339 [True, True, True]], dtype=torch.bool, device=device)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002340 b = a.byte()
2341 aRes = torch.cumprod(a, 0)
2342 bRes = torch.cumprod(b, 0)
2343 self.assertEqual(aRes, bRes)
Shen Li10224432021-08-12 11:39:31 -07002344 self.assertEqual(aRes, torch.tensor([[1, 0, 1],
2345 [0, 0, 0],
2346 [0, 0, 0]]))
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002347
2348 aRes = torch.cumprod(a, 1)
2349 bRes = torch.cumprod(b, 1)
2350 self.assertEqual(aRes, bRes)
Shen Li10224432021-08-12 11:39:31 -07002351 self.assertEqual(aRes, torch.tensor([[1, 0, 0],
2352 [0, 0, 0],
2353 [1, 1, 1]]))
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002354
leetanenbaum5988d362020-01-13 09:46:25 -08002355 # Check that cummulative prod over a zero length dimension doesn't crash on backprop.
2356 # Also check that cumprod over other dimensions in a tensor with a zero-length
2357 # dimensiuon also works
2358 # Also include a basic suite of similar tests for other bases cases.
2359 shapes = [[2, 0], [2, 1, 4], [0, 2, 3], [1], [5]]
2360 for shape in shapes:
2361 for dim in range(len(shape)):
2362 raw_tensor = torch.zeros(*shape, requires_grad=True)
2363 integrated = raw_tensor.cumprod(dim=dim)
2364 # Check that backward does not crash
2365 integrated.sum().backward()
2366 # Check that output maintained correct shape
2367 self.assertEqual(raw_tensor.shape, raw_tensor.grad.shape)
2368
2369 # Check a scalar example
Shen Li10224432021-08-12 11:39:31 -07002370 raw_tensor = torch.tensor(3., requires_grad=True)
leetanenbaum5988d362020-01-13 09:46:25 -08002371 integrated = raw_tensor.cumprod(dim=-1)
xiaobing.zhang4d203c62020-02-25 13:01:13 -08002372 self.assertEqual(raw_tensor, integrated)
leetanenbaum5988d362020-01-13 09:46:25 -08002373 # Check that backward does not crash
2374 integrated.sum().backward()
2375 # Check that output maintained correct shape
2376 self.assertEqual(raw_tensor.shape, raw_tensor.grad.shape)
2377
Kulin Sethe011a8e2022-05-13 18:28:53 +00002378 @skipIfMps
anjali4115b815d92020-01-17 10:45:36 -08002379 def test_cummax_cummin(self, device):
anjali411da015c72020-02-18 14:10:04 -08002380 def test_ops(op, string_of_function_name, expected_output1, expected_output2):
anjali4115b815d92020-01-17 10:45:36 -08002381 x = torch.rand(100, 100, device=device)
2382 out1 = op(x, 1)
anjali411da015c72020-02-18 14:10:04 -08002383 res2 = torch.empty(0, device=device)
2384 indices2 = torch.empty(0, dtype=torch.int64, device=device)
anjali4115b815d92020-01-17 10:45:36 -08002385 op(x, 1, out=(res2, indices2))
2386 self.assertEqual(out1[0], res2)
2387 self.assertEqual(out1[1], indices2)
anjali4118dc67a02020-01-14 16:36:56 -08002388
Shen Li10224432021-08-12 11:39:31 -07002389 a = torch.tensor([[True, False, True],
2390 [False, False, False],
2391 [True, True, True]], dtype=torch.bool, device=device)
anjali4115b815d92020-01-17 10:45:36 -08002392 b = a.byte()
2393 aRes = op(a, 0)
2394 bRes = op(b, 0)
Edward Yangba1bd412020-03-03 14:33:40 -08002395 self.assertEqual(aRes[0], bRes[0].bool())
2396 self.assertEqual(aRes[0], expected_output1.bool())
anjali411da015c72020-02-18 14:10:04 -08002397
2398 # test inf and nan input
2399 x = torch.tensor([4, inf, 1.5, -inf, 0, nan, 1])
2400 xRes = op(x, 0)[0]
Mike Ruberry9cfc10d2020-05-15 16:22:03 -07002401 self.assertEqual(xRes, expected_output2)
anjali4118dc67a02020-01-14 16:36:56 -08002402
anjali4115b815d92020-01-17 10:45:36 -08002403 # op shouldn't support values, indices with a dtype, device type or layout
2404 # different from that of input tensor
2405 t = torch.randn(10)
anjali411da015c72020-02-18 14:10:04 -08002406 values = torch.empty(0, dtype=torch.int16)
2407 indices = torch.empty(0, dtype=torch.int64)
anjali4115b815d92020-01-17 10:45:36 -08002408 with self.assertRaisesRegex(
Shen Li10224432021-08-12 11:39:31 -07002409 RuntimeError,
2410 'expected scalar_type Float but found Short'):
anjali4115b815d92020-01-17 10:45:36 -08002411 op(t, 0, out=(values, indices))
anjali4118dc67a02020-01-14 16:36:56 -08002412
anjali4115b815d92020-01-17 10:45:36 -08002413 # Check that op over a zero length dimension doesn't crash on backprop.
2414 # Also check that op over other dimensions in a tensor with a zero-length
2415 # dimension also works
2416 # Also include a basic suite of similar tests for other bases cases.
2417 shapes = [[2, 0], [2, 1, 4], [0, 2, 3], [1], [5]]
2418 for shape in shapes:
2419 for dim in range(len(shape)):
2420 raw_tensor = torch.zeros(*shape, requires_grad=True)
2421 integrated = getattr(raw_tensor, string_of_function_name)(dim=dim)
2422 # Check that backward does not crash
2423 integrated[0].sum().backward()
2424 # Check that output maintained correct shape
2425 self.assertEqual(raw_tensor.shape, raw_tensor.grad.shape)
anjali4118dc67a02020-01-14 16:36:56 -08002426
anjali4115b815d92020-01-17 10:45:36 -08002427 # Check a scalar example
Shen Li10224432021-08-12 11:39:31 -07002428 raw_tensor = torch.tensor(3., requires_grad=True)
anjali4115b815d92020-01-17 10:45:36 -08002429 integrated = getattr(raw_tensor, string_of_function_name)(dim=-1)
2430 # Check that backward does not crash
2431 integrated[0].sum().backward()
2432 # Check that output maintained correct shape
2433 self.assertEqual(raw_tensor.shape, raw_tensor.grad.shape)
2434
anjali411da015c72020-02-18 14:10:04 -08002435 expected_out = torch.tensor([4, inf, inf, inf, inf, nan, nan])
Shen Li10224432021-08-12 11:39:31 -07002436 test_ops(torch.cummax, "cummax", torch.tensor([[1, 0, 1],
2437 [1, 0, 1],
2438 [1, 1, 1]]), expected_out)
anjali411da015c72020-02-18 14:10:04 -08002439
2440 expected_out = torch.tensor([4, 4, 1.5, -inf, -inf, nan, nan])
Shen Li10224432021-08-12 11:39:31 -07002441 test_ops(torch.cummin, "cummin", torch.tensor([[1, 0, 1],
2442 [0, 0, 0],
2443 [0, 0, 0]]), expected_out)
anjali4118dc67a02020-01-14 16:36:56 -08002444
Kulin Sethe011a8e2022-05-13 18:28:53 +00002445 @skipIfMps
kshitij1234534877442020-05-21 09:09:41 -07002446 def test_logcumsumexp(self, device):
2447 def logcumsumexp(a, axis):
2448 return torch.cumsum(a.exp(), axis=axis).log_()
2449
Natalia Gimelsheine5e54ad2021-03-02 10:36:37 -08002450 axis = -1
kshitij1234534877442020-05-21 09:09:41 -07002451 a = torch.randn(100, 100, device=device)
2452
Natalia Gimelsheine5e54ad2021-03-02 10:36:37 -08002453 actual = a.logcumsumexp(axis)
kshitij1234534877442020-05-21 09:09:41 -07002454 expected = logcumsumexp(a, axis)
2455 self.assertEqual(a.dtype, actual.dtype)
2456 self.assertEqual(expected.shape, actual.shape)
JackCaoG46447042020-06-03 12:34:46 -07002457 self.assertEqual(expected, actual)
kshitij1234534877442020-05-21 09:09:41 -07002458
Natalia Gimelsheine5e54ad2021-03-02 10:36:37 -08002459 # check -inf and nan handling
Shen Li10224432021-08-12 11:39:31 -07002460 x = torch.tensor([-float('inf'), -float('inf'), 1.0, 1.0, float('inf'),
2461 float('inf'), float('nan'), 1.0, 1.0], device=device)
Natalia Gimelsheine5e54ad2021-03-02 10:36:37 -08002462 x2d = x.unsqueeze(0).expand(2, -1)
2463
2464 for inp in (x, x2d):
2465 actual = inp.logcumsumexp(axis)
2466 expected = logcumsumexp(inp, axis)
2467 self.assertEqual(expected, actual)
2468
kshitij1234534877442020-05-21 09:09:41 -07002469 # Check that out is actually inplace
2470 b = torch.randn(5, 2, device=device)
2471 inplace_out = torch.zeros(5, 2, device=device)
2472
2473 expected = logcumsumexp(b, axis)
2474 torch.logcumsumexp(b, axis=axis, out=inplace_out)
2475
JackCaoG46447042020-06-03 12:34:46 -07002476 self.assertEqual(inplace_out, expected)
kshitij1234534877442020-05-21 09:09:41 -07002477
2478 # Check input and inplace_output type mismatch
2479 b = torch.randn(5, 2, device=device, dtype=torch.float64)
2480 inplace_out = torch.zeros(5, 2, device=device, dtype=torch.float32)
2481 with self.assertRaisesRegex(
Shen Li10224432021-08-12 11:39:31 -07002482 RuntimeError,
2483 'expected scalar_type Double but found Float'):
kshitij1234534877442020-05-21 09:09:41 -07002484 torch.logcumsumexp(b, axis, out=inplace_out)
2485
Jeffrey Wanb18eeaa2021-02-02 20:20:15 -08002486 def _test_diff_numpy(self, t, dims=None):
2487 # Helper for test_diff to compare with NumPy reference implementation
2488 def to_np(t):
2489 if t.dtype == torch.bfloat16:
2490 return t.to(dtype=torch.float, device="cpu").numpy()
2491 else:
2492 return t.cpu().numpy()
2493
2494 for dim in dims if dims else range(t.dim()):
2495 prepend = t.narrow(dim, 0, 1)
2496 append = t.narrow(dim, 0, 1)
2497 np_t = to_np(t)
2498
Mikayla Gawareckicac3cd12021-11-17 09:10:37 -08002499 # test when no prepend and append
2500 for n in range(t.size(dim)):
2501 actual = torch.diff(t, dim=dim, n=n)
2502 expected = torch.from_numpy(np.diff(np_t, axis=dim, n=n))
2503 self.assertEqual(actual, expected.to(t.dtype))
2504
Jeffrey Wanb18eeaa2021-02-02 20:20:15 -08002505 # test when prepend and append's size along dim is 1
Mikayla Gawareckicac3cd12021-11-17 09:10:37 -08002506 for n in range(1, t.size(dim) + 4):
2507 actual = torch.diff(t, dim=dim, n=n, prepend=prepend, append=append)
2508 expected = torch.from_numpy(np.diff(np_t, axis=dim, n=n, prepend=to_np(prepend), append=to_np(append)))
2509 self.assertEqual(actual, expected.to(t.dtype))
Jeffrey Wanb18eeaa2021-02-02 20:20:15 -08002510
2511 # test when prepend and append's size along dim != 1
Mikayla Gawareckicac3cd12021-11-17 09:10:37 -08002512 for n in range(1, t.size(dim) * 3):
2513 actual = torch.diff(t, dim=dim, n=n, prepend=t, append=t)
2514 expected = torch.from_numpy(np.diff(np_t, axis=dim, n=n, prepend=np_t, append=np_t))
2515 self.assertEqual(actual, expected.to(t.dtype))
Jeffrey Wanb18eeaa2021-02-02 20:20:15 -08002516
2517 # All tensors appear contiguous on XLA
kshitij12345885a8e52021-11-01 09:21:20 -07002518 @onlyNativeDeviceTypes
Nikita Shulgabfac65d2022-03-30 14:13:21 -07002519 @dtypes(*all_types_and_complex_and(torch.half, torch.bool))
Jeffrey Wanb18eeaa2021-02-02 20:20:15 -08002520 def test_diff_noncontig(self, device, dtype):
Shen Li10224432021-08-12 11:39:31 -07002521 shapes = (
2522 (1,),
2523 (1, 5),
2524 (3, 5),
2525 (1, 5, 1),
2526 (2, 3, 5))
Jeffrey Wanb18eeaa2021-02-02 20:20:15 -08002527
2528 for shape in shapes:
Philip Meier0973c5a2022-02-24 21:47:38 -08002529 contig = make_tensor(shape, dtype=dtype, device=device, low=-9, high=9)
Jeffrey Wanb18eeaa2021-02-02 20:20:15 -08002530
2531 non_contig = torch.empty(shape + (2, 2), device=device, dtype=dtype)[..., 0]
2532 non_contig = non_contig.select(-1, -1)
2533 non_contig.copy_(contig)
2534 self.assertTrue(not non_contig.is_contiguous() or shape == (1,))
2535
2536 self._test_diff_numpy(non_contig)
2537
2538 # RngNormal not implemented for type f16 for XLA
Nikita Shulgabfac65d2022-03-30 14:13:21 -07002539 @dtypes(*all_types_and_complex_and(torch.bool))
2540 @dtypesIfCPU(*all_types_and_complex_and(torch.half, torch.bool))
2541 @dtypesIfCUDA(*all_types_and_complex_and(torch.half, torch.bool))
Jeffrey Wanb18eeaa2021-02-02 20:20:15 -08002542 def test_diff(self, device, dtype):
Shen Li10224432021-08-12 11:39:31 -07002543 shapes = (
2544 (1,),
2545 (1, 5),
2546 (3, 5),
2547 (1, 5, 1),
2548 (2, 3, 5))
Jeffrey Wanb18eeaa2021-02-02 20:20:15 -08002549
2550 for shape in shapes:
Philip Meier0973c5a2022-02-24 21:47:38 -08002551 contig = make_tensor(shape, dtype=dtype, device=device, low=-9, high=9)
Jeffrey Wanb18eeaa2021-02-02 20:20:15 -08002552 self._test_diff_numpy(contig)
2553
2554 t = torch.ones(2, 3)
2555
2556 with self.assertRaisesRegex(
Shen Li10224432021-08-12 11:39:31 -07002557 RuntimeError, 'diff expects prepend or append to be the same dimension as input'):
Jeffrey Wanb18eeaa2021-02-02 20:20:15 -08002558 invalid_prepend = torch.tensor([1, 2, 3], device=device, dtype=dtype)
2559 t.diff(dim=0, prepend=invalid_prepend)
2560
2561 with self.assertRaisesRegex(
Shen Li10224432021-08-12 11:39:31 -07002562 RuntimeError, 'diff expects the shape of tensor to prepend or append to match that of input'):
Jeffrey Wanb18eeaa2021-02-02 20:20:15 -08002563 invalid_prepend = torch.tensor([[0, 1]], device=device, dtype=dtype)
2564 t.diff(dim=0, prepend=invalid_prepend)
2565
Shen Li10224432021-08-12 11:39:31 -07002566 with self.assertRaisesRegex(
Shen Li10224432021-08-12 11:39:31 -07002567 RuntimeError, 'diff expects input to be at least one-dimensional'):
Jeffrey Wanb18eeaa2021-02-02 20:20:15 -08002568 scalar = torch.tensor(2, device=device, dtype=dtype)
2569 torch.diff(scalar)
2570
Ilqar Ramazanli8b816e92021-05-11 18:51:07 -07002571 # if the given input arg is not a list, it returns a list of single element: [arg]
2572 def _wrap_to_list(self, input_array):
2573 return input_array if isinstance(input_array, list) else [input_array]
2574
2575 # To ensure inf, -inf, and nan values do not cause divergence between Numpy and PyTorch.
2576 # There are two types of possible divergence:
2577 # 1. When we compute a,b both real numbers and has very small absolute values (i.e. very near to 0.0)
2578 # then, result of a/b be inf, -inf and nan, and this cause divergence.
2579 # 2. When we are dividing complex numbers by zero. For example, when a = torch.tensor(3+5j) we have
2580 # a/0 to be equal to nan + nan*j in PyTorch and inf + inf*j in Numpy.
2581 def _inf_nan_preprocess(self, actual, expected):
2582 for i in range(len(expected)):
2583 expected[i] = np.nan_to_num(expected[i], nan=nan, posinf=nan, neginf=nan)
2584 # nan_to_num is not defined for complex tensors in PyTorch.
Shen Li10224432021-08-12 11:39:31 -07002585 if actual[i].dtype == torch.complex64 :
2586 actual[i].real = torch.nan_to_num(actual[i].real, nan=nan, posinf=nan, neginf=nan)
2587 actual[i].imag = torch.nan_to_num(actual[i].imag, nan=nan, posinf=nan, neginf=nan)
Ilqar Ramazanli8b816e92021-05-11 18:51:07 -07002588 else:
2589 actual[i] = torch.nan_to_num(actual[i], nan=nan, posinf=nan, neginf=nan)
2590
2591 return actual, expected
2592
kshitij12345885a8e52021-11-01 09:21:20 -07002593 @onlyNativeDeviceTypes
Ilqar Ramazanli8b816e92021-05-11 18:51:07 -07002594 @dtypes(torch.long, torch.float32, torch.complex64)
2595 def test_gradient_all(self, device, dtype):
2596 def create_scalar(shape):
Shen Li10224432021-08-12 11:39:31 -07002597 return make_tensor((1,), device='cpu', dtype=dtype, low=1.).item()
Ilqar Ramazanli8b816e92021-05-11 18:51:07 -07002598
2599 def create_list(shape):
Shen Li10224432021-08-12 11:39:31 -07002600 return make_tensor((len(shape),), device='cpu', dtype=dtype, low=1.).tolist()
Ilqar Ramazanli8b816e92021-05-11 18:51:07 -07002601
2602 def create_coordinate_tensors(shape):
2603 tensor_list = []
2604 for i in range(len(shape)):
2605 tensor_list.append(make_tensor((shape[i],), device=device, dtype=dtype))
2606 return tensor_list
2607
2608 def filter_shape(shape, dim):
2609 filtered_shape = []
2610 for i in range(len(dim)):
2611 filtered_shape.append(shape[dim[i]])
2612 return filtered_shape
2613
2614 # shape, dims format
2615 test_cases = (
2616 ((5,), (0,)),
2617 ((4, 4), (0, 1)),
2618 ((3, 3, 3), (-1, 0)),
2619 ((4, 4, 4), (2,)),
2620 ((4, 4, 4), (0, 1)),
2621 ((4, 4, 4, 3), (0, 2, 3)),
Ilqar Ramazanli90cd57e2021-06-23 03:32:04 -07002622 ((4, 5, 3, 4, 3), (1, 2)),
Ilqar Ramazanli8b816e92021-05-11 18:51:07 -07002623 ((4, 3, 6, 5, 3), (2, 4)),
Ilqar Ramazanli90cd57e2021-06-23 03:32:04 -07002624 ((4, 3, 3, 5, 3), (0, 1, 2, 3, 4)),
Jonathan Colen33403f42022-01-25 12:45:58 -08002625 ((1, 3, 3), (1, 2)),
2626 ((1, 5), (1,)),
Ilqar Ramazanli8b816e92021-05-11 18:51:07 -07002627 )
2628
Shen Li10224432021-08-12 11:39:31 -07002629 for case, contig, edge_order, space_fn in product(test_cases, [True, False], [1, 2],
2630 (create_scalar, create_list, create_coordinate_tensors)):
Ilqar Ramazanli8b816e92021-05-11 18:51:07 -07002631 shape, dims = case
2632 # filter shape by dims before passing filtered shape to create_* functions
2633 filtered_shape = filter_shape(shape, dims)
2634
2635 spacing = space_fn(filtered_shape)
2636 t = make_tensor(shape, device=device, dtype=dtype, noncontiguous=not contig)
2637 t_np = t.cpu().numpy()
2638
Ilqar Ramazanli90cd57e2021-06-23 03:32:04 -07002639 actual = torch.gradient(t, spacing=spacing, dim=dims, edge_order=edge_order)
Shen Li10224432021-08-12 11:39:31 -07002640 if space_fn == create_coordinate_tensors and spacing[0].device != 'cpu':
Ilqar Ramazanli8b816e92021-05-11 18:51:07 -07002641 spacing = [space.cpu().detach().numpy() for space in spacing]
Shen Li10224432021-08-12 11:39:31 -07002642 expected = np.gradient(t_np, *self._wrap_to_list(spacing), axis=dims, edge_order=edge_order)
2643 actual, expected = self._inf_nan_preprocess(list(actual), self._wrap_to_list(expected))
Philip Meier401bbb22021-08-30 12:28:39 -07002644 self.assertEqual(actual, expected, equal_nan=True, atol=1e-4, rtol=0, exact_dtype=False)
Ilqar Ramazanli8b816e92021-05-11 18:51:07 -07002645
kshitij12345885a8e52021-11-01 09:21:20 -07002646 @onlyNativeDeviceTypes
Ilqar Ramazanli8b816e92021-05-11 18:51:07 -07002647 @dtypes(torch.long, torch.float32, torch.complex64)
2648 def test_gradient_extreme_cases(self, device, dtype):
2649 # Test behaviour for inf and nan values
Shen Li10224432021-08-12 11:39:31 -07002650 actual = torch.gradient(torch.tensor([2, -2, inf, inf, -inf, -inf, inf, 3, -inf, 2, nan, nan, 3, inf, nan]))
2651 expected = np.gradient(np.array([2, -2, inf, inf, -inf, -inf, inf, 3, -inf, 2, nan, nan, 3, inf, nan]))
Philip Meier0c916c82021-06-22 13:05:24 -07002652 self.assertEqual(actual, self._wrap_to_list(expected), exact_dtype=False)
Ilqar Ramazanli8b816e92021-05-11 18:51:07 -07002653
2654 # Test behaviour in very big tensors
2655 large_size = 100000
Philip Meier0973c5a2022-02-24 21:47:38 -08002656 t = make_tensor((large_size,), dtype=dtype, device=device)
Ilqar Ramazanli8b816e92021-05-11 18:51:07 -07002657 t_np = t.cpu().numpy()
2658 coordinates_np = list(np.random.randn(large_size))
2659 coordinates = [torch.tensor(coordinates_np, device=device)]
2660 actual = torch.gradient(t, spacing=coordinates, dim=0, edge_order=1)
2661 expected = [np.gradient(t_np, coordinates_np, axis=0, edge_order=1)]
Philip Meier0c916c82021-06-22 13:05:24 -07002662 self.assertEqual(actual, expected, exact_dtype=False)
Ilqar Ramazanli8b816e92021-05-11 18:51:07 -07002663
Ilqar Ramazanli90cd57e2021-06-23 03:32:04 -07002664 actual = torch.gradient(t, spacing=coordinates, dim=0, edge_order=2)
2665 expected = [np.gradient(t_np, coordinates_np, axis=0, edge_order=2)]
2666 self.assertEqual(actual, expected, exact_dtype=False)
2667
kshitij12345885a8e52021-11-01 09:21:20 -07002668 @onlyNativeDeviceTypes
Ilqar Ramazanli8b816e92021-05-11 18:51:07 -07002669 def test_gradient_type_promotion(self, device):
2670 inputs = (
2671 make_tensor((4, 4), device=device, dtype=torch.float32),
2672 make_tensor((4, 4), device=device, dtype=torch.complex64),
2673 make_tensor((4, 4), device=device, dtype=torch.int64),
2674 )
2675
2676 spacing = (
Shen Li10224432021-08-12 11:39:31 -07002677 make_tensor((1,), device='cpu', dtype=torch.float32).item(),
2678 make_tensor((1,), device='cpu', dtype=torch.int64).item(),
2679 make_tensor((1,), device='cpu', dtype=torch.complex64).item(),
2680 make_tensor((2,), device='cpu', dtype=torch.float32, low=0.1).tolist(),
2681 make_tensor((2,), device='cpu', dtype=torch.int64, low=1).tolist(),
2682 make_tensor((2,), device='cpu', dtype=torch.complex64).tolist(),
2683 [make_tensor((4,), device=device, dtype=torch.float32),
2684 make_tensor((4,), device=device, dtype=torch.float32)],
2685 [make_tensor((4,), device=device, dtype=torch.int64),
2686 make_tensor((4,), device=device, dtype=torch.int64)],
2687 [make_tensor((4,), device=device, dtype=torch.complex64),
2688 make_tensor((4,), device=device, dtype=torch.complex64)],
Ilqar Ramazanli8b816e92021-05-11 18:51:07 -07002689 )
2690
Ilqar Ramazanli90cd57e2021-06-23 03:32:04 -07002691 for input, spacing_or_coord, edge_order in product(inputs, spacing, [1, 2]):
Ilqar Ramazanli8b816e92021-05-11 18:51:07 -07002692 input_np = input.cpu().numpy()
2693 input_np = input.cpu().numpy()
Shen Li10224432021-08-12 11:39:31 -07002694 actual = torch.gradient(input, spacing=spacing_or_coord, dim=(0, 1), edge_order=edge_order)
Ilqar Ramazanli8b816e92021-05-11 18:51:07 -07002695 spacing_or_coord_wrapped = self._wrap_to_list(spacing_or_coord)
2696 spacing_or_coord_np = []
Shen Li10224432021-08-12 11:39:31 -07002697 if torch.is_tensor(spacing_or_coord_wrapped[0]) and torch.device(spacing_or_coord_wrapped[0].device).type != 'cpu':
Ilqar Ramazanli8b816e92021-05-11 18:51:07 -07002698 for i in range(len(spacing_or_coord_wrapped)):
Shen Li10224432021-08-12 11:39:31 -07002699 spacing_or_coord_np.append(spacing_or_coord_wrapped[i].detach().clone().cpu().numpy())
Ilqar Ramazanli8b816e92021-05-11 18:51:07 -07002700 else:
2701 spacing_or_coord_np = spacing_or_coord_wrapped
Shen Li10224432021-08-12 11:39:31 -07002702 expected = np.gradient(input_np, *spacing_or_coord_np, axis=(0, 1), edge_order=edge_order)
Ilqar Ramazanli8b816e92021-05-11 18:51:07 -07002703 if actual[0].dtype == torch.complex64 and input.dtype != torch.complex64:
2704 for i in range(len(actual)):
Shen Li10224432021-08-12 11:39:31 -07002705 self.assertEqual(actual[i].real, expected[i].real, exact_dtype=False)
Ilqar Ramazanli8b816e92021-05-11 18:51:07 -07002706 # Type promotion fails on Numpy when spacing is given as complex number and input is given as real.
2707 # Result is given just as real number and all the imaginary parts to be equal to zero.
Shen Li10224432021-08-12 11:39:31 -07002708 self.assertEqual(expected[i].imag, torch.zeros(actual[i].shape), exact_dtype=False)
Ilqar Ramazanli8b816e92021-05-11 18:51:07 -07002709 else:
2710 actual, expected = self._inf_nan_preprocess(list(actual), expected)
Philip Meier401bbb22021-08-30 12:28:39 -07002711 self.assertEqual(actual, expected, equal_nan=True, exact_dtype=False)
Ilqar Ramazanli8b816e92021-05-11 18:51:07 -07002712
Natalia Gimelshein6ca54212020-07-27 15:39:50 -07002713 def _test_large_cum_fn_helper(self, x, fn):
2714 x_cpu = x.cpu().float()
2715 expected = fn(x_cpu)
2716 actual = fn(x).cpu().float()
2717 self.assertEqual(expected, actual.cpu().float())
2718
Shen Li10224432021-08-12 11:39:31 -07002719 @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "sandcastle OOM with current tpx gpu/re configuration")
Natalia Gimelshein6ca54212020-07-27 15:39:50 -07002720 @onlyCUDA
kshitij12345d5d20962021-11-18 08:25:47 -08002721 @dtypes(torch.half) # only small dtype not to get oom
Natalia Gimelshein6ca54212020-07-27 15:39:50 -07002722 def test_large_cumsum(self, device, dtype):
2723 # initialization to avoid overflow and half caveats
Shen Li10224432021-08-12 11:39:31 -07002724 x = torch.empty(2**30 + 200, device=device, dtype=dtype)
Natalia Gimelshein6ca54212020-07-27 15:39:50 -07002725 x[::3] = -3
2726 x[1::3] = 2
2727 x[2::3] = 1
2728 self._test_large_cum_fn_helper(x, lambda x: torch.cumsum(x, 0))
2729
2730 @onlyCUDA
kshitij12345d5d20962021-11-18 08:25:47 -08002731 @dtypes(torch.half) # only small dtype not to get oom
Natalia Gimelshein6ca54212020-07-27 15:39:50 -07002732 def test_large_cumprod(self, device, dtype):
2733 # initialization to avoid overflow and half caveats
Shen Li10224432021-08-12 11:39:31 -07002734 x = torch.empty(2**30 + 200, device=device, dtype=dtype)
Natalia Gimelshein6ca54212020-07-27 15:39:50 -07002735 x[::3] = 8
Shen Li10224432021-08-12 11:39:31 -07002736 x[1::3] = .25
2737 x[2::3] = .5
Natalia Gimelshein6ca54212020-07-27 15:39:50 -07002738 self._test_large_cum_fn_helper(x, lambda x: torch.cumprod(x, 0))
2739
Animesh Jain1d90d6e2022-07-07 18:57:31 +00002740 @skipIfTorchDynamo("Torchdynamo fails with unknown reason")
Kulin Sethe011a8e2022-05-13 18:28:53 +00002741 @skipIfMps
Natalia Gimelshein6ca54212020-07-27 15:39:50 -07002742 def test_discontiguous_out_cumsum(self, device):
2743 x = torch.randn(4, 8, device=device)
2744 y = torch.empty(4, 16, device=device)[:, ::2]
2745 out = torch.cumsum(x, 0)
2746 torch.cumsum(x, 0, out=y)
2747 self.assertFalse(y.is_contiguous())
Shen Li10224432021-08-12 11:39:31 -07002748 self.assertEqual(out, y, atol=0., rtol=0.)
Natalia Gimelshein6ca54212020-07-27 15:39:50 -07002749
Natalia Gimelsheinec898b12020-08-04 10:07:43 -07002750 def _test_cumminmax_helper(self, x, fn, expected_val, expected_ind):
2751 val, ind = fn(x, -1)
2752 self.assertEqual(val, expected_val, atol=0, rtol=0)
2753 self.assertEqual(ind, expected_ind, atol=0, rtol=0)
2754 out_val = torch.empty_like(val).t().contiguous().t()
2755 out_ind = torch.empty_like(ind).t().contiguous().t()
2756 fn(x, -1, out=(out_val, out_ind))
2757 self.assertFalse(out_val.is_contiguous())
2758 self.assertFalse(out_ind.is_contiguous())
2759 self.assertEqual(out_val, expected_val, atol=0, rtol=0)
2760 self.assertEqual(out_ind, expected_ind, atol=0, rtol=0)
2761
Kulin Sethe011a8e2022-05-13 18:28:53 +00002762 @skipIfMps
Natalia Gimelsheinec898b12020-08-04 10:07:43 -07002763 def test_cummax_discontiguous(self, device):
Shen Li10224432021-08-12 11:39:31 -07002764 x = torch.tensor([[0, 1, 2, 3, 2, 1], [4, 5, 6, 5, 6, 7]], device=device, dtype=torch.float).t().contiguous().t()
2765 expected_val = torch.tensor([[0, 1, 2, 3, 3, 3], [4, 5, 6, 6, 6, 7]], device=device, dtype=torch.float)
2766 expected_ind = torch.tensor([[0, 1, 2, 3, 3, 3], [0, 1, 2, 2, 4, 5]], device=device, dtype=torch.long)
Natalia Gimelsheinec898b12020-08-04 10:07:43 -07002767 self._test_cumminmax_helper(x, torch.cummax, expected_val, expected_ind)
2768
Kulin Sethe011a8e2022-05-13 18:28:53 +00002769 @skipIfMps
Natalia Gimelsheinec898b12020-08-04 10:07:43 -07002770 def test_cummin_discontiguous(self, device):
Shen Li10224432021-08-12 11:39:31 -07002771 x = torch.tensor([[3, 2, 1, 0, 1, 2], [7, 6, 5, 4, 5, 2]], device=device, dtype=torch.float).t().contiguous().t()
2772 expected_val = torch.tensor([[3, 2, 1, 0, 0, 0], [7, 6, 5, 4, 4, 2]], device=device, dtype=torch.float)
2773 expected_ind = torch.tensor([[0, 1, 2, 3, 3, 3], [0, 1, 2, 3, 3, 5]], device=device, dtype=torch.long)
Natalia Gimelsheinec898b12020-08-04 10:07:43 -07002774 self._test_cumminmax_helper(x, torch.cummin, expected_val, expected_ind)
2775
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002776 def test_bool_tensor_value_change(self, device):
2777 x = torch.tensor([True, False], dtype=torch.bool, device=device)
2778 x[0] = False
2779 x[1] = True
Shen Li10224432021-08-12 11:39:31 -07002780 self.assertEqual(x, torch.tensor([False, True], dtype=torch.bool, device=device))
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002781
Mike Ruberrye0d829a2022-01-24 01:28:07 -08002782 # FIXME: move to shape ops test suite
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002783 def test_unfold_all_devices_and_dtypes(self, device):
Nikita Shulgabfac65d2022-03-30 14:13:21 -07002784 for dt in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16):
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002785
Gao, Xiange255a4e2020-09-18 15:52:32 -07002786 if dt == torch.bool:
anjali4119e016f72020-04-16 08:21:49 -07002787 x = torch.empty((0, 1, 3, 0), dtype=dt, device=device)
Pavel Belevich35b6d292020-03-06 07:14:31 -08002788 self.assertEqual((0, 1, 1, 0, 3), x.unfold(2, 3, 2).shape)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002789 else:
anjali4119e016f72020-04-16 08:21:49 -07002790 x = torch.empty((0, 1, 3, 0), dtype=dt, device=device)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002791 self.assertEqual((0, 1, 1, 0, 3), x.unfold(2, 3, 2).shape)
2792
Mike Ruberrye0d829a2022-01-24 01:28:07 -08002793 # FIXME: move to shape ops test suite
Gregory Chanan2793d412019-10-25 07:10:32 -07002794 def test_unfold_scalars(self, device):
2795 x = torch.tensor(0.5, device=device)
2796 # unfold on a 0-dimensional tensor should always return a 1-d dimensional
2797 # tensor of shape [size] (i.e., the second parameter to unfold)
2798
2799 self.assertEqual(torch.empty(0, device=device), x.unfold(0, 0, 1))
2800 self.assertEqual(torch.empty(0, device=device), x.unfold(0, 0, 2))
2801 self.assertEqual(torch.tensor([0.5], device=device), x.unfold(0, 1, 1))
2802
Mike Ruberrye0d829a2022-01-24 01:28:07 -08002803 # FIXME: move to data movement test suite
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002804 def test_copy_all_dtypes_and_devices(self, device):
2805 from copy import copy
Nikita Shulgabfac65d2022-03-30 14:13:21 -07002806 for dt in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16):
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002807 x = torch.tensor([1, 2, 3, 4], dtype=dt, device=device)
2808 x_clone = x.clone()
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002809 y = copy(x)
2810 y.fill_(1)
2811 # copy is a shallow copy, only copies the tensor view,
2812 # not the data
2813 self.assertEqual(x, y)
2814
Mike Ruberrye0d829a2022-01-24 01:28:07 -08002815 # FIXME: move to data movement test suite
Peter Bell917d56a2022-01-05 17:57:35 -08002816 @onlyNativeDeviceTypes
2817 def test_copy_math_view(self, device):
2818 for dst_dtype, src_dtype in [
2819 (torch.float32, torch.float32),
2820 (torch.float64, torch.float32),
2821 (torch.int64, torch.int32),
2822 (torch.complex128, torch.complex64),
2823 ]:
2824 src = make_tensor((100,), dtype=src_dtype, device=device)
2825 dst = torch.empty(100, dtype=dst_dtype, device=device)
2826
2827 dst.copy_(src)
2828 self.assertEqual(dst, src, exact_dtype=False)
2829
2830 dst.copy_(src._neg_view())
2831 self.assertEqual(dst, src.neg(), exact_dtype=False)
2832
2833 dst._neg_view().copy_(torch._neg_view(src))
2834 self.assertEqual(dst, src, exact_dtype=False)
2835
2836 dst._neg_view().copy_(src)
2837 self.assertEqual(dst, src.neg(), exact_dtype=False)
2838
2839 for dst_dtype, src_dtype in [
2840 (torch.complex64, torch.complex64),
2841 (torch.complex128, torch.complex64),
2842 ]:
2843 src = make_tensor((100,), dtype=src_dtype, device=device)
2844 dst = torch.empty(100, dtype=dst_dtype, device=device)
2845
2846 dst.conj().copy_(src)
2847 self.assertEqual(dst, src.conj_physical(), exact_dtype=False)
2848
2849 dst.conj().copy_(src._neg_view())
2850 self.assertEqual(dst, src.neg().conj_physical(), exact_dtype=False)
2851
Mike Ruberrye0d829a2022-01-24 01:28:07 -08002852 # FIXME: move to data movement test suite
Peter Bell17bb6862022-01-14 10:12:02 -08002853 @onlyNativeDeviceTypes
2854 @dtypes(torch.int64, torch.float32, torch.complex64)
2855 def test_copy_transpose_math_view(self, device, dtype):
2856 src = make_tensor((100, 100), dtype=dtype, device=device).transpose(0, 1)
2857 dst = torch.empty((100, 100), dtype=dtype, device=device)
2858
2859 dst._neg_view().copy_(src)
2860 self.assertEqual(dst, -src)
2861 dst._neg_view().copy_(src._neg_view())
2862 self.assertEqual(dst, src)
2863 dst.copy_(src._neg_view())
2864 self.assertEqual(dst, -src)
2865
2866 if dtype.is_complex:
2867 dst.conj().copy_(src)
2868 self.assertEqual(dst, src.conj_physical())
2869 dst.conj().copy_(src.conj())
2870 self.assertEqual(dst, src)
2871 dst.copy_(src.conj())
2872 self.assertEqual(dst, src.conj_physical())
2873
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002874 def test_clone_all_dtypes_and_devices(self, device):
Nikita Shulgabfac65d2022-03-30 14:13:21 -07002875 for dt in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16):
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002876 x = torch.tensor((1, 1), dtype=dt, device=device)
2877 y = x.clone()
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002878 self.assertEqual(x, y)
2879
Peter Bell2af64ba2020-03-05 15:54:00 -08002880 def test_clone_zero_stride_dim(self, device):
2881 # stride zero, size 1 axis, not contiguous
2882 x = torch.randn(10)
2883 y = x.as_strided([2, 1, 5], [1, 0, 2])
2884 self.assertEqual(y, y.clone())
2885
anjali4115d80a482021-09-01 16:11:38 -07002886 def test_clone_not_memory_dense(self):
2887 # github issue: https://github.com/pytorch/pytorch/issues/64176
2888 x = torch.randn(10, 8).t()[::2, ::2]
2889 y = x.clone()
2890 # should retain permutation after densification
2891 self.assertTrue(y.stride() == (1, 4))
2892
Mike Ruberrye0d829a2022-01-24 01:28:07 -08002893 # FIXME: move to elementwise ternary test suite
Philip Meier26b7ff52021-09-07 08:57:43 -07002894 @dtypesIfCUDA(*set(get_all_math_dtypes('cuda')))
2895 @dtypes(*set(get_all_math_dtypes('cpu')))
anjali411db1f2172020-11-14 21:25:52 -08002896 def test_addcmul(self, device, dtype):
Mike Ruberryde40c8e2021-06-06 14:51:26 -07002897 # Returns floating or integral scalar corresponding to dtype
2898 def _number(floating, integer, dtype):
2899 if dtype in [torch.half, torch.float, torch.double, torch.bfloat16]:
2900 return floating
2901 elif dtype in [torch.cfloat, torch.cdouble]:
2902 return floating * (1 + 1j)
2903 else:
2904 return integer
2905
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002906 def rand_tensor(size, dtype, device):
anjali4114f3946a2020-04-24 15:03:38 -07002907 if dtype.is_floating_point or dtype.is_complex:
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002908 return torch.rand(size=size, dtype=dtype, device=device)
2909 if dtype == torch.uint8:
2910 return torch.randint(1, 5, size=size, dtype=dtype, device=device)
2911 else:
2912 return torch.randint(-5, 5, size=size, dtype=dtype, device=device)
2913
anjali411db1f2172020-11-14 21:25:52 -08002914 a = rand_tensor((2, 2), dtype=dtype, device=device)
2915 b = rand_tensor((2, 2), dtype=dtype, device=device)
2916 c = rand_tensor((2, 2), dtype=dtype, device=device)
anjali4114f3946a2020-04-24 15:03:38 -07002917
anjali411db1f2172020-11-14 21:25:52 -08002918 alpha = _number(0.5, 3, dtype)
anjali4114f3946a2020-04-24 15:03:38 -07002919
anjali411db1f2172020-11-14 21:25:52 -08002920 actual = torch.addcmul(a, b, c, value=alpha)
2921 expected = a + alpha * b * c
anjali4114f3946a2020-04-24 15:03:38 -07002922
anjali411db1f2172020-11-14 21:25:52 -08002923 self.assertEqual(expected, actual)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002924
mattip54a24982021-03-08 03:30:11 -08002925 with self.assertWarnsOnceRegex(
Shen Li10224432021-08-12 11:39:31 -07002926 UserWarning, "This overload of addcmul is deprecated"):
anjali411db1f2172020-11-14 21:25:52 -08002927 self.assertEqual(actual, torch.addcmul(a, alpha, b, c))
Peter Bellb0ac4252020-01-14 11:32:04 -08002928
Shen Li10224432021-08-12 11:39:31 -07002929 if self.device_type == 'cuda' and dtype == torch.half:
Masaki Kozukia404cc92021-06-25 10:20:10 -07002930 a = torch.tensor([60000.0], device=device, dtype=dtype)
2931 b = torch.tensor([60000.0], device=device, dtype=dtype)
2932 c = torch.tensor([2.0], device=device, dtype=dtype)
2933 out = torch.addcmul(a, b, c, value=-1)
2934 self.assertTrue(not (out.isnan() or out.isinf()))
2935
Mike Ruberrye0d829a2022-01-24 01:28:07 -08002936 # FIXME: move to shape ops test suite
Mike Ruberryb4b8f532019-09-14 17:09:04 -07002937 def test_narrow_empty(self, device):
2938 x = torch.randn(2, 3, 4, device=device)
2939 for d in range(x.dim()):
2940 y = x.narrow(d, x.size(d), 0)
2941 sz = list(x.size())
2942 sz[d] = 0
2943 self.assertEqual(sz, y.size())
2944
Mikayla Gawarecki676a4a32022-04-27 22:00:47 +00002945 # FIXME: move to indexing test suite
Mikayla Gawarecki676a4a32022-04-27 22:00:47 +00002946 @parametrize("reduce", ['prod', 'amin', 'amax', 'mean'])
PyTorch MergeBotf668b7e2022-06-30 10:32:34 +00002947 @dtypes(*all_types_and(torch.half, torch.bfloat16))
Mikayla Gawarecki676a4a32022-04-27 22:00:47 +00002948 def test_index_reduce(self, device, dtype, reduce):
2949 size = (3, 4, 5)
2950 index_dtypes = [torch.int, torch.long]
2951 include_selfs = [True, False]
PyTorch MergeBotf668b7e2022-06-30 10:32:34 +00002952 amin_init = float('inf') if dtype.is_floating_point else torch.iinfo(dtype).max
2953 amax_init = -float('inf') if dtype.is_floating_point else torch.iinfo(dtype).min
2954 reduction_init = {'prod': 1, 'mean': 0, 'amin': amin_init, 'amax': amax_init}
Mikayla Gawarecki676a4a32022-04-27 22:00:47 +00002955
PyTorch MergeBotf668b7e2022-06-30 10:32:34 +00002956 for dest_noncontig, src_noncontig, index_noncontig in product([True, False], repeat=3):
Mikayla Gawarecki676a4a32022-04-27 22:00:47 +00002957 for idx_dtype, include_self in product(index_dtypes, include_selfs):
2958 for dim in range(len(size)):
2959 num_src = np.random.randint(10)
2960 num_dest = size[dim]
PyTorch MergeBotf668b7e2022-06-30 10:32:34 +00002961 dest = make_tensor(size, device=device, dtype=dtype, noncontiguous=dest_noncontig)
2962 src_size = size[:dim] + (num_src,) + size[dim + 1:]
2963 src = make_tensor(src_size, device=device, dtype=dtype, noncontiguous=src_noncontig)
Mikayla Gawarecki676a4a32022-04-27 22:00:47 +00002964 idx = torch.randint(num_dest, (num_src,), dtype=idx_dtype, device=device)
PyTorch MergeBotf668b7e2022-06-30 10:32:34 +00002965 if index_noncontig:
Mikayla Gawarecki1141b452022-05-12 21:57:01 +00002966 # noncontiguous_like fails with RuntimeError: XLA tensors do not have storage
2967 idx = torch.testing.make_non_contiguous(idx)
Mikayla Gawarecki676a4a32022-04-27 22:00:47 +00002968 expected = dest.clone()
Mikayla Gawarecki841c65f2022-05-12 22:19:52 +00002969 dest.index_reduce_(dim, idx, src, reduce, include_self=include_self)
Mikayla Gawarecki676a4a32022-04-27 22:00:47 +00002970 # fill rows in idx with reduction inits if include_self=False
2971 if (not include_self):
2972 expected.index_fill_(dim, idx.long(), reduction_init[reduce])
2973 expected = expected.transpose(0, dim)
2974 src = src.transpose(0, dim)
2975 for i in range(num_src):
2976 if reduce == 'prod':
2977 expected[idx[i]] *= src[i]
2978 elif reduce == 'amin':
2979 torch.minimum(expected[idx[i]], src[i], out=expected[idx[i]])
2980 elif reduce == 'amax':
2981 torch.maximum(expected[idx[i]], src[i], out=expected[idx[i]])
2982 else:
2983 expected[idx[i]] += src[i]
2984 if reduce == 'mean':
2985 counts = torch.ones_like(expected) if include_self else torch.zeros_like(expected)
2986 counts.index_add_(0, idx, torch.ones_like(src))
2987 counts.masked_fill_(counts == 0, 1)
PyTorch MergeBotf668b7e2022-06-30 10:32:34 +00002988 if (dtype.is_floating_point):
2989 expected.div_(counts)
2990 else:
2991 expected.div_(counts, rounding_mode="floor")
Mikayla Gawarecki676a4a32022-04-27 22:00:47 +00002992 expected = expected.transpose(0, dim)
2993
2994 self.assertEqual(dest, expected)
2995
Mike Ruberrye0d829a2022-01-24 01:28:07 -08002996 # FIXME: move to test indexing
Nikita Shulgabfac65d2022-03-30 14:13:21 -07002997 @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
lezcano58703462021-03-22 22:32:36 -07002998 def test_index_copy(self, device, dtype):
2999 # We just test for num_copy <= num_dest, as otherwise there are repeated indices
3000 # and the behavior is undefined
3001 num_copy, num_dest = 3, 5
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003002
lezcano58703462021-03-22 22:32:36 -07003003 def make_arg(batch_sizes, n, dim, contig):
3004 size_arg = batch_sizes[:dim] + (n,) + batch_sizes[dim:]
Philip Meier0973c5a2022-02-24 21:47:38 -08003005 return make_tensor(size_arg, dtype=dtype, device=device, low=None, high=None, noncontiguous=not contig)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003006
lezcano58703462021-03-22 22:32:36 -07003007 def ref_index_copy(tgt, dim, idx, src):
3008 for i in range(idx.size(0)):
3009 idx_dest = dim * (slice(None),) + (idx[i],)
3010 idx_src = dim * (slice(None),) + (i,)
3011 tgt[idx_dest] = src[idx_src]
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003012
lezcano58703462021-03-22 22:32:36 -07003013 # More thorough testing as in index_add
3014 for dest_contig, src_contig, index_contig in product([True, False], repeat=3):
3015 for other_sizes in ((), (4, 5)):
3016 for dim in range(len(other_sizes)):
3017 dest = make_arg(other_sizes, num_dest, dim, dest_contig)
3018 src = make_arg(other_sizes, num_copy, dim, src_contig)
Shen Li10224432021-08-12 11:39:31 -07003019 idx = torch.randperm(num_dest, dtype=torch.int64, device=device)[:num_copy]
lezcano58703462021-03-22 22:32:36 -07003020 if not index_contig:
3021 idx = torch.repeat_interleave(idx, 2, dim=-1)
3022 idx = idx[..., ::2]
3023 dest2 = dest.clone()
3024 dest.index_copy_(dim, idx, src)
3025 ref_index_copy(dest2, dim, idx, src)
3026 self.assertEqual(dest, dest2)
3027
Mike Ruberrye0d829a2022-01-24 01:28:07 -08003028 # FIXME: move to test indexing
kshitij12345885a8e52021-11-01 09:21:20 -07003029 # onlyNativeDeviceTypes due to an XLA error:
lezcano58703462021-03-22 22:32:36 -07003030 # https://github.com/pytorch/pytorch/issues/53256
kshitij12345885a8e52021-11-01 09:21:20 -07003031 @onlyNativeDeviceTypes
Nikita Shulgabfac65d2022-03-30 14:13:21 -07003032 @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
lezcano58703462021-03-22 22:32:36 -07003033 def test_index_copy_scalars(self, device, dtype):
3034 # Create the 8 possible combinations of scalar sizes for target / index / source
Shen Li10224432021-08-12 11:39:31 -07003035 scalars = ((make_tensor(size_t, dtype=dtype, device=device, low=None, high=None),
3036 make_tensor(size_i, dtype=torch.int64, device=device, low=0, high=1),
3037 make_tensor(size_s, dtype=dtype, device=device, low=None, high=None))
3038 for size_t, size_i, size_s in product([(), (1,)], repeat=3))
lezcano58703462021-03-22 22:32:36 -07003039 for target, idx, source in scalars:
3040 target.index_copy_(0, idx, source)
3041 self.assertEqual(target.item(), source.item())
3042
Mike Ruberrye0d829a2022-01-24 01:28:07 -08003043 # FIXME: move to test indexing
lezcano58703462021-03-22 22:32:36 -07003044 @onlyCPU
3045 def test_errors_index_copy(self, device):
3046 # We do not test the GPU as the CUDA_ASSERT would break the CUDA context
3047 idx_dim = 8
3048 tgt_dim = 5
3049 batch_dim = 3
3050
3051 # Too large of an index
3052 a = torch.randn(batch_dim, tgt_dim, device=device)
3053 idx = torch.full((idx_dim,), tgt_dim, device=device)
3054 c = torch.zeros(batch_dim, idx_dim, device=device)
3055 with self.assertRaises(IndexError):
3056 a.index_copy_(1, idx, c)
3057
3058 # Too small (negative indices)
3059 idx = torch.full((idx_dim,), -1, device=device)
3060 with self.assertRaises(IndexError):
3061 a.index_copy_(1, idx, c)
3062
3063 # Too small (very negative indices) - they should be unsupported even
3064 # when support for negative indices is implemented for index_copy_
3065 idx = torch.full((idx_dim,), -tgt_dim - 1, device=device)
3066 with self.assertRaises(IndexError):
3067 a.index_copy_(1, idx, c)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003068
Yu Guo8a450062021-05-12 16:23:48 -07003069 def _prepare_data_for_index_copy_and_add_deterministic(
3070 self, dim: int, device: torch.device
3071 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Shen Li10224432021-08-12 11:39:31 -07003072 assert (dim >= 0 and dim < 3)
Yu Guo8a450062021-05-12 16:23:48 -07003073 a = [5, 4, 3]
3074 a[dim] = 2000
3075 x = torch.zeros(a, device=device)
3076 b = a.copy()
3077 elems = a[dim] * 20
3078 b[dim] = elems
3079 src = torch.rand(b, device=device)
3080 index = torch.randint(a[dim], (elems,), device=device)
3081 return (x, index, src)
Kurt Mohler2cb92042020-12-03 10:55:52 -08003082
Mike Ruberrye0d829a2022-01-24 01:28:07 -08003083 # FIXME: move to test indexing
kshitij12345885a8e52021-11-01 09:21:20 -07003084 @onlyNativeDeviceTypes
Yu Guo8a450062021-05-12 16:23:48 -07003085 def test_index_copy_deterministic(self, device: torch.device) -> None:
Yu Guo72c3ee02021-04-26 12:13:53 -07003086 for dim in range(3):
Shen Li10224432021-08-12 11:39:31 -07003087 x, index, src = self._prepare_data_for_index_copy_and_add_deterministic(dim, device)
Yu Guo8a450062021-05-12 16:23:48 -07003088 with DeterministicGuard(True):
3089 y0 = torch.index_copy(x, dim, index, src)
Mike Ruberryc911c302021-05-12 03:30:58 -07003090
Yu Guo8a450062021-05-12 16:23:48 -07003091 x0 = x.clone().detach()
3092 index_list = index.tolist()
3093 for i in range(len(index_list)):
3094 if dim == 0:
3095 x0[index_list[i], :, :] = src[i, :, :]
3096 elif dim == 1:
3097 x0[:, index_list[i], :] = src[:, i, :]
3098 elif dim == 2:
3099 x0[:, :, index_list[i]] = src[:, :, i]
3100
3101 self.assertEqual(x0, y0, atol=0, rtol=0)
3102
Mike Ruberrye0d829a2022-01-24 01:28:07 -08003103 # FIXME: move to test indexing
kshitij12345885a8e52021-11-01 09:21:20 -07003104 @onlyNativeDeviceTypes
Yu Guo8a450062021-05-12 16:23:48 -07003105 def test_index_add_deterministic(self, device: torch.device) -> None:
3106 for dim in range(3):
Shen Li10224432021-08-12 11:39:31 -07003107 x, index, src = self._prepare_data_for_index_copy_and_add_deterministic(dim, device)
Yu Guo8a450062021-05-12 16:23:48 -07003108 alpha = random.random() + 1
Yu Guo72c3ee02021-04-26 12:13:53 -07003109 # on CPU it should be deterministic regardless of the deterministic mode
3110 with DeterministicGuard(True):
3111 y0 = torch.index_add(x, dim, index, src, alpha=alpha)
3112 for _ in range(3):
3113 y = torch.index_add(x, dim, index, src, alpha=alpha)
3114 self.assertEqual(y, y0, atol=0, rtol=0)
3115
3116 with DeterministicGuard(False):
3117 for _ in range(3):
3118 y_nd = torch.index_add(x, dim, index, src, alpha=alpha)
3119 self.assertEqual(y_nd, y0, atol=1e-3, rtol=1e-5)
3120
Mike Ruberrye0d829a2022-01-24 01:28:07 -08003121 # FIXME: find a test suite for the put operator
kshitij12345885a8e52021-11-01 09:21:20 -07003122 @onlyNativeDeviceTypes
Yu Guoa07a0192021-05-12 00:29:58 -07003123 def test_index_put_non_accumulate_deterministic(self, device) -> None:
3124 with DeterministicGuard(True):
3125 for i in range(3):
3126 m = random.randint(10, 20)
3127 elems = random.randint(20000, 30000)
3128 values = torch.rand(elems, device=device)
3129 indices = torch.randint(m, (elems,), device=device)
3130 input = torch.rand(m, device=device)
3131 output = input.index_put((indices,), values, accumulate=False)
3132
3133 input_list = input.tolist()
3134 indices_list = indices.tolist()
3135 values_list = values.tolist()
3136 for i, v in zip(indices_list, values_list):
3137 input_list[i] = v
3138
3139 self.assertEqual(output, input_list)
3140
Mike Ruberrye0d829a2022-01-24 01:28:07 -08003141 # FIXME: move to test indexing
Nikita Shulgabfac65d2022-03-30 14:13:21 -07003142 @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
Kulin Sethe011a8e2022-05-13 18:28:53 +00003143 @skipIfMps
Nikita Vedeneevb198cf42021-02-01 15:54:13 -08003144 def test_index_fill(self, device, dtype):
3145 x = torch.tensor([[1, 2], [4, 5]], dtype=dtype, device=device)
3146 index = torch.tensor([0], device=device)
3147 x.index_fill_(1, index, 0)
3148 self.assertEqual(x, torch.tensor([[0, 2], [0, 5]], dtype=dtype, device=device))
Philip Meier1f74e082022-02-16 18:25:35 -08003149 if not x.is_complex() and not device == "meta":
Nikita Vedeneevb198cf42021-02-01 15:54:13 -08003150 with self.assertRaisesRegex(RuntimeError, r"Scalar"):
3151 x.index_fill_(1, index, 1 + 1j)
Nikita Vedeneev0048d972021-02-25 00:33:07 -08003152 # Make sure that the result stays 0-dim while applied to
3153 # a 0-dim input
3154 x = torch.tensor(1, dtype=dtype, device=device)
3155 self.assertEqual(0, x.index_fill(0, index, -1).dim())
3156 self.assertEqual(0, x.index_fill_(0, index, -1).dim())
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003157
Mike Ruberrye0d829a2022-01-24 01:28:07 -08003158 # FIXME: move to test indexing
lezcano9d9986f2021-03-19 20:31:51 -07003159 # The test fails for zero-dimensional tensors on XLA
kshitij12345885a8e52021-11-01 09:21:20 -07003160 @onlyNativeDeviceTypes
Nikita Shulgabfac65d2022-03-30 14:13:21 -07003161 @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
lezcano9d9986f2021-03-19 20:31:51 -07003162 def test_index_select(self, device, dtype):
3163 num_src, num_out = 3, 5
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003164
lezcano9d9986f2021-03-19 20:31:51 -07003165 def make_arg(batch_sizes, n, dim, contig):
3166 size_arg = batch_sizes[:dim] + (n,) + batch_sizes[dim:]
Philip Meier0973c5a2022-02-24 21:47:38 -08003167 return make_tensor(size_arg, dtype=dtype, device=device, low=None, high=None, noncontiguous=not contig)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003168
lezcano9d9986f2021-03-19 20:31:51 -07003169 def ref_index_select(src, dim, idx):
3170 # bfloat16 is just used on GPU, so it's not supported on numpy
3171 if dtype == torch.bfloat16:
3172 src = src.float()
Shen Li10224432021-08-12 11:39:31 -07003173 out = torch.from_numpy(np.take(src.cpu().numpy(), idx.cpu().numpy(), axis=dim))
lezcano9d9986f2021-03-19 20:31:51 -07003174 if dtype == torch.bfloat16:
3175 out = out.to(device=device, dtype=dtype)
3176 return out
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003177
lezcano9d9986f2021-03-19 20:31:51 -07003178 for src_contig, idx_contig in product([True, False], repeat=2):
3179 for other_sizes in ((), (4, 5)):
3180 for dim in range(len(other_sizes)):
3181 src = make_arg(other_sizes, num_src, dim, src_contig)
Philip Meier0973c5a2022-02-24 21:47:38 -08003182 idx = make_tensor(
3183 (num_out,), dtype=torch.int64, device=device, low=0, high=num_src, noncontiguous=not idx_contig
3184 )
lezcano9d9986f2021-03-19 20:31:51 -07003185 out = torch.index_select(src, dim, idx)
3186 out2 = ref_index_select(src, dim, idx)
3187 self.assertEqual(out, out2)
3188
3189 for idx_type in (torch.int32, torch.int64):
3190 other_sizes = (3, 2)
3191 dim = 1
3192 src = make_arg(other_sizes, num_src, dim, True)
Philip Meier0973c5a2022-02-24 21:47:38 -08003193 idx = make_tensor((num_out,), dtype=idx_type, device=device, low=0, high=num_src, noncontiguous=False)
lezcano9d9986f2021-03-19 20:31:51 -07003194 out = torch.index_select(src, dim, idx)
3195 out2 = ref_index_select(src, dim, idx)
3196 self.assertEqual(out, out2)
3197
3198 # Create the 4 possible combinations of scalar sizes for index / source
Philip Meier0973c5a2022-02-24 21:47:38 -08003199 scalars = ((make_tensor(size_s, dtype=dtype, device=device),
Shen Li10224432021-08-12 11:39:31 -07003200 torch.zeros(size_i, dtype=torch.int64, device=device))
3201 for size_s, size_i in product([(), (1,)], repeat=2))
lezcano9d9986f2021-03-19 20:31:51 -07003202 for source, idx in scalars:
3203 out = source.index_select(0, idx)
3204 self.assertEqual(out.item(), source.item())
anjali4119e016f72020-04-16 08:21:49 -07003205
Mike Ruberrye0d829a2022-01-24 01:28:07 -08003206 # FIXME: find a test suite for the take operator
Nikita Shulgabfac65d2022-03-30 14:13:21 -07003207 @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
lezcanofd02fc52021-04-05 18:03:59 -07003208 def test_take(self, device, dtype):
3209 idx_size = (4,)
3210
3211 make_arg = partial(make_tensor, device=device, dtype=dtype)
3212 make_idx = partial(make_tensor, low=0, device=device, dtype=torch.int64)
3213
3214 def ref_take(src, idx):
3215 if dtype == torch.bfloat16:
3216 src = src.half()
3217 src = src.cpu().numpy()
3218 idx = idx.cpu().numpy()
3219 out = torch.from_numpy(np.take(src, idx)).to(device=device, dtype=dtype)
3220 return out
3221
3222 for src_contig, idx_contig, idx_reshape in product([True, False], repeat=3):
3223 for src_size in ((5,), (4, 5)):
Mike Ruberry399b66c2021-04-11 20:37:46 -07003224 src = make_arg(src_size, noncontiguous=not src_contig)
3225 idx = make_idx(idx_size, high=src.numel(), noncontiguous=not idx_contig)
lezcanofd02fc52021-04-05 18:03:59 -07003226 if idx_reshape:
3227 idx = idx.reshape(2, 2)
3228 out = torch.take(src, idx)
3229 out2 = ref_take(src, idx)
3230 self.assertEqual(out, out2)
3231
3232 # Create the 4 possible combinations of scalar sizes for source / index
3233 for size_s, size_i in product([(), (1,)], repeat=2):
3234 source = make_arg(size_s)
3235 idx = make_idx(size_i, high=1)
3236 out = source.take(idx)
3237 self.assertEqual(out.item(), source.item())
3238
Mike Ruberrye0d829a2022-01-24 01:28:07 -08003239 # FIXME: find a test suite for the put operator
lezcanofd02fc52021-04-05 18:03:59 -07003240 # The bool instance does not work on GPU. See
3241 # https://github.com/pytorch/pytorch/issues/54317
Nikita Shulgabfac65d2022-03-30 14:13:21 -07003242 @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
lezcanofd02fc52021-04-05 18:03:59 -07003243 def test_put(self, device, dtype):
3244 src_size = (4,)
3245
3246 make_arg = partial(make_tensor, device=device, dtype=dtype)
3247 make_idx = partial(make_tensor, low=0, device=device, dtype=torch.int64)
3248
3249 def ref_put(dst, idx, src, accumulate):
3250 new_dst = dst.clone(memory_format=torch.contiguous_format).view(-1)
3251 new_idx = idx.contiguous().view(-1)
3252 new_src = src.contiguous().view(-1)
3253 method = new_dst.index_add_ if accumulate else new_dst.index_copy_
3254 return method(0, new_idx, new_src).view_as(dst)
3255
Shen Li10224432021-08-12 11:39:31 -07003256 for dst_contig, src_contig, idx_contig, idx_reshape, accumulate in product([True, False], repeat=5):
lezcanofd02fc52021-04-05 18:03:59 -07003257 for dst_size in ((5,), (4, 5)):
Mike Ruberry399b66c2021-04-11 20:37:46 -07003258 dst = make_arg(dst_size, noncontiguous=not dst_contig)
3259 src = make_arg(src_size, noncontiguous=not src_contig)
lezcanofd02fc52021-04-05 18:03:59 -07003260
3261 # If accumulate=True, `put_` should be deterministic regardless of the inputs on CPU
3262 # On CUDA it may not be, but the test has enough tolerance to account for this
3263 if accumulate:
3264 idx = make_idx(src_size, high=dst.numel())
3265 else:
Shen Li10224432021-08-12 11:39:31 -07003266 idx = torch.randperm(dst.numel(), dtype=torch.int64, device=device)[:src_size[0]]
lezcanofd02fc52021-04-05 18:03:59 -07003267 if not idx_contig:
3268 idx = torch.repeat_interleave(idx, 2, dim=-1)[..., ::2]
3269 if idx_reshape:
3270 idx = idx.reshape(2, 2)
3271 out = torch.put(dst, idx, src, accumulate)
3272 # out-place
3273 reference = ref_put(dst, idx, src, accumulate)
3274 self.assertEqual(out, reference)
3275
3276 # in-place
3277 dst.put_(idx, src, accumulate)
3278 self.assertEqual(dst, reference)
3279
Shen Li10224432021-08-12 11:39:31 -07003280
lezcanofd02fc52021-04-05 18:03:59 -07003281 # Create the 8 possible combinations of scalar sizes for target / index / source
Shen Li10224432021-08-12 11:39:31 -07003282 scalars = ((make_arg(size_t),
3283 make_idx(size_i, high=1),
3284 make_arg(size_s))
3285 for size_t, size_i, size_s in product([(), (1,)], repeat=3))
lezcanofd02fc52021-04-05 18:03:59 -07003286 for (dest, idx, source), accumulate in product(scalars, [True, False]):
3287 dest_init = dest.clone()
3288 # out-place
3289 out = torch.put(dest, idx, source, accumulate=accumulate)
3290 # in-place
3291 dest1 = dest.clone()
3292 dest1.put_(idx, source, accumulate=accumulate)
3293 for d in [out, dest1]:
3294 if accumulate:
3295 self.assertEqual(d.item(), (dest_init + source).item())
3296 else:
3297 self.assertEqual(d.item(), source.item())
3298
3299 # Empty case
3300 dest = make_arg((3, 2))
3301 reference = dest.clone()
3302 idx = make_idx((0,), high=1)
3303 source = make_arg((0,))
3304 for accumulate in [True, False]:
3305 out = torch.put(dest, idx, source, accumulate=accumulate)
3306 self.assertEqual(out, reference)
3307 dest.put_(idx, source, accumulate=accumulate)
3308 self.assertEqual(dest, reference)
3309
Mike Ruberrye0d829a2022-01-24 01:28:07 -08003310 # FIXME: find a test suite for the put operator
lezcanofd02fc52021-04-05 18:03:59 -07003311 # The bool instance does not work on GPU. See
3312 # https://github.com/pytorch/pytorch/issues/54317
Nikita Shulgabfac65d2022-03-30 14:13:21 -07003313 @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
lezcanofd02fc52021-04-05 18:03:59 -07003314 def test_put_accumulate(self, device, dtype):
3315 # Test for parallel adds with accumulate == True
3316 low_precision = dtype == torch.half or dtype == torch.bfloat16
3317 # Less numbers to avoid overflow with low_precision
3318 # Grainsize is 3000 for the for_loop to be parallized on CPU
3319 sizes = ((100,)) if low_precision else ((200,), (3002,))
3320 # Bfloat16 has a particularly bad performance here
3321 # This operation is nondeterministic on GPU, so we are generous with the rtol
3322 rtol, atol = (1e-1, 1e-2) if low_precision else (1e-3, 1e-4)
3323
3324 make_arg = partial(make_tensor, low=-2, high=3, device=device, dtype=dtype)
3325 # Dump everything into the 0-th position
3326 make_idx = partial(torch.zeros, device=device, dtype=torch.int64)
3327 args = ((make_idx(size), make_arg(size)) for size in sizes)
3328
3329 for idx, source in args:
3330 orig = make_arg((1,))
3331 out = orig.put(idx, source, accumulate=True)
3332 self.assertEqual(out, orig + source.sum(), rtol=rtol, atol=atol)
3333
Mike Ruberrye0d829a2022-01-24 01:28:07 -08003334 # FIXME: find a test suite for the take operator
Kulin Sethe011a8e2022-05-13 18:28:53 +00003335 @skipIfMps
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003336 def test_take_empty(self, device):
3337 for input_shape in [(0,), (0, 1, 2, 0), (1, 2, 3)]:
3338 for indices_shape in [(0,), (0, 1, 2, 0)]:
3339 input = torch.empty(input_shape, device=device)
3340 indices = torch.empty(indices_shape, dtype=torch.int64, device=device)
Mike Ruberry13120bf2020-05-27 06:28:05 -07003341 self.assertEqual(indices, torch.take(input, indices), exact_dtype=False)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003342
Mike Ruberrye0d829a2022-01-24 01:28:07 -08003343 # FIXME: find a test suite for the put operator
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003344 def test_put_empty(self, device):
3345 for dst_shape in [(0,), (0, 1, 2, 0), (1, 2, 3)]:
3346 for indices_shape in [(0,), (0, 1, 2, 0)]:
3347 for accumulate in [False, True]:
3348 dst = torch.randn(dst_shape, device=device)
Shen Li10224432021-08-12 11:39:31 -07003349 indices = torch.empty(indices_shape, dtype=torch.int64, device=device)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003350 src = torch.randn(indices_shape, device=device)
3351 self.assertEqual(dst, dst.put_(indices, src, accumulate=accumulate))
3352
Mike Ruberrye0d829a2022-01-24 01:28:07 -08003353 # FIXME: port to test_scatter_gather_ops.py
Yukio Siraichi84061da2021-06-08 13:35:56 -07003354 def scatter_allow_reduce(self, device, dtype, reduceop):
3355 device_type = torch.device(device).type
Shen Li10224432021-08-12 11:39:31 -07003356 return device_type != 'cuda' or (reduceop == 'multiply' and dtype.is_floating_point)
Yukio Siraichi84061da2021-06-08 13:35:56 -07003357
Nikita Shulgabfac65d2022-03-30 14:13:21 -07003358 @dtypes(*floating_and_complex_types())
3359 @dtypesIfCPU(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
3360 @dtypesIfCUDA(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
Sameer Deshmukhe18a2212020-09-16 23:21:38 -07003361 def test_scatter_reduce_operations_to_large_input(self, device, dtype):
Sameer Deshmukh9ca4a462020-06-29 15:50:03 -07003362 index = torch.tensor([[1], [2]], device=device, dtype=torch.long)
3363 test_data = [
Shen Li10224432021-08-12 11:39:31 -07003364 (torch.zeros(4, 4, device=device, dtype=dtype),
3365 torch.ones(2, 2, device=device, dtype=dtype),
3366 torch.tensor([[0, 0, 0, 0],
3367 [1, 0, 0, 0],
3368 [1, 0, 0, 0],
3369 [0, 0, 0, 0]],
3370 device=device, dtype=dtype), "add"),
3371 (torch.tensor([2], device=device, dtype=dtype).repeat(4, 4),
3372 torch.tensor([6], device=device, dtype=dtype).repeat(2, 2),
3373 torch.tensor([[2, 2, 2, 2],
3374 [12, 2, 2, 2],
3375 [12, 2, 2, 2],
3376 [2, 2, 2, 2]], device=device, dtype=dtype), "multiply"),
Sameer Deshmukh9ca4a462020-06-29 15:50:03 -07003377 ]
3378
3379 for input, src, result, operation in test_data:
Yukio Siraichi84061da2021-06-08 13:35:56 -07003380 if not self.scatter_allow_reduce(device, dtype, operation):
Sameer Deshmukhe18a2212020-09-16 23:21:38 -07003381 continue
Sameer Deshmukh9ca4a462020-06-29 15:50:03 -07003382 input.scatter_(0, index, src, reduce=operation)
3383 self.assertEqual(input, result)
3384
Nikita Shulgabfac65d2022-03-30 14:13:21 -07003385 @dtypes(*floating_and_complex_types())
3386 @dtypesIfCPU(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
3387 @dtypesIfCUDA(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
Sameer Deshmukhe18a2212020-09-16 23:21:38 -07003388 def test_scatter_reduce_scalar(self, device, dtype):
Sameer Deshmukh9ca4a462020-06-29 15:50:03 -07003389 index = torch.tensor([[1], [2]], device=device, dtype=torch.long)
3390 test_data = [
Shen Li10224432021-08-12 11:39:31 -07003391 (torch.zeros(4, 4, device=device, dtype=dtype), 1,
3392 torch.tensor([[0, 0, 0, 0],
3393 [1, 0, 0, 0],
3394 [1, 0, 0, 0],
3395 [0, 0, 0, 0]],
3396 device=device, dtype=dtype), "add"),
3397 (torch.tensor([2], device=device, dtype=dtype).repeat(4, 4), 2,
3398 torch.tensor([[2, 2, 2, 2],
3399 [4, 2, 2, 2],
3400 [4, 2, 2, 2],
3401 [2, 2, 2, 2]], device=device, dtype=dtype), "multiply"),
Sameer Deshmukh9ca4a462020-06-29 15:50:03 -07003402 ]
3403
3404 for input, src, result, operation in test_data:
Yukio Siraichi84061da2021-06-08 13:35:56 -07003405 if not self.scatter_allow_reduce(device, dtype, operation):
Sameer Deshmukhe18a2212020-09-16 23:21:38 -07003406 continue
Sameer Deshmukh9ca4a462020-06-29 15:50:03 -07003407 input.scatter_(0, index, src, reduce=operation)
3408 self.assertEqual(input, result)
3409
Mike Ruberrye0d829a2022-01-24 01:28:07 -08003410 # FIXME: port to test_scatter_gather_ops.py
Sameer Deshmukh9ca4a462020-06-29 15:50:03 -07003411 # TODO: remove this after scatter_add_ is deprecated.
3412 def test_scatter_add_non_unique_index(self, device):
3413 height = 2
3414 width = 65536
3415 input = torch.ones(height, width, device=device)
3416 index = torch.zeros(height, width, dtype=torch.long, device=device)
3417 src = torch.ones(height, width, device=device)
3418 input.scatter_add_(0, index, src)
3419
Shen Li10224432021-08-12 11:39:31 -07003420 self.assertEqual(input,
3421 torch.tensor([[3], [1]], device=device,
3422 dtype=torch.float32).repeat(1, width))
Sameer Deshmukh9ca4a462020-06-29 15:50:03 -07003423
Nikita Shulgabfac65d2022-03-30 14:13:21 -07003424 @dtypes(*floating_and_complex_types())
3425 @dtypesIfCPU(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
3426 @dtypesIfCUDA(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
Sameer Deshmukhe18a2212020-09-16 23:21:38 -07003427 def test_scatter_reduce_non_unique_index(self, device, dtype):
Sameer Deshmukh9ca4a462020-06-29 15:50:03 -07003428 height = 2
3429 width = 2
3430 index = torch.zeros(height, width, dtype=torch.long, device=device)
3431 test_data = [
Shen Li10224432021-08-12 11:39:31 -07003432 (torch.ones(height, width, device=device, dtype=dtype),
3433 torch.ones(height, width, device=device, dtype=dtype),
3434 torch.tensor([[3], [1]], device=device, dtype=dtype).repeat(1, width), "add"),
3435 (torch.tensor([2], device=device, dtype=dtype).repeat(height, width),
3436 torch.tensor([2], device=device, dtype=dtype).repeat(height, width),
3437 torch.tensor([[8], [2]], device=device,
3438 dtype=dtype).repeat(1, width), "multiply"),
Sameer Deshmukh9ca4a462020-06-29 15:50:03 -07003439 ]
3440
3441 for input, src, result, operation in test_data:
Yukio Siraichi84061da2021-06-08 13:35:56 -07003442 if not self.scatter_allow_reduce(device, dtype, operation):
Sameer Deshmukhe18a2212020-09-16 23:21:38 -07003443 continue
Sameer Deshmukh9ca4a462020-06-29 15:50:03 -07003444 input.scatter_(0, index, src, reduce=operation)
Shen Li10224432021-08-12 11:39:31 -07003445 self.assertEqual(input, result, msg=f"result: {result} input: {input} method: {str(operation)}")
Sameer Deshmukh9ca4a462020-06-29 15:50:03 -07003446
Yukio Siraichi84061da2021-06-08 13:35:56 -07003447 @onlyCUDA
PyTorch MergeBotd7847ed2022-06-29 18:06:01 +00003448 @dtypes(*complex_types())
Sameer Deshmukhe18a2212020-09-16 23:21:38 -07003449 def test_scatter_reduce_multiply_unsupported_dtypes(self, device, dtype):
3450 height = 2
3451 width = 2
3452 index = torch.zeros(height, width, dtype=torch.long, device=device)
3453 input = torch.ones(height, width, device=device, dtype=dtype)
3454 src = torch.ones(height, width, device=device, dtype=dtype)
3455 with self.assertRaises(RuntimeError):
3456 input.scatter_(0, index, src, reduce="multiply")
Sameer Deshmukh9ca4a462020-06-29 15:50:03 -07003457
Mike Ruberrye0d829a2022-01-24 01:28:07 -08003458 # FIXME: port to test_scatter_gather_ops.py
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003459 def test_scatter_to_large_input(self, device):
3460 input = torch.zeros(4, 4, device=device)
3461 src = torch.ones(2, 2, device=device)
3462 index = torch.tensor([[1], [2]], device=device, dtype=torch.long)
3463 input.scatter_(0, index, src)
Shen Li10224432021-08-12 11:39:31 -07003464 self.assertEqual(input, torch.tensor([[0, 0, 0, 0],
3465 [1, 0, 0, 0],
3466 [1, 0, 0, 0],
3467 [0, 0, 0, 0]], device=device, dtype=torch.float32))
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003468
Mike Ruberrye0d829a2022-01-24 01:28:07 -08003469 # FIXME: port to test_scatter_gather_ops.py
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003470 def test_scatter_add_to_large_input(self, device):
3471 input = torch.zeros(4, 4, device=device)
3472 src = torch.ones(2, 2, device=device)
3473 index = torch.tensor([[1], [2]], device=device, dtype=torch.long)
3474 input.scatter_add_(0, index, src)
Shen Li10224432021-08-12 11:39:31 -07003475 self.assertEqual(input, torch.tensor([[0, 0, 0, 0],
3476 [1, 0, 0, 0],
3477 [1, 0, 0, 0],
3478 [0, 0, 0, 0]], device=device, dtype=torch.float32))
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003479
Mike Ruberrye0d829a2022-01-24 01:28:07 -08003480 # FIXME: port to test_scatter_gather_ops.py
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003481 def test_scatter_bool(self, device):
3482 x = torch.tensor([[True, True, True], [True, True, True]], device=device)
3483 res = torch.zeros(3, 3, dtype=torch.bool, device=device)
3484 res = res.scatter_(0, torch.tensor([[0, 1, 2], [0, 1, 2]], device=device), x)
Shen Li10224432021-08-12 11:39:31 -07003485 self.assertEqual(res, torch.tensor([[True, False, False],
3486 [False, True, False],
3487 [False, False, True]], device=device))
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003488
Mike Ruberrye0d829a2022-01-24 01:28:07 -08003489 # FIXME: port to test_scatter_gather_ops.py
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003490 def test_scatter_add_bool(self, device):
Shen Li10224432021-08-12 11:39:31 -07003491 x = torch.tensor([[True, True, True, True, True], [True, True, True, True, True]], device=device)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003492 res = torch.zeros(3, 5, dtype=torch.bool, device=device)
Shen Li10224432021-08-12 11:39:31 -07003493 res = res.scatter_add_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]], device=device), x)
3494 self.assertEqual(res, torch.tensor([[True, True, True, True, True],
3495 [False, True, False, True, False],
3496 [True, False, True, False, True]], device=device))
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003497
Mike Ruberrye0d829a2022-01-24 01:28:07 -08003498 # FIXME: find a test suite for the masked scatter operator
kshitij12345885a8e52021-11-01 09:21:20 -07003499 @onlyNativeDeviceTypes
Nikita Shulgabfac65d2022-03-30 14:13:21 -07003500 @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
kshitij12345eaf5ca02021-01-27 13:58:43 -08003501 def test_masked_scatter(self, device, dtype):
3502 dt = dtype
3503 with warnings.catch_warnings(record=True) as w:
3504 warnings.simplefilter("always")
3505 for maskType in [torch.uint8, torch.bool]:
3506 num_copy, num_dest = 3, 10
Shen Li10224432021-08-12 11:39:31 -07003507 dest = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dt, device=device)
kshitij12345eaf5ca02021-01-27 13:58:43 -08003508 dest2 = dest.clone()
3509 dest_ones = dest.clone()
3510 dest_ones_expected = dest.clone()
Shen Li10224432021-08-12 11:39:31 -07003511 src = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=dt, device=device)
3512 src_ones = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=dt, device=device)
3513 mask = torch.tensor((0, 0, 0, 0, 1, 0, 1, 0, 1, 0), dtype=maskType, device=device)
kshitij12345eaf5ca02021-01-27 13:58:43 -08003514
3515 if dt == torch.bool:
3516 # torch.bool is a special case and is being tested
3517 # in a separate test
3518 return
3519
kshitij12345eaf5ca02021-01-27 13:58:43 -08003520 dest.masked_scatter_(mask, src)
3521 j = 0
3522 for i in range(num_dest):
3523 if mask[i]:
3524 dest2[i] = src[j]
3525 dest_ones_expected[i] = src_ones[j]
3526 j += 1
3527 self.assertEqual(dest, dest2, atol=0, rtol=0)
3528
3529 dest_ones.masked_scatter_(mask, src_ones)
3530 self.assertEqual(dest_ones, dest_ones_expected, atol=0, rtol=0)
3531
Xiang Gaoc8833342021-05-25 10:59:43 -07003532 # Bound checking in CUDA is done inside a kernel
3533 # in order to avoid synchronization, but this means
3534 # we can not clear the failures. So there is no way
3535 # to test it then recover.
Shen Li10224432021-08-12 11:39:31 -07003536 if self.device_type != 'cuda':
Xiang Gaoc8833342021-05-25 10:59:43 -07003537 # make src smaller. this should fail
3538 src = torch.zeros(num_copy - 1, dtype=dt, device=device)
3539 with self.assertRaises(RuntimeError):
3540 dest.masked_scatter_(mask, src)
kshitij12345eaf5ca02021-01-27 13:58:43 -08003541
Xiang Gaoc8833342021-05-25 10:59:43 -07003542 # empty tensor
3543 dest = torch.empty((5, 0, 5), dtype=dt, device=device)
3544 mask = torch.ones_like(dest, dtype=maskType, device=device)
3545 src = torch.empty((0,), dtype=dt, device=device)
3546 dest.masked_scatter_(mask, src)
3547
3548 dest = torch.empty((5, 0, 5), dtype=dt, device=device)
3549 mask = torch.ones((5, 1, 5), dtype=maskType, device=device)
3550 src = torch.empty((0,), dtype=dt, device=device)
3551 dest.masked_scatter_(mask, src)
3552
Shen Li10224432021-08-12 11:39:31 -07003553 if self.device_type != 'cuda':
Xiang Gaoc8833342021-05-25 10:59:43 -07003554 self.assertEqual(len(w), 5)
3555 else:
3556 self.assertEqual(len(w), 4)
kshitij12345eaf5ca02021-01-27 13:58:43 -08003557
Shen Li10224432021-08-12 11:39:31 -07003558 warn = 'masked_scatter_ received a mask with dtype torch.uint8,'
kshitij12345eaf5ca02021-01-27 13:58:43 -08003559 for wi in w:
3560 self.assertEqual(str(wi.message)[0:55], str(warn))
3561
Mike Ruberrye0d829a2022-01-24 01:28:07 -08003562 # FIXME: find a test suite for the masked scatter operator
Kulin Sethe011a8e2022-05-13 18:28:53 +00003563 @skipIfMps
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003564 def test_masked_scatter_bool_tensor(self, device):
3565 src = torch.tensor([True, True, True], device=device)
3566 dst = torch.tensor([False, False, False], device=device)
3567 mask = torch.tensor([False, True, False], device=device)
3568
3569 dst.masked_scatter_(mask, src)
3570 self.assertEqual(dst, torch.tensor([False, True, False], device=device))
3571
3572 mask = torch.tensor([True, False, True], device=device)
3573 dst = dst.masked_scatter(mask, src)
3574 self.assertEqual(dst, torch.tensor([True, True, True], device=device))
3575
Mike Ruberrye0d829a2022-01-24 01:28:07 -08003576 # FIXME: find a test suite for the masked scatter operator
3577 # test_scatter_gather_ops or test_masked_ops?
Xiang Gaoc8833342021-05-25 10:59:43 -07003578 @onlyCUDA
Shen Li10224432021-08-12 11:39:31 -07003579 @largeTensorTest('30GB')
Xiang Gaoc8833342021-05-25 10:59:43 -07003580 def test_masked_scatter_large_tensor(self, device):
Shen Li10224432021-08-12 11:39:31 -07003581 t_cpu = torch.empty(2**31 + 1, dtype=torch.bool).random_()
Xiang Gaoc8833342021-05-25 10:59:43 -07003582 t = t_cpu.to(device)
3583 result_cpu = t_cpu.masked_scatter(t_cpu, t_cpu)
3584 result = t.masked_scatter(t, t)
3585 self.assertEqual(result, result_cpu)
3586
Mike Ruberrye0d829a2022-01-24 01:28:07 -08003587 # FIXME: find a test suite for the masked select operator
Nikita Shulgabfac65d2022-03-30 14:13:21 -07003588 @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
Kurt Mohlerce3555a2020-04-14 13:59:13 -07003589 def test_masked_select(self, device, dtype):
Shen Li10224432021-08-12 11:39:31 -07003590 if device == 'cpu':
3591 warn = 'masked_select received a mask with dtype torch.uint8,'
Kurt Mohlerce3555a2020-04-14 13:59:13 -07003592 else:
Shen Li10224432021-08-12 11:39:31 -07003593 warn = 'indexing with dtype torch.uint8 is now deprecated, pl'
Kurt Mohlerce3555a2020-04-14 13:59:13 -07003594 for maskType in [torch.uint8, torch.bool]:
3595 num_src = 10
Shen Li10224432021-08-12 11:39:31 -07003596 src = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=dtype, device=device)
Natalia Gimelshein7a570882020-08-03 18:41:57 -07003597 mask = torch.randint(2, (num_src,), device=device, dtype=maskType)
Kurt Mohlerce3555a2020-04-14 13:59:13 -07003598
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003599 with warnings.catch_warnings(record=True) as w:
Kurt Mohlerce3555a2020-04-14 13:59:13 -07003600 dst = src.masked_select(mask)
3601 if maskType is torch.uint8:
3602 self.assertEqual(len(w), 1)
3603 self.assertEqual(str(w[0].message)[0:53], str(warn))
3604 dst2 = []
3605 for i in range(num_src):
3606 if mask[i]:
3607 dst2 += [src[i]]
Mike Ruberry13120bf2020-05-27 06:28:05 -07003608 self.assertEqual(dst, torch.tensor(dst2), atol=0, rtol=0)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003609
Mike Ruberry2f840b12020-07-30 22:17:34 -07003610 dst3 = torch.empty(0, device=device, dtype=dtype)
Kurt Mohlerce3555a2020-04-14 13:59:13 -07003611 torch.masked_select(src, mask, out=dst3)
Mike Ruberry13120bf2020-05-27 06:28:05 -07003612 self.assertEqual(dst3, torch.tensor(dst2, dtype=dst3.dtype), atol=0, rtol=0)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003613
Natalia Gimelshein7a570882020-08-03 18:41:57 -07003614 # Since half on CPU is not supported, need to skip the remaining test cases
Shen Li10224432021-08-12 11:39:31 -07003615 if dtype == torch.half and torch.device(device).type == 'cpu':
Kurt Mohlerce3555a2020-04-14 13:59:13 -07003616 return
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003617
Kurt Mohlerce3555a2020-04-14 13:59:13 -07003618 # Ensure that masks are expanded to match tensor properly
Natalia Gimelshein7a570882020-08-03 18:41:57 -07003619 a = torch.rand(100, 100, device=device).mul(100).to(dtype)
3620 mask_first_el_each_row = torch.zeros(100, device=device, dtype=torch.bool)
Kurt Mohlerce3555a2020-04-14 13:59:13 -07003621 mask_first_el_each_row[0] = True
3622 a_masked = a.masked_select(mask_first_el_each_row)
3623 self.assertEqual(a_masked, a[:, 0])
3624
Natalia Gimelshein7a570882020-08-03 18:41:57 -07003625 mask_first_row = torch.zeros(100, 1, device=device, dtype=torch.bool)
Kurt Mohlerce3555a2020-04-14 13:59:13 -07003626 mask_first_row[0][0] = True
3627 a_masked = a.masked_select(mask_first_row)
3628 self.assertEqual(a_masked, a[0, :])
3629
3630 # Ensure that tensor is expanded to match mask properly
Natalia Gimelshein7a570882020-08-03 18:41:57 -07003631 a = torch.rand(100, device=device).mul(100).to(dtype)
Shen Li10224432021-08-12 11:39:31 -07003632 mask_copy_3_times = torch.tensor([[True], [True], [False], [True]], device=device)
Kurt Mohlerce3555a2020-04-14 13:59:13 -07003633 a_masked = a.masked_select(mask_copy_3_times)
3634 self.assertEqual(a_masked, a.unsqueeze(0).expand(3, 100).flatten())
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003635
Mike Ruberrye0d829a2022-01-24 01:28:07 -08003636 # FIXME: find a test suite for the masked select operator
Natalia Gimelshein7a570882020-08-03 18:41:57 -07003637 def test_masked_select_discontiguous(self, device):
3638 for size in (10, 200):
3639 vals = torch.rand(size, size, device=device)
3640 mask = torch.full((size, size), False, dtype=torch.bool, device=device)
3641 mask[:, ::2] = True
3642 vals_list = (vals, vals.t())
3643 mask_list = (mask, mask.t())
3644 out_dc = torch.empty(size * size, device=device)[::2]
3645 for v, m in product(vals_list, mask_list):
3646 if m.is_contiguous():
anjali4115d80a482021-09-01 16:11:38 -07003647 expected = v[:, ::2].clone().reshape((-1, ))
Natalia Gimelshein7a570882020-08-03 18:41:57 -07003648 else:
anjali4115d80a482021-09-01 16:11:38 -07003649 expected = v[::2].clone().reshape((-1, ))
Natalia Gimelshein7a570882020-08-03 18:41:57 -07003650 out = torch.masked_select(v, m)
3651 self.assertEqual(out, expected, atol=0, rtol=0)
3652 torch.masked_select(v, m, out=out_dc)
3653 self.assertEqual(out_dc, expected, atol=0, rtol=0)
3654
Mike Ruberrye0d829a2022-01-24 01:28:07 -08003655 # FIXME: find a test suite for the masked fill operator
Nikita Shulgabfac65d2022-03-30 14:13:21 -07003656 @dtypes(*product(all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16), (torch.uint8, torch.bool)))
kshitij1234576866292021-02-09 22:55:00 -08003657 def test_masked_fill(self, device, dtypes):
3658 dtype = dtypes[0]
3659 mask_dtype = dtypes[1]
3660 with warnings.catch_warnings(record=True) as w:
3661 warnings.simplefilter("always")
3662
3663 num_dest = 10
3664 dst = torch.zeros(num_dest, dtype=dtype)
3665 mask = torch.randint(2, (num_dest,), dtype=mask_dtype)
3666 val = random.random()
3667 dst2 = dst.clone()
3668
3669 dst.masked_fill_(mask, val)
3670 for i in range(num_dest):
3671 if mask[i]:
3672 dst2[i] = val
3673 self.assertEqual(dst, dst2, atol=0, rtol=0)
3674
3675 # test non-contiguous case
Shen Li10224432021-08-12 11:39:31 -07003676 dst = ((torch.randn(num_dest, num_dest, num_dest) * 10).to(dtype)).permute((2, 0, 1))
kshitij1234576866292021-02-09 22:55:00 -08003677 dst2 = dst.contiguous()
3678 if dtype.is_complex:
3679 mask = dst.abs() > 0
3680 else:
3681 mask = dst > 0
3682 self.assertTrue(not dst.is_contiguous())
3683 self.assertTrue(dst2.is_contiguous())
3684 dst.masked_fill_(mask.to(mask_dtype), val)
3685 dst2.masked_fill_(mask.to(mask_dtype), val)
3686 self.assertEqual(dst, dst2, atol=0, rtol=0)
3687
3688 if mask_dtype == torch.uint8:
3689 self.assertEqual(len(w), 3)
3690
Shen Li10224432021-08-12 11:39:31 -07003691 warn = 'masked_fill_ received a mask with dtype torch.uint8,'
kshitij1234576866292021-02-09 22:55:00 -08003692 for wi in w:
3693 self.assertEqual(str(wi.message)[0:52], str(warn))
3694 else:
3695 self.assertEqual(len(w), 0)
Natalia Gimelshein7a570882020-08-03 18:41:57 -07003696
Mike Ruberrye0d829a2022-01-24 01:28:07 -08003697 # FIXME: find a test suite for the masked fill operator
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003698 def test_masked_fill_bool_tensor(self, device):
3699 dst = torch.tensor([True, False, True], device=device)
3700 mask = torch.tensor([False, True, False], device=device)
3701
3702 dst.masked_fill_(mask, True)
3703 self.assertEqual(dst, torch.tensor([True, True, True], device=device))
3704
3705 dst = dst.masked_fill(mask, False)
3706 self.assertEqual(dst, torch.tensor([True, False, True], device=device))
3707
3708 def test_tensor_shape_empty(self, device):
3709 x = torch.randn((0, 1, 3, 0), device=device)
3710 # flatten
3711 self.assertEqual((0,), torch.flatten(x, 0, 3).shape)
3712 self.assertEqual((0, 0), torch.flatten(x, 0, 2).shape)
3713 self.assertEqual((0, 3, 0), torch.flatten(x, 1, 2).shape)
3714
3715 # squeeze, unsqueeze
3716 self.assertEqual((0, 1, 1, 3, 0), torch.unsqueeze(x, 1).shape)
3717 self.assertEqual((0, 3, 0), torch.squeeze(x, 1).shape)
3718 self.assertEqual((0, 3, 0), torch.squeeze(x).shape)
3719
3720 # transpose, t
3721 self.assertEqual((0, 0, 3, 1), torch.transpose(x, 1, 3).shape)
3722 y = torch.randn((5, 0), device=device)
3723 self.assertEqual((0, 5), y.t().shape)
3724
3725 # select
3726 self.assertEqual((0, 1, 0), torch.select(x, 2, 2).shape)
3727
3728 # repeat, permute
3729 self.assertEqual((9, 0, 5, 6, 0), x.repeat(9, 7, 5, 2, 3).shape)
3730 self.assertEqual((3, 0, 0, 1), x.permute(2, 3, 0, 1).shape)
3731
3732 # diagonal, diagflat
3733 self.assertEqual((0,), torch.diagonal(torch.randn((5, 0), device=device)).shape)
3734 self.assertEqual((0,), torch.diagonal(torch.randn((0, 5), device=device)).shape)
3735 # off the end offsets are valid
Shen Li10224432021-08-12 11:39:31 -07003736 self.assertEqual((0,), torch.diagonal(torch.randn((5, 0), device=device), offset=1).shape)
3737 self.assertEqual((0,), torch.diagonal(torch.randn((0, 5), device=device), offset=1).shape)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003738 # check non-zero sized offsets off the end
Shen Li10224432021-08-12 11:39:31 -07003739 self.assertEqual((5, 6, 0), torch.diagonal(torch.randn((3, 4, 5, 6), device=device), offset=45252).shape)
3740 self.assertEqual((5, 6, 0), torch.diagonal(torch.randn((3, 4, 5, 6), device=device), offset=-45252).shape)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003741
3742 self.assertEqual((0, 0), torch.diagflat(torch.tensor([], device=device)).shape)
Shen Li10224432021-08-12 11:39:31 -07003743 self.assertEqual(torch.zeros(1, 1), torch.diagflat(torch.tensor([], device=device), offset=1))
3744 self.assertEqual((0, 0), torch.diagflat(torch.tensor([[]], device=device)).shape)
3745 self.assertEqual(torch.zeros(1, 1), torch.diagflat(torch.tensor([[]], device=device), offset=1))
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003746
3747 # stack, split, chunk
3748 self.assertEqual((4, 0, 1, 3, 0), torch.stack((x, x, x, x)).shape)
Shen Li10224432021-08-12 11:39:31 -07003749 self.assertEqual([(0, 1, 3, 0)],
3750 [z.shape for z in torch.chunk(x, 1, dim=0)])
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003751
Shen Li10224432021-08-12 11:39:31 -07003752 self.assertEqual([(0, 1, 3, 0), ] * 3, [z.shape for z in torch.chunk(x, 3, dim=0)])
3753 self.assertEqual([(0, 1, 1, 0), ] * 3, [z.shape for z in torch.chunk(x, 3, dim=2)])
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003754
3755 # NOTE: split_with_sizes behaves differently than NumPy in that it
3756 # takes sizes rather than offsets
Shen Li10224432021-08-12 11:39:31 -07003757 self.assertEqual([(0, 1, 0, 0), (0, 1, 1, 0), (0, 1, 2, 0)],
3758 [z.shape for z in torch.split(x, (0, 1, 2), dim=2)])
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003759
3760 self.assertRaises(RuntimeError, lambda: torch.split(x, 0, dim=1))
3761 # This is strange because the split size is larger than the dim size, but consistent with
3762 # how split handles that case generally (when no 0s are involved).
3763 self.assertEqual([(0, 1, 3, 0)], [z.shape for z in torch.split(x, 1, dim=0)])
3764 self.assertEqual([(0, 1, 3, 0)], [z.shape for z in torch.split(x, 0, dim=0)])
3765
3766 # functions that operate over a dimension but don't reduce.
3767 def test_dim_function_empty(self, device):
3768 shape = (0, 1, 2, 0)
3769 x = torch.randn(shape, device=device)
3770
3771 # size stride
3772 self.assertEqual(0, x.size(3))
3773 self.assertEqual(2, x.size(2))
3774 self.assertEqual(2, x.stride(0))
3775 self.assertEqual(1, x.stride(2))
3776
3777 self.assertEqual(x, torch.nn.functional.glu(x, 0))
3778 self.assertEqual((0, 1, 1, 0), torch.nn.functional.glu(x, 2).shape)
3779
3780 # softmax, logsoftmax
3781 self.assertEqual(x, torch.nn.functional.softmax(x, 0))
3782 self.assertEqual(x, torch.nn.functional.softmax(x, 2))
3783 self.assertEqual(x, torch.nn.functional.softmax(x, 3))
3784
3785 self.assertEqual(x, torch.nn.functional.log_softmax(x, 0))
3786 self.assertEqual(x, torch.nn.functional.log_softmax(x, 2))
3787 self.assertEqual(x, torch.nn.functional.log_softmax(x, 3))
3788
anjali4115b815d92020-01-17 10:45:36 -08003789 # cumsum, cumprod, cummax, cummin
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003790 self.assertEqual(shape, torch.cumsum(x, 0).shape)
3791 self.assertEqual(shape, torch.cumsum(x, 2).shape)
3792 self.assertEqual(shape, torch.cumprod(x, 0).shape)
3793 self.assertEqual(shape, torch.cumprod(x, 2).shape)
anjali4118dc67a02020-01-14 16:36:56 -08003794 self.assertEqual(shape, torch.cummax(x, 0)[0].shape)
3795 self.assertEqual(shape, torch.cummax(x, 2)[0].shape)
anjali4115b815d92020-01-17 10:45:36 -08003796 self.assertEqual(shape, torch.cummin(x, 0)[0].shape)
3797 self.assertEqual(shape, torch.cummin(x, 2)[0].shape)
kshitij1234534877442020-05-21 09:09:41 -07003798 self.assertEqual(shape, torch.logcumsumexp(x, 0).shape)
3799 self.assertEqual(shape, torch.logcumsumexp(x, 2).shape)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003800
3801 # flip
3802 self.assertEqual(x, x.flip(0))
3803 self.assertEqual(x, x.flip(2))
3804
3805 # roll
3806 self.assertEqual(x, x.roll(0, 1).roll(0, -1))
3807 self.assertEqual(x, x.roll(1, x.size(1)))
3808 self.assertEqual(x, x.roll(1))
3809 self.assertEqual(x, x.roll((1, 1), (3, 1)))
3810
3811 # unbind
3812 self.assertEqual((), x.unbind(0))
Shen Li10224432021-08-12 11:39:31 -07003813 self.assertEqual((torch.empty((0, 1, 0), device=device), torch.empty((0, 1, 0), device=device)),
3814 x.unbind(2))
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003815
3816 # cross
3817 y = torch.randn((0, 1, 3, 0), device=device)
3818 self.assertEqual(y.shape, torch.cross(y, y).shape)
3819
3820 # renorm
3821 self.assertEqual(shape, torch.renorm(x, 1, 0, 5).shape)
3822 self.assertEqual(shape, torch.renorm(x, 1, 2, 5).shape)
3823
3824 # sort
3825 self.assertEqual([shape, shape], [z.shape for z in torch.sort(x, dim=0)])
3826 self.assertEqual([shape, shape], [z.shape for z in torch.sort(x, dim=2)])
3827
3828 # topk
3829 self.assertEqual([shape, shape], [z.shape for z in torch.topk(x, 0, dim=0)])
Shen Li10224432021-08-12 11:39:31 -07003830 self.assertEqual([(0, 1, 1, 0), (0, 1, 1, 0)], [z.shape for z in torch.topk(x, 1, dim=2)])
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003831
3832 y = torch.randn((2, 3, 4), device=device)
3833 self.assertEqual([(2, 3, 0), (2, 3, 0)], [z.shape for z in torch.topk(y, 0)])
3834
3835 # gather
Shen Li10224432021-08-12 11:39:31 -07003836 self.assertEqual(shape, torch.gather(x, 0, torch.empty(shape, dtype=torch.int64, device=device)).shape)
3837 self.assertEqual(shape, torch.gather(x, 2, torch.empty(shape, dtype=torch.int64, device=device)).shape)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003838 larger_shape = torch.empty((0, 1, 3, 0), dtype=torch.int64, device=device)
3839 self.assertEqual(larger_shape.shape, torch.gather(x, 2, larger_shape).shape)
3840 smaller_shape = torch.empty((0, 1, 0, 0), dtype=torch.int64, device=device)
3841 self.assertEqual(smaller_shape.shape, torch.gather(x, 2, smaller_shape).shape)
3842 y = torch.randn((2, 3, 4), device=device)
Shen Li10224432021-08-12 11:39:31 -07003843 self.assertEqual((0, 3, 4),
3844 torch.gather(y, 0, torch.empty((0, 3, 4), dtype=torch.int64, device=device)).shape)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003845
3846 # scatter, scatter_add
3847 for dim in [0, 2]:
3848 y = torch.randn(shape, device=device)
3849 y_src = torch.randn(shape, device=device)
3850 ind = torch.empty(shape, dtype=torch.int64, device=device)
3851 self.assertEqual(shape, y.scatter_(dim, ind, y_src).shape)
3852 self.assertEqual(shape, y.scatter_add_(dim, ind, y_src).shape)
3853
3854 z = torch.randn((2, 3, 4), device=device)
3855 z_src = torch.randn((2, 3, 4), device=device)
Shen Li10224432021-08-12 11:39:31 -07003856 self.assertEqual(z, z.scatter_(2, torch.empty((2, 3, 0), dtype=torch.int64, device=device), z_src))
3857 self.assertEqual(z, z.scatter_add_(2, torch.empty((2, 3, 0), dtype=torch.int64, device=device), z_src))
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003858
3859 # index_fill, index_copy, index_add
3860 c = x.clone()
3861 c_clone = c.clone()
3862 ind_empty = torch.tensor([], dtype=torch.int64, device=device)
3863 ind_01 = torch.tensor([0, 1], dtype=torch.int64, device=device)
3864 self.assertEqual(c_clone, c.index_fill_(0, ind_empty, -1))
3865 self.assertEqual(c_clone, c.index_fill_(2, ind_empty, -1))
Yu Guof69c9902022-05-24 05:48:32 +00003866 self.assertEqual(c_clone, c.index_fill_(2, ind_01, -1))
Shen Li10224432021-08-12 11:39:31 -07003867 self.assertEqual(c_clone, c.index_copy_(0, ind_empty, torch.empty((0, 1, 2, 0), device=device)))
3868 self.assertEqual(c_clone, c.index_copy_(2, ind_empty, torch.empty((0, 1, 0, 0), device=device)))
3869 self.assertEqual(c_clone, c.index_copy_(2, ind_01, torch.empty((0, 1, 2, 0), device=device)))
3870 self.assertEqual(c_clone, c.index_add_(0, ind_empty, torch.empty((0, 1, 2, 0), device=device)))
3871 self.assertEqual(c_clone, c.index_add_(2, ind_empty, torch.empty((0, 1, 0, 0), device=device)))
3872 self.assertEqual(c_clone, c.index_add_(2, ind_01, torch.empty((0, 1, 2, 0), device=device)))
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003873
3874 c = torch.randn((0, 1, 2), device=device)
3875 c_clone = c.clone()
3876 self.assertEqual(c_clone, c.index_fill_(0, ind_empty, -1))
Shen Li10224432021-08-12 11:39:31 -07003877 self.assertEqual(c_clone, c.index_copy_(0, ind_empty, torch.empty((0, 1, 2), device=device)))
3878 self.assertEqual(c_clone, c.index_add_(0, ind_empty, torch.empty((0, 1, 2), device=device)))
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003879 self.assertEqual(c_clone, c.index_fill_(0, ind_empty, -1))
Shen Li10224432021-08-12 11:39:31 -07003880 self.assertEqual(c_clone, c.index_copy_(0, ind_empty, torch.empty((0, 1, 2), device=device)))
3881 self.assertEqual(c_clone, c.index_add_(0, ind_empty, torch.empty((0, 1, 2), device=device)))
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003882
3883 # index fill/copy/add non-empty
3884 z = torch.randn((2, 3, 4), device=device)
3885 self.assertEqual(z, z.index_fill_(0, ind_empty, -1))
3886 z = torch.randn((2, 3, 4), device=device)
Shen Li10224432021-08-12 11:39:31 -07003887 self.assertEqual(z, z.index_copy_(0, ind_empty, torch.empty((0, 3, 4), device=device)))
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003888 z = torch.randn((2, 3, 4), device=device)
Shen Li10224432021-08-12 11:39:31 -07003889 self.assertEqual(z, z.index_add_(0, ind_empty, torch.empty((0, 3, 4), device=device)))
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003890
3891 # index_select
3892 self.assertEqual(x, x.index_select(0, ind_empty))
3893 self.assertEqual((0, 1, 0, 0), x.index_select(2, ind_empty).shape)
3894 self.assertEqual(x, x.index_select(2, ind_01))
3895 z = torch.randn((2, 3, 4), device=device) # non-empty
3896 self.assertEqual((0, 3, 4), z.index_select(0, ind_empty).shape)
3897 c = torch.randn((0, 1, 2), device=device)
3898 self.assertEqual(c, c.index_select(0, ind_empty))
3899 c = torch.randn((0, 1, 2), device=device)
3900 self.assertEqual(c, c.index_select(0, ind_empty))
Yu Guof69c9902022-05-24 05:48:32 +00003901 w = torch.randn((0, 3), device=device)
3902 self.assertEqual((0, 2), w.index_select(1, ind_01).shape)
3903 w = torch.randn((3, 0), device=device)
3904 self.assertEqual((2, 0), w.index_select(0, ind_01).shape)
3905 ind_01_int32 = torch.tensor([0, 1], dtype=torch.int32, device=device)
3906 self.assertEqual((2, 0), w.index_select(0, ind_01_int32).shape)
3907 if device == 'cpu':
3908 w = torch.randn((0, 3), device=device)
3909 with self.assertRaisesRegex(RuntimeError, "self indexing axis dim should be positive"):
3910 torch.index_select(w, 0, ind_01)
3911 ind_05 = torch.tensor([0, 5], dtype=torch.int64, device=device)
3912 with self.assertRaisesRegex(RuntimeError, "INDICES element is out of DATA bounds"):
3913 torch.index_select(w, 1, ind_05)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003914
Mike Ruberrye0d829a2022-01-24 01:28:07 -08003915 # FIXME: find a test suite for the pdist operator
Shen Li10224432021-08-12 11:39:31 -07003916 @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "sandcastle OOM with current tpx gpu/re configuration")
ptrblcka64d0ff2020-02-11 11:58:19 -08003917 @skipIfRocm
Jane Xu3f9115d2022-03-24 10:07:08 -07003918 @onlyCUDA
3919 @largeTensorTest('10GB', device='cpu')
3920 @largeTensorTest('5GB', device='cuda')
ptrblcka64d0ff2020-02-11 11:58:19 -08003921 def test_pdist_norm_large(self, device):
3922 # use dim0>=46342 for forward, see:
3923 # https://github.com/pytorch/pytorch/issues/30583
lezcanof54e7b42022-06-22 20:29:22 +00003924 # Compare output using GPU with the CPU implementation
Jane Xu3f9115d2022-03-24 10:07:08 -07003925 x = torch.randn(50000, 1, dtype=torch.float32) # 50k * 4 bytes = 200 KB
3926 # Will require 1249975000 float32s
3927 expected_cpu = torch.pdist(x, p=2) # ~1250M * 4 bytes = 5 GB on CPU
3928 actual_gpu = torch.pdist(x.to(device), p=2) # 5 GB on GPU
3929 self.assertEqual(expected_cpu, actual_gpu.cpu()) # Another 5 GB on CPU
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003930
Mike Ruberrye0d829a2022-01-24 01:28:07 -08003931 # FIXME: move to elementwise ternary test suite
kshitij12345885a8e52021-11-01 09:21:20 -07003932 @onlyNativeDeviceTypes
Philip Meier26b7ff52021-09-07 08:57:43 -07003933 @dtypesIfCUDA(*set(get_all_math_dtypes('cuda')))
3934 @dtypes(*set(get_all_math_dtypes('cpu')))
anjali411db1f2172020-11-14 21:25:52 -08003935 def test_addcdiv(self, device, dtype):
Mike Ruberryde40c8e2021-06-06 14:51:26 -07003936 # Returns floating or integral scalar corresponding to dtype
3937 def _number(floating, integer, dtype):
3938 if dtype in [torch.half, torch.float, torch.double, torch.bfloat16]:
3939 return floating
3940 elif dtype in [torch.cfloat, torch.cdouble]:
3941 return floating * (1 + 1j)
3942 else:
3943 return integer
3944
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003945 def non_zero_rand(size, dtype, device):
anjali4114f3946a2020-04-24 15:03:38 -07003946 if dtype.is_floating_point or dtype.is_complex:
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003947 a = torch.rand(size=size, dtype=dtype, device=device)
3948 elif dtype == torch.uint8:
3949 a = torch.randint(1, 5, size=size, dtype=dtype, device=device)
3950 else:
3951 a = torch.randint(-5, 5, size=size, dtype=dtype, device=device)
Ailing Zhang7c13a072020-05-12 13:32:26 -07003952 return a + (a == 0).to(dtype)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003953
anjali411db1f2172020-11-14 21:25:52 -08003954 def _test_addcdiv():
3955 a = non_zero_rand((2, 2), dtype=dtype, device=device)
3956 b = non_zero_rand((2, 2), dtype=dtype, device=device)
3957 c = non_zero_rand((2, 2), dtype=dtype, device=device)
3958 alpha = _number(0.5, 3, dtype)
Mike Ruberry64584572020-05-19 19:25:35 -07003959
anjali411db1f2172020-11-14 21:25:52 -08003960 expected = a + (alpha * b) / c
3961 actual = torch.addcdiv(a, b, c, value=alpha)
3962 self.assertEqual(expected, actual)
Mike Ruberry42394162020-05-27 14:37:35 -07003963
mattip54a24982021-03-08 03:30:11 -08003964 with self.assertWarnsOnceRegex(
Shen Li10224432021-08-12 11:39:31 -07003965 UserWarning, "This overload of addcdiv is deprecated"):
anjali411db1f2172020-11-14 21:25:52 -08003966 self.assertEqual(actual, torch.addcdiv(a, alpha, b, c))
3967
3968 if not (dtype.is_floating_point or dtype.is_complex):
3969 # Integer division with addcdiv is prohibited
3970 with self.assertRaises(RuntimeError):
3971 _test_addcdiv()
3972 else:
3973 _test_addcdiv()
Mike Ruberryb4b8f532019-09-14 17:09:04 -07003974
Shen Li10224432021-08-12 11:39:31 -07003975 if self.device_type == 'cuda' and dtype == torch.half:
Masaki Kozukia404cc92021-06-25 10:20:10 -07003976 a = torch.tensor([60000.0], device=device, dtype=dtype)
3977 b = torch.tensor([60000.0], device=device, dtype=dtype)
3978 c = torch.tensor([1.0], device=device, dtype=dtype)
3979 out = torch.addcmul(a, b, c, value=-2)
3980 self.assertTrue(not (out.isnan() or out.isinf()))
3981
Peter Bellc177d252020-08-28 08:36:02 -07003982 def test_nullary_op_mem_overlap(self, device):
3983 ops = (
3984 ("random_", ()),
3985 ("uniform_", ()),
3986 ("cauchy_", ()),
3987 ("log_normal_", ()),
3988 ("exponential_", ()),
3989 ("geometric_", (0.5,)),
3990 ("normal_", ()),
3991 )
3992
3993 x = torch.rand((1, 3)).expand((3, 3))
3994 for op, args in ops:
Shen Li10224432021-08-12 11:39:31 -07003995 with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
Peter Bellc177d252020-08-28 08:36:02 -07003996 getattr(x, op)(*args)
3997
Mike Ruberrye0d829a2022-01-24 01:28:07 -08003998 # FIXME: move to an elementwise ternary test suite and make this an OpInfo test
Mike Ruberry7f183a92019-10-08 09:50:28 -07003999 @dtypes(torch.double)
4000 def test_ternary_op_mem_overlap(self, device, dtype):
Mike Ruberryb4b8f532019-09-14 17:09:04 -07004001 ops = [
Shen Li10224432021-08-12 11:39:31 -07004002 ("addcmul", True, True, 'cpu'),
4003 ("addcmul", True, True, 'cuda'),
4004 ("addcdiv", True, True, 'cpu'),
4005 ("addcdiv", True, True, 'cuda'),
4006 ("lerp", True, True, 'cpu'),
4007 ("lerp", True, True, 'cuda')
Mike Ruberryb4b8f532019-09-14 17:09:04 -07004008 ]
4009
Shen Li10224432021-08-12 11:39:31 -07004010 for (fn, has_input_output_mem_overlap_check,
4011 has_internal_mem_overlap_check, dev) in ops:
Mike Ruberryb4b8f532019-09-14 17:09:04 -07004012 if dev != device:
4013 continue
4014 out_op = getattr(torch, fn)
Shen Li10224432021-08-12 11:39:31 -07004015 inplace_op = getattr(torch.Tensor, fn + '_')
Mike Ruberryb4b8f532019-09-14 17:09:04 -07004016 self.check_internal_mem_overlap(
Shen Li10224432021-08-12 11:39:31 -07004017 inplace_op, 3, dtype, device,
4018 expected_failure=not has_internal_mem_overlap_check)
4019 self.ternary_check_input_output_mem_overlap(out_op, dev,
4020 expected_failure=not has_input_output_mem_overlap_check)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07004021
kshitij12345885a8e52021-11-01 09:21:20 -07004022 @expectedFailureMeta # RuntimeError not raised
Mike Ruberry7f183a92019-10-08 09:50:28 -07004023 @dtypes(torch.double)
kshitij12345885a8e52021-11-01 09:21:20 -07004024 @onlyNativeDeviceTypes
Mike Ruberry7f183a92019-10-08 09:50:28 -07004025 def test_copy_mem_overlap(self, device, dtype):
Mike Ruberryb4b8f532019-09-14 17:09:04 -07004026 self.check_internal_mem_overlap(
Shen Li10224432021-08-12 11:39:31 -07004027 torch.Tensor.copy_, num_inputs=2, dtype=dtype, device=device)
Peter Belle8e33942021-06-18 16:28:00 -07004028 sz = 9
Mike Ruberry7f183a92019-10-08 09:50:28 -07004029 doubles = torch.randn(2 * sz, dtype=dtype, device=device)
Mike Ruberryb4b8f532019-09-14 17:09:04 -07004030 self.unary_check_input_output_mem_overlap(
Shen Li10224432021-08-12 11:39:31 -07004031 doubles, sz, lambda input, out: out.copy_(input))
Mike Ruberryb4b8f532019-09-14 17:09:04 -07004032
Mike Ruberrye0d829a2022-01-24 01:28:07 -08004033 # FIXME: convert to ErrorInputs
kshitij12345885a8e52021-11-01 09:21:20 -07004034 @onlyNativeDeviceTypes
Peter Bell065ebdb2020-08-28 08:36:02 -07004035 def test_index_add_mem_overlap(self, device):
4036 x = torch.rand((1,), device=device).expand((6,))
4037 y = torch.rand((6,), device=device)
Peter Bell5765bbd2020-12-09 15:09:23 -08004038 ind = torch.tensor([2, 1, 0], device=device)
Peter Bell065ebdb2020-08-28 08:36:02 -07004039 value = torch.rand((3,), device=device)
Shen Li10224432021-08-12 11:39:31 -07004040 with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
Peter Bell065ebdb2020-08-28 08:36:02 -07004041 x.index_add_(0, ind, value)
Shen Li10224432021-08-12 11:39:31 -07004042 with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
Peter Bell5765bbd2020-12-09 15:09:23 -08004043 y.index_add_(0, ind, y[:3])
Shen Li10224432021-08-12 11:39:31 -07004044 with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
Peter Bell5765bbd2020-12-09 15:09:23 -08004045 ind.index_add_(0, ind, ind.clone())
Shen Li10224432021-08-12 11:39:31 -07004046 with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
Peter Bell5765bbd2020-12-09 15:09:23 -08004047 ind.index_add_(0, ind.clone(), ind)
Peter Bell065ebdb2020-08-28 08:36:02 -07004048
Mike Ruberrye0d829a2022-01-24 01:28:07 -08004049 # FIXME: convert to ErrorInputs
kshitij12345885a8e52021-11-01 09:21:20 -07004050 @onlyNativeDeviceTypes
Peter Bell5765bbd2020-12-09 15:09:23 -08004051 def test_index_copy_mem_overlap(self, device):
4052 x = torch.rand((1,), device=device).expand((6,))
4053 y = torch.rand((6,), device=device)
4054 ind = torch.tensor([2, 1, 0], device=device)
4055 value = torch.rand((3,), device=device)
Shen Li10224432021-08-12 11:39:31 -07004056 with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
Peter Bell5765bbd2020-12-09 15:09:23 -08004057 x.index_copy_(0, ind, value)
Shen Li10224432021-08-12 11:39:31 -07004058 with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
Peter Bell5765bbd2020-12-09 15:09:23 -08004059 y.index_copy_(0, ind, y[:3])
Shen Li10224432021-08-12 11:39:31 -07004060 with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
Peter Bell5765bbd2020-12-09 15:09:23 -08004061 ind.index_copy_(0, ind, ind.clone())
Shen Li10224432021-08-12 11:39:31 -07004062 with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
Peter Bell5765bbd2020-12-09 15:09:23 -08004063 ind.index_copy_(0, ind.clone(), ind)
4064
Mike Ruberrye0d829a2022-01-24 01:28:07 -08004065 # FIXME: convert to ErrorInputs
kshitij12345885a8e52021-11-01 09:21:20 -07004066 @expectedFailureMeta # Warning not triggered
4067 @onlyNativeDeviceTypes
Peter Bell5765bbd2020-12-09 15:09:23 -08004068 def test_index_fill_mem_overlap(self, device):
4069 x = torch.rand((1,), device=device).expand((6,))
4070 y = torch.rand((6,), device=device)
4071 ind = torch.tensor([2, 1, 0], device=device)
4072 value = torch.rand((3,), device=device)
4073
4074 with self.assertWarnsRegex(UserWarning, "index_fill_ on expanded tensors"):
4075 x.index_fill_(0, ind, 1.0)
Shen Li10224432021-08-12 11:39:31 -07004076 with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
Peter Bell5765bbd2020-12-09 15:09:23 -08004077 ind.index_fill_(0, ind, 0)
4078
Mike Ruberrye0d829a2022-01-24 01:28:07 -08004079 # FIXME: convert to ErrorInputs
kshitij12345885a8e52021-11-01 09:21:20 -07004080 @expectedFailureMeta # RuntimeError not raised
4081 @onlyNativeDeviceTypes
Peter Bell065ebdb2020-08-28 08:36:02 -07004082 def test_shift_mem_overlap(self, device):
4083 x = torch.rand(3, device=device)
Shen Li10224432021-08-12 11:39:31 -07004084 with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
Peter Bell065ebdb2020-08-28 08:36:02 -07004085 x[:-1] <<= x[1:]
Shen Li10224432021-08-12 11:39:31 -07004086 with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
Peter Bell065ebdb2020-08-28 08:36:02 -07004087 x[:-1] >>= x[1:]
4088
Mike Ruberrye0d829a2022-01-24 01:28:07 -08004089 # FIXME: convert to ErrorInputs
kshitij12345885a8e52021-11-01 09:21:20 -07004090 @expectedFailureMeta # RuntimeError not raised
4091 @onlyNativeDeviceTypes
Peter Bell5807bb92020-09-02 08:44:13 -07004092 def test_bernoulli_mem_overlap(self, device):
4093 x = torch.rand((1,), device=device).expand((6,))
4094
Shen Li10224432021-08-12 11:39:31 -07004095 with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
Peter Bell5807bb92020-09-02 08:44:13 -07004096 x.bernoulli_()
Shen Li10224432021-08-12 11:39:31 -07004097 with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
Peter Bell5807bb92020-09-02 08:44:13 -07004098 x.bernoulli_(p=0.1)
4099 p = torch.rand(6, device=device)
Shen Li10224432021-08-12 11:39:31 -07004100 with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
Peter Bell5807bb92020-09-02 08:44:13 -07004101 x.bernoulli_(p=p)
Shen Li10224432021-08-12 11:39:31 -07004102 with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
Peter Bell5807bb92020-09-02 08:44:13 -07004103 torch.bernoulli(torch.rand_like(x), out=x)
4104
Mike Ruberrye0d829a2022-01-24 01:28:07 -08004105 # FIXME: convert to ErrorInputs
kshitij12345885a8e52021-11-01 09:21:20 -07004106 @expectedFailureMeta # RuntimeError not raised
4107 @onlyNativeDeviceTypes
lezcanofd02fc52021-04-05 18:03:59 -07004108 def test_put_mem_overlap(self, device):
4109 x = torch.rand((1,), device=device).expand((6,))
4110 y = torch.rand((6,), device=device)
4111 ind = torch.tensor([2, 1, 0], device=device)
4112 value = torch.rand((3,), device=device)
Shen Li10224432021-08-12 11:39:31 -07004113 with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
lezcanofd02fc52021-04-05 18:03:59 -07004114 x.put_(ind, value)
Shen Li10224432021-08-12 11:39:31 -07004115 with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
lezcanofd02fc52021-04-05 18:03:59 -07004116 y.put_(ind[0], y[0])
Shen Li10224432021-08-12 11:39:31 -07004117 with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
lezcanofd02fc52021-04-05 18:03:59 -07004118 ind.put_(ind, ind)
Shen Li10224432021-08-12 11:39:31 -07004119 with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
lezcanofd02fc52021-04-05 18:03:59 -07004120 y.put_(ind, y[:3])
Shen Li10224432021-08-12 11:39:31 -07004121 with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
lezcanofd02fc52021-04-05 18:03:59 -07004122 ind.put_(ind, ind.clone())
Shen Li10224432021-08-12 11:39:31 -07004123 with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
lezcanofd02fc52021-04-05 18:03:59 -07004124 ind.put_(ind.clone(), ind)
4125
Mike Ruberrye0d829a2022-01-24 01:28:07 -08004126 # FIXME: convert to ErrorInputs
kshitij12345885a8e52021-11-01 09:21:20 -07004127 @expectedFailureMeta # UserWarning not triggered
4128 @onlyNativeDeviceTypes
Peter Bellc88ac252020-09-02 08:44:13 -07004129 def test_index_put_mem_overlap(self, device):
4130 x = torch.rand((1,), device=device).expand((6,))
4131 y = torch.rand((6,), device=device)
Peter Bell5765bbd2020-12-09 15:09:23 -08004132 ind = torch.tensor([2, 1, 0], device=device)
Peter Bellc88ac252020-09-02 08:44:13 -07004133 value = torch.rand((3,), device=device)
Shen Li10224432021-08-12 11:39:31 -07004134 with self.assertWarnsRegex(UserWarning, 'expanded tensors'):
Peter Bellc88ac252020-09-02 08:44:13 -07004135 x.index_put_((ind,), value)
Shen Li10224432021-08-12 11:39:31 -07004136 with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
Peter Bellc88ac252020-09-02 08:44:13 -07004137 y.index_put_((ind,), y[0])
Shen Li10224432021-08-12 11:39:31 -07004138 with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
Peter Bell5765bbd2020-12-09 15:09:23 -08004139 ind.index_put_((ind,), ind)
Shen Li10224432021-08-12 11:39:31 -07004140 with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
Peter Bell5765bbd2020-12-09 15:09:23 -08004141 y.index_put_((ind,), y[:3])
Shen Li10224432021-08-12 11:39:31 -07004142 with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
Peter Bell5765bbd2020-12-09 15:09:23 -08004143 ind.index_put_((ind,), ind.clone())
Shen Li10224432021-08-12 11:39:31 -07004144 with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
Peter Bell5765bbd2020-12-09 15:09:23 -08004145 ind.index_put_((ind.clone(),), ind)
Peter Bellc88ac252020-09-02 08:44:13 -07004146
Mike Ruberrye0d829a2022-01-24 01:28:07 -08004147 # FIXME: convert to ErrorInputs
kshitij12345885a8e52021-11-01 09:21:20 -07004148 @expectedFailureMeta # UserWarning not triggered
4149 @onlyNativeDeviceTypes
Peter Bellc88ac252020-09-02 08:44:13 -07004150 def test_masked_fill_mem_overlap(self, device):
4151 x = torch.rand((1,), device=device).expand((6,))
4152 mask = torch.tensor([True, False, True, True, False, False], device=device)
Shen Li10224432021-08-12 11:39:31 -07004153 with self.assertWarnsRegex(UserWarning, 'expanded tensors'):
4154 x.masked_fill_(mask, 0.)
Peter Bellc88ac252020-09-02 08:44:13 -07004155
Shen Li10224432021-08-12 11:39:31 -07004156 fill_val = torch.tensor(0., device=device)
4157 with self.assertWarnsRegex(UserWarning, 'expanded tensors'):
Peter Bellc88ac252020-09-02 08:44:13 -07004158 x.masked_fill_(mask, fill_val)
4159
Shen Li10224432021-08-12 11:39:31 -07004160 with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
Peter Bell5765bbd2020-12-09 15:09:23 -08004161 mask[1:].masked_fill_(mask[:-1], False)
4162
Mike Ruberrye0d829a2022-01-24 01:28:07 -08004163 # FIXME: convert to ErrorInputs
kshitij12345885a8e52021-11-01 09:21:20 -07004164 @expectedFailureMeta # RuntimeError not raised
4165 @onlyNativeDeviceTypes
Peter Bellc88ac252020-09-02 08:44:13 -07004166 def test_masked_scatter_mem_overlap(self, device):
4167 x = torch.rand((1,), device=device).expand((6,))
4168 src = torch.rand((3,), device=device)
4169 mask = torch.tensor([True, False, True, True, False, False], device=device)
4170
Shen Li10224432021-08-12 11:39:31 -07004171 with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
Peter Bellc88ac252020-09-02 08:44:13 -07004172 x.masked_scatter_(mask, src)
4173
Mike Ruberrye0d829a2022-01-24 01:28:07 -08004174 # FIXME: convert to ErrorInputs
kshitij12345885a8e52021-11-01 09:21:20 -07004175 @onlyNativeDeviceTypes
Peter Bellc88ac252020-09-02 08:44:13 -07004176 def test_scatter_mem_overlap(self, device):
4177 x = torch.rand((1,), device=device).expand((6,))
4178 src = torch.rand((3,), device=device)
Peter Bell5765bbd2020-12-09 15:09:23 -08004179 ind = torch.tensor([2, 1, 0], device=device, dtype=torch.int64)
Peter Bellc88ac252020-09-02 08:44:13 -07004180
Shen Li10224432021-08-12 11:39:31 -07004181 with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
Peter Bellc88ac252020-09-02 08:44:13 -07004182 x.scatter_(0, ind, src)
Shen Li10224432021-08-12 11:39:31 -07004183 with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
Peter Bell5765bbd2020-12-09 15:09:23 -08004184 src.scatter_(0, ind, src)
Shen Li10224432021-08-12 11:39:31 -07004185 with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
Peter Bell5765bbd2020-12-09 15:09:23 -08004186 ind.scatter_(0, ind, ind.clone())
Peter Bellc88ac252020-09-02 08:44:13 -07004187
Mike Ruberrye0d829a2022-01-24 01:28:07 -08004188 # FIXME: move to test distributions
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004189 @onlyCUDA
4190 def test_multinomial_device_constrain(self, device):
4191 x = torch.empty(0, device="cpu")
4192 y = torch.empty(0, device=device)
4193 self.assertRaisesRegex(
Shen Li10224432021-08-12 11:39:31 -07004194 RuntimeError, "Expected all tensors to be on the same device",
4195 lambda: torch.multinomial(x, 2, out=y))
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004196
Mike Ruberrye0d829a2022-01-24 01:28:07 -08004197 # FIXME: move to test distributions
Mike Ruberry25cd3c62019-09-25 10:14:35 -07004198 @deviceCountAtLeast(2)
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004199 @onlyCUDA
Mike Ruberry25cd3c62019-09-25 10:14:35 -07004200 def test_multinomial_gpu_device_constrain(self, devices):
4201 x = torch.empty(0, device=devices[0])
4202 y = torch.empty(0, device=devices[1])
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004203 self.assertRaisesRegex(
Shen Li10224432021-08-12 11:39:31 -07004204 RuntimeError, "Expected all tensors to be on the same device",
4205 lambda: torch.multinomial(x, 2, out=y))
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004206
Mike Ruberrye0d829a2022-01-24 01:28:07 -08004207 # FIXME: convert this to an automated OpInfo test
Mike Ruberry25cd3c62019-09-25 10:14:35 -07004208 @deviceCountAtLeast(2)
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004209 @onlyCUDA
Mike Ruberry25cd3c62019-09-25 10:14:35 -07004210 def test_device_guard(self, devices):
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004211 # verify that all operators with `device_guard: False` behave properly with multiple devices.
4212 # TODO: if we had operator introspection we could figure out this set of operators automatically...
Mike Ruberry25cd3c62019-09-25 10:14:35 -07004213 x = torch.randn((1, 2, 3), device=devices[1])
4214 y = torch.zeros((1, 3, 2), device=devices[1])
4215 scalar = torch.tensor(5, device=devices[1])
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004216
4217 # property ops
4218 torch.cudnn_is_acceptable(x)
4219 x.is_distributed()
4220 x.is_floating_point()
4221 x.is_complex()
4222 x.is_same_size(y)
4223 x.is_signed()
4224 x.size(0)
4225 x.stride(0)
4226 x.numel()
4227 x.is_set_to(y)
4228 x.data_ptr()
4229 scalar.is_nonzero()
4230
4231 # sparse property ops
4232 y[0][1] = 5
4233 y_sparse = y.to_sparse()
4234 y_sparse.sparse_dim()
4235 y_sparse._dimI()
4236 y_sparse.dense_dim()
4237 y_sparse._dimV()
4238 y_sparse._nnz()
4239 y_sparse.is_coalesced()
4240 y_sparse._indices()
4241 y_sparse._values()
4242 y_sparse.indices()
4243 y_sparse.values()
4244
4245 # in-place ops
4246 def inplace():
Mike Ruberry25cd3c62019-09-25 10:14:35 -07004247 return torch.randn((1, 2, 3), device=devices[1])
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004248 inplace().as_strided_(y.size(), y.stride())
4249 inplace().resize_(y.size())
4250 inplace().squeeze_()
4251 inplace().squeeze_(0)
4252 inplace().unsqueeze_(2)
4253 inplace().transpose_(1, 2)
4254 inplace().squeeze_().t_()
4255 inplace().set_(x.storage())
4256 inplace().set_(x.storage(), x.storage_offset(), x.size(), x.stride())
4257 inplace().set_(x)
4258 inplace().set_()
4259 y_sparse._coalesced_(True)
4260
4261 # shape modification
4262 x.as_strided(y.size(), y.stride())
4263 x.expand((5, 2, 3))
4264 x.expand_as(x)
4265 x.sum_to_size((1,))
Shen Li10224432021-08-12 11:39:31 -07004266 torch.broadcast_tensors(x , x)
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004267 x.reshape((1, 3, 2))
4268 x.reshape_as(y)
4269 x.squeeze()
4270 x.squeeze(0)
4271 x.squeeze().t()
4272 x.transpose(1, 2)
4273 x.unsqueeze(2)
4274 x.view((1, 3, 2))
4275 x.view_as(y)
4276
4277 # chunk, split, etc.
4278 x.chunk(2, dim=1)
4279 x.split(1, dim=2)
4280 x.split_with_sizes([1, 2], dim=2)
4281 x.unfold(dimension=2, size=1, step=1)
4282
4283 x.narrow(1, 1, 1)
4284 x.select(1, 1)
4285 torch.isnan(x)
4286
4287 torch.empty((1, 3, 2), out=y)
4288 torch.empty_like(x)
4289 torch.empty_like(x, dtype=torch.int64)
4290
4291 # to
4292 x.to(x)
4293 x.to(y)
4294 x.to(x, copy=True)
4295
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004296 def test_is_signed(self, device):
4297 self.assertEqual(torch.IntTensor(5).to(device).is_signed(), True)
4298 self.assertEqual(torch.ByteTensor(5).to(device).is_signed(), False)
4299 self.assertEqual(torch.CharTensor(5).to(device).is_signed(), True)
4300 self.assertEqual(torch.FloatTensor(5).to(device).is_signed(), True)
4301 self.assertEqual(torch.HalfTensor(10).to(device).is_signed(), True)
4302
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004303 # Note - reports a leak of 512 bytes on CUDA device 1
Mike Ruberry25cd3c62019-09-25 10:14:35 -07004304 @deviceCountAtLeast(2)
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004305 @skipCUDAMemoryLeakCheckIf(True)
4306 @onlyCUDA
Mike Ruberry25cd3c62019-09-25 10:14:35 -07004307 def test_tensor_set_errors_multigpu(self, devices):
4308 f_cuda0 = torch.randn((2, 3), dtype=torch.float32, device=devices[0])
4309 f_cuda1 = torch.randn((2, 3), dtype=torch.float32, device=devices[1])
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004310
4311 self.assertRaises(RuntimeError, lambda: f_cuda0.set_(f_cuda1.storage()))
Shen Li10224432021-08-12 11:39:31 -07004312 self.assertRaises(RuntimeError,
4313 lambda: f_cuda0.set_(f_cuda1.storage(), 0, f_cuda1.size(), f_cuda1.stride()))
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004314 self.assertRaises(RuntimeError, lambda: f_cuda0.set_(f_cuda1))
4315
Mike Ruberrye0d829a2022-01-24 01:28:07 -08004316 # FIXME: move to test_serialization
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004317 @onlyCUDA
Mike Ruberry25cd3c62019-09-25 10:14:35 -07004318 @deviceCountAtLeast(1) # Note: Tests works with one but prefers more devices
4319 def test_serialization(self, devices):
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004320 def _test_serialization(filecontext_lambda):
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004321 t0 = torch.cuda.FloatTensor(5).fill_(1)
Mike Ruberry25cd3c62019-09-25 10:14:35 -07004322 with torch.cuda.device(devices[-1]):
4323 tn = torch.cuda.FloatTensor(3).fill_(2)
4324 torch.cuda.set_device(devices[0])
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004325 b = (t0, tn)
4326 with filecontext_lambda() as f:
4327 torch.save(b, f)
4328 f.seek(0)
4329 c = torch.load(f)
Mike Ruberry13120bf2020-05-27 06:28:05 -07004330 self.assertEqual(b, c, atol=0, rtol=0)
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004331 u0, un = c
Mike Ruberry25cd3c62019-09-25 10:14:35 -07004332 self.assertEqual(str(u0.device), devices[0])
4333 self.assertEqual(str(un.device), devices[-1])
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004334
4335 _test_serialization(tempfile.NamedTemporaryFile)
4336 _test_serialization(BytesIOContext)
4337
Mike Ruberrye0d829a2022-01-24 01:28:07 -08004338 # FIXME: move memory format tests to their own test class/suite
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004339 def test_memory_format_preserved_after_permute(self, device):
Vitaly Fedyunin4bfe2f02019-10-31 13:19:31 -07004340 x = torch.randn(4, 3, 8, 8, device=device)
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004341 nhwc = x.contiguous(memory_format=torch.channels_last)
4342 y = nhwc.permute(0, 1, 3, 2).permute(0, 1, 3, 2)
4343 self.assertTrue(y.is_contiguous(memory_format=torch.channels_last))
4344
lixinyuf9f135c2020-03-06 05:59:20 -08004345 x = torch.randn(4, 3, 8, 8, 8, device=device)
4346 ndhwc = x.contiguous(memory_format=torch.channels_last_3d)
4347 y = ndhwc.permute(0, 1, 4, 3, 2).permute(0, 1, 4, 3, 2)
4348 self.assertTrue(y.is_contiguous(memory_format=torch.channels_last_3d))
4349
Natalia Gimelsheinc8bc2982020-08-20 10:48:26 -07004350 def test_memory_format_propagation_rules(self, device):
Vitaly Fedyunina47fb572020-06-20 10:31:28 -07004351
4352 contiguous = torch.rand(10, 3, 5, 5, device=device)
Shen Li10224432021-08-12 11:39:31 -07004353 cl = torch.rand(10, 3, 5, 5, device=device).contiguous(memory_format=torch.channels_last)
4354 ambiguous = torch.rand(10, 3, 1, 1, device=device).contiguous(memory_format=torch.channels_last)
Vitaly Fedyunina47fb572020-06-20 10:31:28 -07004355 self.assertTrue(ambiguous.is_contiguous(memory_format=torch.channels_last))
4356 self.assertTrue(ambiguous.is_contiguous(memory_format=torch.contiguous_format))
Shen Li10224432021-08-12 11:39:31 -07004357 bias = torch.rand(1, 1, 1, 1, device=device).contiguous(memory_format=torch.channels_last)
Vitaly Fedyunina47fb572020-06-20 10:31:28 -07004358
4359 def _test_propagation_rules(self, contiguous, cl, ambiguous, bias):
Shen Li10224432021-08-12 11:39:31 -07004360 options = ((ambiguous, contiguous, torch.contiguous_format),
4361 (ambiguous, cl, torch.channels_last),
4362 (contiguous, ambiguous, torch.contiguous_format),
4363 (contiguous, cl, torch.contiguous_format),
4364 (cl, ambiguous, torch.channels_last),
4365 (cl, contiguous, torch.channels_last),
4366 (bias, cl, torch.channels_last),
4367 (cl, bias, torch.channels_last),)
Vitaly Fedyunina47fb572020-06-20 10:31:28 -07004368
4369 for a, b, mf in options:
4370 result = a + b
4371 self.assertTrue(result.is_contiguous(memory_format=mf))
4372
4373 _test_propagation_rules(self, contiguous, cl, ambiguous, bias)
4374
4375 cl = cl.to(memory_format=torch.channels_last)
4376 ambiguous = ambiguous.to(memory_format=torch.channels_last)
4377 bias = bias.to(memory_format=torch.channels_last)
4378
4379 _test_propagation_rules(self, contiguous, cl, ambiguous, bias)
4380
4381 # test cases when strides matter in ambiguous tensors
4382 for mf in (torch.channels_last, torch.contiguous_format):
4383 ambiguous = torch.rand(10, 3, 1, 1, device=device).to(memory_format=mf)
4384 bias = torch.rand(3, 1, 1, device=device)
4385 result = ambiguous + bias
4386 self.assertEqual(ambiguous.stride(), result.stride())
4387 result = bias + ambiguous
4388 self.assertEqual(ambiguous.stride(), result.stride())
4389 result = ambiguous * 5
4390 self.assertEqual(ambiguous.stride(), result.stride())
4391
Kulin Sethe011a8e2022-05-13 18:28:53 +00004392 @skipIfMps
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004393 def test_memory_format_empty_like(self, device):
lixinyuf9f135c2020-03-06 05:59:20 -08004394 def test_helper(x, memory_format):
4395 xc = x.contiguous(memory_format=memory_format)
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004396
lixinyuf9f135c2020-03-06 05:59:20 -08004397 like = torch.empty_like(xc, memory_format=torch.preserve_format)
4398 self.assertFalse(like.is_contiguous())
4399 self.assertTrue(like.is_contiguous(memory_format=memory_format))
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004400
lixinyuf9f135c2020-03-06 05:59:20 -08004401 like_x = torch.empty_like(x, memory_format=torch.preserve_format)
4402 self.assertTrue(like_x.is_contiguous())
4403 self.assertFalse(like_x.is_contiguous(memory_format=memory_format))
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004404
lixinyuf9f135c2020-03-06 05:59:20 -08004405 like = torch.empty_like(x, memory_format=memory_format)
4406 self.assertFalse(like.is_contiguous())
4407 self.assertTrue(like.is_contiguous(memory_format=memory_format))
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004408
lixinyuf9f135c2020-03-06 05:59:20 -08004409 like = torch.empty_like(xc, memory_format=torch.contiguous_format)
4410 self.assertTrue(like.is_contiguous())
4411 self.assertFalse(like.is_contiguous(memory_format=memory_format))
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004412
lixinyuf9f135c2020-03-06 05:59:20 -08004413 like = torch.empty_like(xc)
4414 self.assertFalse(like.is_contiguous())
4415 self.assertTrue(like.is_contiguous(memory_format=memory_format))
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004416
lixinyuf9f135c2020-03-06 05:59:20 -08004417 sparse = x.to_sparse()
4418 with self.assertRaises(RuntimeError):
4419 z = torch.empty_like(sparse, memory_format=torch.preserve_format)
4420
4421 test_helper(torch.randn(4, 3, 8, 8, device=device), torch.channels_last)
4422 test_helper(torch.randn(4, 3, 8, 8, 8, device=device), torch.channels_last_3d)
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004423
Vitaly Fedyuninddeeb562019-11-11 17:28:06 -08004424 def test_memory_format_consistency(self, device):
4425 x = torch.randn(10, 3, 1, 1, device=device)
4426 x_rep = x.as_strided(x.size(), x.stride())
4427 self.assertEqual(x.size(), x_rep.size())
4428 self.assertEqual(x.stride(), x_rep.stride())
4429 self.assertEqual(x.is_contiguous(), x_rep.is_contiguous())
Shen Li10224432021-08-12 11:39:31 -07004430 self.assertEqual(x.is_contiguous(memory_format=torch.channels_last), x_rep.is_contiguous(memory_format=torch.channels_last))
lixinyuf9f135c2020-03-06 05:59:20 -08004431 self.assertEqual(
Shen Li10224432021-08-12 11:39:31 -07004432 x.is_contiguous(memory_format=torch.channels_last_3d), x_rep.is_contiguous(memory_format=torch.channels_last_3d))
Vitaly Fedyuninddeeb562019-11-11 17:28:06 -08004433
Mike Ruberrye0d829a2022-01-24 01:28:07 -08004434 # FIXME: make this a elementwise unary and elementwise binary OpInfo test
Vitaly Fedyunina7df3692019-11-18 05:32:23 -08004435 def test_memory_format_operators(self, device):
Hong Xu3894de52020-06-22 10:43:35 -07004436 def _chunk_op(x, y):
Vitaly Fedyunina7df3692019-11-18 05:32:23 -08004437 x1, x2 = x.chunk(2, dim=1)
Vitaly Fedyunina47fb572020-06-20 10:31:28 -07004438 return x1 + x2
Vitaly Fedyunina7df3692019-11-18 05:32:23 -08004439
Hong Xu3894de52020-06-22 10:43:35 -07004440 def _unsqueeze_op_add(x, y):
Vitaly Fedyunina7df3692019-11-18 05:32:23 -08004441 return x[0].unsqueeze(0) + 3
4442
Hong Xu3894de52020-06-22 10:43:35 -07004443 def _unsqueeze_op_clone(x, y):
Vitaly Fedyunina7df3692019-11-18 05:32:23 -08004444 return x[0].unsqueeze(0).clone()
4445
Hong Xu3894de52020-06-22 10:43:35 -07004446 def _test_helper(x, y, bias, memory_format):
Vitaly Fedyunina47fb572020-06-20 10:31:28 -07004447 return_contig_fns = [
4448 lambda x, y: y + x,
4449 lambda x, y: y * x,
4450 lambda x, y: y.addcdiv(x, y, value=2),
4451 lambda x, y: y.addcmul(x, y, value=2),
4452 ]
4453 bias_fns = [
4454 lambda x, b: x + b,
4455 lambda x, b: b + x,
4456 ]
lixinyuf9f135c2020-03-06 05:59:20 -08004457 fns = [
Vitaly Fedyunin930d2182020-03-27 12:01:44 -07004458 lambda x, y: x.clone(),
lixinyuf9f135c2020-03-06 05:59:20 -08004459 lambda x, y: x + 3,
4460 lambda x, y: 3 * x,
4461 lambda x, y: x + y,
lixinyuf9f135c2020-03-06 05:59:20 -08004462 lambda x, y: x * y,
Vitaly Fedyunin930d2182020-03-27 12:01:44 -07004463 lambda x, y: abs(x),
4464 lambda x, y: x.abs(),
4465 lambda x, y: x.abs_(),
4466 lambda x, y: x.acos(),
4467 lambda x, y: x.acos_(),
lixinyuf9f135c2020-03-06 05:59:20 -08004468 lambda x, y: x.add(y, alpha=3),
Vitaly Fedyunin930d2182020-03-27 12:01:44 -07004469 lambda x, y: x.add_(y, alpha=3),
lixinyuf9f135c2020-03-06 05:59:20 -08004470 lambda x, y: x.addcdiv(y, y, value=2),
Vitaly Fedyunin930d2182020-03-27 12:01:44 -07004471 lambda x, y: x.addcdiv_(y, y, value=2),
lixinyuf9f135c2020-03-06 05:59:20 -08004472 lambda x, y: x.addcmul(y, y, value=2),
Vitaly Fedyunin930d2182020-03-27 12:01:44 -07004473 lambda x, y: x.addcmul_(y, y, value=2),
krshrimali335e4a12020-06-04 11:38:14 -07004474 lambda x, y: x.acosh(),
4475 lambda x, y: x.acosh_(),
4476 lambda x, y: x.asinh(),
4477 lambda x, y: x.asinh_(),
4478 lambda x, y: x.atanh(),
4479 lambda x, y: x.atanh_(),
lixinyuf9f135c2020-03-06 05:59:20 -08004480 lambda x, y: x.asin(),
Vitaly Fedyunin930d2182020-03-27 12:01:44 -07004481 lambda x, y: x.asin_(),
Cloud Han8ab63772020-05-12 14:19:14 -07004482 lambda x, y: x.atan(),
lixinyuf9f135c2020-03-06 05:59:20 -08004483 lambda x, y: x.atan2(y),
Vitaly Fedyunin930d2182020-03-27 12:01:44 -07004484 lambda x, y: x.atan2_(y),
lixinyuf9f135c2020-03-06 05:59:20 -08004485 lambda x, y: x.ceil(),
Vitaly Fedyunin930d2182020-03-27 12:01:44 -07004486 lambda x, y: x.ceil_(),
ShawnZhongcb530fc2020-06-03 16:00:10 -07004487 lambda x, y: x.clamp(-1, 1),
4488 lambda x, y: x.cos(),
4489 lambda x, y: x.cosh(),
lixinyuf9f135c2020-03-06 05:59:20 -08004490 lambda x, y: x.div(0.5),
Vitaly Fedyunin930d2182020-03-27 12:01:44 -07004491 lambda x, y: x.div_(0.5),
lixinyuf9f135c2020-03-06 05:59:20 -08004492 lambda x, y: x.div(y),
Vitaly Fedyunin930d2182020-03-27 12:01:44 -07004493 lambda x, y: x.div_(y),
lixinyuf9f135c2020-03-06 05:59:20 -08004494 lambda x, y: x.digamma(),
Vitaly Fedyunin930d2182020-03-27 12:01:44 -07004495 lambda x, y: x.digamma_(),
ShawnZhongcb530fc2020-06-03 16:00:10 -07004496 lambda x, y: x.erf(),
Cloud Han8d946152020-05-13 21:16:58 -07004497 lambda x, y: x.erfc(),
lixinyuf9f135c2020-03-06 05:59:20 -08004498 lambda x, y: x.erfinv(),
Vitaly Fedyunin930d2182020-03-27 12:01:44 -07004499 lambda x, y: x.erfinv_(),
ShawnZhongcb530fc2020-06-03 16:00:10 -07004500 lambda x, y: x.exp(),
lixinyuf9f135c2020-03-06 05:59:20 -08004501 lambda x, y: x.expm1(),
Vitaly Fedyunin930d2182020-03-27 12:01:44 -07004502 lambda x, y: x.expm1_(),
lixinyuf9f135c2020-03-06 05:59:20 -08004503 lambda x, y: x.floor(),
Vitaly Fedyunin930d2182020-03-27 12:01:44 -07004504 lambda x, y: x.floor_(),
Erjia Guanc98c98d2020-12-02 09:37:02 -08004505 lambda x, y: x.fmod(2),
ShawnZhongcb530fc2020-06-03 16:00:10 -07004506 lambda x, y: x.frac(),
Muthu Arivoli92885eb2020-08-12 13:14:36 -07004507 lambda x, y: x.hypot(y),
4508 lambda x, y: x.hypot_(y),
Muthu Arivoli719d29d2020-09-05 23:09:43 -07004509 lambda x, y: x.i0(),
4510 lambda x, y: x.i0_(),
Xiao Wangb5d75dd2020-09-10 21:48:43 -07004511 lambda x, y: x.lerp(y, 0.5),
lixinyuf9f135c2020-03-06 05:59:20 -08004512 lambda x, y: x.log(),
Vitaly Fedyunin930d2182020-03-27 12:01:44 -07004513 lambda x, y: x.log_(),
lixinyuf9f135c2020-03-06 05:59:20 -08004514 lambda x, y: x.log10(),
Vitaly Fedyunin930d2182020-03-27 12:01:44 -07004515 lambda x, y: x.log10_(),
lixinyuf9f135c2020-03-06 05:59:20 -08004516 lambda x, y: x.log1p(),
Vitaly Fedyunin930d2182020-03-27 12:01:44 -07004517 lambda x, y: x.log1p_(),
lixinyuf9f135c2020-03-06 05:59:20 -08004518 lambda x, y: x.log2(),
Vitaly Fedyunin930d2182020-03-27 12:01:44 -07004519 lambda x, y: x.log2_(),
lixinyuf9f135c2020-03-06 05:59:20 -08004520 lambda x, y: x.mul(3),
Vitaly Fedyunin930d2182020-03-27 12:01:44 -07004521 lambda x, y: x.mul_(3),
lixinyuf9f135c2020-03-06 05:59:20 -08004522 lambda x, y: x.neg(),
Vitaly Fedyunin930d2182020-03-27 12:01:44 -07004523 lambda x, y: x.neg_(),
lixinyuf9f135c2020-03-06 05:59:20 -08004524 lambda x, y: x.pow(3),
Vitaly Fedyunin930d2182020-03-27 12:01:44 -07004525 lambda x, y: x.pow_(3),
Hong Xu3894de52020-06-22 10:43:35 -07004526 lambda x, y: x.pow(0.0),
4527 lambda x, y: x.pow(1.0),
4528 lambda x, y: x.reciprocal(),
ShawnZhongcb530fc2020-06-03 16:00:10 -07004529 lambda x, y: x.remainder(2),
lixinyuf9f135c2020-03-06 05:59:20 -08004530 lambda x, y: x.round(),
Vitaly Fedyunin930d2182020-03-27 12:01:44 -07004531 lambda x, y: x.round_(),
lixinyuf9f135c2020-03-06 05:59:20 -08004532 lambda x, y: x.rsqrt(),
Vitaly Fedyunin930d2182020-03-27 12:01:44 -07004533 lambda x, y: x.rsqrt_(),
4534 lambda x, y: x.sigmoid(),
4535 lambda x, y: x.sigmoid_(),
Xiaomeng Yang80d5b372020-07-13 19:31:26 -07004536 lambda x, y: x.logit(),
4537 lambda x, y: x.logit_(),
4538 lambda x, y: x.logit(1e-6),
4539 lambda x, y: x.logit_(1e-6),
lixinyuf9f135c2020-03-06 05:59:20 -08004540 lambda x, y: x.sign(),
Vitaly Fedyunin930d2182020-03-27 12:01:44 -07004541 lambda x, y: x.sign_(),
anjali41158b6ab62020-09-22 08:01:16 -07004542 lambda x, y: x.sgn(),
4543 lambda x, y: x.sgn_(),
lixinyuf9f135c2020-03-06 05:59:20 -08004544 lambda x, y: x.sin(),
Vitaly Fedyunin930d2182020-03-27 12:01:44 -07004545 lambda x, y: x.sin_(),
lixinyuf9f135c2020-03-06 05:59:20 -08004546 lambda x, y: x.sinh(),
Vitaly Fedyunin930d2182020-03-27 12:01:44 -07004547 lambda x, y: x.sinh_(),
lixinyuf9f135c2020-03-06 05:59:20 -08004548 lambda x, y: x.sqrt(),
Vitaly Fedyunin930d2182020-03-27 12:01:44 -07004549 lambda x, y: x.sqrt_(),
ShawnZhongcb530fc2020-06-03 16:00:10 -07004550 lambda x, y: x.tan(),
4551 lambda x, y: x.tanh(),
lixinyuf9f135c2020-03-06 05:59:20 -08004552 lambda x, y: x.trunc(),
Vitaly Fedyunin930d2182020-03-27 12:01:44 -07004553 lambda x, y: x.trunc_(),
Hong Xu3894de52020-06-22 10:43:35 -07004554 _chunk_op,
4555 _unsqueeze_op_add,
4556 _unsqueeze_op_clone,
lixinyuf9f135c2020-03-06 05:59:20 -08004557 ]
Alexander Grund3b8589a2022-06-27 14:47:51 +00004558 x_c = x.contiguous()
4559 y_c = y.contiguous()
4560 b_c = bias.contiguous()
lixinyuf9f135c2020-03-06 05:59:20 -08004561 for fn in fns:
Alexander Grund3b8589a2022-06-27 14:47:51 +00004562 is_inplace = '_(' in inspect.getsource(fn)
4563 x_clone = x.clone() if is_inplace else x
4564 x_c_clone = x_c.clone() if is_inplace else x_c
4565 result_c = fn(x_c_clone, y_c)
4566 result = fn(x_clone, y)
4567 self.assertEqual(result, result_c, "Failed for '{}'".format(inspect.getsource(fn).strip()))
lixinyuf9f135c2020-03-06 05:59:20 -08004568 self.assertTrue(
4569 result.is_contiguous(memory_format=memory_format),
Shen Li10224432021-08-12 11:39:31 -07004570 "result of the '{}' is not in '{}' format".format(inspect.getsource(fn).strip(), memory_format))
lixinyuf9f135c2020-03-06 05:59:20 -08004571
Vitaly Fedyunina47fb572020-06-20 10:31:28 -07004572 for fn in bias_fns:
Vitaly Fedyunina47fb572020-06-20 10:31:28 -07004573 result_c = fn(x_c, b_c)
4574 result = fn(x, bias)
Alexander Grund3b8589a2022-06-27 14:47:51 +00004575 self.assertEqual(result, result_c, "Failed for '{}'".format(inspect.getsource(fn).strip()))
Vitaly Fedyunina47fb572020-06-20 10:31:28 -07004576 self.assertTrue(
4577 result.is_contiguous(memory_format=memory_format),
Shen Li10224432021-08-12 11:39:31 -07004578 "result of the '{}' is not in '{}' format".format(inspect.getsource(fn).strip(), memory_format))
Vitaly Fedyunina47fb572020-06-20 10:31:28 -07004579
4580 for fn in return_contig_fns:
Vitaly Fedyunina47fb572020-06-20 10:31:28 -07004581 result_c = fn(x_c, y_c)
4582 result = fn(x, y)
Alexander Grund3b8589a2022-06-27 14:47:51 +00004583 self.assertEqual(result, result_c, "Failed for '{}'".format(inspect.getsource(fn).strip()))
Vitaly Fedyunina47fb572020-06-20 10:31:28 -07004584 self.assertTrue(
4585 result.is_contiguous(memory_format=torch.contiguous_format),
Shen Li10224432021-08-12 11:39:31 -07004586 "result of the '{}' is not in '{}' format".format(inspect.getsource(fn).strip(), torch.contiguous_format))
Vitaly Fedyunina47fb572020-06-20 10:31:28 -07004587
Hong Xu3894de52020-06-22 10:43:35 -07004588 _test_helper(
Shen Li10224432021-08-12 11:39:31 -07004589 torch.randn((4, 3, 8, 8), device=device).contiguous(memory_format=torch.channels_last),
lixinyuf9f135c2020-03-06 05:59:20 -08004590 abs(torch.randn((4, 3, 8, 8), device=device)) + 1,
Shen Li10224432021-08-12 11:39:31 -07004591 torch.randn((1, 3, 1, 1), device=device).contiguous(memory_format=torch.channels_last),
4592 torch.channels_last)
Hong Xu3894de52020-06-22 10:43:35 -07004593 _test_helper(
Shen Li10224432021-08-12 11:39:31 -07004594 torch.randn((4, 3, 8, 8, 8), device=device).contiguous(memory_format=torch.channels_last_3d),
lixinyuf9f135c2020-03-06 05:59:20 -08004595 abs(torch.randn((4, 3, 8, 8, 8), device=device)) + 1,
Shen Li10224432021-08-12 11:39:31 -07004596 torch.randn((1, 3, 1, 1, 1), device=device).contiguous(memory_format=torch.channels_last_3d),
4597 torch.channels_last_3d)
Vitaly Fedyunina7df3692019-11-18 05:32:23 -08004598
Mike Ruberrye0d829a2022-01-24 01:28:07 -08004599 # FIXME: make this a elementwise unary and elementwise binary OpInfo test
Animesh Jain1d90d6e2022-07-07 18:57:31 +00004600 @skipIfTorchDynamo("Torchdynamo fails with unknown reason")
Natalia Gimelsheinc8bc2982020-08-20 10:48:26 -07004601 def test_strides_propagation(self, device):
Natalia Gimelsheinc8bc2982020-08-20 10:48:26 -07004602 def _test_helper(x, op, unary=False):
4603 def compare_strides(s1, s2, div):
4604 sdiv = [s // div for s in s1]
4605 self.assertEqual(sdiv, s2)
4606
4607 dim = x.dim()
4608 # we produce memory dense outputs, so when input is strided on the last dimension
4609 # we need to divide by that dimension stride to compare input and result strides
4610 div = x.stride(-1)
4611 for p in permutations(range(dim)):
4612 xp = x.permute(p)
4613 if not unary:
4614 y = torch.randn(xp.size(-1), device=x.device, dtype=x.dtype)
4615 for inputs in ((xp, xp), (xp, y), (y, xp)):
4616 res = op(*inputs)
4617 compare_strides(xp.stride(), res.stride(), div)
4618 self.assertEqual(xp.size(), res.size())
4619 out = torch.empty(0, device=xp.device, dtype=res.dtype)
4620 res = op(*inputs, out=out)
4621 compare_strides(xp.stride(), res.stride(), div)
4622 self.assertEqual(xp.size(), res.size())
4623 else:
4624 res = op(xp)
4625 compare_strides(xp.stride(), res.stride(), div)
4626 self.assertEqual(xp.size(), res.size())
4627 out = torch.empty(0, device=xp.device, dtype=res.dtype)
4628 res = op(xp, out=out)
4629 compare_strides(xp.stride(), res.stride(), div)
4630 self.assertEqual(xp.size(), res.size())
4631
4632 # torch.eq by default calls TensorIterator with defined output, torch.add with undefined
4633 binary_ops = (torch.eq, torch.add)
4634 unary_ops = (torch.exp,)
4635 # memory dense, sliced and ambiguous sliced (ambiguous dense loses permutation information)
Shen Li10224432021-08-12 11:39:31 -07004636 xs = (torch.randn(2, 3, 4, device=device), torch.randn(2, 3, 8, device=device)[:, :, ::2],
4637 torch.randn(1, 1, 4, 12, device=device)[:, :, :, ::2])
Natalia Gimelsheinc8bc2982020-08-20 10:48:26 -07004638 for op in binary_ops:
4639 for x in xs:
4640 _test_helper(x, op)
4641 for op in unary_ops:
4642 for x in xs:
4643 _test_helper(x, op, unary=True)
4644
Mike Ruberrye0d829a2022-01-24 01:28:07 -08004645 # FIXME: move dlpack tests to their own test class/suite
Edward Yang1f36ce62021-03-29 08:34:19 -07004646 @skipMeta
kshitij12345885a8e52021-11-01 09:21:20 -07004647 @onlyNativeDeviceTypes
Nikita Shulgabfac65d2022-03-30 14:13:21 -07004648 @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
Emilio Castillo1cb35072021-09-12 19:45:57 -07004649 def test_dlpack_capsule_conversion(self, device, dtype):
4650 # DLpack does not explicitly support bool (xref dmlc/dlpack#75)
Philip Meier0973c5a2022-02-24 21:47:38 -08004651 x = make_tensor((5,), dtype=dtype, device=device)
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004652 z = from_dlpack(to_dlpack(x))
4653 self.assertEqual(z, x)
4654
Emilio Castillo1cb35072021-09-12 19:45:57 -07004655 @skipMeta
kshitij12345885a8e52021-11-01 09:21:20 -07004656 @onlyNativeDeviceTypes
Nikita Shulgabfac65d2022-03-30 14:13:21 -07004657 @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
Emilio Castillo1cb35072021-09-12 19:45:57 -07004658 def test_dlpack_protocol_conversion(self, device, dtype):
Philip Meier0973c5a2022-02-24 21:47:38 -08004659 x = make_tensor((5,), dtype=dtype, device=device)
Emilio Castillo1cb35072021-09-12 19:45:57 -07004660 z = from_dlpack(x)
4661 self.assertEqual(z, x)
4662
4663 @skipMeta
kshitij12345885a8e52021-11-01 09:21:20 -07004664 @onlyNativeDeviceTypes
Emilio Castillo1cb35072021-09-12 19:45:57 -07004665 def test_dlpack_shared_storage(self, device):
Philip Meier0973c5a2022-02-24 21:47:38 -08004666 x = make_tensor((5,), dtype=torch.float64, device=device)
Emilio Castillo1cb35072021-09-12 19:45:57 -07004667 z = from_dlpack(to_dlpack(x))
4668 z[0] = z[0] + 20.0
4669 self.assertEqual(z, x)
4670
4671 @skipMeta
4672 @onlyCUDA
Nikita Shulgabfac65d2022-03-30 14:13:21 -07004673 @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
Emilio Castillo1cb35072021-09-12 19:45:57 -07004674 def test_dlpack_conversion_with_streams(self, device, dtype):
4675 # Create a stream where the tensor will reside
4676 stream = torch.cuda.Stream()
4677 with torch.cuda.stream(stream):
4678 # Do an operation in the actual stream
Philip Meier0973c5a2022-02-24 21:47:38 -08004679 x = make_tensor((5,), dtype=dtype, device=device) + 1
Emilio Castillo1cb35072021-09-12 19:45:57 -07004680 # DLPack protocol helps establish a correct stream order
4681 # (hence data dependency) at the exchange boundary.
4682 # DLPack manages this synchronization for us, so we don't need to
4683 # explicitly wait until x is populated
4684 stream = torch.cuda.Stream()
4685 with torch.cuda.stream(stream):
4686 z = from_dlpack(x)
4687 stream.synchronize()
4688 self.assertEqual(z, x)
4689
4690 @skipMeta
Kurt Mohler47c69932022-02-13 19:30:15 -08004691 @onlyNativeDeviceTypes
Nikita Shulgabfac65d2022-03-30 14:13:21 -07004692 @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
Kurt Mohler47c69932022-02-13 19:30:15 -08004693 def test_from_dlpack(self, device, dtype):
Philip Meier0973c5a2022-02-24 21:47:38 -08004694 x = make_tensor((5,), dtype=dtype, device=device)
Kurt Mohler47c69932022-02-13 19:30:15 -08004695 y = torch.from_dlpack(x)
4696 self.assertEqual(x, y)
4697
4698 @skipMeta
4699 @onlyNativeDeviceTypes
Nikita Shulgabfac65d2022-03-30 14:13:21 -07004700 @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
Kurt Mohler47c69932022-02-13 19:30:15 -08004701 def test_from_dlpack_noncontinguous(self, device, dtype):
Philip Meier0973c5a2022-02-24 21:47:38 -08004702 x = make_tensor((25,), dtype=dtype, device=device).reshape(5, 5)
Kurt Mohler47c69932022-02-13 19:30:15 -08004703
4704 y1 = x[0]
4705 y1_dl = torch.from_dlpack(y1)
4706 self.assertEqual(y1, y1_dl)
4707
4708 y2 = x[:, 0]
4709 y2_dl = torch.from_dlpack(y2)
4710 self.assertEqual(y2, y2_dl)
4711
4712 y3 = x[1, :]
4713 y3_dl = torch.from_dlpack(y3)
4714 self.assertEqual(y3, y3_dl)
4715
4716 y4 = x[1]
4717 y4_dl = torch.from_dlpack(y4)
4718 self.assertEqual(y4, y4_dl)
4719
4720 y5 = x.t()
4721 y5_dl = torch.from_dlpack(y5)
4722 self.assertEqual(y5, y5_dl)
4723
4724 @skipMeta
Emilio Castillo1cb35072021-09-12 19:45:57 -07004725 @onlyCUDA
Nikita Shulgabfac65d2022-03-30 14:13:21 -07004726 @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
Emilio Castillo1cb35072021-09-12 19:45:57 -07004727 def test_dlpack_conversion_with_diff_streams(self, device, dtype):
Emilio Castillo1cb35072021-09-12 19:45:57 -07004728 stream_a = torch.cuda.Stream()
4729 stream_b = torch.cuda.Stream()
4730 # DLPack protocol helps establish a correct stream order
4731 # (hence data dependency) at the exchange boundary.
4732 # the `tensor.__dlpack__` method will insert a synchronization event
4733 # in the current stream to make sure that it was correctly populated.
4734 with torch.cuda.stream(stream_a):
Philip Meier0973c5a2022-02-24 21:47:38 -08004735 x = make_tensor((5,), dtype=dtype, device=device) + 1
Kurt Mohler47c69932022-02-13 19:30:15 -08004736 z = torch.from_dlpack(x.__dlpack__(stream_b.cuda_stream))
Emilio Castillo1cb35072021-09-12 19:45:57 -07004737 stream_a.synchronize()
4738 stream_b.synchronize()
4739 self.assertEqual(z, x)
4740
4741 @skipMeta
Kurt Mohler47c69932022-02-13 19:30:15 -08004742 @onlyNativeDeviceTypes
Nikita Shulgabfac65d2022-03-30 14:13:21 -07004743 @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
Kurt Mohler47c69932022-02-13 19:30:15 -08004744 def test_from_dlpack_dtype(self, device, dtype):
Philip Meier0973c5a2022-02-24 21:47:38 -08004745 x = make_tensor((5,), dtype=dtype, device=device)
Kurt Mohler47c69932022-02-13 19:30:15 -08004746 y = torch.from_dlpack(x)
4747 assert x.dtype == y.dtype
4748
4749 @skipMeta
Emilio Castillo533e72e2021-11-18 08:34:37 -08004750 @onlyCUDA
4751 def test_dlpack_default_stream(self, device):
4752 class DLPackTensor:
4753 def __init__(self, tensor):
4754 self.tensor = tensor
4755
4756 def __dlpack_device__(self):
4757 return self.tensor.__dlpack_device__()
4758
4759 def __dlpack__(self, stream=None):
4760 if torch.version.hip is None:
4761 assert stream == 1
4762 else:
4763 assert stream == 0
4764 capsule = self.tensor.__dlpack__(stream)
4765 converted = True
4766 return capsule
4767
4768 # CUDA-based tests runs on non-default streams
4769 with torch.cuda.stream(torch.cuda.default_stream()):
Philip Meier0973c5a2022-02-24 21:47:38 -08004770 x = DLPackTensor(make_tensor((5,), dtype=torch.float32, device=device))
Emilio Castillo533e72e2021-11-18 08:34:37 -08004771 from_dlpack(x)
4772
4773 @skipMeta
kshitij12345885a8e52021-11-01 09:21:20 -07004774 @onlyNativeDeviceTypes
Nikita Shulgabfac65d2022-03-30 14:13:21 -07004775 @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
Emilio Castillo1cb35072021-09-12 19:45:57 -07004776 def test_dlpack_tensor_invalid_stream(self, device, dtype):
4777 with self.assertRaises(TypeError):
Philip Meier0973c5a2022-02-24 21:47:38 -08004778 x = make_tensor((5,), dtype=dtype, device=device)
Emilio Castillo1cb35072021-09-12 19:45:57 -07004779 x.__dlpack__(stream=object())
4780
4781 @skipMeta
4782 def test_dlpack_error_on_bool_tensor(self):
4783 x = torch.tensor([True], dtype=torch.bool)
4784 with self.assertRaises(RuntimeError):
4785 to_dlpack(x)
4786
4787 # TODO: increase tests once NumPy supports the `__dlpack__` protocol
Emilio Castillo1cb35072021-09-12 19:45:57 -07004788 @skipMeta
4789 def test_dlpack_export_requires_grad(self):
4790 x = torch.zeros(10, dtype=torch.float32, requires_grad=True)
4791 with self.assertRaisesRegex(RuntimeError, r"require gradient"):
4792 x.__dlpack__()
4793
4794 @skipMeta
4795 def test_dlpack_export_is_conj(self):
4796 x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j])
4797 y = torch.conj(x)
4798 with self.assertRaisesRegex(RuntimeError, r"conjugate bit"):
4799 y.__dlpack__()
4800
4801 @skipMeta
4802 def test_dlpack_export_non_strided(self):
4803 x = torch.sparse_coo_tensor([[0]], [1], size=(1,))
4804 y = torch.conj(x)
4805 with self.assertRaisesRegex(RuntimeError, r"strided"):
4806 y.__dlpack__()
4807
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004808 @onlyCUDA
Shen Li10224432021-08-12 11:39:31 -07004809 @unittest.skipIf(PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property")
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004810 def test_pin_memory_from_constructor(self, device):
4811 def _get_like(t, **kwargs):
4812 return [
4813 torch.rand_like(t, **kwargs),
4814 torch.randn_like(t, **kwargs),
4815 torch.empty_like(t, **kwargs),
4816 torch.full_like(t, 4, **kwargs),
4817 torch.zeros_like(t, **kwargs),
4818 torch.ones_like(t, **kwargs),
4819 ]
4820
4821 def _get_tensors(**kwargs):
4822 return [
4823 torch.tensor([10, 11], **kwargs),
4824 torch.randn(3, 5, **kwargs),
4825 torch.rand(3, **kwargs),
4826 # torch.randint(3, 5, **kwargs), // unsupported
4827 torch.zeros(3, **kwargs),
4828 torch.randperm(3, **kwargs),
4829 torch.empty(6, **kwargs),
4830 torch.ones(6, **kwargs),
4831 torch.eye(6, **kwargs),
Shen Li10224432021-08-12 11:39:31 -07004832 torch.arange(3, 5, **kwargs)]
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004833
Shen Li10224432021-08-12 11:39:31 -07004834 pinned_tensors = _get_tensors(pin_memory=True) + _get_like(torch.empty(5, dtype=torch.float64), pin_memory=True)
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004835 for x in pinned_tensors:
4836 self.assertTrue(x.is_pinned())
4837
Shen Li10224432021-08-12 11:39:31 -07004838 tensors = _get_tensors() + _get_like(torch.empty(5, dtype=torch.float64, pin_memory=True))
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004839 for x in tensors:
4840 self.assertFalse(x.is_pinned())
4841
Mike Ruberrye0d829a2022-01-24 01:28:07 -08004842 @deviceCountAtLeast(1)
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004843 @onlyCUDA
Mike Ruberrye0d829a2022-01-24 01:28:07 -08004844 def test_storage_all_devices(self, devices):
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004845 for device in devices:
Mike Ruberrye0d829a2022-01-24 01:28:07 -08004846 t = torch.tensor((), device=device)
4847 self.assertEqual(t.dtype, t.storage().dtype)
Mike Ruberryd9ab78b2019-09-19 01:47:32 -07004848
Mike Ruberrye0d829a2022-01-24 01:28:07 -08004849 # FIXME: move to test distributions
Kulin Sethe011a8e2022-05-13 18:28:53 +00004850 @skipIfMps
Thomas Viehmanndef8aa52021-01-05 19:36:56 -08004851 @dtypesIfCUDA(torch.float, torch.double, torch.half)
4852 @dtypes(torch.float, torch.double)
Mike Ruberryb45f1b92019-10-01 19:17:06 -07004853 def test_multinomial(self, device, dtype):
4854 def make_prob_dist(shape, is_contiguous):
4855 if is_contiguous:
vishwakftwae6af8d2019-11-20 13:04:02 -08004856 if dtype == torch.half:
Shen Li10224432021-08-12 11:39:31 -07004857 return torch.zeros(shape, device=device).uniform_().to(dtype=torch.half)
Mike Ruberryb45f1b92019-10-01 19:17:06 -07004858 return torch.zeros(shape, device=device, dtype=dtype).uniform_()
4859 elif len(shape) == 1:
vishwakftwae6af8d2019-11-20 13:04:02 -08004860 if dtype == torch.half:
Shen Li10224432021-08-12 11:39:31 -07004861 return torch.zeros((shape + [5]), device=device).uniform_().to(dtype=torch.half)[:, 2]
4862 return torch.zeros((shape + [5]), device=device, dtype=dtype).uniform_()[:, 2]
Mike Ruberryb45f1b92019-10-01 19:17:06 -07004863 else:
4864 # num dim = 2
4865 new_shape = [2, shape[1], 7, 1, shape[0], 1, 10]
vishwakftwae6af8d2019-11-20 13:04:02 -08004866 if dtype == torch.half:
Shen Li10224432021-08-12 11:39:31 -07004867 prob_dist = torch.zeros(new_shape, device=device).uniform_().to(dtype=torch.half)
vishwakftwae6af8d2019-11-20 13:04:02 -08004868 else:
Shen Li10224432021-08-12 11:39:31 -07004869 prob_dist = torch.zeros(new_shape, device=device, dtype=dtype).uniform_()
Mike Ruberryb45f1b92019-10-01 19:17:06 -07004870 prob_dist = prob_dist.transpose(1, 4)
4871 prob_dist = prob_dist[1, :, 5, 0, :, 0, 4]
4872 assert not prob_dist.is_contiguous() # sanity check
4873 return prob_dist
4874
4875 for is_contiguous in (True, False):
4876 # with replacement
4877 n_row = 3
4878 for n_col in range(4, 5 + 1):
4879 prob_dist = make_prob_dist([n_row, n_col], is_contiguous)
4880 # indices that shouldn't be sampled (<0 means none)
4881 zero_prob_indices = torch.LongTensor(n_row).random_(-2, n_col).tolist()
4882 for i, j in enumerate(zero_prob_indices):
4883 if j >= 0:
4884 prob_dist[i, j] = 0
4885 n_sample = n_col * 3
4886 sample_indices = torch.multinomial(prob_dist, n_sample, True)
4887 self.assertEqual(prob_dist.dim(), 2)
4888 self.assertEqual(sample_indices.size(1), n_sample)
4889 for i in range(n_row):
4890 zero_prob_idx = zero_prob_indices[i]
4891 if zero_prob_idx < 0:
4892 continue
4893 for j in range(n_sample):
Shen Li10224432021-08-12 11:39:31 -07004894 self.assertNotEqual(sample_indices[i, j], zero_prob_idx,
4895 msg="sampled an index with zero probability")
Mike Ruberryb45f1b92019-10-01 19:17:06 -07004896
4897 # without replacement
4898 n_row = 3
4899 for n_col in range(2, 10 + 1, 2):
4900 prob_dist = make_prob_dist([n_row, n_col], is_contiguous)
4901 # indices that shouldn't be sampled (<0 means none)
4902 zero_prob_indices = torch.LongTensor(n_row).random_(-1, n_col).tolist()
4903 for i, j in enumerate(zero_prob_indices):
4904 if j >= 0:
4905 prob_dist[i, j] = 0
4906 n_sample = max(1, n_col - 2)
4907 sample_indices = torch.multinomial(prob_dist, n_sample, False)
4908 self.assertEqual(prob_dist.dim(), 2)
4909 self.assertEqual(sample_indices.size(1), n_sample)
4910 for i in range(n_row):
4911 row_samples = {}
4912 zero_prob_idx = zero_prob_indices[i]
4913 for j in range(n_sample):
4914 sample_idx = sample_indices[i, j]
4915 if zero_prob_idx >= 0:
Shen Li10224432021-08-12 11:39:31 -07004916 self.assertNotEqual(sample_idx, zero_prob_idx,
4917 msg="sampled an index with zero probability")
4918 self.assertNotIn(sample_idx, row_samples, "sampled an index twice")
Mike Ruberryb45f1b92019-10-01 19:17:06 -07004919 row_samples[sample_idx] = True
4920
4921 # vector
4922 n_col = 4
4923 prob_dist = make_prob_dist([n_col], is_contiguous).fill_(1)
4924 zero_prob_idx = 1 # index that shouldn't be sampled
4925 prob_dist[zero_prob_idx] = 0
4926 n_sample = 20
4927 sample_indices = torch.multinomial(prob_dist, n_sample, True)
4928 for sample_index in sample_indices:
Shen Li10224432021-08-12 11:39:31 -07004929 self.assertNotEqual(sample_index, zero_prob_idx, msg="sampled an index with zero probability")
Mike Ruberryb45f1b92019-10-01 19:17:06 -07004930 s_dim = sample_indices.dim()
Mike Ruberry13120bf2020-05-27 06:28:05 -07004931 self.assertEqual(sample_indices.dim(), 1, msg="wrong number of dimensions")
Shen Li10224432021-08-12 11:39:31 -07004932 self.assertEqual(prob_dist.dim(), 1, msg="wrong number of prob_dist dimensions")
4933 self.assertEqual(sample_indices.size(0), n_sample, msg="wrong number of samples")
Mike Ruberryb45f1b92019-10-01 19:17:06 -07004934
Yukio Siraichicf17fd62021-04-27 12:01:09 -07004935 # CUDA misalignment issue (#46702)
4936 n_row, n_col = 2, 3
4937 prob_dist = make_prob_dist([n_row, n_col], True)
4938 n_sample = 1
4939 sample_indices = torch.multinomial(prob_dist, n_sample, True)
4940 self.assertEqual(sample_indices.dim(), 2, msg="wrong number of dimensions")
Shen Li10224432021-08-12 11:39:31 -07004941 self.assertEqual(sample_indices.size(1), n_sample, msg="wrong number of samples")
Yukio Siraichicf17fd62021-04-27 12:01:09 -07004942
Mike Ruberrye0d829a2022-01-24 01:28:07 -08004943 # FIXME: move to test distributions
Yukio Siraichicf17fd62021-04-27 12:01:09 -07004944 @onlyCUDA
4945 @dtypes(torch.float, torch.double, torch.half)
4946 def test_multinomial_deterministic(self, device, dtype):
4947 gen = torch.Generator(device=device)
4948
4949 trials = 5
4950 seed = 0
4951 prob_dist = torch.rand(10000, 1000, device=device, dtype=dtype)
4952 n_sample = 1
4953
4954 for i in range(trials):
4955 gen.manual_seed(seed)
4956 samples_1 = torch.multinomial(prob_dist, n_sample, True, generator=gen)
4957
4958 gen.manual_seed(seed)
4959 samples_2 = torch.multinomial(prob_dist, n_sample, True, generator=gen)
4960
4961 self.assertEqual(samples_1, samples_2)
4962 self.assertEqual(samples_1.dim(), 2, msg="wrong number of dimensions")
4963 self.assertEqual(samples_1.size(1), n_sample, msg="wrong number of samples")
4964
Mike Ruberrye0d829a2022-01-24 01:28:07 -08004965 # FIXME: move to test distributions
Natalia Gimelshein3d968082020-05-12 22:30:45 -07004966 @slowTest
4967 @dtypes(torch.float)
4968 def test_multinomial_rng_state_advance(self, device, dtype):
4969 corpus_size = 100000
4970 freqs = torch.ones(corpus_size, dtype=torch.float, device=device)
4971 n_sample = 100
4972 samples1 = torch.multinomial(freqs, n_sample, replacement=True)
4973 samples2 = torch.multinomial(freqs, n_sample, replacement=True)
4974 samples = torch.cat([samples1, samples2])
4975 # expect no more than 1 repeating elements generated in 2 attempts
4976 # the probability of at least element being repeated is surprisingly large, 18%
4977 self.assertLessEqual(2 * n_sample - samples.unique().size(0), 2)
4978 samples1 = torch.multinomial(freqs, n_sample, replacement=False)
4979 samples2 = torch.multinomial(freqs, n_sample, replacement=False)
4980 samples = torch.cat([samples1, samples2])
4981 # expect no more than 1 repeating elements generated in 2 attempts
4982 self.assertLessEqual(2 * n_sample - samples.unique().size(0), 1)
4983
Shen Li10224432021-08-12 11:39:31 -07004984 def _test_memory_format_transformations(self, device, input_generator_fn, transformation_fn,
4985 memory_format, compare_data=True, default_is_preserve=False):
lixinyuf9f135c2020-03-06 05:59:20 -08004986
Shen Li10224432021-08-12 11:39:31 -07004987 assert(memory_format == torch.channels_last or memory_format == torch.channels_last_3d)
lixinyuf9f135c2020-03-06 05:59:20 -08004988
4989 # xc is a channels last tensor
4990 xc = input_generator_fn(device)
4991 # xc is not memory dense, but looks like channels last
4992 if memory_format == torch.channels_last:
4993 xc = xc[..., ::2, ::2]
4994 else:
4995 xc = xc[..., ::2, ::2, ::2]
4996
4997 clone = transformation_fn(xc, memory_format=torch.preserve_format)
Vitaly Fedyunind39ab032019-10-15 12:54:18 -07004998 self.assertFalse(clone.is_contiguous())
lixinyuf9f135c2020-03-06 05:59:20 -08004999 self.assertTrue(clone.is_contiguous(memory_format=memory_format))
5000 self.assertFalse(xc.is_contiguous())
5001 self.assertFalse(xc.is_contiguous(memory_format=memory_format))
Vitaly Fedyuninbaf84882019-10-25 07:26:52 -07005002 if compare_data:
lixinyuf9f135c2020-03-06 05:59:20 -08005003 self.assertEqual(xc, clone.to(xc))
Vitaly Fedyunind39ab032019-10-15 12:54:18 -07005004
lixinyuf9f135c2020-03-06 05:59:20 -08005005 xc = input_generator_fn(device)
5006 clone = transformation_fn(xc, memory_format=torch.contiguous_format)
Vitaly Fedyunind39ab032019-10-15 12:54:18 -07005007 self.assertTrue(clone.is_contiguous())
lixinyuf9f135c2020-03-06 05:59:20 -08005008 self.assertFalse(clone.is_contiguous(memory_format=memory_format))
Vitaly Fedyuninbaf84882019-10-25 07:26:52 -07005009 if compare_data:
lixinyuf9f135c2020-03-06 05:59:20 -08005010 self.assertEqual(xc, clone.to(xc))
Vitaly Fedyunind39ab032019-10-15 12:54:18 -07005011
lixinyuf9f135c2020-03-06 05:59:20 -08005012 xc = input_generator_fn(device)
5013 clone = transformation_fn(xc)
Igor Fedan5835ad02019-10-28 08:18:11 -07005014
5015 if default_is_preserve:
5016 self.assertFalse(clone.is_contiguous())
lixinyuf9f135c2020-03-06 05:59:20 -08005017 self.assertTrue(clone.is_contiguous(memory_format=memory_format))
Igor Fedan5835ad02019-10-28 08:18:11 -07005018 else:
5019 self.assertTrue(clone.is_contiguous())
lixinyuf9f135c2020-03-06 05:59:20 -08005020 self.assertFalse(clone.is_contiguous(memory_format=memory_format))
Vitaly Fedyuninbaf84882019-10-25 07:26:52 -07005021 if compare_data:
lixinyuf9f135c2020-03-06 05:59:20 -08005022 self.assertEqual(xc, clone.to(xc))
Vitaly Fedyunind39ab032019-10-15 12:54:18 -07005023
5024 x = torch.randn((3, 4, 5, 6, 7, 8, 9), device=device)
5025 for _ in range(10):
5026 permutation = list(range(len(x.shape)))
5027 random.shuffle(permutation)
5028 x = x.permute(permutation)
Shen Li10224432021-08-12 11:39:31 -07005029 self.assertEqual(x.stride(), transformation_fn(x, memory_format=torch.preserve_format).stride())
Vitaly Fedyunind39ab032019-10-15 12:54:18 -07005030
5031 def test_memory_format_to(self, device):
lixinyuf9f135c2020-03-06 05:59:20 -08005032 def get_generator(memory_format, shape):
5033 def input_generator_fn(device):
Shen Li10224432021-08-12 11:39:31 -07005034 return torch.randn(shape, device=device, dtype=torch.float32).contiguous(memory_format=memory_format)
lixinyuf9f135c2020-03-06 05:59:20 -08005035 return input_generator_fn
Vitaly Fedyunind39ab032019-10-15 12:54:18 -07005036
5037 def transformation_fn(tensor, **kwargs):
5038 return tensor.to(dtype=torch.float64, **kwargs)
5039
lixinyuf9f135c2020-03-06 05:59:20 -08005040 formats_shapes = (
5041 (torch.channels_last, (4, 3, 8, 8)),
Shen Li10224432021-08-12 11:39:31 -07005042 (torch.channels_last_3d, (4, 3, 8, 8, 8)))
lixinyuf9f135c2020-03-06 05:59:20 -08005043
5044 for mf, shape in formats_shapes:
5045 self._test_memory_format_transformations(
Shen Li10224432021-08-12 11:39:31 -07005046 device, get_generator(mf, shape), transformation_fn, mf, default_is_preserve=True)
Vitaly Fedyunind39ab032019-10-15 12:54:18 -07005047
5048 def test_memory_format_type(self, device):
lixinyuf9f135c2020-03-06 05:59:20 -08005049 def get_generator(memory_format, shape):
5050 def input_generator_fn(device):
Shen Li10224432021-08-12 11:39:31 -07005051 return torch.randn(shape, device=device, dtype=torch.float32).contiguous(memory_format=memory_format)
lixinyuf9f135c2020-03-06 05:59:20 -08005052 return input_generator_fn
Vitaly Fedyunind39ab032019-10-15 12:54:18 -07005053
5054 def transformation_fn(tensor, **kwargs):
Ailing Zhang7c13a072020-05-12 13:32:26 -07005055 return tensor.to(torch.float64, **kwargs)
Vitaly Fedyunind39ab032019-10-15 12:54:18 -07005056
lixinyuf9f135c2020-03-06 05:59:20 -08005057 formats_shapes = (
5058 (torch.channels_last, (4, 3, 8, 8)),
Shen Li10224432021-08-12 11:39:31 -07005059 (torch.channels_last_3d, (4, 3, 8, 8, 8)))
lixinyuf9f135c2020-03-06 05:59:20 -08005060
5061 for mf, shape in formats_shapes:
5062 self._test_memory_format_transformations(
Shen Li10224432021-08-12 11:39:31 -07005063 device, get_generator(mf, shape), transformation_fn, mf, default_is_preserve=True)
Vitaly Fedyunind39ab032019-10-15 12:54:18 -07005064
5065 def test_memory_format_clone(self, device):
lixinyuf9f135c2020-03-06 05:59:20 -08005066 def get_generator(memory_format, shape):
5067 def input_generator_fn(device):
Shen Li10224432021-08-12 11:39:31 -07005068 return torch.randn(shape, device=device, dtype=torch.float32).contiguous(memory_format=memory_format)
lixinyuf9f135c2020-03-06 05:59:20 -08005069 return input_generator_fn
Vitaly Fedyunind39ab032019-10-15 12:54:18 -07005070
5071 def transformation_fn(tensor, **kwargs):
5072 return tensor.clone(**kwargs)
5073
lixinyuf9f135c2020-03-06 05:59:20 -08005074 formats_shapes = (
5075 (torch.channels_last, (4, 3, 8, 8)),
Shen Li10224432021-08-12 11:39:31 -07005076 (torch.channels_last_3d, (4, 3, 8, 8, 8)))
lixinyuf9f135c2020-03-06 05:59:20 -08005077
5078 for mf, shape in formats_shapes:
5079 self._test_memory_format_transformations(
Shen Li10224432021-08-12 11:39:31 -07005080 device, get_generator(mf, shape), transformation_fn, mf, True, default_is_preserve=True)
Vitaly Fedyunind39ab032019-10-15 12:54:18 -07005081
Vitaly Fedyunin927588d2019-12-13 15:28:32 -08005082 def test_memory_format_factory_like_functions_preserve(self, device):
lixinyuf9f135c2020-03-06 05:59:20 -08005083 def get_generator(memory_format, shape):
5084 def input_generator_fn(device):
Shen Li10224432021-08-12 11:39:31 -07005085 return torch.randn(shape, device=device, dtype=torch.float32).contiguous(memory_format=memory_format)
lixinyuf9f135c2020-03-06 05:59:20 -08005086 return input_generator_fn
Vitaly Fedyuninc258cd02019-10-25 07:26:52 -07005087
Vitaly Fedyunin4bfe2f02019-10-31 13:19:31 -07005088 transformation_fns = [
5089 lambda t, **kwargs: torch.zeros_like(t, **kwargs),
5090 lambda t, **kwargs: torch.ones_like(t, **kwargs),
5091 lambda t, **kwargs: torch.randint_like(t, 10, 100, **kwargs),
5092 lambda t, **kwargs: torch.randint_like(t, 100, **kwargs),
5093 lambda t, **kwargs: torch.randn_like(t, **kwargs),
5094 lambda t, **kwargs: torch.rand_like(t, **kwargs),
5095 lambda t, **kwargs: torch.full_like(t, 7, **kwargs),
Shen Li10224432021-08-12 11:39:31 -07005096 lambda t, **kwargs: torch.empty_like(t, **kwargs)]
Vitaly Fedyuninc258cd02019-10-25 07:26:52 -07005097
lixinyuf9f135c2020-03-06 05:59:20 -08005098 formats_shapes = (
5099 (torch.channels_last, (4, 3, 8, 8)),
Shen Li10224432021-08-12 11:39:31 -07005100 (torch.channels_last_3d, (4, 3, 8, 8, 8)))
lixinyuf9f135c2020-03-06 05:59:20 -08005101
Shen Li10224432021-08-12 11:39:31 -07005102 for mf, shape, in formats_shapes:
lixinyuf9f135c2020-03-06 05:59:20 -08005103 for transformation_fn in transformation_fns:
5104 self._test_memory_format_transformations(
Shen Li10224432021-08-12 11:39:31 -07005105 device, get_generator(mf, shape), transformation_fn, mf, compare_data=False, default_is_preserve=True)
Vitaly Fedyunin7ff272c2019-10-26 13:05:34 -07005106
Vitaly Fedyunin951dd032019-10-17 09:12:39 -07005107 def test_memory_format_type_shortcuts(self, device):
lixinyuf9f135c2020-03-06 05:59:20 -08005108 def get_generator(memory_format, shape, dtype):
5109 def input_generator_fn(device):
Shen Li10224432021-08-12 11:39:31 -07005110 return torch.randn(shape, device=device, dtype=dtype).clamp(0, 1) \
5111 .round().contiguous(memory_format=memory_format)
Zsolt Dollensteinb0043072021-08-12 10:56:55 -07005112 return input_generator_fn
Vitaly Fedyunin951dd032019-10-17 09:12:39 -07005113
Shen Li10224432021-08-12 11:39:31 -07005114
Vitaly Fedyunin951dd032019-10-17 09:12:39 -07005115 def get_fn(fn_name):
5116 def transformation_fn(tensor, **kwargs):
5117 fn = getattr(tensor, fn_name)
5118 return fn(**kwargs)
5119 return transformation_fn
5120
Shen Li10224432021-08-12 11:39:31 -07005121 shortcuts = ['byte', 'char', 'double', 'bool', 'half', 'int', 'long', 'short']
5122 if device == 'cpu':
5123 shortcuts += ['bfloat16']
Vitaly Fedyunin951dd032019-10-17 09:12:39 -07005124
lixinyuf9f135c2020-03-06 05:59:20 -08005125 formats_shapes = (
5126 (torch.channels_last, (4, 3, 8, 8)),
Shen Li10224432021-08-12 11:39:31 -07005127 (torch.channels_last_3d, (4, 3, 8, 8, 8)))
lixinyuf9f135c2020-03-06 05:59:20 -08005128
5129 for mf, shape in formats_shapes:
5130 for fn_name in shortcuts:
5131 self._test_memory_format_transformations(
Shen Li10224432021-08-12 11:39:31 -07005132 device, get_generator(mf, shape, torch.float32), get_fn(fn_name), mf, default_is_preserve=True)
Vitaly Fedyunin951dd032019-10-17 09:12:39 -07005133
5134 # Test 'float' separately to avoid float->float no-op.
lixinyuf9f135c2020-03-06 05:59:20 -08005135 for mf, shape in formats_shapes:
5136 self._test_memory_format_transformations(
Shen Li10224432021-08-12 11:39:31 -07005137 device, get_generator(mf, shape, torch.float64), get_fn('float'), mf, default_is_preserve=True)
Vitaly Fedyunin951dd032019-10-17 09:12:39 -07005138
Vitaly Fedyunin15df3712019-10-17 09:12:39 -07005139 @onlyCUDA
5140 def test_memory_format_cpu_and_cuda_ops(self, device):
lixinyuf9f135c2020-03-06 05:59:20 -08005141 def get_generator(memory_format, shape):
5142 def input_generator_fn(device):
Shen Li10224432021-08-12 11:39:31 -07005143 return torch.randn(shape, device=device, dtype=torch.float32).contiguous(memory_format=memory_format)
lixinyuf9f135c2020-03-06 05:59:20 -08005144 return input_generator_fn
Vitaly Fedyunin15df3712019-10-17 09:12:39 -07005145
5146 def transformation_cpu_fn(tensor, **kwargs):
5147 return tensor.cpu(**kwargs)
5148
5149 def transformation_cuda_fn(tensor, **kwargs):
5150 return tensor.cuda(**kwargs)
5151
lixinyuf9f135c2020-03-06 05:59:20 -08005152 formats_shapes = (
5153 (torch.channels_last, (4, 3, 8, 8)),
Shen Li10224432021-08-12 11:39:31 -07005154 (torch.channels_last_3d, (4, 3, 8, 8, 8)))
lixinyuf9f135c2020-03-06 05:59:20 -08005155
5156 for mf, shape in formats_shapes:
5157 self._test_memory_format_transformations(
Shen Li10224432021-08-12 11:39:31 -07005158 'cuda', get_generator(mf, shape), transformation_cpu_fn, mf, default_is_preserve=True)
lixinyuf9f135c2020-03-06 05:59:20 -08005159 self._test_memory_format_transformations(
Shen Li10224432021-08-12 11:39:31 -07005160 'cpu', get_generator(mf, shape), transformation_cuda_fn, mf, default_is_preserve=True)
Vitaly Fedyunin15df3712019-10-17 09:12:39 -07005161
Mike Ruberrye0d829a2022-01-24 01:28:07 -08005162 # FIXME: move to test_serialization
Michael Carilli25f91852020-05-14 09:11:39 -07005163 def test_pickle_gradscaler(self, device):
5164 # This test is not in test_cuda.py because it should pass in 3 cases:
5165 # 1. cuda is not available.
5166 # 2. cuda is available but device is not cuda.
5167 # 3. cuda is available and device is cuda.
5168 # In case 1, a and b disable themselves on construction and shouldn't try to pickle workhorse attributes.
5169 # In case 2, a and b are enabled. Workhorse attributes participate in pickling, but none are lazy-inited
5170 # to cuda Tensors, because I don't want to do cuda things if device is not cuda.
5171 # In case 3, a and b are enabled and we may also try lazy-initing _scale to a cuda tensor.
5172 device = torch.device(device)
5173 try_lazy_inits = (True, False) if device.type == "cuda" else (False,)
5174 for lazy_init_scale in try_lazy_inits:
Shen Li10224432021-08-12 11:39:31 -07005175 a = torch.cuda.amp.GradScaler(init_scale=3., growth_factor=4., backoff_factor=.5, growth_interval=2)
5176 self.assertTrue(not a.is_enabled() if torch.cuda.amp.common.amp_definitely_not_available() else a.is_enabled())
Michael Carilli25f91852020-05-14 09:11:39 -07005177 if lazy_init_scale:
5178 # Dummy a.scale() call lazy-inits a._scale Tensor.
5179 a.scale(torch.tensor([4.0], dtype=torch.float32, device=device))
5180 self.assertTrue(isinstance(a._scale, torch.cuda.FloatTensor))
5181 # The following three lines should work whether or not cuda is available.
5182 serialized = pickle.dumps(a)
5183 b = pickle.loads(serialized)
5184 self.assertEqual(b.is_enabled(), a.is_enabled())
5185 if a.is_enabled():
Shen Li10224432021-08-12 11:39:31 -07005186 self.assertEqual(b.get_scale(), 3.)
5187 self.assertEqual(b.get_growth_factor(), 4.)
5188 self.assertEqual(b.get_backoff_factor(), .5)
Michael Carilli25f91852020-05-14 09:11:39 -07005189 self.assertEqual(b.get_growth_interval(), 2)
5190 self.assertEqual(b._init_growth_tracker, 0)
5191 # supplies a dummy key to test the defaultdict's default_factory
Shen Li10224432021-08-12 11:39:31 -07005192 self.assertEqual(b._per_optimizer_states["fdsa"],
5193 torch.cuda.amp.grad_scaler._refresh_per_optimizer_state())
Michael Carilli25f91852020-05-14 09:11:39 -07005194 if lazy_init_scale:
Shen Li10224432021-08-12 11:39:31 -07005195 self.assertEqual(b.scale(torch.tensor([4.0], dtype=torch.float32, device=device)), 12.0)
Michael Carilli25f91852020-05-14 09:11:39 -07005196
Mike Ruberrye0d829a2022-01-24 01:28:07 -08005197 # FIXME: convert to ErrorInputs
Kulin Sethe011a8e2022-05-13 18:28:53 +00005198 @skipIfMps
kshitij1234597dfdaa2020-06-10 20:39:24 -07005199 def test_multinomial_invalid(self, device):
5200 def test(probs):
Shen Li10224432021-08-12 11:39:31 -07005201 with self.assertRaisesRegex(RuntimeError,
5202 'probability tensor contains either `inf`, `nan` or element < 0'):
Alexander Grund71d95922022-06-27 14:49:38 +00005203 out = torch.multinomial(probs.to(device), 2)
5204 if out.is_cuda:
5205 torch.cuda.synchronize()
kshitij1234597dfdaa2020-06-10 20:39:24 -07005206
Shen Li10224432021-08-12 11:39:31 -07005207 test(torch.tensor([1., -1., 1.]))
5208 test(torch.tensor([1., inf, 1.]))
5209 test(torch.tensor([1., -inf, 1.]))
5210 test(torch.tensor([1., 1., nan]))
kshitij1234597dfdaa2020-06-10 20:39:24 -07005211
Mike Ruberrye0d829a2022-01-24 01:28:07 -08005212 # FIXME: convert to ErrorInputs
Kulin Sethe011a8e2022-05-13 18:28:53 +00005213 @skipIfMps
kshitij1234597dfdaa2020-06-10 20:39:24 -07005214 def test_multinomial_invalid_distribution(self, device):
5215 def test(probs, replacement):
Shen Li10224432021-08-12 11:39:31 -07005216 with self.assertRaisesRegex(RuntimeError,
5217 r"invalid multinomial distribution \(sum of probabilities <= 0\)"):
Alexander Grund71d95922022-06-27 14:49:38 +00005218 out = torch.multinomial(probs, 2, replacement)
5219 if out.is_cuda:
5220 torch.cuda.synchronize()
kshitij1234597dfdaa2020-06-10 20:39:24 -07005221
5222 x = torch.zeros(3, device=device)
5223 y = torch.zeros(3, 3, device=device)
5224 z = torch.zeros(3, 3, device=device)
5225 z[1, :] = 1
5226
5227 test(x, False)
5228 test(y, False)
5229 test(z, False)
5230
5231 # Verify only for CPU as replacement=True
5232 # throws device side assert triggered.
Shen Li10224432021-08-12 11:39:31 -07005233 if self.device_type == 'cpu':
kshitij1234597dfdaa2020-06-10 20:39:24 -07005234 test(x, True)
5235 test(y, True)
5236 test(z, True)
5237
Mike Ruberrye0d829a2022-01-24 01:28:07 -08005238 # FIXME: move to test distributions
kshitij123450394c5a2020-08-31 11:53:23 -07005239 def _test_multinomial_empty(self, device, replacement, num_samples):
5240 probs = torch.ones(0, 3, device=device)
Natalia Gimelsheinf59e3892020-06-11 17:24:22 -07005241 expected = torch.empty(0, num_samples, dtype=torch.int64)
kshitij123450394c5a2020-08-31 11:53:23 -07005242 out = torch.multinomial(probs, num_samples=num_samples, replacement=replacement)
5243 self.assertEqual(out, expected)
5244
Mike Ruberrye0d829a2022-01-24 01:28:07 -08005245 # FIXME: move to test distributions
kshitij123450394c5a2020-08-31 11:53:23 -07005246 def test_multinomial_empty_w_replacement(self, device):
5247 self._test_multinomial_empty(device, True, 1)
5248 self._test_multinomial_empty(device, True, 2)
5249
Mike Ruberrye0d829a2022-01-24 01:28:07 -08005250 # FIXME: move to test distributions
kshitij123450394c5a2020-08-31 11:53:23 -07005251 def test_multinomial_empty_wo_replacement(self, device):
5252 self._test_multinomial_empty(device, False, 1)
5253 self._test_multinomial_empty(device, False, 2)
Natalia Gimelsheinf59e3892020-06-11 17:24:22 -07005254
Beilei Zheng332086c2022-04-14 15:42:18 +00005255 @dtypesIfCUDA(torch.float, torch.double, torch.half)
5256 @dtypesIfCPU(torch.float, torch.double, torch.bfloat16)
5257 @dtypes(torch.float, torch.double)
5258 def test_multinomial_cpu(self, device, dtype):
5259 def make_prob_dist(shape, is_contiguous):
5260 if is_contiguous:
5261 if dtype == torch.half or dtype == torch.bfloat16:
5262 return torch.zeros(shape, device=device).uniform_().to(dtype=dtype)
5263 return torch.zeros(shape, device=device, dtype=dtype).uniform_()
5264 elif len(shape) == 1:
5265 if dtype == torch.half or dtype == torch.bfloat16:
5266 return torch.zeros((shape + [5]), device=device).uniform_().to(dtype=dtype)[:, 2]
5267 return torch.zeros((shape + [5]), device=device, dtype=dtype).uniform_()[:, 2]
5268 else:
5269 # num dim = 2
5270 new_shape = [2, shape[1], 7, 1, shape[0], 1, 10]
5271 if dtype == torch.half or dtype == torch.bfloat16:
5272 prob_dist = torch.zeros(new_shape, device=device).uniform_().to(dtype=dtype)
5273 else:
5274 prob_dist = torch.zeros(new_shape, device=device, dtype=dtype).uniform_()
5275 prob_dist = prob_dist.transpose(1, 4)
5276 prob_dist = prob_dist[1, :, 5, 0, :, 0, 4]
5277 assert not prob_dist.is_contiguous() # sanity check
5278 return prob_dist
5279
Mike Ruberrye0d829a2022-01-24 01:28:07 -08005280 # FIXME: move to elementwise ternary test suite
kshitij1234531d41f92020-07-30 22:14:23 -07005281 # As the test fails with Runtime Error not raised on XLA
kshitij12345885a8e52021-11-01 09:21:20 -07005282 @onlyNativeDeviceTypes
Natalia Gimelsheince762442022-05-03 04:40:04 +00005283 def test_where_scalar_handcrafted_values(self, device):
5284 # Tests ScalarxScalar, ScalarxTensor and TensorxScalar
5285 # variant of `where` against NumPy version with
5286 # handcrafted values.
5287 condition_shape = (5, 5)
5288 dtypes = (
5289 torch.bool, torch.uint8, torch.int8, torch.int16, torch.int64,
5290 torch.float16, torch.float32, torch.float64,
5291 torch.complex64, torch.complex128,
5292 )
5293 shapes = ((), (5,), (1, 5),)
kshitij1234531d41f92020-07-30 22:14:23 -07005294
Natalia Gimelsheince762442022-05-03 04:40:04 +00005295 with torch.no_grad():
5296 tensors = (torch.empty(shape, dtype=dtype, device=device).fill_(17)
5297 for shape, dtype in product(shapes, dtypes))
5298
5299 # Use different values for `x` and `y`
5300 # as they are the output values which are compared.
5301 x_vals = (True, 3, 7.0, 1 + 0.5j)
5302 y_vals = itertools.chain((False, 4, 8.0, 2 + 0.5j), tensors)
5303 for x in x_vals:
5304 for y in y_vals:
5305 condition = torch.empty(*condition_shape, dtype=torch.bool, device=device).bernoulli_()
5306 common_dtype = torch.result_type(x, y)
5307
5308 def check_equal(condition, x, y):
5309 condition_np = condition.cpu().numpy()
5310 x_np = x.cpu().numpy() if isinstance(x, torch.Tensor) else x
5311 y_np = y.cpu().numpy() if isinstance(y, torch.Tensor) else y
5312
5313 # NumPy aggressively promotes to double, hence cast to output to correct dtype
5314 expected = torch.from_numpy(np.where(condition_np, x_np, y_np)).to(common_dtype)
5315 result = torch.where(condition, x, y)
5316 self.assertEqual(expected, result)
5317
5318 check_equal(condition, x, y)
5319 check_equal(condition, y, x)
5320
kshitij1234531d41f92020-07-30 22:14:23 -07005321
kshitij123455e9bcf92021-07-09 08:35:59 -07005322 def test_hook_remove(self, device):
5323 # Reference: https://github.com/pytorch/pytorch/issues/58354
5324 def _test_helper(remove_hook):
5325 def install_hook(tensor):
5326 handle = None
5327
5328 def hook(tensor):
5329 if remove_hook:
5330 handle.remove()
5331 return torch.zeros_like(tensor)
5332 handle = tensor.register_hook(hook)
5333
5334 t = torch.ones((1, 5), device=device, requires_grad=True)
5335 install_hook(t)
5336
5337 # First call to backward
5338 t.mean().backward()
5339 self.assertEqual(t.grad, torch.zeros_like(t))
5340
5341 # Second call to backward
5342 t.mean().backward()
5343 if remove_hook:
5344 # After removing the hook, make sure the usual gradient is returned
5345 self.assertEqual(t.grad, 0.2 * torch.ones_like(t))
5346 else:
5347 self.assertEqual(t.grad, torch.zeros_like(t))
5348
5349 _test_helper(remove_hook=True)
5350 _test_helper(remove_hook=False)
5351
Mike Ruberrye0d829a2022-01-24 01:28:07 -08005352 # FIXME: get PyTorch/XLA to run test_testing
kshitij123455b2586f2021-12-02 14:42:13 -08005353 # This test should ideally be in test_testing.py,
5354 # but since pytorch/xla runs tests from test_torch.py, we have it here.
kshitij12345c00806b2021-10-29 19:51:52 -07005355 @skipXLA
5356 def test_skip_xla(self, device):
5357 if self.device_type == 'xla':
5358 # Should not reach here!
5359 self.assertTrue(False)
5360
Mike Ruberrye0d829a2022-01-24 01:28:07 -08005361 # FIXME: get PyTorch/XLA to run test_testing
kshitij123455b2586f2021-12-02 14:42:13 -08005362 # This test should ideally be in test_testing.py,
5363 # but since pytorch/xla runs tests from test_torch.py, we have it here.
kshitij12345c00806b2021-10-29 19:51:52 -07005364 @expectedFailureXLA
5365 def test_expected_failure_xla(self, device):
5366 if self.device_type == 'xla':
5367 self.assertTrue(False)
5368
Mike Ruberrye0d829a2022-01-24 01:28:07 -08005369 # FIXME: get PyTorch/XLA to run test_testing
kshitij123455b2586f2021-12-02 14:42:13 -08005370 # This test should ideally be in test_testing.py,
5371 # but since pytorch/xla runs tests from test_torch.py, we have it here.
5372 def test_assertRaisesRegex_ignore_msg_non_native_device(self, device):
5373 # Verify that self.assertRaisesRegex only checks the Error and ignores
5374 # message for non-native devices.
5375 x = torch.randn((10, 3), device=device)
5376 t = torch.empty(10, dtype=torch.int64, device=device).random_(0, 3)
5377 invalid_weight = torch.randn(4, device=device)
5378 msg = "weight tensor should be defined either for all 3 classes or no classes"
5379
5380 # XLA raises RuntimeError with a different message.
5381 with self.assertRaisesRegex(RuntimeError, msg):
5382 torch.nn.functional.nll_loss(x, t, weight=invalid_weight)
5383
kshitij12345f7ee3082022-03-23 21:42:59 +00005384 @dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.complex32))
5385 def test_copy_(self, device, dtype):
5386 def can_cast(src_dtype, dst_dtype):
5387 # torch.can_cast(torch.int16, torch.uint8) returns True
5388 # which isn't actually safe-cast.
5389 # This function returns False in this case.
5390 def is_unsigned_int(dtype):
5391 return dtype is torch.uint8
5392
5393 if is_unsigned_int(dst_dtype):
5394 return is_unsigned_int(src_dtype)
5395 return torch.can_cast(src_dtype, dst_dtype)
5396
5397 def make_tensor_wrapper(shape, dtype):
5398 if dtype is not torch.complex32:
5399 # Make tensor does not support generating
5400 # complex32 tensor
5401 return make_tensor(shape, device=device, dtype=dtype)
5402 return torch.randn(shape, device=device, dtype=dtype)
5403
5404 t = make_tensor_wrapper((50,), dtype)
5405 src_dtypes = all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.complex32)
5406 for src_dtype in src_dtypes:
5407 src = make_tensor_wrapper((50,), dtype=src_dtype)
5408 t.copy_(src)
5409 dst = make_tensor_wrapper((50, ), dtype=src_dtype)
5410 if can_cast(src_dtype, dtype):
5411 rtol = None
5412 atol = None
5413 if dtype in (torch.half, torch.complex32):
5414 rtol = 1e-3
5415 atol = 1e-3
5416 if dtype in (torch.bfloat16,):
5417 rtol = 1e-2
5418 atol = 1e-2
5419 self.assertEqual(src, dst.copy_(t), rtol=rtol, atol=atol)
5420
kshitij1234565b65af2022-04-01 15:19:05 +00005421 @dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.complex32))
5422 def test_item(self, device, dtype):
5423 t = torch.ones((), device=device, dtype=dtype)
5424 self.assertEqual(1, t.item())
5425
kshitij12345c00806b2021-10-29 19:51:52 -07005426
Mike Ruberryea414e42019-09-30 19:07:28 -07005427# Tests that compare a device's computation with the (gold-standard) CPU's.
5428class TestDevicePrecision(TestCase):
Edward Yangba1bd412020-03-03 14:33:40 -08005429 exact_dtype = True
Richard Zou5c423ca2020-01-09 07:34:10 -08005430
Mike Ruberrye0d829a2022-01-24 01:28:07 -08005431 # FIXME: move to indexing test suite
rohithkrn2f32b922020-03-12 11:26:25 -07005432 @onlyCUDA
rohithkrn2f32b922020-03-12 11:26:25 -07005433 def test_index_add_bfloat16(self, device):
Shen Li10224432021-08-12 11:39:31 -07005434 inp_tensor = torch.randn(5, 3, device='cpu').bfloat16()
5435 t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.bfloat16, device='cpu')
5436 index = torch.tensor([0, 4, 2], device='cpu')
rohithkrn2f32b922020-03-12 11:26:25 -07005437 out_cpu = inp_tensor.index_add(0, index, t)
5438
5439 inp_tensor = inp_tensor.to(device=device)
5440 t = t.to(device=device)
5441 index = index.to(device=device)
5442 out_gpu = inp_tensor.index_add(0, index, t)
5443
Mike Ruberry13120bf2020-05-27 06:28:05 -07005444 self.assertEqual(out_cpu, out_gpu, atol=1e-2, rtol=0)
rohithkrn2f32b922020-03-12 11:26:25 -07005445
Mike Ruberrye0d829a2022-01-24 01:28:07 -08005446 # FIXME: move to serialization test suite
Mike Ruberryb45f1b92019-10-01 19:17:06 -07005447 def test_device_serialization(self, device):
5448 x = torch.randn(4, 4, device=device)
5449
5450 with tempfile.NamedTemporaryFile() as f:
5451 torch.save(x, f)
5452 f.seek(0)
5453 x_copy = torch.load(f)
5454
5455 self.assertEqual(x_copy, x)
5456 self.assertIs(type(x_copy), type(x))
5457 self.assertEqual(x_copy.device, x.device)
5458
Mike Ruberrye0d829a2022-01-24 01:28:07 -08005459 # FIXME: move to serialization test suite
Mike Ruberryb45f1b92019-10-01 19:17:06 -07005460 @deviceCountAtLeast(2)
5461 def test_multidevice_serialization(self, devices):
Shen Li10224432021-08-12 11:39:31 -07005462 x = [torch.randn(4, 4, device=devices[0]),
5463 torch.randn(4, 4, device=devices[1])]
Mike Ruberryb45f1b92019-10-01 19:17:06 -07005464
5465 with tempfile.NamedTemporaryFile() as f:
5466 torch.save(x, f)
5467 f.seek(0)
5468 x_copy = torch.load(f)
5469
5470 for original, cp in zip(x, x_copy):
5471 self.assertEqual(cp, original)
5472 self.assertIs(type(cp), type(original))
5473 self.assertEqual(cp.device, original.device)
5474
Mike Ruberrye0d829a2022-01-24 01:28:07 -08005475 # FIXME: move to data movement test suite
Mike Ruberryb45f1b92019-10-01 19:17:06 -07005476 @deviceCountAtLeast(1)
5477 def test_copy_noncontig(self, devices):
5478 def do_test(d0, d1):
5479 x = torch.tensor([1.5, 2.5, 3.5, 4.5, 5.5, 6.5], device=d0)
5480 y = torch.tensor([0, 0, 0, 0, 0, 0], device=d1)
5481 self.assertNotEqual(x.dtype, y.dtype)
5482
5483 y[::2].copy_(x[::2])
5484 self.assertEqual(y, [1, 0, 3, 0, 5, 0])
5485
Shen Li10224432021-08-12 11:39:31 -07005486 do_test('cpu', devices[0])
5487 do_test(devices[0], 'cpu')
Mike Ruberryb45f1b92019-10-01 19:17:06 -07005488
5489 if len(devices) > 1:
5490 do_test(devices[0], devices[1])
5491
Mike Ruberryb45f1b92019-10-01 19:17:06 -07005492 @deviceCountAtLeast(2)
5493 def test_type_conversions_same_device(self, devices):
5494 x = torch.randn(5, 5, device=devices[1])
5495 self.assertEqual(x.int().device, torch.device(devices[1]))
5496 self.assertEqual(x.type(torch.int).device, torch.device(devices[1]))
5497 self.assertEqual(x.to(torch.int).device, torch.device(devices[1]))
5498
Shen Li10224432021-08-12 11:39:31 -07005499 @dtypesIfCUDA(torch.half, torch.float, torch.double,
5500 torch.int8, torch.short, torch.int, torch.long,
5501 torch.uint8)
5502 @dtypes(torch.float, torch.double,
5503 torch.int8, torch.short, torch.int, torch.long,
5504 torch.uint8)
Mike Ruberryb45f1b92019-10-01 19:17:06 -07005505 def test_from_sequence(self, device, dtype):
5506 seq = [list(range(i * 4, i * 4 + 4)) for i in range(5)]
5507 reference = torch.arange(0, 20).resize_(5, 4)
Shen Li10224432021-08-12 11:39:31 -07005508 self.assertEqual(torch.tensor(seq, dtype=dtype, device=device), reference, exact_dtype=False)
Mike Ruberryb45f1b92019-10-01 19:17:06 -07005509
Mike Ruberrye0d829a2022-01-24 01:28:07 -08005510 # FIXME: moved to indexing test suite
Mike Ruberryb45f1b92019-10-01 19:17:06 -07005511 @deviceCountAtLeast(1)
Nikita Shulga8811e4d2020-06-04 13:37:44 -07005512 def test_advancedindex_mixed_cpu_devices(self, devices) -> None:
5513 def test(x: torch.Tensor, ia: torch.Tensor, ib: torch.Tensor) -> None:
Mike Ruberryb45f1b92019-10-01 19:17:06 -07005514 # test getitem
Shen Li10224432021-08-12 11:39:31 -07005515 self.assertEqual(x[:, ia, None, ib, 0].cpu(),
5516 x.cpu()[:, ia.cpu(), None, ib.cpu(), 0])
Mike Ruberryb45f1b92019-10-01 19:17:06 -07005517 self.assertEqual(x[ia], x.cpu()[ia.cpu()])
5518 # test setitem
5519 x_clone1 = x.clone()
5520 x_clone2 = x.clone()
5521 first_shape = x[:, ia, None, ib, 0].shape
5522 second_shape = x[ia].shape
5523 x_clone1[:, ia, None, ib, 0] = torch.randn(first_shape).to(x_clone1)
5524 x_clone2[ia] = torch.randn(second_shape).to(x_clone2)
5525
Shen Li10224432021-08-12 11:39:31 -07005526 cpu = torch.device('cpu')
Mike Ruberryb45f1b92019-10-01 19:17:06 -07005527 for device in devices:
Mike Ruberryb45f1b92019-10-01 19:17:06 -07005528 x = torch.randn(3, 4, 4, 4, 3)
Brian Hirsh7b3a0ff2022-06-10 09:21:51 -07005529 ia = torch.tensor([0, 2, 1])
5530 ib = torch.tensor([0, 2, 1])
Mike Ruberryb45f1b92019-10-01 19:17:06 -07005531
5532 # Index device tensor with cpu tensor
5533 x = x.to(device)
5534 ia = ia.to(cpu)
5535 ib = ib.to(cpu)
5536 test(x, ia, ib)
5537
PyTorch MergeBot4b82ef72022-06-08 20:16:10 +00005538 # Index device tensor with mixed cpu, device tensors
5539 x = x.to(device)
5540 ia = ia.to(cpu)
5541 ib = ib.to(device)
5542 test(x, ia, ib)
Brian Hirshcfd84122022-06-08 07:29:15 -07005543
Brian Hirsh7b3a0ff2022-06-10 09:21:51 -07005544 @deviceCountAtLeast(1)
5545 def test_advancedindex_mixed_devices_error(self, devices) -> None:
5546 def test(x: torch.Tensor, ia: torch.Tensor, ib: torch.Tensor) -> None:
5547 # test getitem
5548 with self.assertRaisesRegex(RuntimeError, fr"indices should be either .* \({x.device}\)"):
5549 value = x[:, ia, None, ib, 0]
5550 with self.assertRaisesRegex(RuntimeError, fr"indices should be either .* \({x.device}\)"):
5551 value = x[ib]
5552
5553 cpu = torch.device('cpu')
5554 for device in devices:
5555 # Index cpu tensor with device tensor
5556 x = torch.randn(3, 4, 4, 4, 3)
5557 ia = torch.tensor([0, 2, 1]).to(device)
5558 ib = torch.tensor([0, 2, 1]).to(device)
5559 test(x, ia, ib)
5560
5561 # Index cpu tensor with mixed cpu, device tensors
5562 x = x.to(cpu)
5563 ia = ia.to(cpu)
5564 ib = ib.to(device)
5565 test(x, ia, ib)
5566
PyTorch MergeBot4b82ef72022-06-08 20:16:10 +00005567 if len(devices) > 1:
Brian Hirsh7b3a0ff2022-06-10 09:21:51 -07005568 other_device = devices[0] if device == devices[1] else devices[1]
5569
Mike Ruberryb45f1b92019-10-01 19:17:06 -07005570 # Index device tensor with mixed cpu, device tensors on different devices
5571 x = x.to(device)
5572 ia = ia.to(cpu)
5573 ib = ib.to(other_device)
5574 test(x, ia, ib)
5575
Mike Ruberrye0d829a2022-01-24 01:28:07 -08005576 # FIXME: move to data movement test suite
Nikita Shulga8811e4d2020-06-04 13:37:44 -07005577 def test_copy_broadcast(self, device) -> None:
Mike Ruberryb45f1b92019-10-01 19:17:06 -07005578 x = torch.randn(10, 5)
5579 y = torch.randn(5, device=device)
5580 x.copy_(y)
5581 self.assertEqual(x[3], y)
5582
5583 x = torch.randn(10, 5, device=device)
5584 y = torch.randn(5)
5585 x.copy_(y)
5586 self.assertEqual(x[3], y)
5587
Mike Ruberrye0d829a2022-01-24 01:28:07 -08005588 # FIXME: move to an elementwise ternary test suite
Peter Bell33eea142021-05-03 12:51:16 -07005589 @dtypes(torch.int64, torch.float32, torch.float64)
5590 def test_clamp(self, device, dtype):
5591 test_args = [
5592 *product(
5593 [(100, 50), (10, 64), (97,)], # shape
5594 (True, False), # non-contiguous
5595 )
5596 ]
5597
5598 for shape, noncontig in test_args:
Shen Li10224432021-08-12 11:39:31 -07005599 x = make_tensor(shape, device=device, dtype=dtype,
5600 noncontiguous=noncontig)
5601 ub = make_tensor(shape, device=device, dtype=dtype,
5602 noncontiguous=noncontig)
5603 lb = make_tensor(shape, device=device, dtype=dtype,
5604 noncontiguous=noncontig)
Peter Bell33eea142021-05-03 12:51:16 -07005605
5606 expect = x.max(lb).min(ub)
5607 actual = x.clamp(lb, ub)
5608 self.assertEqual(expect, actual)
5609
5610 expect = np.clip(x.cpu().numpy(), lb.cpu().numpy(), ub.cpu().numpy())
5611 self.assertEqual(expect, actual)
5612
5613 expect = x.max(lb)
5614 actual = x.clamp(min=lb)
5615 self.assertEqual(expect, actual)
5616
5617 expect = x.min(ub)
5618 actual = x.clamp(max=ub)
5619 self.assertEqual(expect, actual)
5620
5621 # Test broadcasting min & max
5622 expect = x.max(lb[0]).min(ub[..., :1])
5623 actual = x.clamp(lb[0], ub[..., :1])
5624 self.assertEqual(expect, actual)
5625
5626 # Test broadcasting x
5627 expect = x[..., :1].max(lb).min(ub)
5628 actual = x[..., :1].clamp(lb, ub)
5629 self.assertEqual(expect, actual)
5630
anjali411f607af12022-02-02 13:45:07 -08005631 def test_cuda_device_idx(self, device):
5632 x = torch.zeros(3, device=device)
5633 y = torch._efficientzerotensor(3, device=device)
5634 self.assertEqual(x.device, y.device)
Peter Bell33eea142021-05-03 12:51:16 -07005635
Edward Yangf05d5be2021-06-03 10:47:19 -07005636# we implemented custom deallocation for subclasses, so it behooves
5637# us to make sure all of these bits work. We'll use __del__ to
5638# track if objects die or not
5639class Tracker:
5640 def __init__(self, marker):
5641 self.marker = marker
5642
5643 @staticmethod
5644 def make():
5645 marker = [False]
5646 return marker, Tracker(marker)
5647
5648 def __del__(self):
5649 self.marker[0] = True
5650
5651@contextlib.contextmanager
5652def disable_gc():
5653 if gc.isenabled():
5654 try:
5655 gc.disable()
5656 yield
5657 finally:
5658 gc.enable()
5659 else:
5660 yield
5661
Mike Ruberrye0d829a2022-01-24 01:28:07 -08005662class TestTorch(TestCase):
Edward Yangba1bd412020-03-03 14:33:40 -08005663 exact_dtype = True
Tongzhou Wang6d2b3cc2018-11-01 19:04:17 -07005664
Mike Ruberrye0d829a2022-01-24 01:28:07 -08005665 def test_dir(self):
5666 dir(torch)
5667
5668 def test_wildcard_import(self):
5669 exec('from torch import *')
5670
5671 def test_newaxis_numpy_comparison(self):
5672 def run_test(tensor, *idx):
5673 npt = tensor.numpy()
5674 self.assertEqual(tensor[idx], npt[idx])
5675
5676 # 1D Tensor Tests
5677 x = torch.arange(0, 10)
5678 cases = [
5679 [None],
5680 [None, None],
5681 [Ellipsis, None],
5682 [None, Ellipsis],
5683 [2, None],
5684 [None, 2],
5685 [Ellipsis, None, 2],
5686 [Ellipsis, 2, None],
5687 [2, Ellipsis, None],
5688 [2, None, Ellipsis],
5689 [None, 2, Ellipsis],
5690 [None, Ellipsis, 2],
5691 ]
5692
5693 for case in cases:
5694 run_test(x, *case)
5695
5696 # 2D Tensor Tests
5697 x = torch.arange(0, 12).view(3, 4)
5698 cases = [
5699 [None],
5700 [None, None],
5701 [None, None, None],
5702 [Ellipsis, None],
5703 [Ellipsis, None, None],
5704 [None, Ellipsis],
5705 [None, Ellipsis, None],
5706 [None, None, Ellipsis],
5707 [2, None],
5708 [2, None, Ellipsis],
5709 [2, Ellipsis, None],
5710 [None, 2, Ellipsis],
5711 [Ellipsis, 2, None],
5712 [Ellipsis, None, 2],
5713 [None, Ellipsis, 2],
5714 [1, 2, None],
5715 [1, 2, Ellipsis, None],
5716 [1, Ellipsis, 2, None],
5717 [Ellipsis, 1, None, 2],
5718 [Ellipsis, 1, 2, None],
5719 [1, None, 2, Ellipsis],
5720 [None, 1, Ellipsis, 2],
5721 [None, 1, 2, Ellipsis],
5722 ]
5723
5724 for case in cases:
5725 run_test(x, *case)
5726
5727 def _consecutive(self, size, start=1):
5728 sequence = torch.ones(torch.tensor(size).prod(0)).cumsum(0)
5729 sequence.add_(start - 1)
5730 return sequence.resize_(*size)
5731
5732 def test_newindex(self):
5733 reference = self._consecutive((3, 3, 3))
5734 # This relies on __index__() being correct - but we have separate tests for that
5735
5736 def checkPartialAssign(index):
5737 reference = torch.zeros(3, 3, 3)
5738 reference[index] = self._consecutive((3, 3, 3))[index]
5739 self.assertEqual(reference[index], self._consecutive((3, 3, 3))[index], atol=0, rtol=0)
5740 reference[index] = 0
5741 self.assertEqual(reference, torch.zeros(3, 3, 3), atol=0, rtol=0)
5742
5743 checkPartialAssign(0)
5744 checkPartialAssign(1)
5745 checkPartialAssign(2)
5746 checkPartialAssign((0, 1))
5747 checkPartialAssign((1, 2))
5748 checkPartialAssign((0, 2))
5749 checkPartialAssign(torch.LongTensor((0, 2)))
5750
5751 with self.assertRaises(IndexError):
5752 reference[1, 1, 1, 1] = 1
5753 with self.assertRaises(IndexError):
5754 reference[1, 1, 1, (1, 1)] = 1
5755 with self.assertRaises(IndexError):
5756 reference[3, 3, 3, 3, 3, 3, 3, 3] = 1
5757 with self.assertRaises(IndexError):
5758 reference[0.0] = 1
5759 with self.assertRaises(TypeError):
5760 reference[0.0:2.0] = 1
5761 with self.assertRaises(IndexError):
5762 reference[0.0, 0.0:2.0] = 1
5763 with self.assertRaises(IndexError):
5764 reference[0.0, :, 0.0:2.0] = 1
5765 with self.assertRaises(IndexError):
5766 reference[0.0, ..., 0.0:2.0] = 1
5767 with self.assertRaises(IndexError):
5768 reference[0.0, :, 0.0] = 1
5769
5770 # FIXME: move to indexing test suite
5771 def test_index_add(self):
5772 for device in get_all_device_types():
5773 for dest_contig, src_contig, index_contig in product([True, False], repeat=3):
5774 for other_sizes in ((), (4, 5)):
5775 for dtype in [torch.int, torch.long]:
5776 num_copy, num_dest = 3, 3
5777 dest = torch.randn(num_dest, *other_sizes, device=device)
5778 if not dest_contig:
5779 dest = make_tensor(dest.shape, device=device, dtype=dest.dtype, noncontiguous=True)
5780 src = torch.randn(num_copy, *other_sizes, device=device)
5781 if not src_contig:
5782 src = torch.testing.make_non_contiguous(src)
5783 idx = torch.randperm(num_dest, dtype=dtype, device=device).narrow(0, 0, num_copy)
5784 if not index_contig:
5785 idx = torch.testing.make_non_contiguous(idx)
5786 # index_add_ without alpha argument
5787 dest2 = dest.clone()
5788 dest.index_add_(0, idx, src)
5789 for i in range(idx.size(0)):
5790 dest2[idx[i]] += src[i]
5791 self.assertEqual(dest, dest2)
5792 # index_add_ with alpha argument
5793 dest2 = dest.clone()
5794 dest.index_add_(0, idx, src, alpha=2)
5795 for i in range(idx.size(0)):
5796 dest2[idx[i]] += src[i] * 2
5797 self.assertEqual(dest, dest2)
5798
5799 # FIXME: resolve comment below and move this to indexing test suite
5800 # add coverage for issue with atomic add that appeared only for
5801 # specific dtypes on cuda:
5802 # https://github.com/pytorch/pytorch/issues/29153
5803 def test_index_add_all_dtypes(self):
5804 for device in get_all_device_types():
5805 for dtype in get_all_math_dtypes(device):
5806 for idx_dtype in [torch.int, torch.long]:
5807 size = [5, 5]
5808 if dtype.is_floating_point or dtype.is_complex:
5809 tensor = torch.rand(size, dtype=dtype, device=device)
5810 elif dtype.is_signed:
5811 tensor = torch.randint(-5, 15, size, dtype=dtype, device=device)
5812 else:
5813 tensor = torch.randint(0, 10, size, dtype=dtype, device=device)
5814
5815 # index_add calls atomicAdd on cuda.
5816 zeros = torch.zeros(size, dtype=dtype, device=device)
5817
5818 added = zeros.index_add(0, torch.arange(0, size[0], dtype=idx_dtype, device=device), tensor)
5819 self.assertEqual(added, tensor)
5820
5821 added = zeros.index_add(0, torch.arange(0, size[0], dtype=idx_dtype, device=device), tensor, alpha=-1)
5822 self.assertEqual(added, -tensor)
5823
5824 # FIXME: move to shape ops test suite
5825 def test_unflatten(self):
5826 # test args: tensor, int, sizes
5827 self.assertEqual(torch.tensor([]).unflatten(0, (0, 1)), torch.empty(0, 1))
5828 self.assertEqual(torch.tensor([1]).unflatten(0, (1, 1)), torch.tensor([[1]]))
5829 self.assertEqual(torch.tensor([1, 2, 3, 4]).unflatten(0, (2, 2)), torch.tensor([[1, 2], [3, 4]]))
5830 self.assertEqual(torch.tensor([1, 2, 3, 4]).unflatten(0, [2, 2]), torch.tensor([[1, 2], [3, 4]]))
5831 self.assertEqual(torch.tensor([1, 2, 3, 4]).unflatten(0, torch.Size([2, 2])), torch.tensor([[1, 2], [3, 4]]))
5832 self.assertEqual(torch.ones(2, 10).unflatten(1, (5, 2)), torch.ones(2, 5, 2))
5833 self.assertEqual(torch.tensor([1, 2, 3, 4]).unflatten(0, (-1, 2)),
5834 torch.tensor([[1, 2], [3, 4]]))
5835 self.assertEqual(torch.ones(2, 10).unflatten(1, (5, -1)),
5836 torch.ones(2, 5, 2))
5837 self.assertEqual(torch.ones(2, 10).unflatten(1, (-1,)),
5838 torch.ones(2, 10))
5839 self.assertEqual(torch.ones(2, 3 * 4 * 5 * 6).unflatten(1, (3, 4, -1, 6)),
5840 torch.ones(2, 3, 4, 5, 6))
5841 self.assertEqual(torch.ones(2, 0, 2).unflatten(1, (3, -1, 4, 5)),
5842 torch.ones(2, 3, 0, 4, 5, 2))
5843
5844 # test invalid args: tensor, str, sizes
5845 with self.assertRaisesRegex(TypeError, r"received an invalid combination of arguments"):
5846 torch.tensor([1]).unflatten('A', (1, 1))
5847
5848 # test invalid args: tensor, str, namedshape
5849 with self.assertRaisesRegex(RuntimeError, r"Name 'A' not found in Tensor\[None\]."):
5850 torch.ones(4).unflatten('A', (('A', 2), ('B', 2)))
5851
5852 # test other invalid arguments
5853 with self.assertRaisesRegex(RuntimeError, r"sizes must be non-empty"):
5854 torch.tensor([1]).unflatten(0, [])
5855 with self.assertRaisesRegex(RuntimeError, r"Provided sizes \[2, 2\] don't multiply up to the size of dim 0 \(1\)"):
5856 torch.tensor([1]).unflatten(0, [2, 2])
5857 with self.assertRaisesRegex(IndexError, r"dimension specified as 0 but tensor has no dimensions"):
5858 torch.tensor(1).unflatten(0, [0])
5859 with self.assertRaisesRegex(RuntimeError, r"only one dimension can be inferred"):
5860 torch.randn(5, 10).unflatten(1, (-1, -1))
5861 with self.assertRaisesRegex(RuntimeError,
5862 r"Provided sizes \[-1, 4\] don't multiply up to the size of dim 1 \(10\)"):
5863 torch.randn(5, 10).unflatten(1, (-1, 4))
5864 with self.assertRaisesRegex(RuntimeError,
5865 r"the unspecified dimension size -1 can be any value and is ambiguous"):
5866 torch.randn(2, 0).unflatten(1, (2, -1, 0))
5867
Mike Ruberrye0d829a2022-01-24 01:28:07 -08005868 def test_structseq_repr(self):
5869 a = torch.arange(250).reshape(5, 5, 10)
5870 expected = """
5871 torch.return_types.max(
5872 values=tensor([[ 40, 41, 42, 43, 44, 45, 46, 47, 48, 49],
5873 [ 90, 91, 92, 93, 94, 95, 96, 97, 98, 99],
5874 [140, 141, 142, 143, 144, 145, 146, 147, 148, 149],
5875 [190, 191, 192, 193, 194, 195, 196, 197, 198, 199],
5876 [240, 241, 242, 243, 244, 245, 246, 247, 248, 249]]),
5877 indices=tensor([[4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
5878 [4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
5879 [4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
5880 [4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
5881 [4, 4, 4, 4, 4, 4, 4, 4, 4, 4]]))"""
5882 self.assertEqual(repr(a.max(1)), textwrap.dedent(expected).strip())
5883
5884 def test_is_same_size(self):
5885 t1 = torch.empty(3, 4, 9, 10)
5886 t2 = torch.empty(3, 4)
5887 t3 = torch.empty(1, 9, 3, 3)
5888 t4 = torch.empty(3, 4, 9, 10)
5889
5890 self.assertFalse(t1.is_same_size(t2))
5891 self.assertFalse(t1.is_same_size(t3))
5892 self.assertTrue(t1.is_same_size(t4))
5893
drisspgbdcee8f2022-06-10 13:56:40 -07005894 nt1 = torch.nested_tensor([torch.ones(2, 4), torch.ones(3, 4), torch.ones(5, 4)])
5895 nt2 = torch.nested_tensor([torch.ones(2, 4), torch.ones(2, 4), torch.ones(2, 4)])
5896 nt3 = torch.nested_tensor([torch.ones(2, 4, 5), torch.ones(2, 6, 5)])
5897 nt4 = torch.nested_tensor([torch.ones(2, 4), torch.ones(3, 4), torch.ones(5, 4)])
5898
5899 self.assertFalse(nt1.is_same_size(nt2))
5900 self.assertFalse(nt1.is_same_size(nt3))
5901 self.assertTrue(nt1.is_same_size(nt4))
5902 with self.assertRaisesRegex(RuntimeError, "Expected both self and other to be nested tensors."):
5903 t1.is_same_size(nt1)
5904
5905 with self.assertRaisesRegex(RuntimeError, "Expected both self and other to be nested tensors."):
5906 nt1.is_same_size(t1)
5907
Mike Ruberrye0d829a2022-01-24 01:28:07 -08005908 def test_tensor_set(self):
5909 t1 = torch.tensor([])
5910 t2 = torch.empty(3, 4, 9, 10).uniform_()
5911 t1.set_(t2)
5912 self.assertEqual(t1.storage()._cdata, t2.storage()._cdata)
5913 size = torch.Size([9, 3, 4, 10])
5914 t1.set_(t2.storage(), 0, size)
5915 self.assertEqual(t1.size(), size)
5916 t1.set_(t2.storage(), 0, tuple(size))
5917 self.assertEqual(t1.size(), size)
5918 self.assertEqual(t1.stride(), (120, 40, 10, 1))
5919 stride = (10, 360, 90, 1)
5920 t1.set_(t2.storage(), 0, size, stride)
5921 self.assertEqual(t1.stride(), stride)
5922 t1.set_(t2.storage(), 0, size=size, stride=stride)
5923 self.assertEqual(t1.size(), size)
5924 self.assertEqual(t1.stride(), stride)
5925
5926 # test argument names
5927 t1 = torch.tensor([])
5928 # 1. case when source is tensor
5929 t1.set_(source=t2)
5930 self.assertEqual(t1.storage()._cdata, t2.storage()._cdata)
5931 # 2. case when source is storage
5932 t1.set_(source=t2.storage())
5933 self.assertEqual(t1.storage()._cdata, t2.storage()._cdata)
5934 # 3. case when source is storage, and other args also specified
5935 t1.set_(source=t2.storage(), storage_offset=0, size=size, stride=stride)
5936 self.assertEqual(t1.size(), size)
5937 self.assertEqual(t1.stride(), stride)
5938
5939 t1 = torch.tensor([True, True], dtype=torch.bool)
5940 t2 = torch.tensor([False, False], dtype=torch.bool)
5941 t1.set_(t2)
5942 self.assertEqual(t1.storage()._cdata, t2.storage()._cdata)
5943
5944 def test_tensor_set_errors(self):
5945 f_cpu = torch.randn((2, 3), dtype=torch.float32)
5946 d_cpu = torch.randn((2, 3), dtype=torch.float64)
5947
5948 # change dtype
5949 self.assertRaises(RuntimeError, lambda: f_cpu.set_(d_cpu.storage()))
5950 self.assertRaises(RuntimeError,
5951 lambda: f_cpu.set_(d_cpu.storage(), 0, d_cpu.size(), d_cpu.stride()))
5952 self.assertRaises(RuntimeError, lambda: f_cpu.set_(d_cpu))
5953
5954 # change device
5955 if torch.cuda.is_available():
5956 f_cuda = torch.randn((2, 3), dtype=torch.float32, device='cuda')
5957
5958 # cpu -> cuda
5959 self.assertRaises(RuntimeError, lambda: f_cpu.set_(f_cuda.storage()))
5960 self.assertRaises(RuntimeError,
5961 lambda: f_cpu.set_(f_cuda.storage(), 0, f_cuda.size(), f_cuda.stride()))
5962 self.assertRaises(RuntimeError, lambda: f_cpu.set_(f_cuda))
5963
5964 # cuda -> cpu
5965 self.assertRaises(RuntimeError, lambda: f_cuda.set_(f_cpu.storage()))
5966 self.assertRaises(RuntimeError,
5967 lambda: f_cuda.set_(f_cpu.storage(), 0, f_cpu.size(), f_cpu.stride()))
5968 self.assertRaises(RuntimeError, lambda: f_cuda.set_(f_cpu))
5969
5970 # FIXME: move this test test_testing.py (along with allclose testing)
5971 # NOTE: test_equal will be deprecated in favor of torch.testing.assert_close
5972 # once torch.testing is out of beta
5973 def test_equal(self):
5974 # Contiguous, 1D
5975 t1 = torch.tensor((3., 4., 9., 10.))
5976 t2 = t1.contiguous()
5977 t3 = torch.tensor((1., 9., 3., 10.))
5978 t4 = torch.tensor((3., 4., 9.))
5979 t5 = torch.tensor([])
5980 self.assertTrue(t1.equal(t2))
5981 self.assertFalse(t1.equal(t3))
5982 self.assertFalse(t1.equal(t4))
5983 self.assertFalse(t1.equal(t5))
5984 self.assertTrue(torch.equal(t1, t2))
5985 self.assertFalse(torch.equal(t1, t3))
5986 self.assertFalse(torch.equal(t1, t4))
5987 self.assertFalse(torch.equal(t1, t5))
5988
5989 # Non contiguous, 2D
5990 s = torch.tensor(((1, 2, 3, 4), (5, 6, 7, 8)))
5991 s1 = s[:, 1:3]
5992 s2 = s1.clone()
5993 s3 = torch.tensor(((2, 3), (6, 7)))
5994 s4 = torch.tensor(((0, 0), (0, 0)))
5995
5996 self.assertFalse(s1.is_contiguous())
5997 self.assertTrue(s1.equal(s2))
5998 self.assertTrue(s1.equal(s3))
5999 self.assertFalse(s1.equal(s4))
6000 self.assertTrue(torch.equal(s1, s2))
6001 self.assertTrue(torch.equal(s1, s3))
6002 self.assertFalse(torch.equal(s1, s4))
6003
6004 def test_element_size(self):
6005 byte = torch.ByteStorage().element_size()
6006 char = torch.CharStorage().element_size()
6007 short = torch.ShortStorage().element_size()
6008 int = torch.IntStorage().element_size()
6009 long = torch.LongStorage().element_size()
6010 float = torch.FloatStorage().element_size()
6011 double = torch.DoubleStorage().element_size()
6012 bool = torch.BoolStorage().element_size()
6013 bfloat16 = torch.BFloat16Storage().element_size()
6014 complexfloat = torch.ComplexFloatStorage().element_size()
6015 complexdouble = torch.ComplexDoubleStorage().element_size()
6016
6017 self.assertEqual(byte, torch.ByteTensor().element_size())
6018 self.assertEqual(char, torch.CharTensor().element_size())
6019 self.assertEqual(short, torch.ShortTensor().element_size())
6020 self.assertEqual(int, torch.IntTensor().element_size())
6021 self.assertEqual(long, torch.LongTensor().element_size())
6022 self.assertEqual(float, torch.FloatTensor().element_size())
6023 self.assertEqual(double, torch.DoubleTensor().element_size())
6024 self.assertEqual(bool, torch.BoolTensor().element_size())
6025 self.assertEqual(bfloat16, torch.tensor([], dtype=torch.bfloat16).element_size())
6026 self.assertEqual(complexfloat, torch.tensor([], dtype=torch.complex64).element_size())
6027 self.assertEqual(complexdouble, torch.tensor([], dtype=torch.complex128).element_size())
6028
6029 self.assertGreater(byte, 0)
6030 self.assertGreater(char, 0)
6031 self.assertGreater(short, 0)
6032 self.assertGreater(int, 0)
6033 self.assertGreater(long, 0)
6034 self.assertGreater(float, 0)
6035 self.assertGreater(double, 0)
6036 self.assertGreater(bool, 0)
6037 self.assertGreater(bfloat16, 0)
6038 self.assertGreater(complexfloat, 0)
6039 self.assertGreater(complexdouble, 0)
6040
6041 # These tests are portable, not necessarily strict for your system.
6042 self.assertEqual(byte, 1)
6043 self.assertEqual(char, 1)
6044 self.assertEqual(bool, 1)
6045 self.assertGreaterEqual(short, 2)
6046 self.assertGreaterEqual(int, 2)
6047 self.assertGreaterEqual(int, short)
6048 self.assertGreaterEqual(long, 4)
6049 self.assertGreaterEqual(long, int)
6050 self.assertGreaterEqual(double, float)
6051
6052 def test_permute(self):
6053 orig = [1, 2, 3, 4, 5, 6, 7]
6054 perm = torch.randperm(7).tolist()
6055 x = torch.empty(*orig).fill_(0)
6056 new = [i - 1 for i in x.permute(*perm).size()]
6057 self.assertEqual(perm, new)
6058 self.assertEqual(x.size(), orig)
6059
Animesh Jain1d90d6e2022-07-07 18:57:31 +00006060 @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
Mike Ruberrye0d829a2022-01-24 01:28:07 -08006061 def test_reversed(self):
6062 val = torch.arange(0, 10)
6063 self.assertEqual(reversed(val), torch.arange(9, -1, -1))
6064
6065 val = torch.arange(1, 10).view(3, 3)
6066 self.assertEqual(reversed(val), torch.tensor([[7, 8, 9], [4, 5, 6], [1, 2, 3]]))
6067
6068 val = torch.tensor(42)
6069 self.assertEqual(reversed(val), torch.tensor(42))
6070
6071 def test_contains(self):
6072 x = torch.arange(0, 10)
6073 self.assertEqual(4 in x, True)
6074 self.assertEqual(12 in x, False)
6075
6076 x = torch.arange(1, 10).view(3, 3)
6077 val = torch.arange(1, 4)
6078 self.assertEqual(val in x, True)
6079 val += 10
6080 self.assertEqual(val in x, False)
6081
6082 self.assertRaisesRegex(
6083 RuntimeError,
6084 "Tensor.__contains__ only supports Tensor or scalar, but you passed in a {}.".format(type("foo")),
6085 lambda: "foo" in x)
6086 self.assertRaisesRegex(
6087 RuntimeError,
6088 "Tensor.__contains__ only supports Tensor or scalar, but you passed in a {}.".format(type([1, 2])),
6089 lambda: [1, 2] in x)
6090
6091 def test_deepcopy_parameter(self):
6092 from copy import deepcopy
6093 l = torch.nn.Linear(10, 1)
6094 s = l.state_dict(keep_vars=True)
6095 self.assertEqual(torch.nn.Parameter, type(s['weight']))
6096 self.assertEqual(torch.nn.Parameter, type(s['bias']))
6097
6098 s2 = deepcopy(s)
6099 self.assertEqual(torch.nn.Parameter, type(s2['weight']))
6100 self.assertEqual(torch.nn.Parameter, type(s2['bias']))
6101
6102 def test_pickle(self):
6103 import pickle
6104 a = torch.randn(5, 5)
6105 serialized = pickle.dumps(a)
6106 b = pickle.loads(serialized)
6107 self.assertEqual(a, b)
6108
Animesh Jain1d90d6e2022-07-07 18:57:31 +00006109 @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
Mike Ruberrye0d829a2022-01-24 01:28:07 -08006110 def test_pickle_parameter(self):
6111 import pickle
6112 a = torch.nn.Parameter(torch.randn(5, 5))
6113 serialized = pickle.dumps(a)
6114 b = pickle.loads(serialized)
6115 self.assertTrue(isinstance(b, torch.nn.Parameter))
6116 self.assertEqual(a.requires_grad, b.requires_grad)
6117 self.assertEqual(a, b)
6118
Animesh Jain1d90d6e2022-07-07 18:57:31 +00006119 @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
Mike Ruberrye0d829a2022-01-24 01:28:07 -08006120 def test_pickle_parameter_no_requires_grad(self):
6121 import pickle
6122 a = torch.nn.Parameter(torch.randn(5, 5), requires_grad=False)
6123 serialized = pickle.dumps(a)
6124 b = pickle.loads(serialized)
6125 self.assertTrue(isinstance(b, torch.nn.Parameter))
6126 self.assertEqual(a.requires_grad, b.requires_grad)
6127 self.assertEqual(a, b)
6128
6129 def test_pickle_dtype(self):
6130 t = torch.float32
6131 serialized = pickle.dumps(t)
6132 b = pickle.loads(serialized)
6133 self.assertTrue(isinstance(b, torch.dtype))
6134 self.assertEqual(id(b), id(t))
6135
6136 def test_pickle_size(self):
6137 a = torch.rand(10).size()
6138 serialized = pickle.dumps(a)
6139 b = pickle.loads(serialized)
6140 self.assertTrue(isinstance(b, torch.Size))
6141 self.assertEqual(a, b)
6142
6143 def test_pickle_function(self):
6144 # https://github.com/pytorch/pytorch/issues/37703
6145 a = torch.tanh
6146 serialized = pickle.dumps(a)
6147 b = pickle.loads(serialized)
6148 self.assertEqual(a, b)
6149
6150 def test_generator_cpu(self):
6151 # test default generators are equal
6152 self.assertEqual(torch.default_generator, torch.default_generator)
6153
6154 # tests Generator API
6155 # manual_seed, seed, initial_seed, get_state, set_state
6156 g1 = torch.Generator()
6157 g2 = torch.Generator()
6158 g1.manual_seed(12345)
6159 g2.manual_seed(12345)
6160 self.assertEqual(g1.initial_seed(), g2.initial_seed())
6161
6162 g1.seed()
6163 g2.seed()
6164 self.assertNotEqual(g1.initial_seed(), g2.initial_seed())
6165
6166 g1 = torch.Generator()
6167 g2_state = g2.get_state()
6168 g2_randn = torch.randn(1, generator=g2)
6169 g1.set_state(g2_state)
6170 g1_randn = torch.randn(1, generator=g1)
6171 self.assertEqual(g1_randn, g2_randn)
6172
6173 default_state = torch.default_generator.get_state()
6174 q = torch.empty(100)
6175 g1_normal = q.normal_()
6176 g2 = torch.Generator()
6177 g2.set_state(default_state)
6178 g2_normal = q.normal_(generator=g2)
6179 self.assertEqual(g1_normal, g2_normal)
6180
6181 def test_invalid_generator_raises(self):
6182 self.assertRaises(RuntimeError, lambda: torch.Generator('opengl'))
6183
6184 def _sobol_reference_samples(self, scramble: bool) -> torch.Tensor:
6185 if not scramble:
6186 # theoretical values from Joe Kuo 2010
6187 return torch.tensor(
6188 [
6189 [0., 0.],
6190 [0.5, 0.5],
6191 [0.75, 0.25],
6192 [0.25, 0.75],
6193 [0.375, 0.375],
6194 [0.875, 0.875],
6195 [0.625, 0.125],
6196 [0.125, 0.625],
6197 ],
6198 )
6199 else:
6200 # theoretical values unknown: convergence properties checked
6201 return torch.tensor(
6202 [
6203 [0.50860737, 0.29320504],
6204 [0.07116939, 0.89594537],
6205 [0.49354145, 0.11524881],
6206 [0.93097717, 0.70244044],
6207 [0.87266153, 0.23887917],
6208 [0.31021884, 0.57600391],
6209 [0.13687253, 0.42054182],
6210 [0.69931293, 0.77336788],
6211 ],
6212 )
6213
6214 def test_sobolengine_bounds(self, scramble: bool = False):
6215 engine = torch.quasirandom.SobolEngine(100, scramble=scramble, seed=123456)
6216 sample = engine.draw(512)
6217 self.assertTrue(torch.all(sample >= 0))
6218 self.assertTrue(torch.all(sample <= 1))
6219
6220 def test_sobolengine_bounds_scrambled(self):
6221 self.test_sobolengine_bounds(scramble=True)
6222
6223 def test_sobolengine_draw(self, scramble: bool = False):
6224 ref_sample = self._sobol_reference_samples(scramble=scramble)
6225 engine = torch.quasirandom.SobolEngine(2, scramble=scramble, seed=123456)
6226 sample = engine.draw(n=len(ref_sample))
6227 self.assertEqual(sample, ref_sample)
6228 self.assertEqual(engine.num_generated, len(ref_sample))
6229
6230 def test_sobolengine_draw_scrambled(self):
6231 self.test_sobolengine_draw(scramble=True)
6232
6233 def test_sobolengine_first_point(self):
6234 for dtype in (torch.float, torch.double):
6235 engine = torch.quasirandom.SobolEngine(2, scramble=False)
6236 sample = engine.draw(1, dtype=dtype)
6237 self.assertTrue(torch.all(sample == 0))
6238 self.assertEqual(sample.dtype, dtype)
6239 for dtype in (torch.float, torch.double):
6240 engine = torch.quasirandom.SobolEngine(2, scramble=True, seed=123456)
6241 sample = engine.draw(1, dtype=dtype)
6242 self.assertTrue(torch.all(sample != 0))
6243 self.assertEqual(sample.dtype, dtype)
6244
6245 def test_sobolengine_continuing(self, scramble: bool = False):
6246 ref_sample = self._sobol_reference_samples(scramble=scramble)
6247 engine = torch.quasirandom.SobolEngine(2, scramble=scramble, seed=123456)
6248 n_half = len(ref_sample) // 2
6249 _ = engine.draw(n=n_half)
6250 sample = engine.draw(n=n_half)
6251 torch.testing.assert_close(sample, ref_sample[n_half:])
6252
6253 def test_sobolengine_continuing_scrambled(self):
6254 self.test_sobolengine_continuing(scramble=True)
6255
6256 def test_sobolengine_reset(self, scramble: bool = False):
6257 ref_sample = self._sobol_reference_samples(scramble=scramble)
6258 engine = torch.quasirandom.SobolEngine(2, scramble=scramble, seed=123456)
6259 _ = engine.draw(n=len(ref_sample) // 2)
6260 engine.reset()
6261 self.assertEqual(engine.num_generated, 0)
6262 sample = engine.draw(n=len(ref_sample))
6263 torch.testing.assert_close(sample, ref_sample)
6264
6265 def test_sobolengine_reset_scrambled(self):
6266 self.test_sobolengine_reset(scramble=True)
6267
6268 def test_sobolengine_fast_forward(self, scramble: bool = False):
6269 ref_sample = self._sobol_reference_samples(scramble=scramble)
6270 engine = torch.quasirandom.SobolEngine(2, scramble=scramble, seed=123456)
6271 engine.fast_forward(4)
6272 sample = engine.draw(n=4)
6273 torch.testing.assert_close(sample, ref_sample[4:])
6274 # alternate fast forwarding with sampling
6275 engine.reset()
6276 even_draws = []
6277 for i in range(8):
6278 if i % 2 == 0:
6279 even_draws.append(engine.draw())
6280 else:
6281 engine.fast_forward(1)
6282 torch.testing.assert_close(
6283 ref_sample[[i for i in range(8) if i % 2 == 0]],
6284 torch.from_numpy(np.concatenate(even_draws)),
6285 )
6286
6287 def test_sobolengine_fast_forward_scrambled(self):
6288 self.test_sobolengine_fast_forward(scramble=True)
6289
6290 def test_sobolengine_distribution(self, scramble=False):
6291 d = 50
6292 engine = torch.quasirandom.SobolEngine(d, scramble=scramble, seed=123456)
6293 sample = engine.draw(1024)
6294 torch.testing.assert_close(
6295 torch.mean(sample, dim=0), torch.full((d,), 0.5), atol=2, rtol=2
6296 )
6297 torch.testing.assert_close(
6298 np.percentile(sample, 25, axis=0), np.repeat(0.25, d), atol=2, rtol=2
6299 )
6300 torch.testing.assert_close(
6301 np.percentile(sample, 75, axis=0), np.repeat(0.75, d), atol=2, rtol=2
6302 )
6303
6304 def test_sobolengine_distribution_scrambled(self):
6305 self.test_sobolengine_distribution(scramble=True)
6306
6307 def test_sobolengine_draw_base2(self, scramble=False):
6308 ref_sample = self._sobol_reference_samples(scramble=scramble)
6309 engine = torch.quasirandom.SobolEngine(2, scramble=scramble, seed=123456)
6310 sample = engine.draw_base2(2)
6311 self.assertEqual(ref_sample[:4], sample)
6312 # resampling still having N=2**n
6313 sample = engine.draw_base2(2)
6314 self.assertEqual(ref_sample[4:8], sample)
6315
6316 def test_sobolengine_draw_base2_scrambled(self):
6317 self.test_sobolengine_draw_base2(scramble=True)
6318
6319 def test_sobolengine_raise(self):
6320 maxdim = torch.quasirandom.SobolEngine.MAXDIM
6321 with self.assertRaises(ValueError):
6322 torch.quasirandom.SobolEngine(maxdim + 1)
6323
6324 def test_sobolengine_high_dim(self):
6325 engine = torch.quasirandom.SobolEngine(1111, scramble=False, seed=123456)
6326 samples1 = engine.draw()
6327 vals1, counts1 = torch.unique(samples1, return_counts=True)
6328 samples2 = engine.draw()
6329 vals2, counts2 = torch.unique(samples2, return_counts=True)
6330 self.assertEqual(vals1.item(), 0.0)
6331 self.assertEqual(counts1.item(), 1111)
6332 self.assertEqual(vals2.item(), 0.5)
6333 self.assertEqual(counts1.item(), 1111)
6334
6335 def test_parsing_int64(self):
6336 # accepts integer arguments
6337 x = torch.cumsum(torch.ones(5, 5), 0)
6338 self.assertEqual(x, torch.cumsum(torch.ones(5, 5), torch.tensor(0)))
6339 # doesn't accept floating point variables
6340 self.assertRaises(TypeError, lambda: torch.cumsum(torch.ones(5, 5), torch.tensor(0.)))
6341
6342 def test_parsing_double(self):
6343 # accepts floating point and integer arguments
6344 x = torch.randn(2, 3)
6345 torch.isclose(x, x, 1, 1)
6346 self.assertTrue(torch.isclose(x, x, 1, 1).all())
6347 self.assertTrue(torch.isclose(x, x, 1.5, 1.).all())
6348 # accepts floating point and integer tensors
6349 self.assertTrue(torch.isclose(x, x, torch.tensor(1), torch.tensor(1)).all())
6350 self.assertTrue(torch.isclose(x, x, torch.tensor(1.5), torch.tensor(1.)).all())
6351 # doesn't accept variables with requires_grad
6352 self.assertRaises(TypeError,
6353 lambda: torch.isclose(x, x, torch.tensor(1.5), torch.tensor(1., requires_grad=True)).all())
6354
6355 def test_parsing_intlist(self):
6356 # parse with integer variables
6357 self.assertEqual(torch.Size([3, 4]), torch.ones((torch.tensor(3), torch.tensor(4))).shape)
6358 self.assertEqual(torch.Size([3, 4]), torch.ones(torch.tensor(3), torch.tensor(4)).shape)
6359 # parse with numpy integers
6360 self.assertEqual(torch.Size([3, 4]), torch.ones((np.array(3), np.int64(4))).shape)
6361 self.assertEqual(torch.Size([3, 4]), torch.ones(np.array(3), np.int64(4)).shape)
6362 self.assertEqual(torch.Size([3, 4]), torch.ones((np.int64(3), np.array(4))).shape)
6363 self.assertEqual(torch.Size([3, 4]), torch.ones(np.int64(3), np.array(4)).shape)
6364
6365 # fail parse with float variables
6366 self.assertRaises(TypeError, lambda: torch.ones((torch.tensor(3.), torch.tensor(4))))
6367 # fail parse with numpy floats
6368 self.assertRaises(TypeError, lambda: torch.ones((np.float(3.), torch.tensor(4))))
6369 self.assertRaises(TypeError, lambda: torch.ones((np.array(3.), torch.tensor(4))))
6370
6371 # fail parse with > 1 element variables
6372 self.assertRaises(TypeError, lambda: torch.ones(torch.tensor(3, 3)))
6373 self.assertRaises(TypeError, lambda: torch.ones((torch.tensor(3, 3))))
6374 self.assertRaises(TypeError, lambda: torch.ones(np.array(3, 3)))
6375 self.assertRaises(TypeError, lambda: torch.ones((np.array(3, 3))))
6376
6377 # fail parse with additional positional args after intlist arg
6378 self.assertRaisesRegex(TypeError,
6379 "received an invalid combination of arguments",
6380 lambda: torch.LongTensor((6, 0), 1, 1, 0))
6381 self.assertRaisesRegex(TypeError,
6382 "missing 1 required positional arguments",
6383 lambda: torch.tensor().new_zeros((5, 5), 0))
6384
6385 def test_from_buffer(self):
6386 a = bytearray([1, 2, 3, 4])
6387 self.assertEqual(torch.ByteStorage.from_buffer(a).tolist(), [1, 2, 3, 4])
6388 shorts = torch.ShortStorage.from_buffer(a, 'big')
6389 self.assertEqual(shorts.size(), 2)
6390 self.assertEqual(shorts.tolist(), [258, 772])
6391 ints = torch.IntStorage.from_buffer(a, 'little')
6392 self.assertEqual(ints.size(), 1)
6393 self.assertEqual(ints[0], 67305985)
6394 f = bytearray([0x40, 0x10, 0x00, 0x00])
6395 floats = torch.FloatStorage.from_buffer(f, 'big')
6396 self.assertEqual(floats.size(), 1)
6397 self.assertEqual(floats[0], 2.25)
6398
6399 f = bytearray([0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x10, 0x40])
6400 bools = torch.BoolStorage.from_buffer(f, 'big')
6401 self.assertEqual(bools.size(), 8)
6402 self.assertEqual(bools.tolist(), [False, True, True, True, True, True, True, True])
6403 self.assertEqual(bools.type(), 'torch.BoolStorage')
Kurt Mohler79ddc722022-03-22 16:35:42 -07006404 self.assertTrue(isinstance(bools, torch.BoolStorage))
Mike Ruberrye0d829a2022-01-24 01:28:07 -08006405
6406 f = bytearray(b'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9')
6407 bools = torch.BoolStorage.from_buffer(f, 'big')
6408 self.assertEqual(bools.size(), 19)
6409
6410 f = bytearray(b'\0x4A')
6411 bools = torch.BoolStorage.from_buffer(f, 'big')
6412 self.assertEqual(bools.size(), 4)
6413 self.assertEqual(bools.tolist(), [False, True, True, True])
6414 bytes = torch.ByteStorage.from_buffer(a)
6415 self.assertEqual(bytes.nbytes(), 4)
6416 self.assertEqual(bytes.tolist(), [1, 2, 3, 4])
Kurt Mohler79ddc722022-03-22 16:35:42 -07006417 self.assertTrue(isinstance(bytes, torch.ByteStorage))
6418
6419 def test_storage_error(self):
6420 quantized_storages = [
6421 torch.QInt32Storage,
6422 torch.QInt8Storage,
6423 torch.QUInt2x4Storage,
6424 torch.QUInt4x2Storage,
6425 torch.QUInt8Storage,
6426 ]
6427
6428 with self.assertRaisesRegex(RuntimeError, r"Only child classes of _LegacyStorage can be instantiated"):
6429 torch.storage._LegacyStorage()
6430
6431 for storage_class in torch._storage_classes:
Kurt Mohleraea6e2c2022-05-19 13:54:37 +00006432 if storage_class in [torch._UntypedStorage, torch._TypedStorage]:
Kurt Mohler79ddc722022-03-22 16:35:42 -07006433 continue
6434
6435 device = 'cuda' if storage_class.__module__ == 'torch.cuda' else 'cpu'
6436 dtype = storage_class.dtype
6437
6438 if device == 'cuda' and not torch.cuda.is_available():
6439 continue
6440
6441 # Legacy <type>Storage constructor errors
6442 with self.assertRaisesRegex(RuntimeError, r"'device' cannot be specified"):
6443 storage_class(device='cpu')
6444
6445 with self.assertRaisesRegex(RuntimeError, r"'dtype' cannot be specified"):
6446 storage_class(dtype=torch.float)
6447
6448 with self.assertRaisesRegex(TypeError, r"got an unexpected keyword"):
6449 storage_class(sdlkjf=torch.float)
6450
6451 with self.assertRaisesRegex(RuntimeError, r"Too many positional arguments"):
6452 storage_class(0, 0)
6453
6454 with self.assertRaisesRegex(TypeError, r"invalid data type"):
6455 storage_class('string')
6456
6457 with self.assertRaisesRegex(TypeError, r"Argument type not recognized"):
6458 storage_class(torch.tensor([]))
6459
6460 s = storage_class()
6461
6462 with self.assertRaisesRegex(RuntimeError, r"No positional arguments"):
6463 storage_class(0, wrap_storage=s._untyped())
6464
6465 with self.assertRaisesRegex(TypeError, r"must be _UntypedStorage"):
6466 storage_class(wrap_storage=s)
6467
6468 if torch.cuda.is_available():
6469 if storage_class in quantized_storages:
6470 with self.assertRaisesRegex(RuntimeError, r"Cannot create CUDA storage with quantized dtype"):
6471 s.cuda()
6472
6473 else:
6474
6475 if s.is_cuda:
6476 s_other_device = s.cpu()
6477 else:
6478 s_other_device = s.cuda()
6479
6480 with self.assertRaisesRegex(RuntimeError, r"Device of 'wrap_storage' must be"):
6481 storage_class(wrap_storage=s_other_device._untyped())
6482
6483 # _TypedStorage constructor errors
6484 with self.assertRaisesRegex(RuntimeError, r"No positional arguments"):
6485 torch._TypedStorage(0, wrap_storage=s._untyped(), dtype=dtype)
6486
6487 with self.assertRaisesRegex(RuntimeError, r"Argument 'dtype' must be specified"):
6488 torch._TypedStorage(wrap_storage=s._untyped())
6489
6490 with self.assertRaisesRegex(TypeError, r"Argument 'dtype' must be torch.dtype"):
6491 torch._TypedStorage(wrap_storage=s._untyped(), dtype=0)
6492
6493 with self.assertRaisesRegex(RuntimeError, r"Argument 'device' should not be specified"):
6494 torch._TypedStorage(wrap_storage=s._untyped(), dtype=dtype, device=device)
6495
6496 with self.assertRaisesRegex(TypeError, r"Argument 'wrap_storage' must be _UntypedStorage"):
6497 torch._TypedStorage(wrap_storage=s, dtype=dtype)
6498
6499 with self.assertRaisesRegex(RuntimeError, r"Storage device not recognized"):
6500 torch._TypedStorage(dtype=dtype, device='xla')
6501
6502 if torch.cuda.is_available():
6503 if storage_class in quantized_storages:
6504 with self.assertRaisesRegex(RuntimeError, r"Cannot create CUDA storage with quantized dtype"):
6505 torch._TypedStorage(dtype=dtype, device='cuda')
6506
6507 with self.assertRaisesRegex(TypeError, r"Argument type not recognized"):
6508 torch._TypedStorage(torch.tensor([]), dtype=dtype, device=device)
6509
6510 with self.assertRaisesRegex(RuntimeError, r"Too many positional arguments"):
6511 torch._TypedStorage(0, 0, dtype=dtype, device=device)
6512
Kurt Mohlere9afb432022-05-28 15:33:45 +00006513 if isinstance(s, torch._TypedStorage):
6514 s_other = torch._TypedStorage([1, 2, 3, 4], device=device, dtype=dtype)
6515
6516 with self.assertRaisesRegex(RuntimeError, r'cannot set item'):
6517 s.fill_(s_other)
6518
Kurt Mohler79ddc722022-03-22 16:35:42 -07006519 def test_storage_error_no_attribute(self):
6520 storage_classes = [
6521 torch.cuda.ByteStorage,
6522 torch.cuda.FloatStorage,
Kurt Mohler79ddc722022-03-22 16:35:42 -07006523 ]
6524 for storage_class in storage_classes:
6525 with self.assertRaisesRegex(RuntimeError, r'Not available for CUDA storage'):
6526 storage_class.from_buffer()
6527
Kurt Mohleraea6e2c2022-05-19 13:54:37 +00006528 with self.assertRaisesRegex(RuntimeError, r'Not available for CUDA storage'):
6529 storage_class._new_with_weak_ptr()
Kurt Mohler79ddc722022-03-22 16:35:42 -07006530
6531 with self.assertRaisesRegex(RuntimeError, r'Not available for CUDA storage'):
Kurt Mohlercecb2ad2022-05-20 02:03:34 +00006532 storage_class._new_shared_filename(0, 0, 0)
Mike Ruberrye0d829a2022-01-24 01:28:07 -08006533
6534 def test_storage_casts(self):
6535 storage = torch.IntStorage([-1, 0, 1, 2, 3, 4])
6536 self.assertEqual(storage.size(), 6)
6537 self.assertEqual(storage.tolist(), [-1, 0, 1, 2, 3, 4])
6538 self.assertEqual(storage.type(), 'torch.IntStorage')
6539 self.assertIs(storage.dtype, torch.int32)
6540
6541 floatStorage = storage.float()
6542 self.assertEqual(floatStorage.size(), 6)
6543 self.assertEqual(floatStorage.tolist(), [-1, 0, 1, 2, 3, 4])
6544 self.assertEqual(floatStorage.type(), 'torch.FloatStorage')
6545 self.assertEqual(floatStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
6546 self.assertIs(floatStorage.dtype, torch.float32)
6547
6548 halfStorage = storage.half()
6549 self.assertEqual(halfStorage.size(), 6)
6550 self.assertEqual(halfStorage.tolist(), [-1, 0, 1, 2, 3, 4])
6551 self.assertEqual(halfStorage.type(), 'torch.HalfStorage')
6552 self.assertEqual(halfStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
6553 self.assertIs(halfStorage.dtype, torch.float16)
6554
6555 bfloat16Storage = storage.bfloat16()
6556 self.assertEqual(bfloat16Storage.size(), 6)
6557 self.assertEqual(bfloat16Storage.tolist(), [-1, 0, 1, 2, 3, 4])
6558 self.assertEqual(bfloat16Storage.type(), 'torch.BFloat16Storage')
6559 self.assertEqual(bfloat16Storage.int().tolist(), [-1, 0, 1, 2, 3, 4])
6560 self.assertIs(bfloat16Storage.dtype, torch.bfloat16)
6561
6562 longStorage = storage.long()
6563 self.assertEqual(longStorage.size(), 6)
6564 self.assertEqual(longStorage.tolist(), [-1, 0, 1, 2, 3, 4])
6565 self.assertEqual(longStorage.type(), 'torch.LongStorage')
6566 self.assertEqual(longStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
6567 self.assertIs(longStorage.dtype, torch.int64)
6568
6569 shortStorage = storage.short()
6570 self.assertEqual(shortStorage.size(), 6)
6571 self.assertEqual(shortStorage.tolist(), [-1, 0, 1, 2, 3, 4])
6572 self.assertEqual(shortStorage.type(), 'torch.ShortStorage')
6573 self.assertEqual(shortStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
6574 self.assertIs(shortStorage.dtype, torch.int16)
6575
6576 doubleStorage = storage.double()
6577 self.assertEqual(doubleStorage.size(), 6)
6578 self.assertEqual(doubleStorage.tolist(), [-1.0, 0.0, 1.0, 2.0, 3.0, 4.0])
6579 self.assertEqual(doubleStorage.type(), 'torch.DoubleStorage')
6580 self.assertEqual(doubleStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
6581 self.assertIs(doubleStorage.dtype, torch.float64)
6582
6583 charStorage = storage.char()
6584 self.assertEqual(charStorage.size(), 6)
6585 self.assertEqual(charStorage.tolist(), [-1.0, 0.0, 1.0, 2.0, 3.0, 4.0])
6586 self.assertEqual(charStorage.type(), 'torch.CharStorage')
6587 self.assertEqual(charStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
6588 self.assertIs(charStorage.dtype, torch.int8)
6589
6590 byteStorage = storage.byte()
6591 self.assertEqual(byteStorage.size(), 6)
6592 self.assertEqual(byteStorage.tolist(), [255, 0, 1, 2, 3, 4])
6593 self.assertEqual(byteStorage.type(), 'torch.ByteStorage')
6594 self.assertEqual(byteStorage.int().tolist(), [255, 0, 1, 2, 3, 4])
6595 self.assertIs(byteStorage.dtype, torch.uint8)
6596
6597 boolStorage = storage.bool()
6598 self.assertEqual(boolStorage.size(), 6)
6599 self.assertEqual(boolStorage.tolist(), [True, False, True, True, True, True])
6600 self.assertEqual(boolStorage.type(), 'torch.BoolStorage')
6601 self.assertEqual(boolStorage.int().tolist(), [1, 0, 1, 1, 1, 1])
6602 self.assertIs(boolStorage.dtype, torch.bool)
6603
6604 complexfloat_storage = torch.ComplexFloatStorage([-1, 0, 1 + 2j, 2.5j, 3.5, 4 - 2j])
6605 self.assertEqual(complexfloat_storage.size(), 6)
6606 self.assertEqual(complexfloat_storage.tolist(), [-1, 0, 1 + 2j, 2.5j, 3.5, 4 - 2j])
6607 self.assertEqual(complexfloat_storage.type(), 'torch.ComplexFloatStorage')
6608 self.assertIs(complexfloat_storage.dtype, torch.complex64)
6609
6610 complexdouble_storage = complexfloat_storage.complex_double()
6611 self.assertEqual(complexdouble_storage.size(), 6)
6612 self.assertEqual(complexdouble_storage.tolist(), [-1, 0, 1 + 2j, 2.5j, 3.5, 4 - 2j])
6613 self.assertEqual(complexdouble_storage.type(), 'torch.ComplexDoubleStorage')
6614 self.assertIs(complexdouble_storage.dtype, torch.complex128)
6615
6616 def test_from_file(self):
6617 def assert_with_filename(filename):
6618 size = 10000
6619 s1 = torch.FloatStorage.from_file(filename, True, size)
6620 t1 = torch.FloatTensor(s1).copy_(torch.randn(size))
6621 self.assertEqual(s1.data_ptr(), torch.FloatTensor(s1).data_ptr())
6622
6623 # check mapping
6624 s2 = torch.FloatStorage.from_file(filename, True, size)
6625 t2 = torch.FloatTensor(s2)
6626 self.assertEqual(t1, t2, atol=0, rtol=0)
6627
6628 # check changes to t1 from t2
6629 rnum = random.uniform(-1, 1)
6630 t1.fill_(rnum)
6631 self.assertEqual(t1, t2, atol=0, rtol=0)
6632
6633 # check changes to t2 from t1
6634 rnum = random.uniform(-1, 1)
6635 t2.fill_(rnum)
6636 self.assertEqual(t1, t2, atol=0, rtol=0)
6637
6638 # release the tensors
6639 del s1, t1, s2, t2
6640
6641 with TemporaryFileName() as fname:
6642 assert_with_filename(fname)
6643
6644 if IS_FILESYSTEM_UTF8_ENCODING:
6645 with TemporaryDirectoryName(suffix='中文') as dname, TemporaryFileName(dir=dname) as fname:
6646 assert_with_filename(fname)
6647
6648 def test_torch_from_file(self):
6649 def assert_with_filename(filename):
6650 size = 10000
6651 s1 = torch.from_file(filename, True, size, dtype=torch.float)
6652 t1 = torch.FloatTensor(s1).copy_(torch.randn(size))
6653
6654 # check mapping
6655 s2 = torch.from_file(filename, True, size, dtype=torch.float)
6656 t2 = torch.FloatTensor(s2)
6657 self.assertEqual(t1, t2, atol=0, rtol=0)
6658
6659 # check changes to t1 from t2
6660 rnum = random.uniform(-1, 1)
6661 t1.fill_(rnum)
6662 self.assertEqual(t1, t2, atol=0, rtol=0)
6663
6664 # check changes to t2 from t1
6665 rnum = random.uniform(-1, 1)
6666 t2.fill_(rnum)
6667 self.assertEqual(t1, t2, atol=0, rtol=0)
6668
6669 # release the tensors
6670 del s1, t1, s2, t2
6671
6672 with TemporaryFileName() as fname:
6673 assert_with_filename(fname)
6674
6675 if IS_FILESYSTEM_UTF8_ENCODING:
6676 with TemporaryDirectoryName(suffix='中文') as dname, TemporaryFileName(dir=dname) as fname:
6677 assert_with_filename(fname)
6678
6679 def test_print(self):
6680 default_type = torch.tensor([]).type()
6681 for t in torch._tensor_classes:
6682 if t == torch.HalfTensor:
6683 continue # HalfTensor does not support fill
6684 if t.is_sparse:
6685 continue
6686 if t.is_cuda and not torch.cuda.is_available():
6687 continue
6688 obj = t(100, 100).fill_(1)
6689 obj.__repr__()
6690 str(obj)
6691 # test half tensor
6692 obj = torch.rand(100, 100, device='cpu').half()
6693 obj.__repr__()
6694 str(obj)
6695 for t in torch._storage_classes:
6696 if t == torch.BFloat16Storage:
6697 continue # Fix once fill is enabled for bfloat16
6698 if t.is_cuda and not torch.cuda.is_available():
6699 continue
6700 if t == torch.BoolStorage or t == torch.cuda.BoolStorage:
6701 obj = t(100).fill_(True)
6702 else:
6703 obj = t(100).fill_(1)
6704 obj.__repr__()
6705 str(obj)
6706
6707 # test complex tensor
6708 # complex tensor print uses two formatters, one for real values
6709 # and the other for imag values. this is consistent with numpy
6710 x = torch.tensor([2.3 + 4j, 7 + 6j])
6711 self.assertEqual(x.__repr__(), str(x))
6712 self.assertExpectedInline(str(x), '''tensor([2.3000+4.j, 7.0000+6.j])''')
6713
kshitij12345e36d25f2022-05-01 12:46:09 +00006714 # test complex half tensor
6715 x = torch.tensor([1.25 + 4j, -7. + 6j], dtype=torch.chalf)
6716 self.assertEqual(x.__repr__(), str(x))
6717 self.assertExpectedInline(str(x), '''tensor([ 1.2500+4.j, -7.0000+6.j], dtype=torch.complex32)''')
6718
Mike Ruberrye0d829a2022-01-24 01:28:07 -08006719 # test scientific notation for complex tensors
6720 x = torch.tensor([1e28 + 2j , -1e-28j])
6721 self.assertEqual(x.__repr__(), str(x))
6722 self.assertExpectedInline(str(x), '''tensor([1.0000e+28+2.0000e+00j, -0.0000e+00-1.0000e-28j])''')
6723
6724 # test big integer
6725 x = torch.tensor(2341234123412341)
6726 self.assertEqual(x.__repr__(), str(x))
6727 self.assertExpectedInline(str(x), '''tensor(2341234123412341)''')
6728
6729 # test scientific notation
6730 x = torch.tensor([1e28, 1e-28])
6731 self.assertEqual(x.__repr__(), str(x))
6732 self.assertExpectedInline(str(x), '''tensor([1.0000e+28, 1.0000e-28])''')
6733
6734 # test scientific notation using set_printoptions
6735 x = torch.tensor([1e2, 1e-2])
6736 torch.set_printoptions(sci_mode=True)
6737 self.assertEqual(x.__repr__(), str(x))
6738 self.assertExpectedInline(str(x), '''tensor([1.0000e+02, 1.0000e-02])''')
6739 torch.set_printoptions(sci_mode=False)
6740 self.assertEqual(x.__repr__(), str(x))
6741 self.assertExpectedInline(str(x), '''tensor([ 100.0000, 0.0100])''')
6742 torch.set_printoptions(sci_mode=None) # reset to the default value
6743
6744 # test no leading space if all elements positive
6745 x = torch.tensor([1, 2])
6746 self.assertEqual(x.__repr__(), str(x))
6747 self.assertExpectedInline(str(x), '''tensor([1, 2])''')
6748
6749 # test for leading space if there are negative elements
6750 x = torch.tensor([1, -2])
6751 self.assertEqual(x.__repr__(), str(x))
6752 self.assertExpectedInline(str(x), '''tensor([ 1, -2])''')
6753
6754 # test inf and nan
6755 x = torch.tensor([4, inf, 1.5, -inf, 0, nan, 1])
6756 self.assertEqual(x.__repr__(), str(x))
6757 self.assertExpectedInline(str(x), '''tensor([4.0000, inf, 1.5000, -inf, 0.0000, nan, 1.0000])''')
6758
6759 y = torch.tensor([4, inf, complex(1.5, inf), complex(-inf, 4), 0, complex(nan, inf), complex(3, nan)])
6760 self.assertEqual(y.__repr__(), str(y))
6761 expected_str = '''\
6762tensor([4.0000+0.j, inf+0.j, 1.5000+infj, -inf+4.j, 0.0000+0.j, nan+infj,
6763 3.0000+nanj])'''
6764 self.assertExpectedInline(str(y), expected_str)
6765
6766 # test dtype
6767 torch.set_default_dtype(torch.float)
6768 x = torch.tensor([1e-324, 1e-323, 1e-322, 1e307, 1e308, 1e309], dtype=torch.float64)
6769 self.assertEqual(x.__repr__(), str(x))
6770 expected_str = '''\
6771tensor([ 0.0000e+00, 9.8813e-324, 9.8813e-323, 1.0000e+307, 1.0000e+308,
6772 inf], dtype=torch.float64)'''
6773 self.assertExpectedInline(str(x), expected_str)
6774
6775 # test changing default dtype
6776 torch.set_default_dtype(torch.float64)
6777 self.assertEqual(x.__repr__(), str(x))
6778 expected_str = '''\
6779tensor([ 0.0000e+00, 9.8813e-324, 9.8813e-323, 1.0000e+307, 1.0000e+308,
6780 inf])'''
6781 self.assertExpectedInline(str(x), expected_str)
6782
6783 # test summary
6784 x = torch.zeros(10000)
6785 self.assertEqual(x.__repr__(), str(x))
6786 self.assertExpectedInline(str(x), '''tensor([0., 0., 0., ..., 0., 0., 0.])''')
6787
6788 # test internal summary function
6789 x = torch.rand(1, 20, 5, 30)
6790 summary = torch._tensor_str.get_summarized_data(x)
6791 self.assertEqual(summary.shape, (1, 6, 5, 6))
6792 first_and_last = [0, 1, 2, -3, -2, -1]
6793 self.assertEqual(summary, x[:, first_and_last][..., first_and_last])
6794
6795 # test device
6796 if torch.cuda.is_available():
6797 x = torch.tensor([123], device='cuda:0')
6798 self.assertEqual(x.__repr__(), str(x))
6799 self.assertExpectedInline(str(x), '''tensor([123], device='cuda:0')''')
6800
6801 # test changing default to cuda
6802 torch.set_default_tensor_type(torch.cuda.FloatTensor)
6803 self.assertEqual(x.__repr__(), str(x))
6804 self.assertExpectedInline(str(x), '''tensor([123])''')
6805
6806 # test printing a tensor on a different gpu than current one.
6807 if torch.cuda.device_count() >= 2:
6808 with torch.cuda.device(1):
6809 self.assertEqual(x.__repr__(), str(x))
6810 self.assertExpectedInline(str(x), '''tensor([123], device='cuda:0')''')
6811
6812 # test printing cpu tensor when default device is cuda
6813 y = torch.tensor([123], device='cpu')
6814 self.assertEqual(y.__repr__(), str(y))
6815 self.assertExpectedInline(str(y), '''tensor([123], device='cpu')''')
6816 torch.set_default_tensor_type(default_type)
6817
6818
6819 # test integral floats and requires_grad
6820 x = torch.tensor([123.], requires_grad=True)
6821 self.assertEqual(x.__repr__(), str(x))
6822 self.assertExpectedInline(str(x), '''tensor([123.], requires_grad=True)''')
6823
6824 # test non-contiguous print
6825 # sliced tensor should have > PRINT_OPTS.threshold elements
6826 x = torch.ones(100, 2, 2, 10)
6827 y = x.as_strided(size=(100, 2, 10), stride=(2 * 2 * 10, 2 * 10, 1))
6828 self.assertEqual(str(y), y.__repr__())
6829 expected_str = '''\
6830tensor([[[1., 1., 1., ..., 1., 1., 1.],
6831 [1., 1., 1., ..., 1., 1., 1.]],
6832
6833 [[1., 1., 1., ..., 1., 1., 1.],
6834 [1., 1., 1., ..., 1., 1., 1.]],
6835
6836 [[1., 1., 1., ..., 1., 1., 1.],
6837 [1., 1., 1., ..., 1., 1., 1.]],
6838
6839 ...,
6840
6841 [[1., 1., 1., ..., 1., 1., 1.],
6842 [1., 1., 1., ..., 1., 1., 1.]],
6843
6844 [[1., 1., 1., ..., 1., 1., 1.],
6845 [1., 1., 1., ..., 1., 1., 1.]],
6846
6847 [[1., 1., 1., ..., 1., 1., 1.],
6848 [1., 1., 1., ..., 1., 1., 1.]]])\
6849'''
6850
6851 self.assertExpectedInline(str(y), expected_str)
6852
6853 x = torch.ones(100, 2, 2, 10) * (1 + 1j)
6854 y = x.as_strided(size=(100, 2, 10), stride=(2 * 2 * 10, 2 * 10, 1))
6855 self.assertEqual(str(y), y.__repr__())
6856 expected_str = '''\
6857tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
6858 [1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j]],
6859
6860 [[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
6861 [1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j]],
6862
6863 [[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
6864 [1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j]],
6865
6866 ...,
6867
6868 [[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
6869 [1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j]],
6870
6871 [[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
6872 [1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j]],
6873
6874 [[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
6875 [1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j]]])\
6876'''
6877 self.assertExpectedInline(str(y), expected_str)
6878
6879 # test print 0-dim tensor: there's no 0-dim in Numpy, we match arrayprint style
6880 x = torch.tensor(0.00002)
6881 self.assertEqual(x.__repr__(), str(x))
6882 self.assertExpectedInline(str(x), '''tensor(2.0000e-05)''')
6883
6884 # test print boolean tensor
6885 x = torch.tensor([True])
6886 self.assertEqual(x.__repr__(), str(x))
6887 self.assertExpectedInline(str(x), '''tensor([True])''')
6888
6889 x = torch.tensor(True)
6890 self.assertEqual(x.__repr__(), str(x))
6891 self.assertExpectedInline(str(x), '''tensor(True)''')
6892
6893 # [Numpy] test print float in sci_mode when min < 0.0001.
6894 x = torch.tensor([0.00002])
6895 self.assertEqual(x.__repr__(), str(x))
6896 self.assertExpectedInline(str(x), '''tensor([2.0000e-05])''')
6897
6898 # [Numpy] test print complex in sci_mode when real_min < 0.0001 and (or) imag_min < 0.0001.
6899 x = torch.tensor([0.00002]) * (1 + 1j)
6900 self.assertEqual(x.__repr__(), str(x))
6901 self.assertExpectedInline(str(x), '''tensor([2.0000e-05+2.0000e-05j])''')
6902
6903 # [Numpy] test print float in sci_mode when max > 1e8.
6904 # TODO: Pytorch uses fixed precision to print, while Numpy uses dragon4_scientific
6905 # to do automatic trimming and padding.
6906 x = torch.tensor([123456789.])
6907 self.assertEqual(x.__repr__(), str(x))
6908 self.assertExpectedInline(str(x), '''tensor([1.2346e+08])''')
6909
6910 # [Numpy] test print float in sci_mode when max / min > 1000.
6911 x = torch.tensor([0.01, 11])
6912 self.assertEqual(x.__repr__(), str(x))
6913 self.assertExpectedInline(str(x), '''tensor([1.0000e-02, 1.1000e+01])''')
6914
6915 # [Numpy] test print int max / min > 1000, no sci_mode
6916 x = torch.tensor([1, 1010])
6917 self.assertEqual(x.__repr__(), str(x))
6918 self.assertExpectedInline(str(x), '''tensor([ 1, 1010])''')
6919
6920 # [Numpy] test print int > 1e8, no sci_mode
6921 x = torch.tensor([1000000000]) # 1e9
6922 self.assertEqual(x.__repr__(), str(x))
6923 self.assertExpectedInline(str(x), '''tensor([1000000000])''')
6924
6925 # [Numpy] test printing float in int_mode
6926 x = torch.tensor([1., 1000.])
6927 self.assertEqual(x.__repr__(), str(x))
6928 self.assertExpectedInline(str(x), '''tensor([ 1., 1000.])''')
6929
6930 # [Numpy] test printing float in int_mode in sci format when max / min > 1000.
6931 x = torch.tensor([1., 1010.])
6932 self.assertEqual(x.__repr__(), str(x))
6933 self.assertExpectedInline(str(x), '''tensor([1.0000e+00, 1.0100e+03])''')
6934
6935 def test_sizeof(self) -> None:
6936 sizeof_empty = torch.randn(0).storage().__sizeof__()
6937 sizeof_10 = torch.randn(10).storage().__sizeof__()
6938 sizeof_100 = torch.randn(100).storage().__sizeof__()
6939 self.assertEqual((sizeof_100 - sizeof_empty) // (sizeof_10 - sizeof_empty), 10)
6940 self.assertEqual((sizeof_100 - sizeof_empty) % (sizeof_10 - sizeof_empty), 0)
6941
6942 sizeof_empty = torch.randn(0).to(torch.uint8).storage().__sizeof__()
6943 sizeof_10 = torch.randn(10).to(torch.uint8).storage().__sizeof__()
6944 sizeof_100 = torch.randn(100).to(torch.uint8).storage().__sizeof__()
6945 self.assertEqual((sizeof_100 - sizeof_empty) // (sizeof_10 - sizeof_empty), 10)
6946 self.assertEqual((sizeof_100 - sizeof_empty) % (sizeof_10 - sizeof_empty), 0)
6947
6948 def test_iter(self) -> None:
6949 x = torch.randn(5, 5)
6950 for i, sub in enumerate(x):
6951 self.assertEqual(sub, x[i])
6952
6953 x = torch.tensor([])
6954 self.assertEqual(list(x), [])
6955
6956 def test_new(self) -> None:
6957 x = torch.autograd.Variable(torch.tensor([]))
6958 y = torch.autograd.Variable(torch.randn(4, 4))
6959 z = torch.autograd.Variable(torch.IntTensor([1, 2, 3]))
6960 self.assertEqual(x.new().shape, [0])
6961 self.assertEqual(x.new(), x)
6962 self.assertEqual(x.new(1, 2).shape, [1, 2])
6963 self.assertEqual(x.new(torch.Size([3, 4])).shape, [3, 4])
6964 self.assertEqual(x.new([3, 4]).shape, [2])
6965 self.assertEqual(x.new([3, 4]).tolist(), [3, 4])
6966 self.assertEqual(x.new((3, 4)).tolist(), [3, 4])
6967 self.assertEqual(x.new([np.int32(3), np.float64(4)]).tolist(), [3, 4])
6968 self.assertEqual(x.new(np.array((3, 4))).tolist(), [3, 4])
6969 self.assertEqual(x.new([z[2], z[0] + 3]).tolist(), [3, 4])
6970 self.assertEqual(x.new(size=(3, 4)).shape, [3, 4])
6971 self.assertEqual(x.new(()).shape, [0])
6972 self.assertEqual(x.new(y.storage()).data_ptr(), y.data_ptr())
6973 self.assertEqual(x.new(y).data_ptr(), y.data_ptr())
6974 self.assertIsNot(x.new(y), y)
6975
6976 self.assertRaises(TypeError, lambda: x.new(z))
6977 # TypeError would be better
6978 self.assertRaises(RuntimeError, lambda: x.new(z.storage()))
6979
6980 @unittest.skipIf(PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property")
6981 def test_pin_memory(self):
6982 x = torch.randn(3, 5)
6983 self.assertFalse(x.is_pinned())
6984 if not torch.cuda.is_available():
6985 self.assertRaises(RuntimeError, lambda: x.pin_memory())
6986 else:
6987 pinned = x.pin_memory()
6988 self.assertTrue(pinned.is_pinned())
6989 self.assertEqual(pinned, x)
6990 self.assertNotEqual(pinned.data_ptr(), x.data_ptr())
6991 # test that pin_memory on already pinned tensor has no effect
6992 self.assertIs(pinned, pinned.pin_memory())
6993 self.assertEqual(pinned.data_ptr(), pinned.pin_memory().data_ptr())
6994
6995 def test_error_msg_type_translation(self):
6996 with self.assertRaisesRegex(
6997 RuntimeError,
6998 # message includes both Double and Long
6999 '(?=.*Double)(?=.*Long)'):
7000
7001 # Calls model with a LongTensor input but DoubleTensor weights
7002 input = torch.zeros(1, 1, 1, 6, dtype=torch.long)
7003 weight = torch.nn.Parameter(torch.zeros(1, 1, 1, 3, dtype=torch.double))
7004 model = torch.nn.Conv2d(1, 1, (1, 3), stride=1, padding=0, bias=False)
7005 model.weight = weight
7006 out = model(input)
7007
7008 def test_apply(self):
7009 x = torch.arange(1, 6)
7010 res = x.clone().apply_(lambda k: k + k)
7011 self.assertEqual(res, x * 2)
7012 self.assertRaises(TypeError, lambda: x.apply_(lambda k: "str"))
7013
7014 def test_map(self):
7015 x = torch.autograd.Variable(torch.randn(3, 3))
7016 y = torch.autograd.Variable(torch.randn(3))
7017 res = x.clone()
7018 res.map_(y, lambda a, b: a + b)
7019 self.assertEqual(res, x + y)
7020 self.assertRaisesRegex(TypeError, "not callable", lambda: res.map_(y, "str"))
7021
7022 def test_map2(self):
7023 x = torch.autograd.Variable(torch.randn(3, 3))
7024 y = torch.autograd.Variable(torch.randn(3))
7025 z = torch.autograd.Variable(torch.randn(1, 3))
7026 res = x.clone()
7027 res.map2_(y, z, lambda a, b, c: a + b * c)
7028 self.assertEqual(res, x + y * z)
7029 z.requires_grad = True
7030 self.assertRaisesRegex(
7031 RuntimeError, "requires grad",
7032 lambda: res.map2_(y, z, lambda a, b, c: a + b * c))
7033
7034 def test_Size(self):
7035 x = torch.Size([1, 2, 3])
7036 self.assertIsInstance(x, tuple)
7037 self.assertEqual(x[0], 1)
7038 self.assertEqual(x[1], 2)
7039 self.assertEqual(x[2], 3)
7040 self.assertEqual(len(x), 3)
7041 self.assertRaises(TypeError, lambda: torch.Size(torch.ones(3)))
7042
7043 self.assertIsInstance(x * 2, torch.Size)
7044 self.assertIsInstance(x[:-1], torch.Size)
7045 self.assertIsInstance(x + x, torch.Size)
7046
7047 def test_Size_scalar(self):
7048 three = torch.tensor(3)
7049 two = torch.tensor(2)
7050 x = torch.Size([0, 1, two, three, 4])
7051 for i in range(1, 5):
7052 self.assertEqual(x[i], i)
7053
7054 def test_Size_iter(self):
7055 for sizes in [iter([1, 2, 3, 4, 5]), range(1, 6)]:
7056 x = torch.Size(sizes)
7057 for i in range(0, 5):
7058 self.assertEqual(x[i], i + 1)
7059
7060 def test_t_not_2d_error(self):
7061 self.assertRaises(RuntimeError, lambda: torch.randn(2, 3, 4).t())
7062 self.assertRaises(RuntimeError, lambda: torch.randn(2, 3, 4).t_())
7063
7064 # skip this test for now as it affects all tests
7065 @unittest.skipIf(True, "flush_denormal not supported")
7066 def test_set_flush_denormal(self):
7067 tiny_float = 1e-42
7068 tiny_double = 1e-320
7069 float_tensor = torch.FloatTensor([1.0, tiny_float])
7070 double_tensor = torch.DoubleTensor([1.0, tiny_float, tiny_double])
7071
7072 self.assertEqual(float_tensor[0], 1.0, atol=0.0, rtol=0)
7073 self.assertEqual(float_tensor[1], tiny_float, atol=tiny_float / 16, rtol=0)
7074 self.assertEqual(double_tensor[0], 1.0, atol=0.0, rtol=0)
7075 self.assertEqual(double_tensor[1], tiny_float, atol=0.0, rtol=0)
7076 self.assertEqual(double_tensor[2], tiny_double, atol=0.0, rtol=0)
7077
7078 torch.set_flush_denormal(True)
7079 self.assertEqual(float_tensor[0], 1.0, atol=0.0, rtol=0)
7080 self.assertEqual(float_tensor[1], 0.0, atol=0.0, rtol=0) # tiny_float to zero
7081 self.assertEqual(double_tensor[0], 1.0, atol=0.0, rtol=0)
7082 # tiny_float is not converted to zero in double type
7083 self.assertEqual(double_tensor[1], tiny_float, atol=0.0, rtol=0)
7084 self.assertEqual(double_tensor[2], 0.0, atol=0.0, rtol=0) # tiny_double to zero
7085 torch.set_flush_denormal(False)
7086
7087 def test_show_config(self):
7088 # We can't usefully test the output; just make sure this doesn't crash
7089 torch.__config__.show()
7090
7091 @unittest.skipIf(IS_FBCODE, "CXX_FLAGS is only for OSS build.")
7092 def test_cxx_flags(self):
7093 torch.__config__._cxx_flags()
7094
7095 def test_parallel_info(self):
7096 torch.__config__.parallel_info()
7097
7098 @slowTest
7099 def test_slow_test(self):
7100 # Just a smoketest to make sure our slowTest decorator works.
7101 pass
7102
7103 def test_is_nonzero(self):
7104 with self.assertRaisesRegex(RuntimeError, "Boolean value of Tensor with no values is ambiguous"):
7105 torch.tensor([]).is_nonzero()
7106 with self.assertRaisesRegex(RuntimeError, "Boolean value of Tensor with more than one value is ambiguous"):
7107 torch.tensor([0, 0]).is_nonzero()
7108 self.assertFalse(torch.tensor(0).is_nonzero())
7109 self.assertTrue(torch.tensor(1).is_nonzero())
7110 self.assertFalse(torch.tensor([0]).is_nonzero())
7111 self.assertTrue(torch.tensor([1]).is_nonzero())
7112 self.assertFalse(torch.tensor([[0]]).is_nonzero())
7113 self.assertTrue(torch.tensor([[1]]).is_nonzero())
7114 self.assertTrue(torch.tensor(0.1).is_nonzero())
7115 self.assertTrue(torch.tensor(-0.1).is_nonzero())
7116 self.assertFalse(torch.tensor(0.0).is_nonzero())
7117 self.assertTrue(torch.tensor(True).is_nonzero())
7118 self.assertFalse(torch.tensor(False).is_nonzero())
7119 self.assertFalse(torch.tensor(0 + 0j).is_nonzero())
7120 self.assertTrue(torch.tensor(0 + 0.1j).is_nonzero())
7121
7122 def test_assert_async(self):
7123 with self.assertRaisesRegex(RuntimeError, "Boolean value of Tensor with no values is ambiguous"):
7124 torch._assert_async(torch.tensor([]))
7125 with self.assertRaisesRegex(RuntimeError, "Boolean value of Tensor with more than one value is ambiguous"):
7126 torch._assert_async(torch.tensor([0, 0]))
7127 with self.assertRaisesRegex(RuntimeError, "Expected Tensor with single nonzero value, but got zero"):
7128 torch._assert_async(torch.tensor(0))
7129 torch._assert_async(torch.tensor(1))
7130 torch._assert_async(torch.tensor(0.1))
7131 torch._assert_async(torch.tensor(-0.1))
7132 with self.assertRaisesRegex(RuntimeError, "Expected Tensor with single nonzero value, but got zero"):
7133 torch._assert_async(torch.tensor(0.0))
7134 torch._assert_async(torch.tensor(True))
7135 with self.assertRaisesRegex(RuntimeError, "Expected Tensor with single nonzero value, but got zero"):
7136 torch._assert_async(torch.tensor(False))
7137 torch._assert_async(torch.tensor(0 + 0.1j))
7138 with self.assertRaisesRegex(RuntimeError, "Expected Tensor with single nonzero value, but got zero"):
7139 torch._assert_async(torch.tensor(0 + 0j))
7140
7141 # NB: we must not be built with CUDA; if we are built with CUDA but no CUDA
7142 # is available, we get a different error.
7143 @unittest.skipIf(torch.backends.cuda.is_built() or IS_SANDCASTLE, "CUDA is built, can't test CUDA not built error")
7144 def test_cuda_not_built(self):
7145 msg = "Torch not compiled with CUDA enabled"
7146 self.assertRaisesRegex(AssertionError, msg, lambda: torch.cuda.current_device())
7147 self.assertRaisesRegex(AssertionError, msg, lambda: torch.tensor([1], device="cuda"))
7148 self.assertRaisesRegex(AssertionError, msg, lambda: torch.tensor([1]).cuda())
7149 self.assertRaisesRegex(TypeError, msg, lambda: torch.cuda.FloatTensor())
7150 self.assertRaisesRegex(TypeError, msg, lambda: torch.set_default_tensor_type(torch.cuda.FloatTensor))
7151 self.assertRaisesRegex(AssertionError, msg, lambda: torch.tensor([1]).to(device="cuda"))
7152
7153 def test_has_internal_overlap(self):
7154 OVERLAP_NO = 0
7155 OVERLAP_YES = 1
7156 OVERLAP_TOO_HARD = 2
7157
7158 # Check for contiguous tensors
7159 a = torch.randn(3, 3)
7160 self.assertEqual(torch._debug_has_internal_overlap(a), OVERLAP_NO)
7161
7162 # Checks for zero strides
7163 b = torch.randn(1, 3)
7164 b_expanded = b.expand(4, 3)
7165 self.assertEqual(torch._debug_has_internal_overlap(b_expanded), OVERLAP_YES)
7166
7167 # Check for zero strided, size 1 axis, in non-contiguous storage (gh-33812)
7168 c = torch.randn(10).as_strided([2, 1, 5], [1, 0, 2])
7169 self.assertEqual(torch._debug_has_internal_overlap(c), OVERLAP_NO)
7170 c = torch.randn(2, 1, 10)[::2].as_strided((2, 1, 5), (10, 0, 2))
7171 self.assertEqual(torch._debug_has_internal_overlap(c), OVERLAP_TOO_HARD)
7172
7173 def test_allow_tensor_metadata_change(self):
7174 def do_test(t):
7175 with self.assertRaisesRegex(
7176 RuntimeError,
7177 "set_sizes_contiguous is not allowed on a Tensor created from .data or .detach()"):
7178 t.resize_((2, 1))
7179 with self.assertRaisesRegex(
7180 RuntimeError,
7181 "set_storage is not allowed on a Tensor created from .data or .detach()"):
7182 t.set_()
7183 with self.assertRaisesRegex(
7184 RuntimeError,
7185 "set_storage_offset is not allowed on a Tensor created from .data or .detach()"):
7186 t.set_(t.storage(), 0, t.size(), list(t.stride()))
7187
7188 do_test(torch.tensor([[1, 2]]).data)
7189 do_test(torch.tensor([[1, 2]]).detach())
7190
7191 @skipIfNotRegistered("LayerNorm", "Skipping as LayerNorm is not registered")
7192 def test_c10_layer_norm(self):
7193 # test that we can call c10 ops and they return a reasonable result
7194 X = torch.rand(5, 5, dtype=torch.float)
7195 weight = torch.rand(*X.size()[1:], dtype=torch.float)
7196 bias = torch.rand(*X.size()[1:], dtype=torch.float)
7197 epsilon = 1e-4
7198
7199 expected_norm = torch.nn.functional.layer_norm(
7200 X, X.size()[1:], weight=weight, bias=bias, eps=epsilon)
7201 actual_norm, actual_mean, actual_stdev = \
7202 torch.ops._caffe2.LayerNorm(torch.tensor(X), torch.tensor(
7203 weight), torch.tensor(bias), 1, epsilon, True)
7204 torch.testing.assert_close(expected_norm, actual_norm)
7205
7206 def test_memory_format(self):
7207 def test_helper(x, memory_format):
7208 y = x.contiguous(memory_format=memory_format)
7209 self.assertFalse(y.is_contiguous())
7210 self.assertTrue(y.is_contiguous(memory_format=memory_format))
7211 self.assertEqual(y, x)
7212
7213 test_helper(torch.randn(4, 3, 8, 8), torch.channels_last)
7214 test_helper(torch.randn(4, 3, 8, 8, 8), torch.channels_last_3d)
7215
7216 def test_memory_format_contiguous_returns_same_tensor_if_already_satisfies(self):
7217 def test_helper(x, memory_format):
7218 alias = x.contiguous(memory_format=memory_format)
7219 alias.fill_(7)
7220 self.assertEqual(x, alias)
7221
7222 test_helper(torch.randn(4, 8, 8, 3).permute(0, 3, 1, 2), torch.channels_last)
7223 test_helper(torch.randn(4, 8, 8, 8, 3).permute(0, 4, 1, 2, 3), torch.channels_last_3d)
7224
7225 def test_memory_format_empty(self):
7226 def test_helper(dim1, dim2, memory_format):
7227 with self.assertRaises(RuntimeError):
7228 x = torch.empty(dim1, memory_format=memory_format)
7229 x = torch.empty(dim2, memory_format=memory_format)
7230 self.assertTrue(x.is_contiguous(memory_format=memory_format))
7231
7232 test_helper((3, 3), (3, 3, 3, 3), torch.channels_last)
7233 test_helper((3, 3, 3), (3, 3, 3, 3, 3), torch.channels_last_3d)
7234
7235 def test_subclass_tensors(self):
7236 # raise an error when trying to subclass FloatTensor
7237 with self.assertRaisesRegex(TypeError, "type 'torch.FloatTensor' is not an acceptable base type"):
7238 class Foo1(torch.FloatTensor):
7239 pass
7240
7241 # but allow subclassing Tensor:
7242 class Foo2(torch.Tensor):
7243 def foo(self):
7244 return 5
7245 f = Foo2()
7246 self.assertEqual(f.foo(), 5)
7247
7248 def test_ndim(self):
7249 a = torch.randn(1, 2, 3)
7250 self.assertEqual(3, a.ndim)
7251 b = torch.randn(())
7252 self.assertEqual(0, b.ndim)
7253 c = torch.randn(1, 0)
7254 self.assertEqual(2, c.ndim)
7255
7256 def test_fill_diagonal(self):
7257 a1 = torch.randn(7, 3)
7258 a2 = a1.clone()
7259 v = 1
7260 for i in range(3):
7261 a2[i][i] = v
7262 a1.fill_diagonal_(v)
7263 self.assertEqual(a1, a2)
7264
7265 b1 = torch.randn(7, 3)
7266 b2 = b1.clone()
7267 for i in range(3):
7268 b2[i][i] = v
7269 b2[i + 4][i] = v
7270 b1.fill_diagonal_(v, wrap=True)
7271 self.assertEqual(b1, b2)
7272
7273 c1 = torch.rand(3, 3, 3)
7274 c2 = c1.clone()
7275 for i in range(3):
7276 c2[i][i][i] = v
7277 c1.fill_diagonal_(v)
7278 self.assertEqual(c1, c2)
7279
7280 # non-contiguous tensor
7281 d1 = torch.rand(3, 3, 3)[:, 1, ...]
7282 d2 = d1.clone()
7283 for i in range(3):
7284 d2[i][i] = v
7285 d1.fill_diagonal_(v)
7286 self.assertEqual(d1, d2)
7287
7288 e1 = torch.rand(7, 3, 3)[:, 1, ...]
7289 e2 = e1.clone()
7290 for i in range(3):
7291 e2[i][i] = v
7292 e2[i + 4][i] = v
7293 e1.fill_diagonal_(v, wrap=True)
7294 self.assertEqual(e1, e2)
7295
anjali41137e0d2e2022-03-08 10:53:20 -08007296 def test_setting_real_imag_to_a_number(self):
7297 x = torch.randn(4, dtype=torch.cfloat)
7298 x.real = 0
7299 x.imag = 0
7300 zeros = torch.zeros(4)
7301 self.assertEqual(x.real, zeros)
7302 self.assertEqual(x.imag, zeros)
7303
Mike Ruberrye0d829a2022-01-24 01:28:07 -08007304 def test_batch_norm_cpu_inference(self):
7305 # input nchw in (2,1,1,1), (2,2,2,2)
7306 inputs = [
7307 torch.tensor([[[[-0.5000]]], [[[0.5000]]]]),
7308 torch.tensor([
7309 [
7310 [[-0.5000, 0.5000], [-1.0000, 1.0000]],
7311 [[-0.2500, -0.5000], [0.2500, 0.5000]]
7312 ],
7313 [
7314 [[0.1000, 1.0000], [1.0000, 0.1000]],
7315 [[1.0000, 0.5000], [1.5000, -1.5000]]
7316 ]])]
7317 # output nchw in (2,1,1,1), (2,2,2,2)
7318 outputs = [
7319 torch.tensor([
7320 [[[-0.499997496604919433593750000]]],
7321 [[[0.499997496604919433593750000]]]]),
7322 torch.tensor([
7323 [[[-0.499997496604919433593750000, 0.499997496604919433593750000],
7324 [-0.999994993209838867187500000, 0.999994993209838867187500000]],
7325 [[-0.249998748302459716796875000, -0.499997496604919433593750000],
7326 [0.249998748302459716796875000, 0.499997496604919433593750000]]],
7327 [[[0.099999502301216125488281250, 0.999994993209838867187500000],
7328 [0.999994993209838867187500000, 0.099999502301216125488281250]],
7329 [[0.999994993209838867187500000, 0.499997496604919433593750000],
7330 [1.499992489814758300781250000, -1.499992489814758300781250000]]]])]
7331
7332
7333 for i in range(len(inputs)):
7334 for affine in [False, True]:
7335 m = torch.nn.BatchNorm2d(inputs[i].size()[1], 1e-05, 0.1, affine=affine)
7336 m.eval()
7337 # contiguous case
7338 input1 = inputs[i].contiguous()
7339 output1 = m(input1)
7340 # non-contiguous case
7341 input2 = input1.permute(0, 1, 3, 2)
7342 output2 = m(input2).permute(0, 1, 3, 2)
7343 # channels last case
7344 input3 = input1.contiguous(memory_format=torch.channels_last)
7345 output3 = m(input3)
7346 self.assertEqual(output3, outputs[i])
7347 self.assertEqual(output3, output1)
7348 self.assertEqual(output3, output2)
7349
7350 # FIXME: move these meta tests to their own test suite/class or
7351 # distribute them among the appropriate test suites for their ops
Mike Ruberrye0d829a2022-01-24 01:28:07 -08007352 def test_empty_meta(self):
7353 x = torch.empty(2 ** 20, 2 ** 20, device='meta')
7354 y = torch.empty(2 ** 20, device='meta')
7355 z = x + y
7356 self.assertEqual(z.size(), (2 ** 20, 2 ** 20))
7357 self.assertRaises(RuntimeError, lambda: z[0][0].item())
7358
Edward Z. Yang51e7a342022-03-22 19:21:54 -04007359 def test_format_scalar_meta(self):
7360 x = torch.empty((), device='meta')
7361 self.assertEqual(format(x), repr(x))
7362
Mike Ruberrye0d829a2022-01-24 01:28:07 -08007363 def test_upsample_nearest1d_meta(self):
7364 # TODO: this test should be triggered by test_nn.py but right
7365 # now meta is not enabled (and even if it was, we are probably
7366 # missing too many meta functions to get through the test unmolested)
7367
7368 # NB: Can't make the exponent too big, or it will overflow
7369 # signed 64-bit integer
7370 x = torch.empty(2 * 10 ** 8, 3, 2 * 10 ** 8, device='meta')
7371 z = torch.nn.functional.interpolate(x, scale_factor=2)
7372 self.assertEqual(z.size(), (2 * 10 ** 8, 3, 4 * 10 ** 8))
7373 self.assertRaises(RuntimeError, lambda: z[0][0][0].item())
7374
7375 # TODO: the out tests cannot be triggered by test_nn.py because
7376 # we don't actually do out= arguments for nn functions, so there
7377 # is no public API by which to get the out version
7378
7379 # interpolate doesn't seem to support out=
7380 # (not sure why passing None here doesn't work? How strange...)
7381 z = torch.empty(0, device='meta')
7382 torch._C._nn.upsample_nearest1d(x, (4 * 10 ** 8,), 2, out=z)
7383 self.assertEqual(z.size(), (2 * 10 ** 8, 3, 4 * 10 ** 8))
7384 self.assertRaises(RuntimeError, lambda: z[0][0][0].item())
7385
Mike Ruberrye0d829a2022-01-24 01:28:07 -08007386 def test_upsample_nearest2d_meta(self):
7387 # TODO: the out tests cannot be triggered by test_nn.py because
7388 # we don't actually do out= arguments for nn functions, so there
7389 # is no public API by which to get the out version
7390
7391 # Make sure we don't clobber strides of out tensor. NB: this
7392 # test must be done on 2d/3d, because 1d doesn't have any meaningful
7393 # layout support
7394 x = torch.empty(4, 3, 8, 8, device='meta')
7395 out = torch.empty(4, 3, 16, 16, device='meta', memory_format=torch.channels_last)
7396 torch._C._nn.upsample_nearest2d(x, (16, 16), out=out)
7397 self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
7398
7399 x = torch.empty(4, 3, 8, 8, device='meta', memory_format=torch.channels_last)
7400 out = torch.empty(4, 3, 16, 16, device='meta')
7401 torch._C._nn.upsample_nearest2d(x, (16, 16), out=out)
7402 self.assertTrue(out.is_contiguous())
7403
7404 # But if resize occurs, do clobber
7405 x = torch.empty(4, 3, 8, 8, device='meta', memory_format=torch.channels_last)
7406 out = torch.empty(0, device='meta')
7407 torch._C._nn.upsample_nearest2d(x, (16, 16), out=out)
7408 self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
7409
7410 # Complain if out dtype mismatch
7411 x = torch.empty(4, 3, 8, 8, device='meta', dtype=torch.float)
7412 out = torch.empty(4, 3, 16, 16, device='meta', dtype=torch.double)
7413 self.assertExpectedRaisesInline(
7414 RuntimeError, lambda: torch._C._nn.upsample_nearest2d(x, (16, 16), out=out),
7415 """Expected out tensor to have dtype float, but got double instead"""
7416 )
7417
7418 # Complain if out device mismatch
7419 x = torch.empty(0, 3, 8, 8, device='meta')
7420 out = torch.empty(0, 3, 16, 16, device='cpu')
7421 self.assertExpectedRaisesInline(
7422 RuntimeError, lambda: torch._C._nn.upsample_nearest2d(x, (16, 16), out=out),
7423 """Expected out tensor to have device meta, but got cpu instead"""
7424 )
7425
Mike Ruberrye0d829a2022-01-24 01:28:07 -08007426 def test_add_meta_scalar(self):
7427 # From https://github.com/pytorch/pytorch/issues/53815
7428 x = torch.empty(2, device='meta')
7429 y = x + 2
7430 self.assertEqual(y.size(), x.size())
7431
7432 def test_normal_shape(self):
7433 warned = False
7434 for device in get_all_device_types():
7435 tensor1 = torch.rand(1, device=device)
7436 tensor4 = torch.rand(4, device=device)
7437 tensor120 = torch.rand(120, device=device)
7438 tensor2145 = torch.rand(2, 1, 4, 5, device=device)
7439 tensor2345 = torch.rand(2, 3, 4, 5, device=device)
7440 tensor2345_non_contiguous = torch.rand(2, 4, 3, 5, device=device).permute(0, 2, 1, 3)
7441 tensor2345_channels_last = tensor2345.contiguous(memory_format=torch.channels_last)
7442 output2345 = torch.zeros(2, 3, 4, 5, device=device)
7443 output345 = torch.zeros(3, 4, 5, device=device)
7444
7445 # inputs have same size
7446 self.assertEqual(torch.normal(tensor2345, tensor2345).size(), (2, 3, 4, 5))
7447 self.assertEqual(torch.normal(tensor2345_non_contiguous, tensor2345).size(), (2, 3, 4, 5))
7448 self.assertEqual(torch.normal(tensor2345, tensor2345_channels_last).size(), (2, 3, 4, 5))
7449 self.assertEqual(torch.normal(tensor2345_non_contiguous, tensor2345_channels_last).size(), (2, 3, 4, 5))
7450
7451 # scalar case
7452 self.assertEqual(torch.normal(tensor2345, 2).size(), (2, 3, 4, 5))
7453 self.assertEqual(torch.normal(2, tensor2345).size(), (2, 3, 4, 5))
7454
7455 # inputs are expandable tensors
7456 self.assertEqual(torch.normal(tensor2345, tensor1).size(), (2, 3, 4, 5))
7457 self.assertEqual(torch.normal(tensor2145, tensor2345).size(), (2, 3, 4, 5))
7458
7459 # inputs are non-expandable tensors, but they have same number of elements
Nikita Karetnikoveb0d3702022-03-01 15:07:13 -08007460 with self.assertRaisesRegex(
7461 RuntimeError,
7462 r"The size of tensor a \(120\) must match the size of "
7463 r"tensor b \(5\) at non-singleton dimension 3"):
Mike Ruberrye0d829a2022-01-24 01:28:07 -08007464 self.assertEqual(torch.normal(tensor120, tensor2345).size(), (120,))
Nikita Karetnikoveb0d3702022-03-01 15:07:13 -08007465 with self.assertRaisesRegex(
7466 RuntimeError,
7467 r"The size of tensor a \(5\) must match the size of "
7468 r"tensor b \(120\) at non-singleton dimension 3"):
7469 self.assertEqual(torch.normal(tensor2345, tensor120).size(), (2, 3, 4, 5))
Mike Ruberrye0d829a2022-01-24 01:28:07 -08007470
7471 # inputs are non-expandable tensors and they don't have same number of elements
Nikita Karetnikoveb0d3702022-03-01 15:07:13 -08007472 with self.assertRaisesRegex(
7473 RuntimeError,
7474 r"The size of tensor a \(5\) must match the size of "
7475 r"tensor b \(4\) at non-singleton dimension 3"):
Mike Ruberrye0d829a2022-01-24 01:28:07 -08007476 torch.normal(tensor2345, tensor4)
7477
7478 # output and inputs are size compatible
7479 self.assertEqual(torch.normal(tensor2345, tensor2345, out=output2345).size(), (2, 3, 4, 5))
7480
7481 # output and inputs are not size compatible
Nikita Karetnikoveb0d3702022-03-01 15:07:13 -08007482 with self.assertWarnsRegex(
7483 UserWarning,
7484 "This behavior is deprecated, and in a future PyTorch "
7485 "release outputs will not be resized unless they have "
7486 "zero elements"):
7487 self.assertEqual(torch.normal(tensor2345, tensor2145, out=output345).size(), (2, 3, 4, 5))
7488 with self.assertRaisesRegex(
7489 RuntimeError,
7490 r"The size of tensor a \(5\) must match the size of "
7491 r"tensor b \(120\) at non-singleton dimension 3"):
7492 # inputs are not expandable, output size is not the same as mean
Mike Ruberrye0d829a2022-01-24 01:28:07 -08007493 torch.normal(tensor2345, tensor120, out=output345)
7494
7495 def test_tensoriterator_output_setup(self):
7496 # Test whether the output's memory layout is correct
7497 def test_memory_layout(x, y, scale, zero_point, out):
7498 self.assertEqual(x.dim(), 4)
7499 self.assertEqual(x.size(), y.size())
7500 self.assertEqual(y.size(), out.size())
7501
7502 shape = x.size()
7503 for n in range(shape[0]):
7504 for c in range(shape[1]):
7505 for h in range(shape[2]):
7506 for w in range(shape[3]):
7507 if scale is not None and zero_point is not None:
7508 self.assertEqual(
7509 out[n][c][h][w],
7510 torch.ops.quantized.add(x[n][c][h][w], y[n][c][h][w], scale, zero_point))
7511 else:
7512 self.assertEqual(out[n][c][h][w], x[n][c][h][w] + y[n][c][h][w])
7513
7514 xraw = torch.rand(2, 3, 4, 4)
7515 yraw = torch.rand(2, 3, 4, 4)
7516 qxraw = torch.quantize_per_tensor(xraw, 0.1, 5, torch.quint8)
7517 qyraw = torch.quantize_per_tensor(yraw, 0.1, 5, torch.quint8)
7518
7519 # contiguous case fast setup
7520 test_memory_layout(xraw, yraw, None, None, xraw + yraw)
7521 test_memory_layout(qxraw, qyraw, 0.1, 5, torch.ops.quantized.add(qxraw, qyraw, 0.1, 5))
7522
7523 # channels last case fast setup
7524 x = xraw.contiguous(memory_format=torch.channels_last)
7525 y = yraw.contiguous(memory_format=torch.channels_last)
7526 test_memory_layout(x, y, None, None, x + y)
7527 qx = qxraw.contiguous(memory_format=torch.channels_last)
7528 qy = qyraw.contiguous(memory_format=torch.channels_last)
7529 test_memory_layout(qx, qy, 0.1, 5, torch.ops.quantized.add(qx, qy, 0.1, 5))
7530
7531 # non contiguous case fast setup (dense, non-overlapping, same shape and strides)
7532 x = xraw.permute(0, 2, 3, 1)
7533 y = yraw.permute(0, 2, 3, 1)
7534 test_memory_layout(x, y, None, None, x + y)
7535 qx = qxraw.permute(0, 2, 3, 1)
7536 qy = qyraw.permute(0, 2, 3, 1)
7537 test_memory_layout(qx, qy, 0.1, 5, torch.ops.quantized.add(qx, qy, 0.1, 5))
7538
7539 # non contiguous case fast setup (dense, non-overlapping)
7540 # input tensors have same shape and strides
7541 # output tensor have same shape as input tensors but different stride
7542 # output tensor should preserve its strides in this case
7543 x = xraw.permute(0, 2, 3, 1)
7544 y = yraw.permute(0, 2, 3, 1)
7545 out = torch.empty_like(xraw)
7546 out = out.permute(0, 3, 2, 1)
7547 expected_stride = out.stride()
7548 test_memory_layout(x, y, None, None, torch.add(x, y, out=out))
7549 self.assertEqual(expected_stride, out.stride())
7550
7551 # non contiguous case non fast setup
7552 x = xraw.permute(0, 2, 3, 1)
7553 y = yraw.permute(0, 3, 2, 1)
7554 test_memory_layout(x, y, None, None, x + y)
7555 qx = qxraw.permute(0, 2, 3, 1)
7556 qy = qyraw.permute(0, 3, 2, 1)
7557 test_memory_layout(qx, qy, 0.1, 5, torch.ops.quantized.add(qx, qy, 0.1, 5))
7558
7559 # Tests to make sure we still handle .data properly until it is removed
7560 def test_dot_data_use(self):
7561 # .data allows to change the Tensors types inplace, check that we still
7562 # raise a nice error.
7563 with self.assertRaisesRegex(
7564 RuntimeError,
yuguo68efdb4192022-05-31 18:11:31 -07007565 # message includes both Double and ComplexFloat
7566 '(?=.*Double)(?=.*ComplexFloat)'):
Mike Ruberrye0d829a2022-01-24 01:28:07 -08007567
7568 # Calls model with a LongTensor input but DoubleTensor weights
7569 input = torch.randn(1, 1, 1, 6, dtype=torch.double)
yuguo68efdb4192022-05-31 18:11:31 -07007570 weight = torch.zeros(1, 1, 1, 3, dtype=torch.complex64)
Mike Ruberrye0d829a2022-01-24 01:28:07 -08007571 model = torch.nn.Conv2d(1, 1, (1, 3), stride=1, padding=0, bias=False)
7572 model.weight.data = weight
7573 out = model(input)
7574
7575 def test_empty_storage_view(self):
7576 # we should be able to "modify" slices of a 0-element
7577 # array without an error being raised due to
7578 # trying to resize its storage
7579 t = torch.from_numpy(np.empty((0, 4)))
7580 t[:, 1::2] *= 1
7581
7582 def test_has_storage(self):
7583 self.assertIsNotNone(torch.tensor([]).storage())
7584 self.assertIsNotNone(torch.empty(0).storage())
7585 self.assertIsNotNone(torch.tensor([]).clone().storage())
7586 self.assertIsNotNone(torch.tensor([0, 0, 0]).nonzero().storage())
7587 self.assertIsNotNone(torch.tensor([]).new().storage())
7588
7589 # FIXME: Extend this test and put in a TensorProperties test class
7590 def test_numel(self):
7591 b = torch.ByteTensor(3, 100, 100)
7592 self.assertEqual(b.nelement(), 3 * 100 * 100)
7593 self.assertEqual(b.numel(), 3 * 100 * 100)
7594
7595 # Verifies that (deep)copies of dtypes are the same objects
7596 def test_copy_dtypes(self):
Nikita Shulgabfac65d2022-03-30 14:13:21 -07007597 for dtype in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool):
Mike Ruberrye0d829a2022-01-24 01:28:07 -08007598 copied_dtype = copy.deepcopy(dtype)
7599 self.assertIs(dtype, copied_dtype)
7600
7601 def test_dtype_is_signed(self):
Nikita Shulgabfac65d2022-03-30 14:13:21 -07007602 for dtype in all_types_and_complex_and(torch.half, torch.bfloat16, torch.half):
Mike Ruberrye0d829a2022-01-24 01:28:07 -08007603 self.assertEqual(dtype.is_signed, torch.is_signed(torch.tensor(0, dtype=dtype)))
7604
7605 self.assertRaisesRegex(RuntimeError, 'not supported for quantized', lambda: torch.quint8.is_signed)
7606 self.assertRaisesRegex(RuntimeError, 'not supported for quantized', lambda: torch.qint8.is_signed)
7607 self.assertRaisesRegex(RuntimeError, 'not supported for quantized', lambda: torch.qint32.is_signed)
7608
7609 # FIXME: Put the following random tests into their own test class or test suite
7610 def test_RNGState(self):
7611 state = torch.get_rng_state()
7612 stateCloned = state.clone()
7613 before = torch.rand(1000)
7614
7615 self.assertEqual(state.ne(stateCloned).long().sum(), 0, atol=0, rtol=0)
7616
7617 torch.set_rng_state(state)
7618 after = torch.rand(1000)
7619 self.assertEqual(before, after, atol=0, rtol=0)
7620
7621 def test_RNGStateAliasing(self):
7622 # Fork the random number stream at this point
7623 gen = torch.Generator()
7624 gen.set_state(torch.get_rng_state())
7625 self.assertEqual(gen.get_state(), torch.get_rng_state())
7626
7627 target_value = torch.rand(1000)
7628 # Dramatically alter the internal state of the main generator
7629 _ = torch.rand(100000)
7630 forked_value = torch.rand(1000, generator=gen)
7631 self.assertEqual(target_value, forked_value, atol=0, rtol=0, msg="RNG has not forked correctly.")
7632
7633 def test_RNG_after_pickle(self):
7634 torch.random.manual_seed(100)
7635 before = torch.rand(10)
7636
7637 torch.random.manual_seed(100)
7638 buf = io.BytesIO()
7639 tensor = torch.tensor([1, 2, 3])
7640 ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(tensor)
7641 after = torch.rand(10)
7642
7643 self.assertEqual(before, after, atol=0, rtol=0)
7644
7645 def test_boxMullerState(self):
7646 torch.manual_seed(123)
7647 odd_number = 101
7648 seeded = torch.randn(odd_number)
7649 state = torch.get_rng_state()
7650 midstream = torch.randn(odd_number)
7651 torch.set_rng_state(state)
7652 repeat_midstream = torch.randn(odd_number)
7653 torch.manual_seed(123)
7654 reseeded = torch.randn(odd_number)
7655 self.assertEqual(midstream, repeat_midstream, atol=0, rtol=0,
7656 msg='get_rng_state/set_rng_state not generating same sequence of normally distributed numbers')
7657 self.assertEqual(seeded, reseeded, atol=0, rtol=0,
7658 msg='repeated calls to manual_seed not generating same sequence of normally distributed numbers')
7659
7660 def test_manual_seed(self):
7661 rng_state = torch.get_rng_state()
7662 torch.manual_seed(2)
7663 x = torch.randn(100)
7664 self.assertEqual(torch.initial_seed(), 2)
7665 torch.manual_seed(2)
7666 y = torch.randn(100)
7667 self.assertEqual(x, y)
7668
7669 max_int64 = 0x7fff_ffff_ffff_ffff
7670 min_int64 = -max_int64 - 1
7671 max_uint64 = 0xffff_ffff_ffff_ffff
7672 # Check all boundary cases of valid seed value inputs
7673 test_cases = [
7674 # (seed, expected_initial_seed)
7675 # Positive seeds should be unchanged
7676 (max_int64, max_int64),
7677 (max_int64 + 1, max_int64 + 1),
7678 (max_uint64, max_uint64),
7679 (0, 0),
7680 # Negative seeds wrap around starting from the largest seed value
7681 (-1, max_uint64),
7682 (min_int64, max_int64 + 1)
7683 ]
7684 for seed, expected_initial_seed in test_cases:
7685 torch.manual_seed(seed)
7686 actual_initial_seed = torch.initial_seed()
7687 msg = "expected initial_seed() = %x after calling manual_seed(%x), but got %x instead" % (
7688 expected_initial_seed, seed, actual_initial_seed)
7689 self.assertEqual(expected_initial_seed, actual_initial_seed, msg=msg)
7690 for invalid_seed in [min_int64 - 1, max_uint64 + 1]:
7691 with self.assertRaisesRegex(RuntimeError, r'Overflow when unpacking long'):
7692 torch.manual_seed(invalid_seed)
7693
7694 torch.set_rng_state(rng_state)
7695
7696 # FIXME: Describe this test and port to the generic device framework in a more
7697 # appropriate test suite for the copy operation
7698 def test_copy_transpose(self):
7699 x = torch.arange(100 * 100, dtype=torch.float).reshape(100, 100).t()
7700 y = torch.empty(100, 100, dtype=torch.float)
7701 y.copy_(x)
7702 self.assertEqual(y[:, 0], range(100))
7703 self.assertEqual(y[:, 40], range(4000, 4100))
7704
7705 y = torch.empty(100, 100, dtype=torch.double)
7706 y.copy_(x)
7707 self.assertEqual(y[:, 0], range(100))
7708 self.assertEqual(y[:, 40], range(4000, 4100))
7709
7710 # Validates regression reported in https://github.com/pytorch/pytorch/issues/45269
7711 x = torch.arange(100 * 100).reshape(100, 100).to(dtype=torch.cfloat).t()
7712 y = torch.empty(100, 100, dtype=torch.cfloat)
7713 y.copy_(x)
7714 self.assertEqual(y[:, 0], range(100))
7715 self.assertEqual(y[:, 40], range(4000, 4100))
7716
kshitij12345f7ee3082022-03-23 21:42:59 +00007717 x = torch.arange(100 * 100).reshape(100, 100).to(dtype=torch.complex32).t()
7718 y = torch.empty(100, 100, dtype=torch.complex32)
7719 y.copy_(x)
7720 self.assertEqual(y[:, 0], range(100))
7721 self.assertEqual(y[:, 40], range(4000, 4100))
7722
Mike Ruberrye0d829a2022-01-24 01:28:07 -08007723 # FIXME: Port to a more appropriate test suite
7724 def test_copy_broadcast(self):
7725 torch.zeros(5, 6).copy_(torch.zeros(6))
7726 self.assertRaises(RuntimeError, lambda: torch.zeros(5, 6).copy_(torch.zeros(30)))
7727
7728 # FIXME: Port to a more appropriate test suite
7729 def test_copy_many_to_one(self):
7730 # Testing in-place copy where it attempt to write from many memory
7731 # storage to a single storage would cause RuntimeError to be thrown
7732 self.assertRaises(RuntimeError, lambda: torch.zeros(1, 6).expand(5, 6).copy_(torch.zeros(5, 6)))
7733
7734 # FIXME: Port to a more appropriate test suite
Christian Puhrschce9a4772022-05-05 21:59:50 +00007735 def _test_to_with_layout(self, layout):
Mike Ruberrye0d829a2022-01-24 01:28:07 -08007736 def test_copy_behavior(t, non_blocking=False):
7737 self.assertIs(t, t.to(t, non_blocking=non_blocking))
7738 self.assertIs(t, t.to(t.dtype, non_blocking=non_blocking))
7739 self.assertIs(t, t.to(torch.empty_like(t), non_blocking=non_blocking))
7740 self.assertIsNot(t, t.to(t, non_blocking=non_blocking, copy=True))
7741 self.assertIsNot(t, t.to(t.dtype, non_blocking=non_blocking, copy=True))
7742 self.assertIsNot(t, t.to(torch.empty_like(t), non_blocking=non_blocking, copy=True))
7743
7744 devices = [t.device]
7745 if t.device.type == 'cuda':
7746 if t.device.index == -1:
7747 devices.append('cuda:{}'.format(torch.cuda.current_device()))
7748 elif t.device.index == torch.cuda.current_device():
7749 devices.append('cuda')
7750 for device in devices:
7751 self.assertIs(t, t.to(device, non_blocking=non_blocking))
7752 self.assertIs(t, t.to(device, t.dtype, non_blocking=non_blocking))
7753 self.assertIsNot(t, t.to(device, non_blocking=non_blocking, copy=True))
7754 self.assertIsNot(t, t.to(device, t.dtype, non_blocking=non_blocking, copy=True))
7755
7756 a = torch.tensor(5)
Christian Puhrschce9a4772022-05-05 21:59:50 +00007757 if layout == torch.sparse_csr:
7758 a = torch.tensor([[0, 1, 2], [2, 0, 3]]).to_sparse_csr()
Mike Ruberrye0d829a2022-01-24 01:28:07 -08007759 test_copy_behavior(a)
7760 self.assertEqual(a.device, a.to('cpu').device)
7761 self.assertEqual(a.device, a.to('cpu', dtype=torch.float32).device)
7762 self.assertIs(torch.float32, a.to('cpu', dtype=torch.float32).dtype)
7763 self.assertEqual(a.device, a.to(torch.float32).device)
7764 self.assertIs(torch.float32, a.to(dtype=torch.float32).dtype)
Christian Puhrschce9a4772022-05-05 21:59:50 +00007765
7766 def test_data_ptr(getter):
7767 self.assertEqual(getter(a), getter(a.to('cpu')))
7768 self.assertEqual(getter(a), getter(a.to(dtype=a.dtype, device=a.device, copy=False)))
7769 self.assertEqual(getter(a), getter(a.to('cpu', copy=False)))
7770 self.assertNotEqual(getter(a), getter(a.to('cpu', copy=True)))
7771 if layout == torch.sparse_csr:
7772 # TODO: compressed sparse tensors currently don't support data_ptr.
7773 # Exercising failure will allow us to widen coverage of this test once it does.
7774 with self.assertRaisesRegex(RuntimeError, "Cannot access data pointer of Tensor that doesn't have storage"):
7775 a.data_ptr()
7776 # While compressed sparse tensors don't have a concept of data_ptr
7777 # the underlying tensors do. The implementation of to appropriately forwards
7778 # the call to the components, which is what we're test here.
7779 test_data_ptr(lambda a: a.values().data_ptr())
7780 test_data_ptr(lambda a: a.crow_indices().data_ptr())
7781 test_data_ptr(lambda a: a.col_indices().data_ptr())
7782 else:
7783 test_data_ptr(lambda a: a.data_ptr())
Mike Ruberrye0d829a2022-01-24 01:28:07 -08007784
7785 if torch.cuda.is_available():
7786 for non_blocking in [True, False]:
7787 for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']:
7788 b = torch.tensor(5., device=cuda)
7789 test_copy_behavior(b, non_blocking)
7790 self.assertEqual(b.device, b.to(cuda, non_blocking=non_blocking).device)
7791 self.assertEqual(a.device, b.to('cpu', non_blocking=non_blocking).device)
7792 self.assertEqual(b.device, a.to(cuda, non_blocking=non_blocking).device)
7793 self.assertIs(torch.int32, b.to('cpu', dtype=torch.int32, non_blocking=non_blocking).dtype)
7794 self.assertEqual(a.device, b.to('cpu', dtype=torch.int32, non_blocking=non_blocking).device)
7795 self.assertIs(torch.int32, b.to(dtype=torch.int32).dtype)
7796 self.assertEqual(b.device, b.to(dtype=torch.int32).device)
7797
Christian Puhrschce9a4772022-05-05 21:59:50 +00007798 def test_to(self):
7799 self._test_to_with_layout(torch.strided)
Eric Sauser2d4291f2022-05-19 14:04:13 +00007800 is_cuda10_2_or_higher = (
7801 (torch.version.cuda is not None)
7802 and ([int(x) for x in torch.version.cuda.split(".")] >= [10, 2]))
7803 if is_cuda10_2_or_higher: # in cuda10_1 sparse_csr is beta
7804 self._test_to_with_layout(torch.sparse_csr)
Christian Puhrschce9a4772022-05-05 21:59:50 +00007805
Mike Ruberrye0d829a2022-01-24 01:28:07 -08007806 # FIXME: describe this test
7807 def test_as_subclass(self):
7808 class SubTensor(torch.Tensor):
7809 member_var = object()
7810
7811 t0 = torch.tensor(0)
7812 t1 = torch.tensor([1, 2])
7813 t2 = torch.tensor([[3, 4], [5, 6]])
7814
7815 s0 = t0.as_subclass(SubTensor)
7816 s1 = t1.as_subclass(SubTensor)
7817 s2 = t2.as_subclass(SubTensor)
7818
7819 # Check that the correct type is returned.
7820 self.assertTrue(type(s0) is SubTensor)
7821 self.assertTrue(type(s1) is SubTensor)
7822 self.assertTrue(type(s2) is SubTensor)
7823
7824 # Check that the data is equal.
7825 self.assertEqual(t0, s0)
7826 self.assertEqual(t1, s1)
7827 self.assertEqual(t2, s2)
7828
7829 t0[()] = 1
7830 t1[1] = 3
7831 t2[1, 1] = 7
7832
7833 # Check that the data is equal even after modification.
7834 self.assertEqual(t0, s0)
7835 self.assertEqual(t1, s1)
7836 self.assertEqual(t2, s2)
7837
7838 # Check that member variables are passed through.
7839 self.assertTrue(s0.member_var is SubTensor.member_var)
7840 self.assertTrue(s1.member_var is SubTensor.member_var)
7841 self.assertTrue(s2.member_var is SubTensor.member_var)
7842
7843 # Test that autograd is propagated.
7844 t = torch.tensor(5, dtype=torch.float32, requires_grad=True)
7845
7846 # Run a calculation on the tensor.
7847 exp_t = torch.exp(t)
7848
7849 # Cast exp_t to a subclass.
7850 exp_s = exp_t.as_subclass(SubTensor)
7851
7852 # Make sure that t.grad was initially None
7853 self.assertTrue(t.grad is None)
7854
7855 # Run the autograd calculation.
7856 exp_s.backward()
7857
7858 # Make sure autograd was propagated to the original tensor
7859 # declared with requires_grad.
7860 self.assertTrue(t.grad is not None)
7861
7862 # Make sure invalid subclasses raise nice errors
7863 class BadSubTensor():
7864 member_var = object()
7865
7866 err_msg = "Creating a Tensor subclass from a class that does not inherit from Tensor"
7867 with self.assertRaisesRegex(RuntimeError, err_msg):
7868 s0 = t0.as_subclass(BadSubTensor)
7869
7870 # FIXME: Port to a test suite that better fits slicing
7871 def test_slice(self):
7872 empty = torch.empty(0, 4)
7873 x = torch.arange(0., 16).view(4, 4)
7874 self.assertEqual(x[:], x)
7875 self.assertEqual(x[:4], x)
7876 # start and stop are clamped to the size of dim
7877 self.assertEqual(x[:5], x)
7878 # if start >= stop then the result is empty
7879 self.assertEqual(x[2:1], empty)
7880 self.assertEqual(x[2:2], empty)
7881 # out of bounds is also empty
7882 self.assertEqual(x[10:12], empty)
7883 # additional correctness checks
7884 self.assertEqual(x[:1].tolist(), [[0, 1, 2, 3]])
7885 self.assertEqual(x[:-3].tolist(), [[0, 1, 2, 3]])
7886 self.assertEqual(x[:, -2:3].tolist(), [[2], [6], [10], [14]])
7887 self.assertEqual(x[0:-1:2].tolist(), [[0, 1, 2, 3], [8, 9, 10, 11]])
7888
7889 def test_type(self):
7890 x = torch.randn(3, 3).double()
7891 self.assertEqual(x.type('torch.FloatTensor').dtype, torch.float32)
7892 self.assertEqual(x.type(torch.FloatTensor).dtype, torch.float32)
7893 self.assertEqual(x.int().type(torch.Tensor).dtype, torch.get_default_dtype())
7894 self.assertEqual(x.type(torch.int32).dtype, torch.int32)
7895
7896 # FIXME: port to a quantization test suite
7897 def test_qengine(self):
7898 qengines = torch.backends.quantized.supported_engines
7899 original_qe = torch.backends.quantized.engine
7900 for qe in qengines:
7901 torch.backends.quantized.engine = qe
7902 assert torch.backends.quantized.engine == qe, 'qengine not set successfully'
7903 torch.backends.quantized.engine = original_qe
7904
7905 # FIXME: port to a distributed test suite -- also... how could this be OOMing on Windows CUDA?
7906 @slowTest
7907 @unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
7908 don't support multiprocessing with spawn start method")
7909 @unittest.skipIf(IS_WINDOWS, 'FIXME: CUDA OOM error on Windows')
7910 def test_multinomial_invalid_probs(self):
7911 def _spawn_method(self, method, arg):
7912 try:
7913 mp.set_start_method('spawn')
7914 except RuntimeError:
7915 pass
7916 with mp.Pool(1) as pool:
7917 out: list = pool.map(method, [arg])
7918 self.assertTrue(out[0])
7919
7920 def _test_multinomial_invalid_probs(probs):
7921 try:
7922 # n_sample = 1 is a special case, test n_sample=2 which is more general
7923 torch.multinomial(probs.to('cpu'), 2)
7924 return False # Should not be reached
7925 except RuntimeError as e:
7926 return 'probability tensor contains either `inf`, `nan` or element < 0' in str(e)
7927
7928 _spawn_method(_test_multinomial_invalid_probs, torch.tensor([1., -1., 1.]))
7929 _spawn_method(_test_multinomial_invalid_probs, torch.tensor([1., inf, 1.]))
7930 _spawn_method(_test_multinomial_invalid_probs, torch.tensor([1., -inf, 1.]))
7931 _spawn_method(_test_multinomial_invalid_probs, torch.tensor([1., 1., nan]))
7932
7933 # FIXME: port to more appropriate test suite
7934 def test_to_with_tensor(self):
7935 a = torch.tensor(5)
7936 self.assertEqual(a.device, a.to(a).device)
7937
7938 if torch.cuda.is_available():
7939 for non_blocking in [True, False]:
7940 for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']:
7941 b = torch.tensor(5., device=cuda)
7942 self.assertEqual(b.device, b.to(b, non_blocking=non_blocking).device)
7943 self.assertEqual(a.device, b.to(a, non_blocking=non_blocking).device)
7944 self.assertEqual(b.device, a.to(b, non_blocking=non_blocking).device)
7945
7946 def test_device(self):
7947 cpu = torch.device('cpu')
7948 self.assertEqual('cpu', str(cpu))
7949 self.assertEqual('cpu', cpu.type)
7950 self.assertEqual(None, cpu.index)
7951
7952 cpu0 = torch.device('cpu:0')
7953 self.assertEqual('cpu:0', str(cpu0))
7954 self.assertEqual('cpu', cpu0.type)
7955 self.assertEqual(0, cpu0.index)
7956
7957 cpu0 = torch.device('cpu', 0)
7958 self.assertEqual('cpu:0', str(cpu0))
7959 self.assertEqual('cpu', cpu0.type)
7960 self.assertEqual(0, cpu0.index)
7961
7962 cuda = torch.device('cuda')
7963 self.assertEqual('cuda', str(cuda))
7964 self.assertEqual('cuda', cuda.type)
7965 self.assertEqual(None, cuda.index)
7966
7967 cuda1 = torch.device('cuda:1')
7968 self.assertEqual('cuda:1', str(cuda1))
7969 self.assertEqual('cuda', cuda1.type)
7970 self.assertEqual(1, cuda1.index)
7971
7972 cuda1 = torch.device('cuda', 1)
7973 self.assertEqual('cuda:1', str(cuda1))
7974 self.assertEqual('cuda', cuda1.type)
7975 self.assertEqual(1, cuda1.index)
7976
7977 cuda90 = torch.device('cuda', 90)
7978 self.assertEqual('cuda:90', str(cuda90))
7979 self.assertEqual('cuda', cuda90.type)
7980 self.assertEqual(90, cuda90.index)
7981
7982 self.assertRaises(RuntimeError, lambda: torch.device('cpu:-1'))
7983 self.assertRaises(RuntimeError, lambda: torch.device('cuda:-1'))
7984 self.assertRaises(RuntimeError, lambda: torch.device('cuda:2 '))
7985 self.assertRaises(RuntimeError, lambda: torch.device('cuda: 2'))
7986 self.assertRaises(RuntimeError, lambda: torch.device('cuda:2 2'))
7987 self.assertRaises(RuntimeError, lambda: torch.device('cuda:2.'))
7988 self.assertRaises(RuntimeError, lambda: torch.device('cuda:2?'))
7989 self.assertRaises(RuntimeError, lambda: torch.device('cuda:?2'))
7990 self.assertRaises(RuntimeError, lambda: torch.device('cuda:'))
7991 self.assertRaises(RuntimeError, lambda: torch.device('cuda:2.232'))
7992 self.assertRaises(RuntimeError, lambda: torch.device('cuda:2 cuda:3'))
7993 self.assertRaises(RuntimeError, lambda: torch.device('cuda:2+cuda:3'))
7994 self.assertRaises(RuntimeError, lambda: torch.device('cuda:2cuda:3'))
7995 self.assertRaises(RuntimeError, lambda: torch.device(-1))
7996
7997 self.assertRaises(RuntimeError, lambda: torch.device('other'))
7998 self.assertRaises(RuntimeError, lambda: torch.device('other:0'))
7999
8000 device_set = {'cpu', 'cpu:0', 'cuda', 'cuda:0', 'cuda:1', 'cuda:10', 'cuda:100'}
8001 device_hash_set = set()
8002 for device in list(device_set):
8003 device_hash_set.add(hash(torch.device(device)))
8004 self.assertEqual(len(device_set), len(device_hash_set))
8005
8006 def get_expected_device_repr(device):
8007 if device.index is not None:
8008 return "device(type='{type}', index={index})".format(
8009 type=device.type, index=device.index)
8010
8011 return "device(type='{type}')".format(type=device.type)
8012
8013 for device in device_set:
8014 dev = torch.device(device)
8015 self.assertEqual(repr(dev), get_expected_device_repr(dev))
8016
8017 # Tests that the use_deterministic_flag can be set as expected
8018 @wrapDeterministicFlagAPITest
8019 def test_deterministic_flag(self):
8020 for deterministic, warn_only in product([True, False], [True, False]):
8021 torch.use_deterministic_algorithms(deterministic, warn_only=warn_only)
8022 self.assertEqual(deterministic, torch.are_deterministic_algorithms_enabled())
8023 self.assertEqual(warn_only, torch.is_deterministic_algorithms_warn_only_enabled())
8024
8025 if deterministic:
8026 if warn_only:
8027 debug_mode = 1
8028 else:
8029 debug_mode = 2
8030 else:
8031 debug_mode = 0
8032
8033 self.assertEqual(debug_mode, torch.get_deterministic_debug_mode())
8034
8035 for debug_mode in [0, 1, 2]:
8036 torch.set_deterministic_debug_mode(debug_mode)
8037 self.assertEqual(debug_mode, torch.get_deterministic_debug_mode())
8038 deterministic = debug_mode in [1, 2]
8039 warn_only = debug_mode == 1
8040
8041 self.assertEqual(deterministic, torch.are_deterministic_algorithms_enabled())
8042 self.assertEqual(warn_only, torch.is_deterministic_algorithms_warn_only_enabled())
8043
8044 for debug_mode, debug_mode_str in [(0, 'default'), (1, 'warn'), (2, 'error')]:
8045 torch.set_deterministic_debug_mode(debug_mode_str)
8046 self.assertEqual(debug_mode, torch.get_deterministic_debug_mode())
8047
8048 with self.assertRaisesRegex(
8049 TypeError,
8050 r"_set_deterministic_algorithms\(\): argument 'mode' \(position 1\) must be bool, not int"):
8051 torch.use_deterministic_algorithms(1)
8052
8053 with self.assertRaisesRegex(
8054 TypeError,
8055 r"_set_deterministic_algorithms\(\): argument 'warn_only' must be bool, not int"):
8056 torch.use_deterministic_algorithms(False, warn_only=1)
8057
8058 def test_type_conversion_via_dtype_name(self):
8059 x = torch.tensor([1])
8060 self.assertEqual(x.byte().dtype, torch.uint8)
8061 self.assertEqual(x.bool().dtype, torch.bool)
8062 self.assertEqual(x.char().dtype, torch.int8)
8063 self.assertEqual(x.double().dtype, torch.float64)
8064 self.assertEqual(x.float().dtype, torch.float32)
8065 self.assertEqual(x.half().dtype, torch.float16)
8066 self.assertEqual(x.int().dtype, torch.int32)
8067 self.assertEqual(x.bfloat16().dtype, torch.bfloat16)
8068 cfloat = x.cfloat()
8069 self.assertEqual(cfloat.dtype, torch.complex64)
8070 self.assertEqual(cfloat.real, x.float())
8071 self.assertEqual(cfloat.imag, torch.zeros_like(cfloat.imag))
8072 cdouble = x.cdouble()
8073 self.assertEqual(cdouble.dtype, torch.complex128)
8074 self.assertEqual(cdouble.real, x.double())
8075 self.assertEqual(cdouble.imag, torch.zeros_like(cdouble.imag))
kshitij12345aa517042022-04-20 23:44:47 +00008076 chalf = x.chalf()
8077 self.assertEqual(chalf.dtype, torch.complex32)
8078 self.assertEqual(chalf.real, x.half())
8079 self.assertEqual(chalf.imag, torch.zeros_like(chalf.imag))
8080
8081 def test_type_alias(self):
8082 type_alias_map = {torch.float64: torch.double,
8083 torch.float32: torch.float,
8084 torch.int32: torch.int,
8085 torch.int64: torch.long,
8086 torch.int16: torch.short,
8087 torch.float16: torch.half,
8088 torch.complex32: torch.chalf,
8089 torch.complex64: torch.cfloat}
8090 for dtype, alias in type_alias_map.items():
8091 self.assertIs(alias, dtype)
Mike Ruberrye0d829a2022-01-24 01:28:07 -08008092
Mike Ruberrye0d829a2022-01-24 01:28:07 -08008093 def test_doc_template(self) -> None:
Huy Doedf18682022-07-21 16:28:29 +00008094 """
8095 Test that all public API doc strings use the same standard template for
8096 all common arguments such as tensor or dim
8097 """
Mike Ruberrye0d829a2022-01-24 01:28:07 -08008098 from torch._torch_docs import __file__ as doc_file
8099 from torch._torch_docs import multi_dim_common, single_dim_common, factory_common_args, factory_like_common_args
8100
8101 with open(doc_file, "r", encoding="utf-8") as f:
8102 doc_strs = f.read()
8103
Huy Doedf18682022-07-21 16:28:29 +00008104 matches = re.findall(
8105 r'add_docstr\(([^,]+?),[^"\']*?(?:"""|\'\'\')(.*?)(?:"""|\'\'\')(?:\.|,?[^,\)]*?\))',
8106 doc_strs,
8107 re.MULTILINE | re.DOTALL,
8108 )
8109 self.assertTrue(matches)
8110
8111 for m in matches:
8112 func = m[0].strip()
8113 desc = m[1].strip()
8114
Mike Ruberrye0d829a2022-01-24 01:28:07 -08008115 for common_args in [multi_dim_common, single_dim_common, factory_common_args, factory_like_common_args]:
8116 for k, v in common_args.items():
Huy Doedf18682022-07-21 16:28:29 +00008117 self.assertNotIn(v, desc, 'The argument description "{}" in {} can be '
8118 'replaced by {{{}}}'.format(v, func, k))
Mike Ruberrye0d829a2022-01-24 01:28:07 -08008119
8120 def test_doc(self):
8121 checked_types = (types.MethodType, types.FunctionType,
8122 types.BuiltinFunctionType, types.BuiltinMethodType)
8123
8124 def _test_namespace(ns, *skips):
8125 if isinstance(ns, object):
8126 ns_name = ns.__class__.__name__
8127 else:
8128 ns_name = ns.__name__
8129 skip_regexes = []
8130 for r in skips:
8131 if isinstance(r, string_classes):
8132 skip_regexes.append(re.compile('^{}$'.format(re.escape(r))))
8133 else:
8134 skip_regexes.append(r)
8135
8136 for name in dir(ns):
8137 if name.startswith('_'):
8138 continue
8139 if name in ['real', 'imag']:
8140 y = torch.randn(1, dtype=torch.cfloat)
8141 var = getattr(y, name)
8142 elif name in ["H", "mT", "mH"]:
8143 y = torch.randn(1, 1)
8144 var = getattr(y, name)
8145 else:
8146 var = getattr(ns, name)
8147 if not isinstance(var, checked_types):
8148 continue
8149 doc = var.__doc__
8150 has_doc = doc is not None and len(doc.strip()) > 0
8151 full_name = ns_name + '.' + name
8152 if any(r.match(name) for r in skip_regexes):
8153 self.assertFalse(has_doc,
8154 'New docs have been added for {}, please remove '
8155 'it from the skipped list in TestTorch.test_doc'.format(full_name))
8156 else:
8157 self.assertTrue(has_doc, '{} is missing documentation'.format(full_name))
8158
8159 # FIXME: All of the following should be marked as expected failures
8160 # so that it is easier to tell when missing has been added.
8161 # FIXME: fix all the skipped ones below!
8162 test_namespace(torch.randn(1),
8163 'as_strided_',
8164 re.compile('^clamp_(min|max)_?$'),
8165 'is_distributed',
8166 'is_nonzero',
8167 'is_same_size',
8168 'log_softmax',
8169 'map2_',
8170 'new',
8171 'reinforce',
8172 'relu',
8173 'relu_',
8174 'prelu',
8175 'resize',
8176 'resize_as',
8177 'softmax',
8178 'split_with_sizes',
8179 'unsafe_split_with_sizes',
8180 '_autocast_to_fp16',
8181 '_autocast_to_fp32',
8182 )
8183
8184 test_namespace(torch.nn)
8185 test_namespace(torch.nn.functional, 'assert_int_or_pair')
8186 # TODO: add torch.* tests when we have proper namespacing on ATen functions
8187 # test_namespace(torch)
8188
8189 # FIXME: deprecate torch.Tensor constructor
Edward Yang17fb6512021-05-26 11:31:47 -07008190 def test_tensor_ctor_scalar(self):
8191 x = torch.Tensor(torch.tensor(1.0))
8192 self.assertEqual(x, torch.tensor(1.0))
8193
mattip345844d2021-01-26 16:17:40 -08008194 def test_deepcopy_gradient(self):
8195 from copy import deepcopy
8196 a = torch.zeros(10)
8197 a.grad = torch.ones(10)
8198 self.assertEqual(a.grad, deepcopy(a).grad)
8199 s = torch.zeros(10).to_sparse()
8200 s.grad = torch.ones(10).to_sparse()
8201 self.assertEqual(s.grad, deepcopy(s).grad)
8202
8203 # ensure sharing is not broken
8204 c = deepcopy([a, a.grad])
8205 self.assertTrue(c[0].grad is c[1])
8206
Edward Yange362ee62021-04-28 09:23:07 -07008207 def test_tensor_base_init(self):
8208 # Direct construction not OK
Edward Yangda8cc352021-05-05 09:03:37 -07008209 self.assertRaises(RuntimeError, lambda: torch._C._TensorBase())
Edward Yange362ee62021-04-28 09:23:07 -07008210
8211 # But construction of subclass is OK
8212 class T(torch._C._TensorBase):
8213 pass
8214
8215 T()
8216
Edward Yangda8cc352021-05-05 09:03:37 -07008217 def test_tensor_base_new(self):
8218
8219 # OK to call super().__new__, see
8220 # https://github.com/pytorch/pytorch/issues/57421
8221 class TestTensor(torch._C._TensorBase):
8222 @staticmethod
8223 def __new__(cls, x, *args, **kwargs):
8224 return super().__new__(cls, x, *args, **kwargs)
8225
8226 x = torch.ones(5)
8227 test_tensor = TestTensor(x)
8228
Edward Yangf05d5be2021-06-03 10:47:19 -07008229 def test_pyobj_preserved(self):
8230 x = torch.empty(2)
8231 x.foo = 2 # put something on __dict__
8232 y = torch.empty(2)
8233 y.grad = x
8234 del x # x is dead in Python
8235 self.assertEqual(y.grad.foo, 2)
8236 z = y.grad # it's live
8237 del z # it's dead again
8238 self.assertEqual(y.grad.foo, 2)
8239
8240 def test_subclass_preserved(self):
Victor Quach8131bc82021-09-13 16:39:55 -07008241 class MyTensor(torch.Tensor):
Edward Yangf05d5be2021-06-03 10:47:19 -07008242 pass
8243
8244 x = MyTensor(torch.empty(2))
8245 y = torch.empty(2)
8246 y.grad = x
8247 del x # x is dead in Python
8248 self.assertEqual(type(y.grad), MyTensor)
8249 z = y.grad # it's live
8250 del z # it's dead again
8251 self.assertEqual(type(y.grad), MyTensor)
8252
8253 def test_tensor_slot_dealloc(self):
Shen Li10224432021-08-12 11:39:31 -07008254
Edward Yangf05d5be2021-06-03 10:47:19 -07008255 class SlotTensor1(torch._C._TensorBase):
Shen Li10224432021-08-12 11:39:31 -07008256 __slots__ = ['slot1']
Edward Yangf05d5be2021-06-03 10:47:19 -07008257
8258 class SlotTensor2(SlotTensor1):
Shen Li10224432021-08-12 11:39:31 -07008259 __slots__ = ['slot2']
Edward Yangf05d5be2021-06-03 10:47:19 -07008260
8261 m1, t1 = Tracker.make()
8262 m2, t2 = Tracker.make()
8263 slot_tensor = SlotTensor2(torch.empty(2))
8264 slot_tensor.slot1 = t1
8265 slot_tensor.slot2 = t2
8266 del t1
8267 del t2
8268 self.assertFalse(m1[0])
8269 self.assertFalse(m2[0])
8270 del slot_tensor
8271 self.assertTrue(m1[0])
8272 self.assertTrue(m2[0])
8273
Animesh Jain1d90d6e2022-07-07 18:57:31 +00008274 @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
Edward Yangf05d5be2021-06-03 10:47:19 -07008275 def test_tensor_dict_dealloc(self):
8276 m, t = Tracker.make()
8277 x = torch.empty(2)
8278 x.arf = t
8279 del t
8280 self.assertFalse(m[0])
8281 del x
8282 self.assertTrue(m[0])
8283
8284 def test_tensor_finalizer_dealloc(self):
8285 m = [False]
8286
8287 class FinalizerTensor(torch._C._TensorBase):
8288 def __del__(self):
8289 m[0] = True
8290
8291 fin_tensor = FinalizerTensor(torch.empty(2))
8292 self.assertFalse(m[0])
8293 del fin_tensor
8294 self.assertTrue(m[0])
8295
Animesh Jain1d90d6e2022-07-07 18:57:31 +00008296 @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
Edward Yangf05d5be2021-06-03 10:47:19 -07008297 def test_tensor_weakref_dealloc(self):
8298
8299 x = torch.empty(2)
8300 m = [False]
8301
8302 def cb(r):
8303 m[0] = True
8304
8305 wref = weakref.ref(x, cb)
8306 del x
8307 self.assertTrue(m[0])
8308 self.assertEqual(wref(), None)
8309
Animesh Jain1d90d6e2022-07-07 18:57:31 +00008310 @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
Edward Yangf05d5be2021-06-03 10:47:19 -07008311 def test_tensor_cycle_via_dict(self):
8312 m1, t1 = Tracker.make()
8313 x = torch.empty(2)
8314 x._tracker = t1
8315 del t1
8316
8317 m2, t2 = Tracker.make()
8318 y = torch.empty(2)
8319 y._tracker = t2
8320 del t2
8321
8322 x._loop = y
8323 y._loop = x
8324
8325 # C++ reference should keep the cycle live!
8326 # This exercise THPVariable_subtype_traverse
8327 # NB: Because z.grad is a reference done entirely in C++, cycles
8328 # involving it directly are NOT broken by Python GC; you've
8329 # set up a good old C++ reference cycle which we cannot safely
8330 # break (because C++ references are allowed to be accessed
8331 # multithreaded-ly) (TODO: except maybe if you can prove that
8332 # only Python has access to the C++ object, in which case you can
8333 # also prove that no multithreaded access occurs)
8334 z = torch.empty(2)
8335 z.grad = x
8336
8337 del x
8338 del y
8339
8340 gc.collect()
8341 self.assertFalse(m1[0])
8342 self.assertFalse(m2[0])
8343
8344 with disable_gc():
8345 del z
8346 self.assertFalse(m1[0])
8347 self.assertFalse(m2[0])
8348
8349 gc.collect()
8350 self.assertTrue(m1[0])
8351 self.assertTrue(m2[0])
8352
8353 def test_tensor_cycle_via_slots(self):
8354 m1 = [False]
8355 m2 = [False]
8356
8357 class SlotTensor1(torch._C._TensorBase):
Shen Li10224432021-08-12 11:39:31 -07008358 __slots__ = ['slot1']
Edward Yangf05d5be2021-06-03 10:47:19 -07008359
8360 def __del__(self):
8361 m1[0] = True
8362
8363 class SlotTensor2(SlotTensor1):
Shen Li10224432021-08-12 11:39:31 -07008364 __slots__ = ['slot2']
Edward Yangf05d5be2021-06-03 10:47:19 -07008365
8366 def __del__(self):
8367 m2[0] = True
8368
8369 x = SlotTensor1(torch.empty(2))
8370 y = SlotTensor2(torch.empty(2))
8371
8372 x.slot1 = y
8373 y.slot2 = x
8374
8375 del x
8376 with disable_gc():
8377 del y
8378 self.assertFalse(m1[0])
8379 self.assertFalse(m2[0])
8380
8381 gc.collect()
8382 self.assertTrue(m1[0])
8383 self.assertTrue(m2[0])
8384
Mike Ruberrye0d829a2022-01-24 01:28:07 -08008385 # FIXME: move to test_autograd?
Animesh Jain1d90d6e2022-07-07 18:57:31 +00008386 @skipIfTorchDynamo("TorchDynamo does not work well with hooks")
Edward Yangf05d5be2021-06-03 10:47:19 -07008387 def test_backward_hooks_traverse(self):
8388 m1, t1 = Tracker.make()
8389 m2, t2 = Tracker.make()
8390 x = torch.empty(2, requires_grad=True)
8391 x._tracker = t1
8392 y = torch.empty(2, requires_grad=True)
8393 y._tracker = t2
8394 del t1
8395 del t2
8396
8397 # this hits a special setter, it's not just a __dict__ entry
8398 x._backward_hooks = y
8399 y._backward_hooks = x
8400
8401 del x
8402 with disable_gc():
8403 del y
8404 self.assertFalse(m1[0])
8405 self.assertFalse(m2[0])
8406
8407 gc.collect()
8408
8409 self.assertTrue(m1[0])
8410 self.assertTrue(m2[0])
8411
Animesh Jain1d90d6e2022-07-07 18:57:31 +00008412 @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
Edward Yangf05d5be2021-06-03 10:47:19 -07008413 def test_dead_weak_ref(self):
8414 x = torch.empty(2)
8415 w_x = weakref.ref(x)
8416 y = torch.empty(2)
8417 y.grad = x
8418 del x
8419
8420 x = w_x()
8421 # Ideally, x would keep the tensor live. But CPython doesn't
8422 # provide enough hooks to do this. So it will go dead and x
8423 # will transmute into an undefined tensor. Not great, but the
8424 # best we can do.
8425 del y
8426
8427 self.assertRaises(RuntimeError, lambda: x.sigmoid())
8428
8429 def test_resurrected_weak_ref(self):
8430 x = torch.empty(2)
8431 w_x = weakref.ref(x)
8432 y = torch.empty(2)
8433 y.grad = x
8434 del x
8435
8436 x = w_x()
8437 # Use this to manually fix weak references after dereferencing them
8438 x._fix_weakref()
8439 del y
8440 x.sigmoid()
8441
Mike Ruberrye0d829a2022-01-24 01:28:07 -08008442 # FIXME: move to test_linalg
Kimish Patel4f792702021-06-10 08:23:10 -07008443 @torch.inference_mode()
8444 def test_bmm_multithreaded(self):
Shen Li10224432021-08-12 11:39:31 -07008445 device = 'cpu'
Kimish Patel4f792702021-06-10 08:23:10 -07008446 num_threads = torch.get_num_threads()
8447
8448 torch.set_num_threads(4)
8449 batch_sizes = [1, 10]
8450 M, N, O = 23, 8, 12
8451 dtype = torch.float32
8452 numpy_dtype = dtype
8453
8454 def invert_perm(p):
8455 d = {x: i for i, x in enumerate(p)}
8456 return (d[0], d[1], d[2])
8457
8458 def generate_inputs(num_batches):
8459 # transposed tensors
Shen Li10224432021-08-12 11:39:31 -07008460 for perm1, perm2 in itertools.product(itertools.permutations((0, 1, 2)), repeat=2):
Philip Meier0973c5a2022-02-24 21:47:38 -08008461 b1 = make_tensor((num_batches, M, N), dtype=dtype, device=device, low=-1, high=1)
8462 b2 = make_tensor((num_batches, N, O), dtype=dtype, device=device, low=-1, high=1)
Kimish Patel4f792702021-06-10 08:23:10 -07008463 b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1))
8464 b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2))
8465 yield b1, b2
8466 # broadcasting tensors
8467 for b1, b2, b3, b4, b5, b6 in itertools.product((True, False), repeat=6):
8468 shape1 = (num_batches if b1 else 1, M if b2 else 1, N if b3 else 1)
8469 shape2 = (num_batches if b4 else 1, N if b5 else 1, O if b6 else 1)
Philip Meier0973c5a2022-02-24 21:47:38 -08008470 b1 = make_tensor(shape1, dtype=dtype, device=device, low=-1, high=1).expand(num_batches, M, N)
8471 b2 = make_tensor(shape2, dtype=dtype, device=device, low=-1, high=1).expand(num_batches, N, O)
Kimish Patel4f792702021-06-10 08:23:10 -07008472 yield b1, b2
8473 # zero-sized tensors
8474 for z1, z2, z3, z4 in itertools.product((True, False), repeat=4):
8475 shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0)
8476 shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0)
8477 b1 = torch.randn(shape1, dtype=dtype, device=device)
8478 b2 = torch.randn(shape2, dtype=dtype, device=device)
8479 yield b1, b2
8480
8481 try:
8482 for num_batches in batch_sizes:
Shen Li10224432021-08-12 11:39:31 -07008483 for (b1, b2), perm3 in itertools.product(generate_inputs(num_batches), itertools.permutations((0, 1, 2))):
Kimish Patel4f792702021-06-10 08:23:10 -07008484 res1 = torch.bmm(b1, b2)
Shen Li10224432021-08-12 11:39:31 -07008485 res2 = torch.full((num_batches, M, O), math.nan, dtype=dtype, device=device) \
8486 .permute(perm3).contiguous().permute(invert_perm(perm3))
Kimish Patel4f792702021-06-10 08:23:10 -07008487 torch.bmm(b1, b2, out=res2)
8488 expect = torch.from_numpy(
Shen Li10224432021-08-12 11:39:31 -07008489 b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()).to(device=device, dtype=dtype)
Kimish Patel4f792702021-06-10 08:23:10 -07008490 self.assertEqual(expect, res1)
8491 self.assertEqual(expect, res2)
8492 finally:
8493 torch.set_num_threads(num_threads)
8494
anjali411a82fcd32021-10-13 13:49:31 -07008495 def test_conj_neg_tolist(self):
8496 x = torch.randn(2, dtype=torch.cfloat)
8497 y1 = x.conj()
8498 y1_expect = x.conj_physical()
8499 y2 = y1.imag
8500 self.assertEqual(y1, y1_expect.tolist())
8501 self.assertEqual(y2, y1_expect.imag.tolist())
Edward Yangf05d5be2021-06-03 10:47:19 -07008502
Mike Ruberrye0d829a2022-01-24 01:28:07 -08008503# The following block extends TestTorch with negative dim wrapping tests
8504# FIXME: replace these with OpInfo sample inputs or systemic OpInfo tests
8505# Functions to test negative dimension wrapping
8506METHOD = 1
8507INPLACE_METHOD = 2
8508FUNCTIONAL = 4
8509DIM_ARG = None
8510
8511def make_neg_dim_test(name, tensor_arg, arg_constr, types, extra_dim=0):
8512 def neg_dim_test(self):
8513 if isinstance(tensor_arg, list):
8514 assert METHOD not in types and INPLACE_METHOD not in types
8515 x = [torch.randn(arg) for arg in tensor_arg]
8516 ndim = len(tensor_arg[-1])
8517 else:
8518 x = torch.randn(*tensor_arg)
8519 ndim = len(tensor_arg)
8520 ndim += extra_dim
8521
8522 n_dim_to_test = sum(e is DIM_ARG for e in arg_constr())
8523
8524 for dims_val in combinations(range(ndim), n_dim_to_test):
8525 arg = arg_constr()
8526 arg_neg = copy.deepcopy(arg)
8527 idx = 0
8528 for i, v in enumerate(arg):
8529 if v is DIM_ARG:
8530 arg[i] = dims_val[idx]
8531 arg_neg[i] = dims_val[idx] - ndim
8532 idx += 1
8533
8534 if METHOD in types:
8535 a = getattr(x, name)(*arg)
8536 b = getattr(x, name)(*arg_neg)
8537 self.assertEqual(a, b)
8538
8539 if INPLACE_METHOD in types:
8540 a = x.clone()
8541 getattr(a, name + '_')(*arg)
8542 b = x.clone()
8543 getattr(b, name + '_')(*arg_neg)
8544 self.assertEqual(a, b)
8545
8546 if FUNCTIONAL in types:
8547 a = getattr(torch, name)(x, *arg)
8548 b = getattr(torch, name)(x, *arg_neg)
8549 self.assertEqual(a, b)
8550
8551 return neg_dim_test
8552
8553def idx_tensor(size, max_val):
8554 return torch.LongTensor(*size).random_(0, max_val - 1)
8555
8556def add_neg_dim_tests():
8557 neg_dim_tests = [
8558 ('narrow', (10, 20, 30), lambda: [DIM_ARG, 0, 5], [METHOD]),
8559 ('transpose', (10, 20, 30), lambda: [DIM_ARG, DIM_ARG], [METHOD, INPLACE_METHOD, FUNCTIONAL]),
8560 ('size', (10, 20, 30), lambda: [DIM_ARG], [METHOD]),
8561 ('cat', [(2, 3, 4), (2, 3, 4)], lambda: [DIM_ARG], [FUNCTIONAL]),
8562 ('chunk', (10, 20, 30), lambda: [5, DIM_ARG], [METHOD, FUNCTIONAL]),
8563 ('gather', (10, 20), lambda: [DIM_ARG, idx_tensor((10, 20), 10)], [METHOD, FUNCTIONAL]),
8564 ('index_select', (10, 10), lambda: [DIM_ARG, idx_tensor((10,), 10)], [METHOD, FUNCTIONAL]),
8565 ('split', (10, 20), lambda: [5, DIM_ARG], [METHOD, FUNCTIONAL]),
8566 ('squeeze', (10, 1, 20, 1), lambda: [DIM_ARG], [METHOD, INPLACE_METHOD, FUNCTIONAL]),
8567 ('unbind', (2, 3, 4), lambda: [DIM_ARG], [FUNCTIONAL]),
8568 ('unsqueeze', (10, 20), lambda: [DIM_ARG], [METHOD, INPLACE_METHOD, FUNCTIONAL], 1),
8569 ('logcumsumexp', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
8570 ('cumprod', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
8571 ('cumsum', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
8572 ('cummax', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
8573 ('cummin', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
8574 ('mean', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
8575 ('median', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
8576 ('nanmedian', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
8577 ('mode', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
8578 ('norm', (10, 20), lambda: [2, DIM_ARG], [METHOD, FUNCTIONAL]),
8579 ('prod', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
8580 ('std', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
8581 ('sum', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
8582 ('var', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
8583 ('kthvalue', (10, 20), lambda: [3, DIM_ARG], [METHOD, FUNCTIONAL]),
8584 ('max', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
8585 ('min', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
8586 ('sort', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
8587 ('topk', (10, 20), lambda: [5, DIM_ARG], [METHOD, FUNCTIONAL]),
8588 ('renorm', (10, 20), lambda: [2, DIM_ARG, 1], [METHOD, INPLACE_METHOD, FUNCTIONAL]),
8589 ('index_add', (10, 10), lambda: [DIM_ARG, idx_tensor((10,), 10), torch.randn(10, 10)], [INPLACE_METHOD]),
8590 ('index_copy', (10, 10), lambda: [DIM_ARG, idx_tensor((10,), 10), torch.randn(10, 10)], [INPLACE_METHOD]),
8591 ('index_fill', (10, 10), lambda: [DIM_ARG, idx_tensor((10,), 10), 12], [INPLACE_METHOD]),
8592 ('scatter', (10, 10), lambda: [DIM_ARG, idx_tensor((10, 10), 10), torch.randn(10, 10)], [INPLACE_METHOD]),
8593 ('select', (10, 20), lambda: [DIM_ARG, 3], [METHOD]),
8594 ('unfold', (10, 20), lambda: [DIM_ARG, 5, 2], [METHOD]),
8595 ]
8596
8597 for decl in neg_dim_tests:
8598 if len(decl) == 4:
8599 name, tensor_arg, arg_constr, types = decl
8600 extra_dim = 0
8601 elif len(decl) == 5:
8602 name, tensor_arg, arg_constr, types, extra_dim = decl
8603
8604 test_name = 'test_' + name + '_neg_dim'
8605
8606 assert not hasattr(TestTorch, test_name), "Duplicated test name: " + test_name
8607 setattr(TestTorch, test_name, make_neg_dim_test(name, tensor_arg, arg_constr, types, extra_dim))
8608
Mike Ruberryde40c8e2021-06-06 14:51:26 -07008609# TODO: these empy classes are temporarily instantiated for XLA compatibility
Mike Ruberry36c87f12020-11-28 20:09:52 -08008610# once XLA updates their test suite it should be removed
8611class TestViewOps(TestCase):
8612 pass
Mike Ruberrya7de5452019-10-04 02:39:26 -07008613
Mike Ruberryde40c8e2021-06-06 14:51:26 -07008614class TestTensorDeviceOps(TestCase):
8615 pass
8616
Mike Ruberrya7de5452019-10-04 02:39:26 -07008617# Generates tests
8618# Note: test generation must be done at file scope, not within main, or
8619# pytest will fail.
8620add_neg_dim_tests()
Mike Ruberryaa3c8712020-02-04 11:08:23 -08008621instantiate_device_type_tests(TestViewOps, globals())
Victor Bittorf8b6487c2021-06-25 16:27:45 -07008622instantiate_device_type_tests(TestVitalSignsCuda, globals())
xiaobingsuper07dbf0d2020-03-31 14:04:00 -07008623instantiate_device_type_tests(TestTensorDeviceOps, globals())
Mike Ruberry36c87f12020-11-28 20:09:52 -08008624instantiate_device_type_tests(TestTorchDeviceType, globals())
Shen Li10224432021-08-12 11:39:31 -07008625instantiate_device_type_tests(TestDevicePrecision, globals(), except_for='cpu')
Mike Ruberrya7de5452019-10-04 02:39:26 -07008626
Shen Li10224432021-08-12 11:39:31 -07008627if __name__ == '__main__':
Adam Paszkea1fa9952017-01-26 04:21:49 +01008628 run_tests()