try to fix the warning in distribute_tensor (#125476)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125476
Approved by: https://github.com/albanD, https://github.com/awgu
ghstack dependencies: #125475
diff --git a/torch/distributed/_tensor/api.py b/torch/distributed/_tensor/api.py
index 16cd41a..42c01b6 100644
--- a/torch/distributed/_tensor/api.py
+++ b/torch/distributed/_tensor/api.py
@@ -573,7 +573,7 @@
# OffsetBasedRNGTracker to perform random operators.
# TODO: the value assignment to global variable is not the ideal solution
# we can replace it in future.
- if is_rng_supported_mesh(device_mesh) and not random._rng_tracker:
+ if not random._rng_tracker and is_rng_supported_mesh(device_mesh):
random._rng_tracker = OffsetBasedRNGTracker(device_type)
if not tensor.is_leaf:
@@ -612,7 +612,7 @@
)
return tensor
- local_tensor = tensor
+ local_tensor = tensor.detach()
# distribute the tensor according to the placements.
placements = list(placements)
@@ -637,7 +637,7 @@
# detach the local tensor passed to DTensor since after the construction
# of DTensor, autograd would work on top of DTensor instead of local tensor
return DTensor(
- local_tensor.detach().requires_grad_(tensor.requires_grad),
+ local_tensor.requires_grad_(tensor.requires_grad),
device_mesh,
placements,
shape=tensor.size(),
diff --git a/torch/distributed/_tensor/dispatch.py b/torch/distributed/_tensor/dispatch.py
index 72c4802..0ce1c6b 100644
--- a/torch/distributed/_tensor/dispatch.py
+++ b/torch/distributed/_tensor/dispatch.py
@@ -1,4 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
+import contextlib
import functools
import operator
from typing import cast, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING
@@ -181,15 +182,15 @@
# run local op computation with potentially modified args/kwargs
local_tensor_args = cast(Tuple[object, ...], local_tensor_args)
- if op_call in self._random_ops and is_rng_supported_mesh(mesh):
- if not random._rng_tracker:
+ if op_call in self._random_ops:
+ if not random._rng_tracker and is_rng_supported_mesh(mesh):
# Default to `OffsetBasedRNGTracker` if the parallelism API
# did not already construct one
random._rng_tracker = random.OffsetBasedRNGTracker(mesh.device_type)
# For DTensor random operator, run it within a distribute region
with random._rng_tracker._distribute_region(
cast(dtensor.DTensor, args[0])._spec
- ):
+ ) if random._rng_tracker else contextlib.nullcontext():
local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
else:
local_results = op_call(*local_tensor_args, **op_info.local_kwargs)