OpInfo for aten.exponential, Add check for dtype, parameter in decomp ref (#92709)
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92709
Approved by: https://github.com/lezcano
diff --git a/test/distributed/_tensor/test_dtensor_ops.py b/test/distributed/_tensor/test_dtensor_ops.py
index 9131c1a..14f8f1b 100644
--- a/test/distributed/_tensor/test_dtensor_ops.py
+++ b/test/distributed/_tensor/test_dtensor_ops.py
@@ -150,6 +150,7 @@
xfail("einsum"),
xfail("empty"),
xfail("empty_like"),
+ xfail("exponential"),
xfail("eye"),
xfail("fft.fft2"),
xfail("fft.fft"),
diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py
index 215dcfa..ad3661d 100644
--- a/test/inductor/test_torchinductor_opinfo.py
+++ b/test/inductor/test_torchinductor_opinfo.py
@@ -259,6 +259,7 @@
"to_sparse": {f32, f64},
# AssertionError: Tensor-likes are not close!
"cauchy": {f16},
+ "exponential": {f16},
"geometric": {f16},
"log_normal": {f16},
"uniform": {f16},
@@ -333,6 +334,7 @@
"to_sparse": {f16, f32, f64},
# AssertionError: Tensor-likes are not close!
"cauchy": {f16, f32, f64},
+ "exponential": {f16, f32, f64},
"geometric": {f16, f32, f64, i32, i64},
"log_normal": {f16, f32, f64},
"uniform": {f16, f32, f64},
diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py
index ac11f1a..f5e6bd7 100644
--- a/torch/_refs/__init__.py
+++ b/torch/_refs/__init__.py
@@ -5274,6 +5274,17 @@
)
def exponential(self, rate=1, generator=None):
assert generator is None
+ utils.check(
+ not utils.is_complex_dtype(self.dtype)
+ and not utils.is_integer_dtype(self.dtype)
+ and not utils.is_boolean_dtype(self.dtype),
+ lambda: f"Exponential distribution is a continuous probability distribution. \
+ dtype must be a floating point but you specified {self.dtype}",
+ )
+ utils.check(
+ rate > 0.0,
+ lambda: f"exponential_ expects lambda > 0.0, but found lambda={rate}",
+ )
return -1 / rate * torch.log1p(-torch.rand_like(self))
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 8460741..9fb8800 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -818,6 +818,28 @@
)
+def sample_inputs_exponential(op, device, dtype, requires_grad, **kwargs):
+
+ make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False)
+ samples = (
+ ((M,), 0.5),
+ ((S, S), 1),
+ ((S, S, S), 1.5),
+ )
+ for shape, rate in samples:
+ yield SampleInput(make_arg(shape), args=(rate,))
+
+
+def error_inputs_exponential(op, device, **kwargs):
+ t = torch.zeros([10], device=device)
+ invalid_rate = 0
+ yield ErrorInput(
+ SampleInput(t, args=(invalid_rate,)),
+ error_type=RuntimeError,
+ error_regex=r"exponential_ expects lambda > 0.0, but found lambda={}".format(invalid_rate),
+ )
+
+
def sample_inputs_geometric(op, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False)
@@ -8960,6 +8982,36 @@
DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'),
)),
+ OpInfo('exponential',
+ op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.exponential_, inp, *args, **kwargs),
+ inplace_variant=torch.Tensor.exponential_,
+ dtypes=floating_types_and(torch.float16, torch.bfloat16),
+ supports_out=False,
+ supports_autograd=False,
+ sample_inputs_func=sample_inputs_exponential,
+ error_inputs_func=error_inputs_exponential,
+ skips=(
+ # Tests that assume input tensor has a meaningful effect on output tensor
+ DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
+ DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+
+ # AssertionError: JIT Test does not execute any logic
+ DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+
+ # AssertionError: Tensor-likes are not close!
+ DecorateInfo(unittest.expectedFailure, 'TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive_inplace'),
+ DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
+
+ # FX failed to normalize op - add the op to the op_skip list.
+ DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+
+ # vmap: calling random operator not supported
+ DecorateInfo(unittest.expectedFailure, "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"),
+ DecorateInfo(unittest.expectedFailure, "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
+
+ DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'),
+ DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
+ )),
OpInfo('geometric',
op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.geometric_, inp, *args, **kwargs),
inplace_variant=torch.Tensor.geometric_,
@@ -17844,6 +17896,36 @@
)
),
PythonRefInfo(
+ "_refs.exponential",
+ torch_opinfo_name="exponential",
+ supports_out=True,
+ decorators=(
+ # dtypes that do not support check_uniform_bounds of rand_like
+ DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta',
+ dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64)),
+ DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_dtypes'),
+
+ # TODO: RuntimeError: no _refs support for torch.rand_like
+ DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"),
+ 'TestCommon',
+ 'test_python_ref'),
+
+ # AssertionError: Tensor-likes are not close!
+ DecorateInfo(unittest.skip("Expected: exponential is not comparable"),
+ 'TestCommon',
+ 'test_out'),
+ DecorateInfo(unittest.skip("Expected: exponential is not comparable"),
+ 'TestCommon',
+ 'test_out_warning'),
+ DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'),
+ DecorateInfo(unittest.skip("Expected: exponential is not comparable"),
+ 'TestCommon',
+ 'test_python_ref_torch_fallback'),
+ DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
+ DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+ )
+ ),
+ PythonRefInfo(
"_refs.geometric",
torch_opinfo_name="geometric",
supports_out=True,