[SPMD][EASY] Remove unnecessary torch.ops prefix (#99331)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99331
Approved by: https://github.com/dracifer
diff --git a/torch/distributed/_spmd/distribute.py b/torch/distributed/_spmd/distribute.py
index 2f26a4c..408d6cd 100644
--- a/torch/distributed/_spmd/distribute.py
+++ b/torch/distributed/_spmd/distribute.py
@@ -188,9 +188,9 @@
def is_sym_int_or_int(arg: Union[int, torch.fx.Node]) -> bool:
if isinstance(arg, torch.fx.Node):
return arg.target in [
- torch.ops.aten.sym_size,
- torch.ops.aten.sym_numel,
- torch.ops.aten.sym_stride,
+ aten.sym_size,
+ aten.sym_numel,
+ aten.sym_stride,
]
return isinstance(arg, int)
@@ -370,12 +370,12 @@
)
return None
- if node.target == torch.ops.aten.view.default:
+ if node.target == aten.view.default:
# HACK: this is a hack to get around with the fact that some
# view operations on a "global" tensor is invalid usage
# but somehow the view operation on the batch input might hit it
# so we convert the view op to reshape before calling DTensor
- op_overload = torch.ops.aten.reshape.default
+ op_overload = aten.reshape.default
# DSymInt args are not sharded on any dimension, local value and global
# value should be the same
diff --git a/torch/distributed/_spmd/graph_optimization.py b/torch/distributed/_spmd/graph_optimization.py
index f5ef2d1..caa9ac9 100644
--- a/torch/distributed/_spmd/graph_optimization.py
+++ b/torch/distributed/_spmd/graph_optimization.py
@@ -824,7 +824,7 @@
)
orig_step_outputs.append(orig_optim_block.step.outputs[idx])
step = gm.graph.call_function(
- torch.ops.aten._foreach_add.Scalar,
+ aten._foreach_add.Scalar,
(step_args, 1),
)
step_block = ForeachAddBlock(step, generate_output=True)
@@ -844,7 +844,7 @@
# topo sort order is the last.
with gm.graph.inserting_after(step_block.outputs[0]):
optim = gm.graph.call_function(
- torch.ops.aten._fused_adam.default,
+ aten._fused_adam.default,
optim_args[group_idx],
orig_optim_block.optim.optim_node.kwargs,
)