[primTorch] Add refs for `softmax`, `softmin`, `log_softmax` (#84956)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84956
Approved by: https://github.com/lezcano, https://github.com/mruberry
diff --git a/test/test_ops.py b/test/test_ops.py
index 3317e18..c63de0a 100644
--- a/test/test_ops.py
+++ b/test/test_ops.py
@@ -1714,14 +1714,21 @@
         '_refs.isclose',
         '_refs.isfinite',
         '_refs.isreal',
+        '_refs.log_softmax',
         '_refs.movedim',
         '_refs.narrow',
         '_refs.nn.functional.l1_loss',
+        '_refs.nn.functional.log_softmax',
         '_refs.nn.functional.poisson_nll_loss',
+        '_refs.nn.functional.softmax',
+        '_refs.nn.functional.softmin',
         '_refs.positive',
         '_refs.ravel',
         '_refs.reshape',
+        '_refs.softmax',
         '_refs.special.expit',
+        '_refs.special.log_softmax',
+        '_refs.special.softmax',
         '_refs.square',
         '_refs.T',
         '_refs.tensor_split',
diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py
index 9d86f91..a37673a 100644
--- a/torch/_refs/__init__.py
+++ b/torch/_refs/__init__.py
@@ -86,6 +86,7 @@
     "log1p",
     "log2",
     "log10",
+    "log_softmax",
     "nan_to_num",
     "neg",
     "positive",
@@ -98,6 +99,7 @@
     "sin",
     "sinc",
     "sinh",
+    "softmax",
     "sqrt",
     "square",
     "tan",
@@ -651,15 +653,15 @@
     return prims.log10(a)
 
 
+# CompositeImplicitAutograd - don't register decomp
 @out_wrapper()
 def log_softmax(
     a: TensorLikeType,
     dim: int,
-    *,
     dtype: Optional[torch.dtype] = None,
 ) -> TensorLikeType:
     result_dtype = dtype or a.dtype
-    computation_dtype = utils.get_computation_dtype(a.dtype)
+    computation_dtype = utils.get_computation_dtype(result_dtype)
     a_ = _maybe_convert_to_dtype(a, computation_dtype)
     return _maybe_convert_to_dtype(a_ - logsumexp(a_, dim, keepdim=True), result_dtype)  # type: ignore[return-value]
 
@@ -3117,17 +3119,16 @@
     return torch.cat([t.unsqueeze(wrapped_dim) for t in tensors], dim)
 
 
+# CompositeImplicitAutograd - don't register decomp
 @out_wrapper()
 def softmax(
     a: TensorLikeType,
     dim: int,
-    *,
     dtype: Optional[torch.dtype] = None,
 ) -> TensorLikeType:
     result_dtype = dtype or a.dtype
-    computation_dtype = utils.get_computation_dtype(a.dtype)
+    computation_dtype = utils.get_computation_dtype(result_dtype)
     a_ = _maybe_convert_to_dtype(a, computation_dtype)
-    assert isinstance(a_, TensorLike)  # to avoid MyPy error for amax
     a_max = amax(a_, dim, keepdim=True)
     a_exp = exp(a_ - a_max)
     return _maybe_convert_to_dtype(
diff --git a/torch/_refs/nn/functional/__init__.py b/torch/_refs/nn/functional/__init__.py
index 58d3c82..3cde678 100644
--- a/torch/_refs/nn/functional/__init__.py
+++ b/torch/_refs/nn/functional/__init__.py
@@ -35,6 +35,7 @@
     "hinge_embedding_loss",
     "huber_loss",
     "l1_loss",
+    "log_softmax",
     "margin_ranking_loss",
     "mish",
     "nll_loss",
@@ -44,6 +45,8 @@
     "relu",
     "relu6",
     "selu",
+    "softmax",
+    "softmin",
     "softplus",
     "softshrink",
     "tanhshrink",
@@ -238,6 +241,37 @@
     return scale * torch.where(a > 0, a, rhs)
 
 
+# Forwarding alias: the functional variant doesn't support the out kwarg
+# CompositeImplicitAutograd - don't register decomp
+def softmax(
+    a: TensorLikeType,
+    dim: Optional[int] = None,
+    _stacklevel: int = 3,  # for compat when using TorchRefsMode(strict=True)
+    dtype: Optional[torch.dtype] = None,
+) -> TensorLikeType:
+    # The error is for compat with regular PyTorch, which has this behavior
+    # deprecated.  For PrimTorch, it's fine to drop support for deprecated
+    # behavior because it requires explicit opt in.  This error is to inform
+    # users how to update their calls.
+    check(dim is not None, lambda: "implicit dim not supported, use dim=X")
+    return torch.softmax(a=a, dim=dim, dtype=dtype)  # type: ignore[call-overload]
+
+
+# CompositeImplicitAutograd - don't register decomp
+def softmin(
+    a: TensorLikeType,
+    dim: Optional[int] = None,
+    _stacklevel: int = 3,  # for compat when using TorchRefsMode(strict=True)
+    dtype: Optional[torch.dtype] = None,
+) -> TensorLikeType:
+    # The error is for compat with regular PyTorch, which has this behavior
+    # deprecated.  For PrimTorch, it's fine to drop support for deprecated
+    # behavior because it requires explicit opt in.  This error is to inform
+    # users how to update their calls.
+    check(dim is not None, lambda: "implicit dim not supported, use dim=X")
+    return torch.softmax(a=-a, dim=dim, dtype=dtype)  # type: ignore[call-overload]
+
+
 # softplus is implemented specially because it has beta and threshold arguments
 @register_decomposition(torch.ops.aten.softplus)
 @out_wrapper()
@@ -374,6 +408,22 @@
     return _apply_loss_reduction(loss, reduction)
 
 
+# Forwarding alias: the functional variant doesn't support the out kwarg
+# CompositeImplicitAutograd - don't register decomp
+def log_softmax(
+    a: TensorLikeType,
+    dim: Optional[int] = None,
+    _stacklevel: int = 3,  # for compat when using TorchRefsMode(strict=True)
+    dtype: Optional[torch.dtype] = None,
+) -> TensorLikeType:
+    # The error is for compat with regular PyTorch, which has this behavior
+    # deprecated.  For PrimTorch, it's fine to drop support for deprecated
+    # behavior because it requires explicit opt in.  This error is to inform
+    # users how to update their calls.
+    check(dim is not None, lambda: "implicit dim not supported, use dim=X")
+    return torch.log_softmax(a=a, dim=dim, dtype=dtype)  # type: ignore[call-overload]
+
+
 @register_decomposition(torch.ops.aten.margin_ranking_loss)
 def margin_ranking_loss(
     input1: TensorLikeType,
diff --git a/torch/_refs/special/__init__.py b/torch/_refs/special/__init__.py
index c4aa33f..fae9f9d 100644
--- a/torch/_refs/special/__init__.py
+++ b/torch/_refs/special/__init__.py
@@ -27,9 +27,11 @@
     "i1e",
     "log_ndtr",
     "logit",
+    "log_softmax",
     "multigammaln",
     "ndtr",
     "ndtri",
+    "softmax",
     "spherical_bessel_j0",
     "zeta",
 ]
@@ -167,6 +169,26 @@
     return prims.ndtri(a)
 
 
+# Forwarding alias: the special variant doesn't support the out kwarg
+# CompositeImplicitAutograd - don't register decomp
+def log_softmax(
+    a: TensorLikeType,
+    dim: int,
+    dtype: Optional[torch.dtype] = None,
+) -> TensorLikeType:
+    return torch.log_softmax(a=a, dim=dim, dtype=dtype)  # type: ignore[call-overload]
+
+
+# Forwarding alias: the special variant doesn't support the out kwarg
+# CompositeImplicitAutograd - don't register decomp
+def softmax(
+    a: TensorLikeType,
+    dim: int,
+    dtype: Optional[torch.dtype] = None,
+) -> TensorLikeType:
+    return torch.softmax(a=a, dim=dim, dtype=dtype)  # type: ignore[call-overload]
+
+
 @_make_elementwise_unary_reference(
     ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
     aten_op=torch.ops.aten.special_spherical_bessel_j0,
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 0623a0b..550ef33 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -15140,7 +15140,7 @@
         assert_autodiffed=True),
     OpInfo(
         'log_softmax',
-        variant_test_name='dtype',
+        variant_test_name='with_dtype',
         aliases=('special.log_softmax', 'nn.functional.log_softmax'),
         supports_out=True,
         dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
@@ -16697,6 +16697,7 @@
     PythonRefInfo(
         "_refs.log_softmax",
         torch_opinfo_name="log_softmax",
+        torch_opinfo_variant_name="with_dtype",
     ),
     ElementwiseUnaryPythonRefInfo(
         "_refs.nan_to_num",
@@ -16771,6 +16772,7 @@
     PythonRefInfo(
         "_refs.softmax",
         torch_opinfo_name="softmax",
+        torch_opinfo_variant_name="with_dtype",
     ),
     ElementwiseUnaryPythonRefInfo(
         "_refs.sqrt",
@@ -16799,6 +16801,18 @@
         # https://github.com/pytorch/pytorch/issues/85258
         supports_nvfuser=False,
     ),
+    PythonRefInfo(
+        "_refs.special.log_softmax",
+        torch_opinfo_name="log_softmax",  # alias
+        torch_opinfo_variant_name="with_dtype",
+        supports_out=False,
+    ),
+    PythonRefInfo(
+        "_refs.special.softmax",
+        torch_opinfo_name="softmax",  # alias
+        torch_opinfo_variant_name="with_dtype",
+        supports_out=False,
+    ),
     #
     # Elementwise Unary Special OpInfos
     #
@@ -16897,6 +16911,12 @@
         torch_opinfo_name="nn.functional.leaky_relu",
     ),
     PythonRefInfo(
+        "_refs.nn.functional.log_softmax",
+        torch_opinfo_name="log_softmax",  # alias
+        torch_opinfo_variant_name="with_dtype",
+        supports_out=False,
+    ),
+    PythonRefInfo(
         "_refs.nn.functional.poisson_nll_loss",
         torch_opinfo_name="nn.functional.poisson_nll_loss",
     ),
@@ -16921,6 +16941,18 @@
         "_refs.nn.functional.selu",
         torch_opinfo_name="nn.functional.selu",
     ),
+    PythonRefInfo(
+        "_refs.nn.functional.softmax",
+        torch_opinfo_name="softmax",  # alias
+        torch_opinfo_variant_name="with_dtype",
+        supports_out=False,
+    ),
+    PythonRefInfo(
+        "_refs.nn.functional.softmin",
+        torch_opinfo_name="nn.functional.softmin",
+        torch_opinfo_variant_name="with_dtype",
+        supports_out=False,
+    ),
     ElementwiseUnaryPythonRefInfo(
         "_refs.nn.functional.softplus",
         torch_opinfo_name="nn.functional.softplus",