Revert "Fix jagged NT softmax semantics (#119459)"

This reverts commit 6adadbaf7943f760ea2375619b1783020b69d4e6.

Reverted https://github.com/pytorch/pytorch/pull/119459 on behalf of https://github.com/malfet due to broke dynamo, see https://github.com/pytorch/pytorch/actions/runs/7835402753/job/21386634602 ([comment](https://github.com/pytorch/pytorch/pull/119459#issuecomment-1935246413))
diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py
index 9ffb0c4..355ae71 100644
--- a/test/test_nestedtensor.py
+++ b/test/test_nestedtensor.py
@@ -3182,22 +3182,6 @@
         ):
             torch.split(nt, [1, 2], 1)
 
-    def test_softmax(self, device):
-        nt = random_nt_from_dims(
-            [3, None, 5], device=device, dtype=torch.float32, layout=torch.jagged)
-
-        # operate on dim=2
-        output = nt.softmax(dim=2)
-        for in_component, out_component in zip(nt.unbind(), output.unbind()):
-            # dim=2 -> dim=1 after unbind
-            self.assertEqual(in_component.softmax(dim=1), out_component)
-
-        # operate on dim=-1
-        output2 = nt.softmax(dim=-1)
-        self.assertEqual(output, output2)
-        for in_component, out_component in zip(nt.unbind(), output2.unbind()):
-            self.assertEqual(in_component.softmax(dim=-1), out_component)
-
     def test_views_inherit_ragged_dim(self, device):
         # view
         nt = random_nt_from_dims(
diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py
index 83b9e2e..5a5acca 100644
--- a/torch/nested/_internal/ops.py
+++ b/torch/nested/_internal/ops.py
@@ -447,19 +447,9 @@
 )(jagged_unary_pointwise)
 
 
-@register_jagged_func(
+register_jagged_func(
     torch.ops.aten._softmax.default, "self: jt, dim: any, half_to_float: any"
-)
-def _softmax_default(func, *args, **kwargs):
-    _, new_kwargs = normalize_function(
-        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
-    )
-
-    inp = new_kwargs.pop("input")
-    dim = new_kwargs["dim"]
-    new_kwargs["dim"] = _wrap_jagged_dim(len(inp._size), dim, "softmax")
-
-    return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
+)(jagged_unary_pointwise)
 
 
 @register_jagged_func(