blob: b91a56eeb1503c77e3f1d6e108fa23b11cb31841 [file] [log] [blame]
Mikayla Gawarecki814ff742022-06-06 20:27:49 +00001# Owner(s): ["module: scatter & gather ops"]
Jane Xu62596012021-10-29 12:40:39 -07002
Serhat Yilmaza78ad5d2021-07-07 13:07:12 -07003from itertools import product
Mikayla Gawarecki7360b532022-06-16 17:32:10 +00004from functools import partial
Serhat Yilmaza78ad5d2021-07-07 13:07:12 -07005
Serhat Yilmaza727f652021-06-17 16:23:15 -07006import numpy as np
Serhat Yilmaz7e3cf1e2021-03-23 15:56:00 -07007import torch
8from torch.testing._internal.common_device_type import (
9 instantiate_device_type_tests,
Serhat Yilmaz7e3cf1e2021-03-23 15:56:00 -070010 dtypes,
11)
12from torch.testing._internal.common_utils import (
13 TestCase,
14 run_tests,
Serhat Yilmaze27740b2021-04-29 15:40:36 -070015 gradcheck,
Mikayla Gawareckie7275392022-06-10 22:33:06 +000016 parametrize,
Mikayla Gawareckie289a182022-06-09 04:03:06 +000017
Serhat Yilmaz7e3cf1e2021-03-23 15:56:00 -070018)
19
20
Mikayla Gawarecki814ff742022-06-06 20:27:49 +000021reductions = ["max", "mean", "min", "sum", "prod"]
Serhat Yilmazaf668242021-06-23 18:50:36 -070022
23
Serhat Yilmazf84a4412021-07-07 13:07:12 -070024def get_default_value(initial_value, reduction):
25 if initial_value is not None:
26 return initial_value
27 if reduction == "max":
28 return -float("Inf")
29 elif reduction == "mean":
30 return float("nan")
31 elif reduction == "min":
32 return float("Inf")
33 elif reduction == "sum":
34 return 0.0
Mikayla Gawarecki814ff742022-06-06 20:27:49 +000035 elif reduction == "prod":
36 return 1.0
Serhat Yilmazf84a4412021-07-07 13:07:12 -070037
38
Serhat Yilmaz7e3cf1e2021-03-23 15:56:00 -070039class TestSegmentReductions(TestCase):
Serhat Yilmaza727f652021-06-17 16:23:15 -070040 def _test_common(
41 self,
42 reduction,
43 device,
44 dtype,
45 unsafe,
46 axis,
47 initial_value,
48 data_arr,
49 lengths_arr,
50 expected_arr,
51 expected_grad_arr,
52 check_backward,
Serhat Yilmaza78ad5d2021-07-07 13:07:12 -070053 lengths_dtype=torch.int,
Serhat Yilmaza727f652021-06-17 16:23:15 -070054 ):
Serhat Yilmaza78ad5d2021-07-07 13:07:12 -070055 lengths = torch.tensor(lengths_arr, device=device, dtype=lengths_dtype)
Mikayla Gawarecki7360b532022-06-16 17:32:10 +000056 # generate offsets from lengths
57 zeros_shape = list(lengths.shape)
58 zeros_shape[-1] = 1
59 offsets = torch.cat((lengths.new_zeros(zeros_shape), lengths), -1).cumsum_(-1)
60
Serhat Yilmaze27740b2021-04-29 15:40:36 -070061 data = torch.tensor(
Serhat Yilmaza727f652021-06-17 16:23:15 -070062 data_arr,
Serhat Yilmaze27740b2021-04-29 15:40:36 -070063 device=device,
64 dtype=dtype,
65 requires_grad=True,
Serhat Yilmaz7e3cf1e2021-03-23 15:56:00 -070066 )
Serhat Yilmaza727f652021-06-17 16:23:15 -070067 expected_result = torch.tensor(expected_arr, device=device, dtype=dtype)
68 expected_grad = torch.tensor(expected_grad_arr, device=device, dtype=dtype)
Mikayla Gawarecki7360b532022-06-16 17:32:10 +000069 for mode in ['lengths', 'offsets']:
70 segment_reduce_kwargs = dict(
71 axis=axis,
72 unsafe=unsafe,
73 initial=initial_value)
74 if (mode == 'lengths'):
75 segment_reduce_kwargs['lengths'] = lengths
76 else:
77 segment_reduce_kwargs['offsets'] = offsets
78 actual_result = torch.segment_reduce(
79 data=data,
80 reduce=reduction,
81 **segment_reduce_kwargs
Mikayla Gawarecki1ec30a62022-06-10 22:33:07 +000082 )
Mikayla Gawarecki7360b532022-06-16 17:32:10 +000083 self.assertEqual(
84 expected_result, actual_result, rtol=1e-02, atol=1e-05, equal_nan=True
85 )
86
87 if not check_backward:
88 return
89
90 # Test backward
91 actual_result.sum().backward()
92 self.assertEqual(
93 expected_grad, data.grad, rtol=1e-02, atol=1e-05, equal_nan=True
94 )
95 data = data.clone().detach().requires_grad_(True)
96
97 # gradcheck does not work well with bfloat16 or fp16 cpu types
98 # also there is small numerical difference with fp32
99 if dtype not in [torch.half, torch.bfloat16, torch.float]:
100 # gradcheck does not like "nan" input, setting to random 10
101 d_non_nan = np.nan_to_num(data_arr, nan=10)
102 new_data = torch.tensor(
103 # [10 if v == float("nan") else v for v in data],
104 d_non_nan,
105 device=device,
106 dtype=dtype,
107 requires_grad=True,
Serhat Yilmaze27740b2021-04-29 15:40:36 -0700108 )
Mikayla Gawarecki7360b532022-06-16 17:32:10 +0000109 self.assertTrue(
110 gradcheck(
111 lambda x: torch.segment_reduce(
112 data=x,
113 reduce=reduction,
114 **segment_reduce_kwargs
115 ),
116 (new_data,),
117 )
118 )
Serhat Yilmaze27740b2021-04-29 15:40:36 -0700119
Serhat Yilmaza78ad5d2021-07-07 13:07:12 -0700120 @dtypes(
121 *product(
122 (torch.half, torch.bfloat16, torch.float, torch.double),
123 (torch.int, torch.int64),
124 )
125 )
126 def test_simple_1d(self, device, dtypes):
127 val_dtype, length_type = dtypes
Serhat Yilmaza727f652021-06-17 16:23:15 -0700128 lengths = [1, 2, 3, 0]
129 data = [1, float("nan"), 3, 4, 5, 5]
Serhat Yilmaza727f652021-06-17 16:23:15 -0700130
Serhat Yilmazaf668242021-06-23 18:50:36 -0700131 for reduction in reductions:
Serhat Yilmazf84a4412021-07-07 13:07:12 -0700132 for initial in [0, None]:
133 check_backward = True if initial is not None else False
134 initial_value = initial
135 default_value = get_default_value(initial_value, reduction)
136 if reduction == "max":
137 expected_result = [1, float("nan"), 5, default_value]
138 expected_grad = [1, 1, 0, 0, 0.5, 0.5]
139 elif reduction == "mean":
140 expected_result = [1, float("nan"), 4.666, default_value]
141 expected_grad = [1.0, 0.5, 0.5, 0.333, 0.333, 0.333]
142 elif reduction == "min":
143 if initial is not None:
144 initial_value = 1000 # some high number
145 default_value = get_default_value(initial_value, reduction)
146 expected_result = [1, float("nan"), 4, default_value]
147 expected_grad = [1.0, 1.0, 0, 1, 0, 0]
148 elif reduction == "sum":
149 expected_result = [1, float("nan"), 14, default_value]
150 expected_grad = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
Mikayla Gawarecki814ff742022-06-06 20:27:49 +0000151 elif reduction == "prod":
152 if initial is not None:
153 initial_value = 2 # 0 initial_value will zero out everything for prod
154 default_value = get_default_value(initial_value, reduction)
155 expected_result = [2, float("nan"), 200, default_value]
156 expected_grad = [2.0, 6.0, float("nan"), 50.0, 40.0, 40.0]
157 else:
158 expected_result = [1, float("nan"), 100, default_value]
159 expected_grad = [1.0, 3.0, float("nan"), 25.0, 20.0, 20.0]
Serhat Yilmazf84a4412021-07-07 13:07:12 -0700160 for axis in [0, -1]:
161 for unsafe in [True, False]:
Serhat Yilmazaf668242021-06-23 18:50:36 -0700162 self._test_common(
163 reduction,
164 device,
Serhat Yilmaza78ad5d2021-07-07 13:07:12 -0700165 val_dtype,
Serhat Yilmazaf668242021-06-23 18:50:36 -0700166 unsafe,
167 axis,
168 initial_value,
169 data,
170 lengths,
171 expected_result,
172 expected_grad,
173 check_backward,
Serhat Yilmaza78ad5d2021-07-07 13:07:12 -0700174 length_type,
Serhat Yilmazaf668242021-06-23 18:50:36 -0700175 )
Serhat Yilmaza727f652021-06-17 16:23:15 -0700176
Serhat Yilmaza78ad5d2021-07-07 13:07:12 -0700177 @dtypes(
178 *product(
179 (torch.half, torch.bfloat16, torch.float, torch.double),
180 (torch.int, torch.int64),
181 )
182 )
183 def test_multi_d_simple(self, device, dtypes):
184 val_dtype, length_type = dtypes
Serhat Yilmaza727f652021-06-17 16:23:15 -0700185 axis = 0
186 lengths = [1, 2, 3, 0]
187 data = [[1, 1], [float("nan"), 1], [3, float("nan")], [4, 1], [3, 2], [2, 3]]
188
Serhat Yilmazaf668242021-06-23 18:50:36 -0700189 for reduction in reductions:
Serhat Yilmazf84a4412021-07-07 13:07:12 -0700190 for initial in [0, None]:
191 check_backward = True if initial is not None else False
192 initial_value = initial
193 default_value = get_default_value(initial_value, reduction)
194 if reduction == "max":
195 expected_result = [
196 [1, 1],
197 [float("nan"), float("nan")],
198 [4, 3],
199 [default_value, default_value],
200 ]
201 expected_grad = [
202 [1, 1],
203 [1, 0],
204 [0, 1],
205 [1, 0],
206 [0, 0],
207 [0, 1],
208 ]
209 elif reduction == "mean":
210 expected_result = [
211 [1, 1],
212 [float("nan"), float("nan")],
213 [3, 2],
214 [default_value, default_value],
215 ]
216 expected_grad = [
217 [1.0, 1.0],
218 [0.5, 0.5],
219 [0.5, 0.5],
220 [0.333, 0.333],
221 [0.333, 0.333],
222 [0.333, 0.333],
223 ]
224 elif reduction == "min":
225 if initial is not None:
226 initial_value = 1000 # some high number
227 default_value = get_default_value(initial_value, reduction)
228 expected_result = [
229 [1, 1],
230 [float("nan"), float("nan")],
231 [2, 1],
232 [default_value, default_value],
233 ]
234 expected_grad = [
235 [1.0, 1.0],
236 [1, 0],
237 [0, 1],
238 [0, 1],
239 [0, 0],
240 [1, 0],
241 ]
242 elif reduction == "sum":
243 expected_result = [
244 [1, 1],
245 [float("nan"), float("nan")],
246 [9, 6],
247 [default_value, default_value],
248 ]
249 expected_grad = [
250 [1.0, 1.0],
251 [1.0, 1.0],
252 [1.0, 1.0],
253 [1.0, 1.0],
254 [1.0, 1.0],
255 [1.0, 1.0],
256 ]
Mikayla Gawarecki814ff742022-06-06 20:27:49 +0000257 elif reduction == "prod":
258 if initial is not None:
259 initial_value = 2 # 0 initial_value will zero out everything for prod
260 default_value = get_default_value(initial_value, reduction)
261 expected_result = [
262 [2, 2],
263 [float("nan"), float("nan")],
264 [48, 12],
265 [default_value, default_value],
266 ]
267 expected_grad = [
268 [2.0, 2.0],
269 [6.0, float("nan")],
270 [float("nan"), 2.0],
271 [12.0, 12.0],
272 [16.0, 6.0],
273 [24.0, 4.0],
274 ]
275 else:
276 expected_result = [
277 [1, 1],
278 [float("nan"), float("nan")],
279 [24, 6],
280 [default_value, default_value],
281 ]
282 expected_grad = [
283 [1.0, 1.0],
284 [3.0, float("nan")],
285 [float("nan"), 1.0],
286 [6.0, 6.0],
287 [8.0, 3.0],
288 [12.0, 2.0],
289 ]
Serhat Yilmazf84a4412021-07-07 13:07:12 -0700290 for unsafe in [True, False]:
Serhat Yilmazaf668242021-06-23 18:50:36 -0700291 self._test_common(
292 reduction,
293 device,
Serhat Yilmaza78ad5d2021-07-07 13:07:12 -0700294 val_dtype,
Serhat Yilmazaf668242021-06-23 18:50:36 -0700295 unsafe,
296 axis,
297 initial_value,
298 data,
299 lengths,
300 expected_result,
301 expected_grad,
302 check_backward,
303 )
Serhat Yilmaza727f652021-06-17 16:23:15 -0700304
Serhat Yilmaza78ad5d2021-07-07 13:07:12 -0700305 @dtypes(
306 *product(
307 (torch.half, torch.bfloat16, torch.float, torch.double),
308 (torch.int, torch.int64),
309 )
310 )
Mikayla Gawareckie289a182022-06-09 04:03:06 +0000311 @parametrize("reduce", ['sum', 'prod', 'min', 'max', 'mean'])
Mikayla Gawareckie289a182022-06-09 04:03:06 +0000312 def test_pytorch_scatter_test_cases(self, device, dtypes, reduce):
313 val_dtype, length_dtype = dtypes
314 # zero-length segments are filled with reduction inits contrary to pytorch_scatter.
315 tests = [
316 {
317 'src': [1, 2, 3, 4, 5, 6],
318 'index': [0, 0, 1, 1, 1, 3],
319 'indptr': [0, 2, 5, 5, 6],
320 'sum': [3, 12, 0, 6],
321 'prod': [2, 60, 1, 6],
322 'mean': [1.5, 4, float('nan'), 6],
323 'min': [1, 3, float('inf'), 6],
324 'max': [2, 5, -float('inf'), 6],
325 },
326 {
327 'src': [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]],
328 'index': [0, 0, 1, 1, 1, 3],
329 'indptr': [0, 2, 5, 5, 6],
330 'sum': [[4, 6], [21, 24], [0, 0], [11, 12]],
331 'prod': [[3, 8], [315, 480], [1, 1], [11, 12]],
332 'mean': [[2, 3], [7, 8], [float('nan'), float('nan')], [11, 12]],
333 'min': [[1, 2], [5, 6], [float('inf'), float('inf')], [11, 12]],
334 'max': [[3, 4], [9, 10], [-float('inf'), -float('inf')], [11, 12]],
335 },
336 {
337 'src': [[1, 3, 5, 7, 9, 11], [2, 4, 6, 8, 10, 12]],
338 'index': [[0, 0, 1, 1, 1, 3], [0, 0, 0, 1, 1, 2]],
339 'indptr': [[0, 2, 5, 5, 6], [0, 3, 5, 6, 6]],
340 'sum': [[4, 21, 0, 11], [12, 18, 12, 0]],
341 'prod': [[3, 315, 1, 11], [48, 80, 12, 1]],
342 'mean': [[2, 7, float('nan'), 11], [4, 9, 12, float('nan')]],
343 'min': [[1, 5, float('inf'), 11], [2, 8, 12, float('inf')]],
344 'max': [[3, 9, -float('inf'), 11], [6, 10, 12, -float('inf')]],
345 },
346 {
347 'src': [[[1, 2], [3, 4], [5, 6]], [[7, 9], [10, 11], [12, 13]]],
348 'index': [[0, 0, 1], [0, 2, 2]],
349 'indptr': [[0, 2, 3, 3], [0, 1, 1, 3]],
350 'sum': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]],
351 'prod': [[[3, 8], [5, 6], [1, 1]], [[7, 9], [1, 1], [120, 143]]],
352 'mean': [[[2, 3], [5, 6], [float('nan'), float('nan')]],
353 [[7, 9], [float('nan'), float('nan')], [11, 12]]],
354 'min': [[[1, 2], [5, 6], [float('inf'), float('inf')]],
355 [[7, 9], [float('inf'), float('inf')], [10, 11]]],
356 'max': [[[3, 4], [5, 6], [-float('inf'), -float('inf')]],
357 [[7, 9], [-float('inf'), -float('inf')], [12, 13]]],
358 },
359 {
360 'src': [[1, 3], [2, 4]],
361 'index': [[0, 0], [0, 0]],
362 'indptr': [[0, 2], [0, 2]],
363 'sum': [[4], [6]],
364 'prod': [[3], [8]],
365 'mean': [[2], [3]],
366 'min': [[1], [2]],
367 'max': [[3], [4]],
368 },
369 {
370 'src': [[[1, 1], [3, 3]], [[2, 2], [4, 4]]],
371 'index': [[0, 0], [0, 0]],
372 'indptr': [[0, 2], [0, 2]],
373 'sum': [[[4, 4]], [[6, 6]]],
374 'prod': [[[3, 3]], [[8, 8]]],
375 'mean': [[[2, 2]], [[3, 3]]],
376 'min': [[[1, 1]], [[2, 2]]],
377 'max': [[[3, 3]], [[4, 4]]],
378 },
379 ]
380 for test in tests:
381 data = torch.tensor(test['src'], dtype=val_dtype, device=device, requires_grad=True)
382 indptr = torch.tensor(test['indptr'], dtype=length_dtype, device=device)
383 dim = indptr.ndim - 1
384 # calculate lengths from indptr
385 lengths = torch.diff(indptr, dim=dim)
386 expected = torch.tensor(test[reduce], dtype=val_dtype, device=device)
387
388 actual_result = torch.segment_reduce(
389 data=data,
390 reduce=reduce,
391 lengths=lengths,
392 axis=dim,
393 unsafe=True,
394 )
Mikayla Gawareckie289a182022-06-09 04:03:06 +0000395 self.assertEqual(actual_result, expected)
396
Mikayla Gawarecki7360b532022-06-16 17:32:10 +0000397 # test offsets
398 actual_result = torch.segment_reduce(
399 data=data,
400 reduce=reduce,
401 offsets=indptr,
402 axis=dim,
403 unsafe=True,
404 )
405 self.assertEqual(actual_result, expected)
406
Mikayla Gawareckie289a182022-06-09 04:03:06 +0000407 if val_dtype == torch.float64:
Mikayla Gawarecki7360b532022-06-16 17:32:10 +0000408 def fn(x, mode='lengths'):
Mikayla Gawareckie289a182022-06-09 04:03:06 +0000409 initial = 1
410 # supply initial values to prevent gradcheck from failing for 0 length segments
411 # where nan/inf are reduction identities that produce nans when calculating the numerical jacobian
412 if reduce == 'min':
413 initial = 1000
414 elif reduce == 'max':
415 initial = -1000
Mikayla Gawarecki7360b532022-06-16 17:32:10 +0000416 segment_reduce_args = {x, reduce}
417 segment_reduce_kwargs = dict(axis=dim, unsafe=True, initial=initial)
418 if mode == 'lengths':
419 segment_reduce_kwargs[mode] = lengths
420 elif mode == 'offsets':
421 segment_reduce_kwargs[mode] = indptr
422 return torch.segment_reduce(*segment_reduce_args, **segment_reduce_kwargs)
423 self.assertTrue(gradcheck(partial(fn, mode='lengths'), (data.clone().detach().requires_grad_(True))))
424 self.assertTrue(gradcheck(partial(fn, mode='offsets'), (data.clone().detach().requires_grad_(True))))
425
Mikayla Gawareckie289a182022-06-09 04:03:06 +0000426
427 @dtypes(
428 *product(
429 (torch.half, torch.bfloat16, torch.float, torch.double),
430 (torch.int, torch.int64),
431 )
432 )
Serhat Yilmaza78ad5d2021-07-07 13:07:12 -0700433 def test_multi_d(self, device, dtypes):
434 val_dtype, length_type = dtypes
Serhat Yilmaza727f652021-06-17 16:23:15 -0700435 axis = 0
Mikayla Gawarecki814ff742022-06-06 20:27:49 +0000436 lengths = [0, 2, 3, 0]
437 data = np.arange(50).reshape(5, 2, 5).tolist()
Serhat Yilmaza727f652021-06-17 16:23:15 -0700438 expected_grad = []
439
440 # TODO: calculate grad and check correctness
441 check_backward = False
442
Serhat Yilmazaf668242021-06-23 18:50:36 -0700443 for reduction in reductions:
Serhat Yilmazf84a4412021-07-07 13:07:12 -0700444 initial_value = 0
Serhat Yilmaza727f652021-06-17 16:23:15 -0700445 if reduction == "max":
446 expected_result = [
447 np.full((2, 5), initial_value).tolist(),
Mikayla Gawarecki814ff742022-06-06 20:27:49 +0000448 np.max(data[:2], axis=0).tolist(),
449 np.max(data[2:], axis=0).tolist(),
450 np.full((2, 5), initial_value).tolist(),
Serhat Yilmaza727f652021-06-17 16:23:15 -0700451 ]
452 elif reduction == "mean":
453 expected_result = [
454 np.full((2, 5), initial_value).tolist(),
Mikayla Gawarecki814ff742022-06-06 20:27:49 +0000455 np.mean(data[:2], axis=0).tolist(),
456 np.mean(data[2:], axis=0).tolist(),
457 np.full((2, 5), initial_value).tolist(),
Serhat Yilmaza727f652021-06-17 16:23:15 -0700458 ]
Serhat Yilmazaf668242021-06-23 18:50:36 -0700459 elif reduction == "min":
460 initial_value = 1000 # some high number
461 expected_result = [
462 np.full((2, 5), initial_value).tolist(),
Mikayla Gawarecki814ff742022-06-06 20:27:49 +0000463 np.min(data[:2], axis=0).tolist(),
464 np.min(data[2:], axis=0).tolist(),
465 np.full((2, 5), initial_value).tolist(),
Serhat Yilmazaf668242021-06-23 18:50:36 -0700466 ]
467 elif reduction == "sum":
Serhat Yilmazaf668242021-06-23 18:50:36 -0700468 expected_result = [
469 np.full((2, 5), initial_value).tolist(),
Mikayla Gawarecki814ff742022-06-06 20:27:49 +0000470 np.sum(data[:2], axis=0).tolist(),
471 np.sum(data[2:], axis=0).tolist(),
472 np.full((2, 5), initial_value).tolist(),
473 ]
474 elif reduction == "prod":
475 initial_value = 1
476 expected_result = [
477 np.full((2, 5), initial_value).tolist(),
478 np.prod(data[:2], axis=0).tolist(),
479 np.prod(data[2:], axis=0).tolist(),
480 np.full((2, 5), initial_value).tolist(),
Serhat Yilmazaf668242021-06-23 18:50:36 -0700481 ]
Serhat Yilmaza727f652021-06-17 16:23:15 -0700482 for unsafe in [True, False]:
Serhat Yilmazf84a4412021-07-07 13:07:12 -0700483 self._test_common(
484 reduction,
485 device,
486 val_dtype,
487 unsafe,
488 axis,
489 initial_value,
490 data,
491 lengths,
492 expected_result,
493 expected_grad,
494 check_backward,
495 )
Serhat Yilmaz6c377882021-04-27 11:27:36 -0700496
Mikayla Gawareckie289a182022-06-09 04:03:06 +0000497 @dtypes(torch.int, torch.int64)
498 def test_unsafe_flag(self, device, dtype):
499 length_type = dtype
Mikayla Gawareckie7275392022-06-10 22:33:06 +0000500 lengths = torch.tensor([0, 2, 3, 0], device=device, dtype=length_type)
501 data = torch.arange(6, dtype=torch.float, device=device)
Mikayla Gawareckie289a182022-06-09 04:03:06 +0000502
503 # test for error on 1-D lenghts
504 with self.assertRaisesRegex(RuntimeError, "Expected all rows of lengths along axis"):
505 torch.segment_reduce(data, 'sum', lengths=lengths, axis=0, unsafe=False)
506
507 # test for error on multi-D lengths
Mikayla Gawareckie7275392022-06-10 22:33:06 +0000508 nd_lengths = torch.tensor([[0, 3, 3, 0], [2, 3, 0, 0]], dtype=length_type, device=device)
509 nd_data = torch.arange(12, dtype=torch.float, device=device).reshape(2, 6)
Mikayla Gawareckie289a182022-06-09 04:03:06 +0000510 with self.assertRaisesRegex(RuntimeError, "Expected all rows of lengths along axis"):
511 torch.segment_reduce(nd_data, 'sum', lengths=nd_lengths, axis=1, unsafe=False)
512
513
514
Serhat Yilmaz7e3cf1e2021-03-23 15:56:00 -0700515
516instantiate_device_type_tests(TestSegmentReductions, globals())
517
518if __name__ == "__main__":
519 run_tests()