Revert "Fix jagged NT softmax semantics (#119459)"
This reverts commit 6adadbaf7943f760ea2375619b1783020b69d4e6.
Reverted https://github.com/pytorch/pytorch/pull/119459 on behalf of https://github.com/malfet due to broke dynamo, see https://github.com/pytorch/pytorch/actions/runs/7835402753/job/21386634602 ([comment](https://github.com/pytorch/pytorch/pull/119459#issuecomment-1935246413))
diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py
index 9ffb0c4..355ae71 100644
--- a/test/test_nestedtensor.py
+++ b/test/test_nestedtensor.py
@@ -3182,22 +3182,6 @@
):
torch.split(nt, [1, 2], 1)
- def test_softmax(self, device):
- nt = random_nt_from_dims(
- [3, None, 5], device=device, dtype=torch.float32, layout=torch.jagged)
-
- # operate on dim=2
- output = nt.softmax(dim=2)
- for in_component, out_component in zip(nt.unbind(), output.unbind()):
- # dim=2 -> dim=1 after unbind
- self.assertEqual(in_component.softmax(dim=1), out_component)
-
- # operate on dim=-1
- output2 = nt.softmax(dim=-1)
- self.assertEqual(output, output2)
- for in_component, out_component in zip(nt.unbind(), output2.unbind()):
- self.assertEqual(in_component.softmax(dim=-1), out_component)
-
def test_views_inherit_ragged_dim(self, device):
# view
nt = random_nt_from_dims(
diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py
index 83b9e2e..5a5acca 100644
--- a/torch/nested/_internal/ops.py
+++ b/torch/nested/_internal/ops.py
@@ -447,19 +447,9 @@
)(jagged_unary_pointwise)
-@register_jagged_func(
+register_jagged_func(
torch.ops.aten._softmax.default, "self: jt, dim: any, half_to_float: any"
-)
-def _softmax_default(func, *args, **kwargs):
- _, new_kwargs = normalize_function(
- func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
- )
-
- inp = new_kwargs.pop("input")
- dim = new_kwargs["dim"]
- new_kwargs["dim"] = _wrap_jagged_dim(len(inp._size), dim, "softmax")
-
- return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
+)(jagged_unary_pointwise)
@register_jagged_func(