OpInfo: `fmod` and `remainder` (#57941)
Summary:
See https://github.com/pytorch/pytorch/issues/54261
cc: mruberry Lezcano kshitij12345
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57941
Reviewed By: mrshenli
Differential Revision: D28744464
Pulled By: mruberry
fbshipit-source-id: 19847277d4f8d3a39a706c2b3c9eddf0dedcb20c
diff --git a/aten/src/THC/generic/THCTensorMathPairwise.cu b/aten/src/THC/generic/THCTensorMathPairwise.cu
index 262a4f0..aba731c 100644
--- a/aten/src/THC/generic/THCTensorMathPairwise.cu
+++ b/aten/src/THC/generic/THCTensorMathPairwise.cu
@@ -24,24 +24,6 @@
THCudaCheck(cudaGetLastError());
}
-void THCTensor_(fmod)(THCState *state, THCTensor *self_, THCTensor *src_, scalar_t value)
-{
- THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src_));
- if (self_ == src_) {
- if (!THC_pointwiseApply1<scalar_t>(state, self_, TensorFmodOp<scalar_t>(value))) {
- THArgCheck(false, 2, CUTORCH_DIM_WARNING);
- }
- } else {
- THCTensor_(resizeAs)(state, self_, src_);
-
- if (!THC_pointwiseApply2<scalar_t, scalar_t>(state, self_, src_, TensorFmodOp<scalar_t>(value))) {
- THArgCheck(false, 2, CUTORCH_DIM_WARNING);
- }
- }
-
- THCudaCheck(cudaGetLastError());
-}
-
#endif
#endif
diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py
index 26638db..0c08012 100644
--- a/test/test_jit_fuser_te.py
+++ b/test/test_jit_fuser_te.py
@@ -1853,6 +1853,8 @@
'expand',
'expm1',
'floor',
+ 'fmod',
+ 'fmod.autodiffed',
'ge',
'gt',
'le',
@@ -1878,6 +1880,8 @@
'nn.functional.relu6',
'pow',
'reciprocal',
+ 'remainder',
+ 'remainder.autodiffed',
'round',
'rsqrt',
'sigmoid',
diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py
index 0d9100f..b3e243d 100644
--- a/torch/_torch_docs.py
+++ b/torch/_torch_docs.py
@@ -3274,10 +3274,10 @@
r"""
fmod(input, other, *, out=None) -> Tensor
-Computes the element-wise remainder of division.
-
-The dividend and divisor may contain both for integer and floating point
-numbers. The remainder has the same sign as the dividend :attr:`input`.
+Applies C++'s `std::fmod <https://en.cppreference.com/w/cpp/numeric/math/fmod>`_
+for floating point tensors, and the modulus operation for integer tensors. The result
+has the same sign as the dividend :attr:`input` and its absolute value
+is less than that of :attr:`other`.
Supports :ref:`broadcasting to a common shape <broadcasting-semantics>`,
:ref:`type promotion <type-promotion-doc>`, and integer and float inputs.
@@ -3288,6 +3288,11 @@
on both CPU and GPU; raises ``RuntimeError`` for integer division by
zero on CPU; Integer division by zero on GPU may return any value.
+.. note::
+
+ Complex inputs are not supported. In some cases, it is not mathematically
+ possible to satisfy the definition of a modulo operation with complex numbers.
+
Args:
input (Tensor): the dividend
other (Tensor or Scalar): the divisor
@@ -3299,9 +3304,14 @@
>>> torch.fmod(torch.tensor([-3., -2, -1, 1, 2, 3]), 2)
tensor([-1., -0., -1., 1., 0., 1.])
- >>> torch.fmod(torch.tensor([1, 2, 3, 4, 5]), 1.5)
+ >>> torch.fmod(torch.tensor([1, 2, 3, 4, 5]), -1.5)
tensor([1.0000, 0.5000, 0.0000, 1.0000, 0.5000])
+.. seealso::
+
+ :func:`torch.remainder` which is similar to :func:`torch.fmod` except that if the sign
+ of the modulus is different than the sign of the divisor :attr:`other` then the divisor
+ is added to the modulus.
""".format(**common_args))
add_docstr(torch.frac,
@@ -7639,10 +7649,10 @@
r"""
remainder(input, other, *, out=None) -> Tensor
-Computes the element-wise remainder of division.
-
-The dividend and divisor may contain both for integer and floating point
-numbers. The remainder has the same sign as the divisor :attr:`other`.
+Like :func:`torch.fmod` this applies C++'s `std::fmod <https://en.cppreference.com/w/cpp/numeric/math/fmod>`_
+for floating point tensors and the modulus operation for integer tensors.
+Unlike :func:`torch.fmod`, however, if the sign of the modulus is different
+than the sign of the divisor :attr:`other` then the divisor is added to the modulus.
Supports :ref:`broadcasting to a common shape <broadcasting-semantics>`,
:ref:`type promotion <type-promotion-doc>`, and integer and float inputs.
@@ -7663,13 +7673,14 @@
>>> torch.remainder(torch.tensor([-3., -2, -1, 1, 2, 3]), 2)
tensor([ 1., 0., 1., 1., 0., 1.])
- >>> torch.remainder(torch.tensor([1, 2, 3, 4, 5]), 1.5)
- tensor([ 1.0000, 0.5000, 0.0000, 1.0000, 0.5000])
+ >>> torch.remainder(torch.tensor([1, 2, 3, 4, 5]), -1.5)
+ tensor([ -0.5000, -1.0000, 0.0000, -0.5000, -1.0000 ])
.. seealso::
- :func:`torch.fmod`, which computes the element-wise remainder of
- division equivalently to the C library function ``fmod()``.
+ :func:`torch.fmod` which just computes the modulus for integer inputs and
+ applies C++'s `std::fmod <https://en.cppreference.com/w/cpp/numeric/math/fmod>`_
+ for floating point inputs.
""".format(**common_args))
add_docstr(torch.renorm,
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index ef611de..3637dae 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -2807,6 +2807,46 @@
)
return [SampleInput(tensor) for tensor in tensors]
+def sample_inputs_fmod_remainder(op_info, device, dtype, requires_grad, *, autodiffed=False, **kwargs):
+ make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+
+ if autodiffed:
+ samples = ( # type: ignore[assignment]
+ ((S, S, S), 1.5, False),
+ ((), 1.5, False),
+ )
+ else:
+ cases = ( # type: ignore[assignment]
+ ((S, S, S), (), False),
+ ((S, S, S), (S, S, S), False),
+ ((S, S, S), (S,), False),
+ )
+
+ # Sample inputs with scalars as torch tensors
+ cases_with_tensor_scalar = ( # type: ignore[assignment]
+ ((), torch.tensor(1, dtype=dtype, device=device, requires_grad=False), False),
+ )
+
+ # Sample inputs with broadcasting
+ cases_with_broadcasting = ( # type: ignore[assignment]
+ ((S,), (S, S, S), True),
+ ((S, 1, S), (S, S, S), True),
+ ((), (S, S, S), True),
+ )
+
+ samples = cases + cases_with_tensor_scalar + cases_with_broadcasting # type: ignore[assignment]
+
+ def generator():
+ for shape, arg_other, broadcasts_input in samples:
+ if isinstance(arg_other, tuple):
+ arg = make_arg(arg_other, requires_grad=False, exclude_zero=True)
+ else:
+ # shape_other is scalar or torch.tensor
+ arg = arg_other
+ yield(SampleInput(make_arg(shape), args=(arg,), broadcasts_input=broadcasts_input))
+
+ return list(generator())
+
# TODO: clamp shares tensors among its sample inputs --- we should prohibit this!
def sample_inputs_clamp(op_info, device, dtype, requires_grad, **kwargs):
x = make_tensor((S, M, S), device, dtype, low=None, high=None, requires_grad=requires_grad)
@@ -4607,6 +4647,24 @@
op=torch.fmin,
dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
sample_inputs_func=sample_inputs_max_min_binary,),
+ OpInfo('fmod',
+ dtypes=all_types_and(torch.float16),
+ sample_inputs_func=sample_inputs_fmod_remainder),
+ OpInfo('fmod',
+ variant_test_name='autodiffed',
+ dtypes=all_types_and(torch.float16, torch.bool),
+ assert_autodiffed=True,
+ sample_inputs_func=partial(sample_inputs_fmod_remainder, autodiffed=True)),
+ OpInfo('remainder',
+ dtypesIfCPU=all_types_and(torch.float16),
+ dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
+ sample_inputs_func=sample_inputs_fmod_remainder),
+ OpInfo('remainder',
+ variant_test_name='autodiffed',
+ dtypesIfCPU=all_types_and(torch.float16, torch.bool),
+ dtypesIfCUDA=all_types_and(torch.float16, torch.bool, torch.bfloat16),
+ assert_autodiffed=True,
+ sample_inputs_func=partial(sample_inputs_fmod_remainder, autodiffed=True)),
UnaryUfuncInfo('frac',
ref=lambda x: np.modf(x)[0],
dtypes=floating_types_and(torch.bfloat16, torch.float16),
@@ -6656,22 +6714,6 @@
('div', torch.rand(S, S, S, dtype=torch.cdouble) + 1e-1, (3.14j,), 'complex_constant', (True,)),
('div', uniform_scalar(1e-1j, requires_grad=True), (3.14j,), 'complex_scalar_constant', (True,)),
('t', (1, 2), NO_ARGS, '', (False,)),
- ('fmod', (S, S, S), (1.5,), '', (True,)),
- ('fmod', (), (1.5,), 'scalar', (True,)),
- ('fmod', (S, S, S), (non_differentiable(torch.rand(S, S, S) + 1.5),), 'tensor'),
- ('fmod', (S,), (non_differentiable(torch.rand(S, S, S) + 1.5),), 'tensor_broadcast_lhs'),
- ('fmod', (S, S, S), (non_differentiable(torch.rand(S) + 1.5),), 'tensor_broadcast_rhs'),
- ('fmod', (S, 1, S), (non_differentiable(torch.rand(S, S) + 1.5),), 'tensor_broadcast_all'),
- ('fmod', (), (non_differentiable(uniform_scalar(1.5)),), 'scalar_tensor'),
- ('fmod', (), (non_differentiable(torch.rand(S, S, S) + 1.5),), 'scalar_tensor_broadcast_lhs'),
- ('fmod', (S, S, S), (non_differentiable(uniform_scalar(1.5)),), 'scalar_tensor_broadcast_rhs'),
- ('remainder', (S, S, S), (1.5,), '', (True,)),
- ('remainder', (), (1.5,), 'scalar', (True,)),
- ('remainder', (S, S, S), (non_differentiable(torch.rand(S, S, S) + 1.5),), 'tensor'),
- ('remainder', (S,), (non_differentiable(torch.rand(S, S, S) + 1.5),), 'tensor_broadcast_lhs'),
- ('remainder', (S, 1, S), (non_differentiable(torch.rand(S, S) + 1.5),), 'tensor_broadcast_all'),
- ('remainder', (), (non_differentiable(uniform_scalar(1.5)),), 'scalar_tensor'),
- ('remainder', (), (non_differentiable(torch.rand(S, S, S) + 1.5),), 'scalar_tensor_broadcast_lhs'),
('median', (S, S, S), NO_ARGS),
('median', (S, S, S), (1,), 'dim', (), [0]),
('median', (S, S, S), (1, True,), 'keepdim_dim', (), [0]),
diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py
index db827f1..d4b4a9f 100644
--- a/torch/testing/_internal/common_utils.py
+++ b/torch/testing/_internal/common_utils.py
@@ -1726,7 +1726,8 @@
# Methods for matrix and tensor generation
def make_tensor(size, device: torch.device, dtype: torch.dtype, *, low=None, high=None,
- requires_grad: bool = False, noncontiguous: bool = False) -> torch.Tensor:
+ requires_grad: bool = False, noncontiguous: bool = False,
+ exclude_zero: bool = False) -> torch.Tensor:
""" Creates a random tensor with the given size, device and dtype.
By default, the tensor's values are in the range [-9, 9] for most dtypes. If low
@@ -1738,6 +1739,10 @@
If noncontiguous=True, a noncontiguous tensor with the given size will be returned unless the size
specifies a tensor with a 1 or 0 elements in which case the noncontiguous parameter is ignored because
it is not possible to create a noncontiguous Tensor with a single element.
+
+ If exclude_zero is passed with True (default is False), all the matching values (with zero) in
+ created tensor are replaced with an epsilon value if floating type, [`eps + `eps`.j] if
+ complex type and 1 if integer/boolean type.
"""
assert low is None or low < 9, "low value too high!"
@@ -1776,6 +1781,19 @@
result = torch.repeat_interleave(result, 2, dim=-1)
result = result[..., ::2]
+ if exclude_zero:
+ if dtype in integral_types() or dtype is torch.bool:
+ replace_with = torch.tensor(1, device=device, dtype=dtype)
+ elif dtype in floating_types_and(torch.half, torch.bfloat16):
+ replace_with = torch.tensor(torch.finfo(dtype).eps, device=device, dtype=dtype)
+ else:
+ assert dtype in complex_types()
+ float_dtype = torch.float if dtype is torch.cfloat else torch.double
+ real = torch.tensor(torch.finfo(float_dtype).eps, device=device, dtype=dtype)
+ imag = torch.tensor(torch.finfo(float_dtype).eps, device=device, dtype=dtype)
+ replace_with = torch.complex(real, imag)
+ result[result == 0] = replace_with
+
if dtype in floating_types_and(torch.half, torch.bfloat16) or\
dtype in complex_types():
result.requires_grad = requires_grad