[primTorch] Add refs for `softmax`, `softmin`, `log_softmax` (#84956)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84956
Approved by: https://github.com/lezcano, https://github.com/mruberry
diff --git a/test/test_ops.py b/test/test_ops.py
index 3317e18..c63de0a 100644
--- a/test/test_ops.py
+++ b/test/test_ops.py
@@ -1714,14 +1714,21 @@
'_refs.isclose',
'_refs.isfinite',
'_refs.isreal',
+ '_refs.log_softmax',
'_refs.movedim',
'_refs.narrow',
'_refs.nn.functional.l1_loss',
+ '_refs.nn.functional.log_softmax',
'_refs.nn.functional.poisson_nll_loss',
+ '_refs.nn.functional.softmax',
+ '_refs.nn.functional.softmin',
'_refs.positive',
'_refs.ravel',
'_refs.reshape',
+ '_refs.softmax',
'_refs.special.expit',
+ '_refs.special.log_softmax',
+ '_refs.special.softmax',
'_refs.square',
'_refs.T',
'_refs.tensor_split',
diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py
index 9d86f91..a37673a 100644
--- a/torch/_refs/__init__.py
+++ b/torch/_refs/__init__.py
@@ -86,6 +86,7 @@
"log1p",
"log2",
"log10",
+ "log_softmax",
"nan_to_num",
"neg",
"positive",
@@ -98,6 +99,7 @@
"sin",
"sinc",
"sinh",
+ "softmax",
"sqrt",
"square",
"tan",
@@ -651,15 +653,15 @@
return prims.log10(a)
+# CompositeImplicitAutograd - don't register decomp
@out_wrapper()
def log_softmax(
a: TensorLikeType,
dim: int,
- *,
dtype: Optional[torch.dtype] = None,
) -> TensorLikeType:
result_dtype = dtype or a.dtype
- computation_dtype = utils.get_computation_dtype(a.dtype)
+ computation_dtype = utils.get_computation_dtype(result_dtype)
a_ = _maybe_convert_to_dtype(a, computation_dtype)
return _maybe_convert_to_dtype(a_ - logsumexp(a_, dim, keepdim=True), result_dtype) # type: ignore[return-value]
@@ -3117,17 +3119,16 @@
return torch.cat([t.unsqueeze(wrapped_dim) for t in tensors], dim)
+# CompositeImplicitAutograd - don't register decomp
@out_wrapper()
def softmax(
a: TensorLikeType,
dim: int,
- *,
dtype: Optional[torch.dtype] = None,
) -> TensorLikeType:
result_dtype = dtype or a.dtype
- computation_dtype = utils.get_computation_dtype(a.dtype)
+ computation_dtype = utils.get_computation_dtype(result_dtype)
a_ = _maybe_convert_to_dtype(a, computation_dtype)
- assert isinstance(a_, TensorLike) # to avoid MyPy error for amax
a_max = amax(a_, dim, keepdim=True)
a_exp = exp(a_ - a_max)
return _maybe_convert_to_dtype(
diff --git a/torch/_refs/nn/functional/__init__.py b/torch/_refs/nn/functional/__init__.py
index 58d3c82..3cde678 100644
--- a/torch/_refs/nn/functional/__init__.py
+++ b/torch/_refs/nn/functional/__init__.py
@@ -35,6 +35,7 @@
"hinge_embedding_loss",
"huber_loss",
"l1_loss",
+ "log_softmax",
"margin_ranking_loss",
"mish",
"nll_loss",
@@ -44,6 +45,8 @@
"relu",
"relu6",
"selu",
+ "softmax",
+ "softmin",
"softplus",
"softshrink",
"tanhshrink",
@@ -238,6 +241,37 @@
return scale * torch.where(a > 0, a, rhs)
+# Forwarding alias: the functional variant doesn't support the out kwarg
+# CompositeImplicitAutograd - don't register decomp
+def softmax(
+ a: TensorLikeType,
+ dim: Optional[int] = None,
+ _stacklevel: int = 3, # for compat when using TorchRefsMode(strict=True)
+ dtype: Optional[torch.dtype] = None,
+) -> TensorLikeType:
+ # The error is for compat with regular PyTorch, which has this behavior
+ # deprecated. For PrimTorch, it's fine to drop support for deprecated
+ # behavior because it requires explicit opt in. This error is to inform
+ # users how to update their calls.
+ check(dim is not None, lambda: "implicit dim not supported, use dim=X")
+ return torch.softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload]
+
+
+# CompositeImplicitAutograd - don't register decomp
+def softmin(
+ a: TensorLikeType,
+ dim: Optional[int] = None,
+ _stacklevel: int = 3, # for compat when using TorchRefsMode(strict=True)
+ dtype: Optional[torch.dtype] = None,
+) -> TensorLikeType:
+ # The error is for compat with regular PyTorch, which has this behavior
+ # deprecated. For PrimTorch, it's fine to drop support for deprecated
+ # behavior because it requires explicit opt in. This error is to inform
+ # users how to update their calls.
+ check(dim is not None, lambda: "implicit dim not supported, use dim=X")
+ return torch.softmax(a=-a, dim=dim, dtype=dtype) # type: ignore[call-overload]
+
+
# softplus is implemented specially because it has beta and threshold arguments
@register_decomposition(torch.ops.aten.softplus)
@out_wrapper()
@@ -374,6 +408,22 @@
return _apply_loss_reduction(loss, reduction)
+# Forwarding alias: the functional variant doesn't support the out kwarg
+# CompositeImplicitAutograd - don't register decomp
+def log_softmax(
+ a: TensorLikeType,
+ dim: Optional[int] = None,
+ _stacklevel: int = 3, # for compat when using TorchRefsMode(strict=True)
+ dtype: Optional[torch.dtype] = None,
+) -> TensorLikeType:
+ # The error is for compat with regular PyTorch, which has this behavior
+ # deprecated. For PrimTorch, it's fine to drop support for deprecated
+ # behavior because it requires explicit opt in. This error is to inform
+ # users how to update their calls.
+ check(dim is not None, lambda: "implicit dim not supported, use dim=X")
+ return torch.log_softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload]
+
+
@register_decomposition(torch.ops.aten.margin_ranking_loss)
def margin_ranking_loss(
input1: TensorLikeType,
diff --git a/torch/_refs/special/__init__.py b/torch/_refs/special/__init__.py
index c4aa33f..fae9f9d 100644
--- a/torch/_refs/special/__init__.py
+++ b/torch/_refs/special/__init__.py
@@ -27,9 +27,11 @@
"i1e",
"log_ndtr",
"logit",
+ "log_softmax",
"multigammaln",
"ndtr",
"ndtri",
+ "softmax",
"spherical_bessel_j0",
"zeta",
]
@@ -167,6 +169,26 @@
return prims.ndtri(a)
+# Forwarding alias: the special variant doesn't support the out kwarg
+# CompositeImplicitAutograd - don't register decomp
+def log_softmax(
+ a: TensorLikeType,
+ dim: int,
+ dtype: Optional[torch.dtype] = None,
+) -> TensorLikeType:
+ return torch.log_softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload]
+
+
+# Forwarding alias: the special variant doesn't support the out kwarg
+# CompositeImplicitAutograd - don't register decomp
+def softmax(
+ a: TensorLikeType,
+ dim: int,
+ dtype: Optional[torch.dtype] = None,
+) -> TensorLikeType:
+ return torch.softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload]
+
+
@_make_elementwise_unary_reference(
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
aten_op=torch.ops.aten.special_spherical_bessel_j0,
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 0623a0b..550ef33 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -15140,7 +15140,7 @@
assert_autodiffed=True),
OpInfo(
'log_softmax',
- variant_test_name='dtype',
+ variant_test_name='with_dtype',
aliases=('special.log_softmax', 'nn.functional.log_softmax'),
supports_out=True,
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
@@ -16697,6 +16697,7 @@
PythonRefInfo(
"_refs.log_softmax",
torch_opinfo_name="log_softmax",
+ torch_opinfo_variant_name="with_dtype",
),
ElementwiseUnaryPythonRefInfo(
"_refs.nan_to_num",
@@ -16771,6 +16772,7 @@
PythonRefInfo(
"_refs.softmax",
torch_opinfo_name="softmax",
+ torch_opinfo_variant_name="with_dtype",
),
ElementwiseUnaryPythonRefInfo(
"_refs.sqrt",
@@ -16799,6 +16801,18 @@
# https://github.com/pytorch/pytorch/issues/85258
supports_nvfuser=False,
),
+ PythonRefInfo(
+ "_refs.special.log_softmax",
+ torch_opinfo_name="log_softmax", # alias
+ torch_opinfo_variant_name="with_dtype",
+ supports_out=False,
+ ),
+ PythonRefInfo(
+ "_refs.special.softmax",
+ torch_opinfo_name="softmax", # alias
+ torch_opinfo_variant_name="with_dtype",
+ supports_out=False,
+ ),
#
# Elementwise Unary Special OpInfos
#
@@ -16897,6 +16911,12 @@
torch_opinfo_name="nn.functional.leaky_relu",
),
PythonRefInfo(
+ "_refs.nn.functional.log_softmax",
+ torch_opinfo_name="log_softmax", # alias
+ torch_opinfo_variant_name="with_dtype",
+ supports_out=False,
+ ),
+ PythonRefInfo(
"_refs.nn.functional.poisson_nll_loss",
torch_opinfo_name="nn.functional.poisson_nll_loss",
),
@@ -16921,6 +16941,18 @@
"_refs.nn.functional.selu",
torch_opinfo_name="nn.functional.selu",
),
+ PythonRefInfo(
+ "_refs.nn.functional.softmax",
+ torch_opinfo_name="softmax", # alias
+ torch_opinfo_variant_name="with_dtype",
+ supports_out=False,
+ ),
+ PythonRefInfo(
+ "_refs.nn.functional.softmin",
+ torch_opinfo_name="nn.functional.softmin",
+ torch_opinfo_variant_name="with_dtype",
+ supports_out=False,
+ ),
ElementwiseUnaryPythonRefInfo(
"_refs.nn.functional.softplus",
torch_opinfo_name="nn.functional.softplus",