blob: eeac91ef88c9f5209e160a5d1c9d6be427340ba2 [file] [log] [blame]
# Owner(s): ["module: scatter & gather ops"]
from itertools import product
import numpy as np
import torch
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
dtypes,
)
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
gradcheck,
)
reductions = ["max", "mean", "min", "sum", "prod"]
def get_default_value(initial_value, reduction):
if initial_value is not None:
return initial_value
if reduction == "max":
return -float("Inf")
elif reduction == "mean":
return float("nan")
elif reduction == "min":
return float("Inf")
elif reduction == "sum":
return 0.0
elif reduction == "prod":
return 1.0
class TestSegmentReductions(TestCase):
def _test_common(
self,
reduction,
device,
dtype,
unsafe,
axis,
initial_value,
data_arr,
lengths_arr,
expected_arr,
expected_grad_arr,
check_backward,
lengths_dtype=torch.int,
):
lengths = torch.tensor(lengths_arr, device=device, dtype=lengths_dtype)
data = torch.tensor(
data_arr,
device=device,
dtype=dtype,
requires_grad=True,
)
expected_result = torch.tensor(expected_arr, device=device, dtype=dtype)
expected_grad = torch.tensor(expected_grad_arr, device=device, dtype=dtype)
actual_result = torch.segment_reduce(
data=data,
reduce=reduction,
lengths=lengths,
axis=axis,
unsafe=unsafe,
initial=initial_value,
)
self.assertEqual(
expected_result, actual_result, rtol=1e-02, atol=1e-05, equal_nan=True
)
if not check_backward:
return
# Test backward
actual_result.sum().backward()
self.assertEqual(
expected_grad, data.grad, rtol=1e-02, atol=1e-05, equal_nan=True
)
# gradcheck does not work well with bfloat16 or fp16 cpu types
# also there is small numerical difference with fp32
if dtype not in [torch.half, torch.bfloat16, torch.float]:
# gradcheck does not like "nan" input, setting to random 10
d_non_nan = np.nan_to_num(data_arr, nan=10)
data = torch.tensor(
# [10 if v == float("nan") else v for v in data],
d_non_nan,
device=device,
dtype=dtype,
requires_grad=True,
)
self.assertTrue(
gradcheck(
lambda x: torch.segment_reduce(
data=x,
reduce=reduction,
lengths=lengths,
axis=axis,
unsafe=unsafe,
initial=initial_value,
),
(data,),
)
)
@dtypes(
*product(
(torch.half, torch.bfloat16, torch.float, torch.double),
(torch.int, torch.int64),
)
)
def test_simple_1d(self, device, dtypes):
val_dtype, length_type = dtypes
lengths = [1, 2, 3, 0]
data = [1, float("nan"), 3, 4, 5, 5]
for reduction in reductions:
for initial in [0, None]:
check_backward = True if initial is not None else False
initial_value = initial
default_value = get_default_value(initial_value, reduction)
if reduction == "max":
expected_result = [1, float("nan"), 5, default_value]
expected_grad = [1, 1, 0, 0, 0.5, 0.5]
elif reduction == "mean":
expected_result = [1, float("nan"), 4.666, default_value]
expected_grad = [1.0, 0.5, 0.5, 0.333, 0.333, 0.333]
elif reduction == "min":
if initial is not None:
initial_value = 1000 # some high number
default_value = get_default_value(initial_value, reduction)
expected_result = [1, float("nan"), 4, default_value]
expected_grad = [1.0, 1.0, 0, 1, 0, 0]
elif reduction == "sum":
expected_result = [1, float("nan"), 14, default_value]
expected_grad = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
elif reduction == "prod":
if initial is not None:
initial_value = 2 # 0 initial_value will zero out everything for prod
default_value = get_default_value(initial_value, reduction)
expected_result = [2, float("nan"), 200, default_value]
expected_grad = [2.0, 6.0, float("nan"), 50.0, 40.0, 40.0]
else:
expected_result = [1, float("nan"), 100, default_value]
expected_grad = [1.0, 3.0, float("nan"), 25.0, 20.0, 20.0]
for axis in [0, -1]:
for unsafe in [True, False]:
self._test_common(
reduction,
device,
val_dtype,
unsafe,
axis,
initial_value,
data,
lengths,
expected_result,
expected_grad,
check_backward,
length_type,
)
@dtypes(
*product(
(torch.half, torch.bfloat16, torch.float, torch.double),
(torch.int, torch.int64),
)
)
def test_multi_d_simple(self, device, dtypes):
val_dtype, length_type = dtypes
axis = 0
lengths = [1, 2, 3, 0]
data = [[1, 1], [float("nan"), 1], [3, float("nan")], [4, 1], [3, 2], [2, 3]]
for reduction in reductions:
for initial in [0, None]:
check_backward = True if initial is not None else False
initial_value = initial
default_value = get_default_value(initial_value, reduction)
if reduction == "max":
expected_result = [
[1, 1],
[float("nan"), float("nan")],
[4, 3],
[default_value, default_value],
]
expected_grad = [
[1, 1],
[1, 0],
[0, 1],
[1, 0],
[0, 0],
[0, 1],
]
elif reduction == "mean":
expected_result = [
[1, 1],
[float("nan"), float("nan")],
[3, 2],
[default_value, default_value],
]
expected_grad = [
[1.0, 1.0],
[0.5, 0.5],
[0.5, 0.5],
[0.333, 0.333],
[0.333, 0.333],
[0.333, 0.333],
]
elif reduction == "min":
if initial is not None:
initial_value = 1000 # some high number
default_value = get_default_value(initial_value, reduction)
expected_result = [
[1, 1],
[float("nan"), float("nan")],
[2, 1],
[default_value, default_value],
]
expected_grad = [
[1.0, 1.0],
[1, 0],
[0, 1],
[0, 1],
[0, 0],
[1, 0],
]
elif reduction == "sum":
expected_result = [
[1, 1],
[float("nan"), float("nan")],
[9, 6],
[default_value, default_value],
]
expected_grad = [
[1.0, 1.0],
[1.0, 1.0],
[1.0, 1.0],
[1.0, 1.0],
[1.0, 1.0],
[1.0, 1.0],
]
elif reduction == "prod":
if initial is not None:
initial_value = 2 # 0 initial_value will zero out everything for prod
default_value = get_default_value(initial_value, reduction)
expected_result = [
[2, 2],
[float("nan"), float("nan")],
[48, 12],
[default_value, default_value],
]
expected_grad = [
[2.0, 2.0],
[6.0, float("nan")],
[float("nan"), 2.0],
[12.0, 12.0],
[16.0, 6.0],
[24.0, 4.0],
]
else:
expected_result = [
[1, 1],
[float("nan"), float("nan")],
[24, 6],
[default_value, default_value],
]
expected_grad = [
[1.0, 1.0],
[3.0, float("nan")],
[float("nan"), 1.0],
[6.0, 6.0],
[8.0, 3.0],
[12.0, 2.0],
]
for unsafe in [True, False]:
self._test_common(
reduction,
device,
val_dtype,
unsafe,
axis,
initial_value,
data,
lengths,
expected_result,
expected_grad,
check_backward,
)
@dtypes(
*product(
(torch.half, torch.bfloat16, torch.float, torch.double),
(torch.int, torch.int64),
)
)
def test_multi_d(self, device, dtypes):
val_dtype, length_type = dtypes
axis = 0
lengths = [0, 2, 3, 0]
data = np.arange(50).reshape(5, 2, 5).tolist()
expected_grad = []
# TODO: calculate grad and check correctness
check_backward = False
for reduction in reductions:
initial_value = 0
if reduction == "max":
expected_result = [
np.full((2, 5), initial_value).tolist(),
np.max(data[:2], axis=0).tolist(),
np.max(data[2:], axis=0).tolist(),
np.full((2, 5), initial_value).tolist(),
]
elif reduction == "mean":
expected_result = [
np.full((2, 5), initial_value).tolist(),
np.mean(data[:2], axis=0).tolist(),
np.mean(data[2:], axis=0).tolist(),
np.full((2, 5), initial_value).tolist(),
]
elif reduction == "min":
initial_value = 1000 # some high number
expected_result = [
np.full((2, 5), initial_value).tolist(),
np.min(data[:2], axis=0).tolist(),
np.min(data[2:], axis=0).tolist(),
np.full((2, 5), initial_value).tolist(),
]
elif reduction == "sum":
expected_result = [
np.full((2, 5), initial_value).tolist(),
np.sum(data[:2], axis=0).tolist(),
np.sum(data[2:], axis=0).tolist(),
np.full((2, 5), initial_value).tolist(),
]
elif reduction == "prod":
initial_value = 1
expected_result = [
np.full((2, 5), initial_value).tolist(),
np.prod(data[:2], axis=0).tolist(),
np.prod(data[2:], axis=0).tolist(),
np.full((2, 5), initial_value).tolist(),
]
for unsafe in [True, False]:
self._test_common(
reduction,
device,
val_dtype,
unsafe,
axis,
initial_value,
data,
lengths,
expected_result,
expected_grad,
check_backward,
)
instantiate_device_type_tests(TestSegmentReductions, globals())
if __name__ == "__main__":
run_tests()