Mikayla Gawarecki | 814ff74 | 2022-06-06 20:27:49 +0000 | [diff] [blame] | 1 | # Owner(s): ["module: scatter & gather ops"] |
Jane Xu | 6259601 | 2021-10-29 12:40:39 -0700 | [diff] [blame] | 2 | |
Serhat Yilmaz | a78ad5d | 2021-07-07 13:07:12 -0700 | [diff] [blame] | 3 | from itertools import product |
Mikayla Gawarecki | 7360b53 | 2022-06-16 17:32:10 +0000 | [diff] [blame] | 4 | from functools import partial |
Serhat Yilmaz | a78ad5d | 2021-07-07 13:07:12 -0700 | [diff] [blame] | 5 | |
Serhat Yilmaz | a727f65 | 2021-06-17 16:23:15 -0700 | [diff] [blame] | 6 | import numpy as np |
Serhat Yilmaz | 7e3cf1e | 2021-03-23 15:56:00 -0700 | [diff] [blame] | 7 | import torch |
| 8 | from torch.testing._internal.common_device_type import ( |
| 9 | instantiate_device_type_tests, |
Serhat Yilmaz | 7e3cf1e | 2021-03-23 15:56:00 -0700 | [diff] [blame] | 10 | dtypes, |
| 11 | ) |
| 12 | from torch.testing._internal.common_utils import ( |
| 13 | TestCase, |
| 14 | run_tests, |
Serhat Yilmaz | e27740b | 2021-04-29 15:40:36 -0700 | [diff] [blame] | 15 | gradcheck, |
Mikayla Gawarecki | e727539 | 2022-06-10 22:33:06 +0000 | [diff] [blame] | 16 | parametrize, |
Mikayla Gawarecki | e289a18 | 2022-06-09 04:03:06 +0000 | [diff] [blame] | 17 | |
Serhat Yilmaz | 7e3cf1e | 2021-03-23 15:56:00 -0700 | [diff] [blame] | 18 | ) |
| 19 | |
| 20 | |
Mikayla Gawarecki | 814ff74 | 2022-06-06 20:27:49 +0000 | [diff] [blame] | 21 | reductions = ["max", "mean", "min", "sum", "prod"] |
Serhat Yilmaz | af66824 | 2021-06-23 18:50:36 -0700 | [diff] [blame] | 22 | |
| 23 | |
Serhat Yilmaz | f84a441 | 2021-07-07 13:07:12 -0700 | [diff] [blame] | 24 | def get_default_value(initial_value, reduction): |
| 25 | if initial_value is not None: |
| 26 | return initial_value |
| 27 | if reduction == "max": |
| 28 | return -float("Inf") |
| 29 | elif reduction == "mean": |
| 30 | return float("nan") |
| 31 | elif reduction == "min": |
| 32 | return float("Inf") |
| 33 | elif reduction == "sum": |
| 34 | return 0.0 |
Mikayla Gawarecki | 814ff74 | 2022-06-06 20:27:49 +0000 | [diff] [blame] | 35 | elif reduction == "prod": |
| 36 | return 1.0 |
Serhat Yilmaz | f84a441 | 2021-07-07 13:07:12 -0700 | [diff] [blame] | 37 | |
| 38 | |
Serhat Yilmaz | 7e3cf1e | 2021-03-23 15:56:00 -0700 | [diff] [blame] | 39 | class TestSegmentReductions(TestCase): |
Serhat Yilmaz | a727f65 | 2021-06-17 16:23:15 -0700 | [diff] [blame] | 40 | def _test_common( |
| 41 | self, |
| 42 | reduction, |
| 43 | device, |
| 44 | dtype, |
| 45 | unsafe, |
| 46 | axis, |
| 47 | initial_value, |
| 48 | data_arr, |
| 49 | lengths_arr, |
| 50 | expected_arr, |
| 51 | expected_grad_arr, |
| 52 | check_backward, |
Serhat Yilmaz | a78ad5d | 2021-07-07 13:07:12 -0700 | [diff] [blame] | 53 | lengths_dtype=torch.int, |
Serhat Yilmaz | a727f65 | 2021-06-17 16:23:15 -0700 | [diff] [blame] | 54 | ): |
Serhat Yilmaz | a78ad5d | 2021-07-07 13:07:12 -0700 | [diff] [blame] | 55 | lengths = torch.tensor(lengths_arr, device=device, dtype=lengths_dtype) |
Mikayla Gawarecki | 7360b53 | 2022-06-16 17:32:10 +0000 | [diff] [blame] | 56 | # generate offsets from lengths |
| 57 | zeros_shape = list(lengths.shape) |
| 58 | zeros_shape[-1] = 1 |
| 59 | offsets = torch.cat((lengths.new_zeros(zeros_shape), lengths), -1).cumsum_(-1) |
| 60 | |
Serhat Yilmaz | e27740b | 2021-04-29 15:40:36 -0700 | [diff] [blame] | 61 | data = torch.tensor( |
Serhat Yilmaz | a727f65 | 2021-06-17 16:23:15 -0700 | [diff] [blame] | 62 | data_arr, |
Serhat Yilmaz | e27740b | 2021-04-29 15:40:36 -0700 | [diff] [blame] | 63 | device=device, |
| 64 | dtype=dtype, |
| 65 | requires_grad=True, |
Serhat Yilmaz | 7e3cf1e | 2021-03-23 15:56:00 -0700 | [diff] [blame] | 66 | ) |
Serhat Yilmaz | a727f65 | 2021-06-17 16:23:15 -0700 | [diff] [blame] | 67 | expected_result = torch.tensor(expected_arr, device=device, dtype=dtype) |
| 68 | expected_grad = torch.tensor(expected_grad_arr, device=device, dtype=dtype) |
Mikayla Gawarecki | 7360b53 | 2022-06-16 17:32:10 +0000 | [diff] [blame] | 69 | for mode in ['lengths', 'offsets']: |
| 70 | segment_reduce_kwargs = dict( |
| 71 | axis=axis, |
| 72 | unsafe=unsafe, |
| 73 | initial=initial_value) |
| 74 | if (mode == 'lengths'): |
| 75 | segment_reduce_kwargs['lengths'] = lengths |
| 76 | else: |
| 77 | segment_reduce_kwargs['offsets'] = offsets |
| 78 | actual_result = torch.segment_reduce( |
| 79 | data=data, |
| 80 | reduce=reduction, |
| 81 | **segment_reduce_kwargs |
Mikayla Gawarecki | 1ec30a6 | 2022-06-10 22:33:07 +0000 | [diff] [blame] | 82 | ) |
Mikayla Gawarecki | 7360b53 | 2022-06-16 17:32:10 +0000 | [diff] [blame] | 83 | self.assertEqual( |
| 84 | expected_result, actual_result, rtol=1e-02, atol=1e-05, equal_nan=True |
| 85 | ) |
| 86 | |
| 87 | if not check_backward: |
| 88 | return |
| 89 | |
| 90 | # Test backward |
| 91 | actual_result.sum().backward() |
| 92 | self.assertEqual( |
| 93 | expected_grad, data.grad, rtol=1e-02, atol=1e-05, equal_nan=True |
| 94 | ) |
| 95 | data = data.clone().detach().requires_grad_(True) |
| 96 | |
| 97 | # gradcheck does not work well with bfloat16 or fp16 cpu types |
| 98 | # also there is small numerical difference with fp32 |
| 99 | if dtype not in [torch.half, torch.bfloat16, torch.float]: |
| 100 | # gradcheck does not like "nan" input, setting to random 10 |
| 101 | d_non_nan = np.nan_to_num(data_arr, nan=10) |
| 102 | new_data = torch.tensor( |
| 103 | # [10 if v == float("nan") else v for v in data], |
| 104 | d_non_nan, |
| 105 | device=device, |
| 106 | dtype=dtype, |
| 107 | requires_grad=True, |
Serhat Yilmaz | e27740b | 2021-04-29 15:40:36 -0700 | [diff] [blame] | 108 | ) |
Mikayla Gawarecki | 7360b53 | 2022-06-16 17:32:10 +0000 | [diff] [blame] | 109 | self.assertTrue( |
| 110 | gradcheck( |
| 111 | lambda x: torch.segment_reduce( |
| 112 | data=x, |
| 113 | reduce=reduction, |
| 114 | **segment_reduce_kwargs |
| 115 | ), |
| 116 | (new_data,), |
| 117 | ) |
| 118 | ) |
Serhat Yilmaz | e27740b | 2021-04-29 15:40:36 -0700 | [diff] [blame] | 119 | |
Serhat Yilmaz | a78ad5d | 2021-07-07 13:07:12 -0700 | [diff] [blame] | 120 | @dtypes( |
| 121 | *product( |
| 122 | (torch.half, torch.bfloat16, torch.float, torch.double), |
| 123 | (torch.int, torch.int64), |
| 124 | ) |
| 125 | ) |
| 126 | def test_simple_1d(self, device, dtypes): |
| 127 | val_dtype, length_type = dtypes |
Serhat Yilmaz | a727f65 | 2021-06-17 16:23:15 -0700 | [diff] [blame] | 128 | lengths = [1, 2, 3, 0] |
| 129 | data = [1, float("nan"), 3, 4, 5, 5] |
Serhat Yilmaz | a727f65 | 2021-06-17 16:23:15 -0700 | [diff] [blame] | 130 | |
Serhat Yilmaz | af66824 | 2021-06-23 18:50:36 -0700 | [diff] [blame] | 131 | for reduction in reductions: |
Serhat Yilmaz | f84a441 | 2021-07-07 13:07:12 -0700 | [diff] [blame] | 132 | for initial in [0, None]: |
| 133 | check_backward = True if initial is not None else False |
| 134 | initial_value = initial |
| 135 | default_value = get_default_value(initial_value, reduction) |
| 136 | if reduction == "max": |
| 137 | expected_result = [1, float("nan"), 5, default_value] |
| 138 | expected_grad = [1, 1, 0, 0, 0.5, 0.5] |
| 139 | elif reduction == "mean": |
| 140 | expected_result = [1, float("nan"), 4.666, default_value] |
| 141 | expected_grad = [1.0, 0.5, 0.5, 0.333, 0.333, 0.333] |
| 142 | elif reduction == "min": |
| 143 | if initial is not None: |
| 144 | initial_value = 1000 # some high number |
| 145 | default_value = get_default_value(initial_value, reduction) |
| 146 | expected_result = [1, float("nan"), 4, default_value] |
| 147 | expected_grad = [1.0, 1.0, 0, 1, 0, 0] |
| 148 | elif reduction == "sum": |
| 149 | expected_result = [1, float("nan"), 14, default_value] |
| 150 | expected_grad = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] |
Mikayla Gawarecki | 814ff74 | 2022-06-06 20:27:49 +0000 | [diff] [blame] | 151 | elif reduction == "prod": |
| 152 | if initial is not None: |
| 153 | initial_value = 2 # 0 initial_value will zero out everything for prod |
| 154 | default_value = get_default_value(initial_value, reduction) |
| 155 | expected_result = [2, float("nan"), 200, default_value] |
| 156 | expected_grad = [2.0, 6.0, float("nan"), 50.0, 40.0, 40.0] |
| 157 | else: |
| 158 | expected_result = [1, float("nan"), 100, default_value] |
| 159 | expected_grad = [1.0, 3.0, float("nan"), 25.0, 20.0, 20.0] |
Serhat Yilmaz | f84a441 | 2021-07-07 13:07:12 -0700 | [diff] [blame] | 160 | for axis in [0, -1]: |
| 161 | for unsafe in [True, False]: |
Serhat Yilmaz | af66824 | 2021-06-23 18:50:36 -0700 | [diff] [blame] | 162 | self._test_common( |
| 163 | reduction, |
| 164 | device, |
Serhat Yilmaz | a78ad5d | 2021-07-07 13:07:12 -0700 | [diff] [blame] | 165 | val_dtype, |
Serhat Yilmaz | af66824 | 2021-06-23 18:50:36 -0700 | [diff] [blame] | 166 | unsafe, |
| 167 | axis, |
| 168 | initial_value, |
| 169 | data, |
| 170 | lengths, |
| 171 | expected_result, |
| 172 | expected_grad, |
| 173 | check_backward, |
Serhat Yilmaz | a78ad5d | 2021-07-07 13:07:12 -0700 | [diff] [blame] | 174 | length_type, |
Serhat Yilmaz | af66824 | 2021-06-23 18:50:36 -0700 | [diff] [blame] | 175 | ) |
Serhat Yilmaz | a727f65 | 2021-06-17 16:23:15 -0700 | [diff] [blame] | 176 | |
Serhat Yilmaz | a78ad5d | 2021-07-07 13:07:12 -0700 | [diff] [blame] | 177 | @dtypes( |
| 178 | *product( |
| 179 | (torch.half, torch.bfloat16, torch.float, torch.double), |
| 180 | (torch.int, torch.int64), |
| 181 | ) |
| 182 | ) |
| 183 | def test_multi_d_simple(self, device, dtypes): |
| 184 | val_dtype, length_type = dtypes |
Serhat Yilmaz | a727f65 | 2021-06-17 16:23:15 -0700 | [diff] [blame] | 185 | axis = 0 |
| 186 | lengths = [1, 2, 3, 0] |
| 187 | data = [[1, 1], [float("nan"), 1], [3, float("nan")], [4, 1], [3, 2], [2, 3]] |
| 188 | |
Serhat Yilmaz | af66824 | 2021-06-23 18:50:36 -0700 | [diff] [blame] | 189 | for reduction in reductions: |
Serhat Yilmaz | f84a441 | 2021-07-07 13:07:12 -0700 | [diff] [blame] | 190 | for initial in [0, None]: |
| 191 | check_backward = True if initial is not None else False |
| 192 | initial_value = initial |
| 193 | default_value = get_default_value(initial_value, reduction) |
| 194 | if reduction == "max": |
| 195 | expected_result = [ |
| 196 | [1, 1], |
| 197 | [float("nan"), float("nan")], |
| 198 | [4, 3], |
| 199 | [default_value, default_value], |
| 200 | ] |
| 201 | expected_grad = [ |
| 202 | [1, 1], |
| 203 | [1, 0], |
| 204 | [0, 1], |
| 205 | [1, 0], |
| 206 | [0, 0], |
| 207 | [0, 1], |
| 208 | ] |
| 209 | elif reduction == "mean": |
| 210 | expected_result = [ |
| 211 | [1, 1], |
| 212 | [float("nan"), float("nan")], |
| 213 | [3, 2], |
| 214 | [default_value, default_value], |
| 215 | ] |
| 216 | expected_grad = [ |
| 217 | [1.0, 1.0], |
| 218 | [0.5, 0.5], |
| 219 | [0.5, 0.5], |
| 220 | [0.333, 0.333], |
| 221 | [0.333, 0.333], |
| 222 | [0.333, 0.333], |
| 223 | ] |
| 224 | elif reduction == "min": |
| 225 | if initial is not None: |
| 226 | initial_value = 1000 # some high number |
| 227 | default_value = get_default_value(initial_value, reduction) |
| 228 | expected_result = [ |
| 229 | [1, 1], |
| 230 | [float("nan"), float("nan")], |
| 231 | [2, 1], |
| 232 | [default_value, default_value], |
| 233 | ] |
| 234 | expected_grad = [ |
| 235 | [1.0, 1.0], |
| 236 | [1, 0], |
| 237 | [0, 1], |
| 238 | [0, 1], |
| 239 | [0, 0], |
| 240 | [1, 0], |
| 241 | ] |
| 242 | elif reduction == "sum": |
| 243 | expected_result = [ |
| 244 | [1, 1], |
| 245 | [float("nan"), float("nan")], |
| 246 | [9, 6], |
| 247 | [default_value, default_value], |
| 248 | ] |
| 249 | expected_grad = [ |
| 250 | [1.0, 1.0], |
| 251 | [1.0, 1.0], |
| 252 | [1.0, 1.0], |
| 253 | [1.0, 1.0], |
| 254 | [1.0, 1.0], |
| 255 | [1.0, 1.0], |
| 256 | ] |
Mikayla Gawarecki | 814ff74 | 2022-06-06 20:27:49 +0000 | [diff] [blame] | 257 | elif reduction == "prod": |
| 258 | if initial is not None: |
| 259 | initial_value = 2 # 0 initial_value will zero out everything for prod |
| 260 | default_value = get_default_value(initial_value, reduction) |
| 261 | expected_result = [ |
| 262 | [2, 2], |
| 263 | [float("nan"), float("nan")], |
| 264 | [48, 12], |
| 265 | [default_value, default_value], |
| 266 | ] |
| 267 | expected_grad = [ |
| 268 | [2.0, 2.0], |
| 269 | [6.0, float("nan")], |
| 270 | [float("nan"), 2.0], |
| 271 | [12.0, 12.0], |
| 272 | [16.0, 6.0], |
| 273 | [24.0, 4.0], |
| 274 | ] |
| 275 | else: |
| 276 | expected_result = [ |
| 277 | [1, 1], |
| 278 | [float("nan"), float("nan")], |
| 279 | [24, 6], |
| 280 | [default_value, default_value], |
| 281 | ] |
| 282 | expected_grad = [ |
| 283 | [1.0, 1.0], |
| 284 | [3.0, float("nan")], |
| 285 | [float("nan"), 1.0], |
| 286 | [6.0, 6.0], |
| 287 | [8.0, 3.0], |
| 288 | [12.0, 2.0], |
| 289 | ] |
Serhat Yilmaz | f84a441 | 2021-07-07 13:07:12 -0700 | [diff] [blame] | 290 | for unsafe in [True, False]: |
Serhat Yilmaz | af66824 | 2021-06-23 18:50:36 -0700 | [diff] [blame] | 291 | self._test_common( |
| 292 | reduction, |
| 293 | device, |
Serhat Yilmaz | a78ad5d | 2021-07-07 13:07:12 -0700 | [diff] [blame] | 294 | val_dtype, |
Serhat Yilmaz | af66824 | 2021-06-23 18:50:36 -0700 | [diff] [blame] | 295 | unsafe, |
| 296 | axis, |
| 297 | initial_value, |
| 298 | data, |
| 299 | lengths, |
| 300 | expected_result, |
| 301 | expected_grad, |
| 302 | check_backward, |
| 303 | ) |
Serhat Yilmaz | a727f65 | 2021-06-17 16:23:15 -0700 | [diff] [blame] | 304 | |
Serhat Yilmaz | a78ad5d | 2021-07-07 13:07:12 -0700 | [diff] [blame] | 305 | @dtypes( |
| 306 | *product( |
| 307 | (torch.half, torch.bfloat16, torch.float, torch.double), |
| 308 | (torch.int, torch.int64), |
| 309 | ) |
| 310 | ) |
Mikayla Gawarecki | e289a18 | 2022-06-09 04:03:06 +0000 | [diff] [blame] | 311 | @parametrize("reduce", ['sum', 'prod', 'min', 'max', 'mean']) |
Mikayla Gawarecki | e289a18 | 2022-06-09 04:03:06 +0000 | [diff] [blame] | 312 | def test_pytorch_scatter_test_cases(self, device, dtypes, reduce): |
| 313 | val_dtype, length_dtype = dtypes |
| 314 | # zero-length segments are filled with reduction inits contrary to pytorch_scatter. |
| 315 | tests = [ |
| 316 | { |
| 317 | 'src': [1, 2, 3, 4, 5, 6], |
| 318 | 'index': [0, 0, 1, 1, 1, 3], |
| 319 | 'indptr': [0, 2, 5, 5, 6], |
| 320 | 'sum': [3, 12, 0, 6], |
| 321 | 'prod': [2, 60, 1, 6], |
| 322 | 'mean': [1.5, 4, float('nan'), 6], |
| 323 | 'min': [1, 3, float('inf'), 6], |
| 324 | 'max': [2, 5, -float('inf'), 6], |
| 325 | }, |
| 326 | { |
| 327 | 'src': [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], |
| 328 | 'index': [0, 0, 1, 1, 1, 3], |
| 329 | 'indptr': [0, 2, 5, 5, 6], |
| 330 | 'sum': [[4, 6], [21, 24], [0, 0], [11, 12]], |
| 331 | 'prod': [[3, 8], [315, 480], [1, 1], [11, 12]], |
| 332 | 'mean': [[2, 3], [7, 8], [float('nan'), float('nan')], [11, 12]], |
| 333 | 'min': [[1, 2], [5, 6], [float('inf'), float('inf')], [11, 12]], |
| 334 | 'max': [[3, 4], [9, 10], [-float('inf'), -float('inf')], [11, 12]], |
| 335 | }, |
| 336 | { |
| 337 | 'src': [[1, 3, 5, 7, 9, 11], [2, 4, 6, 8, 10, 12]], |
| 338 | 'index': [[0, 0, 1, 1, 1, 3], [0, 0, 0, 1, 1, 2]], |
| 339 | 'indptr': [[0, 2, 5, 5, 6], [0, 3, 5, 6, 6]], |
| 340 | 'sum': [[4, 21, 0, 11], [12, 18, 12, 0]], |
| 341 | 'prod': [[3, 315, 1, 11], [48, 80, 12, 1]], |
| 342 | 'mean': [[2, 7, float('nan'), 11], [4, 9, 12, float('nan')]], |
| 343 | 'min': [[1, 5, float('inf'), 11], [2, 8, 12, float('inf')]], |
| 344 | 'max': [[3, 9, -float('inf'), 11], [6, 10, 12, -float('inf')]], |
| 345 | }, |
| 346 | { |
| 347 | 'src': [[[1, 2], [3, 4], [5, 6]], [[7, 9], [10, 11], [12, 13]]], |
| 348 | 'index': [[0, 0, 1], [0, 2, 2]], |
| 349 | 'indptr': [[0, 2, 3, 3], [0, 1, 1, 3]], |
| 350 | 'sum': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]], |
| 351 | 'prod': [[[3, 8], [5, 6], [1, 1]], [[7, 9], [1, 1], [120, 143]]], |
| 352 | 'mean': [[[2, 3], [5, 6], [float('nan'), float('nan')]], |
| 353 | [[7, 9], [float('nan'), float('nan')], [11, 12]]], |
| 354 | 'min': [[[1, 2], [5, 6], [float('inf'), float('inf')]], |
| 355 | [[7, 9], [float('inf'), float('inf')], [10, 11]]], |
| 356 | 'max': [[[3, 4], [5, 6], [-float('inf'), -float('inf')]], |
| 357 | [[7, 9], [-float('inf'), -float('inf')], [12, 13]]], |
| 358 | }, |
| 359 | { |
| 360 | 'src': [[1, 3], [2, 4]], |
| 361 | 'index': [[0, 0], [0, 0]], |
| 362 | 'indptr': [[0, 2], [0, 2]], |
| 363 | 'sum': [[4], [6]], |
| 364 | 'prod': [[3], [8]], |
| 365 | 'mean': [[2], [3]], |
| 366 | 'min': [[1], [2]], |
| 367 | 'max': [[3], [4]], |
| 368 | }, |
| 369 | { |
| 370 | 'src': [[[1, 1], [3, 3]], [[2, 2], [4, 4]]], |
| 371 | 'index': [[0, 0], [0, 0]], |
| 372 | 'indptr': [[0, 2], [0, 2]], |
| 373 | 'sum': [[[4, 4]], [[6, 6]]], |
| 374 | 'prod': [[[3, 3]], [[8, 8]]], |
| 375 | 'mean': [[[2, 2]], [[3, 3]]], |
| 376 | 'min': [[[1, 1]], [[2, 2]]], |
| 377 | 'max': [[[3, 3]], [[4, 4]]], |
| 378 | }, |
| 379 | ] |
| 380 | for test in tests: |
| 381 | data = torch.tensor(test['src'], dtype=val_dtype, device=device, requires_grad=True) |
| 382 | indptr = torch.tensor(test['indptr'], dtype=length_dtype, device=device) |
| 383 | dim = indptr.ndim - 1 |
| 384 | # calculate lengths from indptr |
| 385 | lengths = torch.diff(indptr, dim=dim) |
| 386 | expected = torch.tensor(test[reduce], dtype=val_dtype, device=device) |
| 387 | |
| 388 | actual_result = torch.segment_reduce( |
| 389 | data=data, |
| 390 | reduce=reduce, |
| 391 | lengths=lengths, |
| 392 | axis=dim, |
| 393 | unsafe=True, |
| 394 | ) |
Mikayla Gawarecki | e289a18 | 2022-06-09 04:03:06 +0000 | [diff] [blame] | 395 | self.assertEqual(actual_result, expected) |
| 396 | |
Mikayla Gawarecki | 7360b53 | 2022-06-16 17:32:10 +0000 | [diff] [blame] | 397 | # test offsets |
| 398 | actual_result = torch.segment_reduce( |
| 399 | data=data, |
| 400 | reduce=reduce, |
| 401 | offsets=indptr, |
| 402 | axis=dim, |
| 403 | unsafe=True, |
| 404 | ) |
| 405 | self.assertEqual(actual_result, expected) |
| 406 | |
Mikayla Gawarecki | e289a18 | 2022-06-09 04:03:06 +0000 | [diff] [blame] | 407 | if val_dtype == torch.float64: |
Mikayla Gawarecki | 7360b53 | 2022-06-16 17:32:10 +0000 | [diff] [blame] | 408 | def fn(x, mode='lengths'): |
Mikayla Gawarecki | e289a18 | 2022-06-09 04:03:06 +0000 | [diff] [blame] | 409 | initial = 1 |
| 410 | # supply initial values to prevent gradcheck from failing for 0 length segments |
| 411 | # where nan/inf are reduction identities that produce nans when calculating the numerical jacobian |
| 412 | if reduce == 'min': |
| 413 | initial = 1000 |
| 414 | elif reduce == 'max': |
| 415 | initial = -1000 |
Mikayla Gawarecki | 7360b53 | 2022-06-16 17:32:10 +0000 | [diff] [blame] | 416 | segment_reduce_args = {x, reduce} |
| 417 | segment_reduce_kwargs = dict(axis=dim, unsafe=True, initial=initial) |
| 418 | if mode == 'lengths': |
| 419 | segment_reduce_kwargs[mode] = lengths |
| 420 | elif mode == 'offsets': |
| 421 | segment_reduce_kwargs[mode] = indptr |
| 422 | return torch.segment_reduce(*segment_reduce_args, **segment_reduce_kwargs) |
| 423 | self.assertTrue(gradcheck(partial(fn, mode='lengths'), (data.clone().detach().requires_grad_(True)))) |
| 424 | self.assertTrue(gradcheck(partial(fn, mode='offsets'), (data.clone().detach().requires_grad_(True)))) |
| 425 | |
Mikayla Gawarecki | e289a18 | 2022-06-09 04:03:06 +0000 | [diff] [blame] | 426 | |
| 427 | @dtypes( |
| 428 | *product( |
| 429 | (torch.half, torch.bfloat16, torch.float, torch.double), |
| 430 | (torch.int, torch.int64), |
| 431 | ) |
| 432 | ) |
Serhat Yilmaz | a78ad5d | 2021-07-07 13:07:12 -0700 | [diff] [blame] | 433 | def test_multi_d(self, device, dtypes): |
| 434 | val_dtype, length_type = dtypes |
Serhat Yilmaz | a727f65 | 2021-06-17 16:23:15 -0700 | [diff] [blame] | 435 | axis = 0 |
Mikayla Gawarecki | 814ff74 | 2022-06-06 20:27:49 +0000 | [diff] [blame] | 436 | lengths = [0, 2, 3, 0] |
| 437 | data = np.arange(50).reshape(5, 2, 5).tolist() |
Serhat Yilmaz | a727f65 | 2021-06-17 16:23:15 -0700 | [diff] [blame] | 438 | expected_grad = [] |
| 439 | |
| 440 | # TODO: calculate grad and check correctness |
| 441 | check_backward = False |
| 442 | |
Serhat Yilmaz | af66824 | 2021-06-23 18:50:36 -0700 | [diff] [blame] | 443 | for reduction in reductions: |
Serhat Yilmaz | f84a441 | 2021-07-07 13:07:12 -0700 | [diff] [blame] | 444 | initial_value = 0 |
Serhat Yilmaz | a727f65 | 2021-06-17 16:23:15 -0700 | [diff] [blame] | 445 | if reduction == "max": |
| 446 | expected_result = [ |
| 447 | np.full((2, 5), initial_value).tolist(), |
Mikayla Gawarecki | 814ff74 | 2022-06-06 20:27:49 +0000 | [diff] [blame] | 448 | np.max(data[:2], axis=0).tolist(), |
| 449 | np.max(data[2:], axis=0).tolist(), |
| 450 | np.full((2, 5), initial_value).tolist(), |
Serhat Yilmaz | a727f65 | 2021-06-17 16:23:15 -0700 | [diff] [blame] | 451 | ] |
| 452 | elif reduction == "mean": |
| 453 | expected_result = [ |
| 454 | np.full((2, 5), initial_value).tolist(), |
Mikayla Gawarecki | 814ff74 | 2022-06-06 20:27:49 +0000 | [diff] [blame] | 455 | np.mean(data[:2], axis=0).tolist(), |
| 456 | np.mean(data[2:], axis=0).tolist(), |
| 457 | np.full((2, 5), initial_value).tolist(), |
Serhat Yilmaz | a727f65 | 2021-06-17 16:23:15 -0700 | [diff] [blame] | 458 | ] |
Serhat Yilmaz | af66824 | 2021-06-23 18:50:36 -0700 | [diff] [blame] | 459 | elif reduction == "min": |
| 460 | initial_value = 1000 # some high number |
| 461 | expected_result = [ |
| 462 | np.full((2, 5), initial_value).tolist(), |
Mikayla Gawarecki | 814ff74 | 2022-06-06 20:27:49 +0000 | [diff] [blame] | 463 | np.min(data[:2], axis=0).tolist(), |
| 464 | np.min(data[2:], axis=0).tolist(), |
| 465 | np.full((2, 5), initial_value).tolist(), |
Serhat Yilmaz | af66824 | 2021-06-23 18:50:36 -0700 | [diff] [blame] | 466 | ] |
| 467 | elif reduction == "sum": |
Serhat Yilmaz | af66824 | 2021-06-23 18:50:36 -0700 | [diff] [blame] | 468 | expected_result = [ |
| 469 | np.full((2, 5), initial_value).tolist(), |
Mikayla Gawarecki | 814ff74 | 2022-06-06 20:27:49 +0000 | [diff] [blame] | 470 | np.sum(data[:2], axis=0).tolist(), |
| 471 | np.sum(data[2:], axis=0).tolist(), |
| 472 | np.full((2, 5), initial_value).tolist(), |
| 473 | ] |
| 474 | elif reduction == "prod": |
| 475 | initial_value = 1 |
| 476 | expected_result = [ |
| 477 | np.full((2, 5), initial_value).tolist(), |
| 478 | np.prod(data[:2], axis=0).tolist(), |
| 479 | np.prod(data[2:], axis=0).tolist(), |
| 480 | np.full((2, 5), initial_value).tolist(), |
Serhat Yilmaz | af66824 | 2021-06-23 18:50:36 -0700 | [diff] [blame] | 481 | ] |
Serhat Yilmaz | a727f65 | 2021-06-17 16:23:15 -0700 | [diff] [blame] | 482 | for unsafe in [True, False]: |
Serhat Yilmaz | f84a441 | 2021-07-07 13:07:12 -0700 | [diff] [blame] | 483 | self._test_common( |
| 484 | reduction, |
| 485 | device, |
| 486 | val_dtype, |
| 487 | unsafe, |
| 488 | axis, |
| 489 | initial_value, |
| 490 | data, |
| 491 | lengths, |
| 492 | expected_result, |
| 493 | expected_grad, |
| 494 | check_backward, |
| 495 | ) |
Serhat Yilmaz | 6c37788 | 2021-04-27 11:27:36 -0700 | [diff] [blame] | 496 | |
Mikayla Gawarecki | e289a18 | 2022-06-09 04:03:06 +0000 | [diff] [blame] | 497 | @dtypes(torch.int, torch.int64) |
| 498 | def test_unsafe_flag(self, device, dtype): |
| 499 | length_type = dtype |
Mikayla Gawarecki | e727539 | 2022-06-10 22:33:06 +0000 | [diff] [blame] | 500 | lengths = torch.tensor([0, 2, 3, 0], device=device, dtype=length_type) |
| 501 | data = torch.arange(6, dtype=torch.float, device=device) |
Mikayla Gawarecki | e289a18 | 2022-06-09 04:03:06 +0000 | [diff] [blame] | 502 | |
| 503 | # test for error on 1-D lenghts |
| 504 | with self.assertRaisesRegex(RuntimeError, "Expected all rows of lengths along axis"): |
| 505 | torch.segment_reduce(data, 'sum', lengths=lengths, axis=0, unsafe=False) |
| 506 | |
| 507 | # test for error on multi-D lengths |
Mikayla Gawarecki | e727539 | 2022-06-10 22:33:06 +0000 | [diff] [blame] | 508 | nd_lengths = torch.tensor([[0, 3, 3, 0], [2, 3, 0, 0]], dtype=length_type, device=device) |
| 509 | nd_data = torch.arange(12, dtype=torch.float, device=device).reshape(2, 6) |
Mikayla Gawarecki | e289a18 | 2022-06-09 04:03:06 +0000 | [diff] [blame] | 510 | with self.assertRaisesRegex(RuntimeError, "Expected all rows of lengths along axis"): |
| 511 | torch.segment_reduce(nd_data, 'sum', lengths=nd_lengths, axis=1, unsafe=False) |
| 512 | |
| 513 | |
| 514 | |
Serhat Yilmaz | 7e3cf1e | 2021-03-23 15:56:00 -0700 | [diff] [blame] | 515 | |
| 516 | instantiate_device_type_tests(TestSegmentReductions, globals()) |
| 517 | |
| 518 | if __name__ == "__main__": |
| 519 | run_tests() |