[FSDP] Skip `_use_sharded_views()` for `SHARD_GRAD_OP` (#98250)
This PR has `SHARD_GRAD_OP` (and `_HYBRID_SHARD_ZERO2`) skip `_use_sharded_views()` in the post-forward reshard since the strategy does not free the unsharded flat parameter and can preserve the unsharded views. This saves nontrivial CPU overhead both in the post-forward reshard (`_use_sharded_views()`) and the pre-backward unshard (`_use_unsharded_views()`).
<details>
<summary>(Before) Pre-backward hook: 4.356 ms</summary>
<img width="812" alt="Screenshot 2023-04-03 at 6 32 19 PM" src="https://user-images.githubusercontent.com/31054793/229641309-778cf1f9-4b5b-42ec-b2d8-0a1e6e7ce330.png">
</details>
<details>
<summary>(After) Pre-backward hook: 1.044 ms</summary>

</details>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98250
Approved by: https://github.com/rohan-varma
diff --git a/test/distributed/fsdp/test_fsdp_use_orig_params.py b/test/distributed/fsdp/test_fsdp_use_orig_params.py
index 10bf09b..3100135 100644
--- a/test/distributed/fsdp/test_fsdp_use_orig_params.py
+++ b/test/distributed/fsdp/test_fsdp_use_orig_params.py
@@ -18,6 +18,7 @@
ShardingStrategy,
)
from torch.distributed.fsdp._common_utils import clean_tensor_name
+from torch.distributed.fsdp._init_utils import NO_RESHARD_AFTER_FORWARD_STRATEGIES
from torch.distributed.fsdp.wrap import always_wrap_policy, ModuleWrapPolicy
from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
@@ -500,6 +501,10 @@
# Check that FSDP correctly exposes gradients even after forward
# (namely, `None` for weights and non-`None` for biases)
+ if sharding_strategy in NO_RESHARD_AFTER_FORWARD_STRATEGIES:
+ # Skip the check since we do not expose the gradients after forward
+ # for these strategies
+ return
for (ddp_n, ddp_p), (fsdp_n, fsdp_p) in zip(
ddp_model.module.named_parameters(),
fsdp_model.named_parameters(),
@@ -759,7 +764,9 @@
def get_loss(self, inp, out):
return out.sum()
- def check_parameter_parity(ddp_model, fsdp_model):
+ def check_parameter_parity(
+ ddp_model: DDP, fsdp_model: FSDP, between_fwd_and_bwd: bool
+ ):
assert self.rank in (
0,
1,
@@ -773,6 +780,13 @@
# For `NO_SHARD`, do nothing since the original parameters
# are unflattened
pass
+ elif (
+ between_fwd_and_bwd
+ and sharding_strategy in NO_RESHARD_AFTER_FORWARD_STRATEGIES
+ ):
+ # For no reshard after forward strategies, do nothing since
+ # FSDP did not use sharded views after forward
+ pass
# Otherwise, case on the parameter (see the model definition)
elif n1 == "lin1.weight":
if self.rank == 0:
@@ -806,7 +820,7 @@
inp = fsdp_model.get_input(device)
ddp_out = ddp_model(*inp)
fsdp_out = fsdp_model(*inp)
- check_parameter_parity(ddp_model, fsdp_model)
+ check_parameter_parity(ddp_model, fsdp_model, True)
ddp_loss = ddp_model.module.get_loss(inp, ddp_out)
fsdp_loss = fsdp_model.get_loss(inp, fsdp_out)
@@ -814,23 +828,23 @@
fsdp_loss.backward()
ddp_optim.step()
fsdp_optim.step()
- check_parameter_parity(ddp_model, fsdp_model)
+ check_parameter_parity(ddp_model, fsdp_model, False)
inp = fsdp_model.get_input(device)
ddp_out = ddp_model(*inp)
fsdp_out = fsdp_model(*inp)
- check_parameter_parity(ddp_model, fsdp_model)
+ check_parameter_parity(ddp_model, fsdp_model, True)
class TestFSDPUseOrigParamsWriteback(FSDPTest):
"""Tests parameter and gradient writeback."""
class Model(nn.Module):
- def __init__(self):
+ def __init__(self, device: torch.device):
super().__init__()
torch.manual_seed(42)
- self.lin1 = nn.Linear(5, 5, bias=True)
- self.lin2 = nn.Linear(5, 7, bias=True)
+ self.lin1 = nn.Linear(5, 5, bias=True, device=device)
+ self.lin2 = nn.Linear(5, 7, bias=True, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
z = self.lin1(x)
@@ -876,10 +890,12 @@
# Check that the writeback propagates
ddp_model = DDP(
- TestFSDPUseOrigParamsWriteback.Model().cuda(), device_ids=[self.rank]
+ TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")),
+ device_ids=[self.rank],
)
fsdp_model = FSDP(
- TestFSDPUseOrigParamsWriteback.Model().cuda(), use_orig_params=True
+ TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")),
+ use_orig_params=True,
)
ddp = ddp_model.module # for brevity
fsdp = fsdp_model.module
@@ -927,10 +943,12 @@
return None if set_to_none else torch.ones_like(param) * 2
ddp_model = DDP(
- TestFSDPUseOrigParamsWriteback.Model().cuda(), device_ids=[self.rank]
+ TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")),
+ device_ids=[self.rank],
)
fsdp_model = FSDP(
- TestFSDPUseOrigParamsWriteback.Model().cuda(), use_orig_params=True
+ TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")),
+ use_orig_params=True,
)
LR = 1e-2
# TODO: If we add `summon_full_params(with_grads=True)`, then replace
@@ -982,7 +1000,8 @@
@skip_if_lt_x_gpu(2)
def test_writeback_shape_mismatch(self):
fsdp_model = FSDP(
- TestFSDPUseOrigParamsWriteback.Model().cuda(), use_orig_params=True
+ TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")),
+ use_orig_params=True,
)
# Check that writing back with mismatched shape errors
fsdp = fsdp_model.module # for brevity
@@ -1019,6 +1038,41 @@
with FSDP.summon_full_params(fsdp_model): # triggers a writeback
...
+ @skip_if_lt_x_gpu(2)
+ def test_writeback_between_fwd_and_bwd_for_no_reshard_raises(self):
+ fsdp_kwargs = {
+ "sharding_strategy": ShardingStrategy.SHARD_GRAD_OP,
+ "auto_wrap_policy": ModuleWrapPolicy({nn.Linear}),
+ "use_orig_params": True,
+ }
+ fsdp_wrapper = functools.partial(FSDP, **fsdp_kwargs)
+
+ # Test changing the parameter storage to no longer be a view into the
+ # flat parameter
+ fsdp_model = fsdp_wrapper(
+ TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda"))
+ )
+ inp = fsdp_model.get_input(torch.device("cuda"))
+ loss = fsdp_model(*inp).sum()
+ fsdp_model.lin1.weight.data = fsdp_model.lin1.weight.clone()
+ assert_msg = (
+ "FSDP does not support changing the parameters between forward and backward"
+ )
+ with self.assertRaisesRegex(AssertionError, assert_msg):
+ loss.backward()
+
+ # Test changing the parameter variable itself
+ fsdp_model = fsdp_wrapper(
+ TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda"))
+ )
+ inp = fsdp_model.get_input(torch.device("cuda"))
+ loss = fsdp_model(*inp).sum()
+ fsdp_model.lin1._fsdp_wrapped_module.weight = nn.Parameter(
+ fsdp_model.lin1.weight.clone()
+ )
+ with self.assertRaisesRegex(AssertionError, assert_msg):
+ loss.backward()
+
class TestFSDPUseOrigParamsFQNs(FSDPTest):
@skip_if_lt_x_gpu(2)
diff --git a/torch/distributed/fsdp/_init_utils.py b/torch/distributed/fsdp/_init_utils.py
index 8f97bcf..3c97b30 100644
--- a/torch/distributed/fsdp/_init_utils.py
+++ b/torch/distributed/fsdp/_init_utils.py
@@ -75,11 +75,14 @@
ShardingStrategy.HYBRID_SHARD: HandleShardingStrategy.HYBRID_SHARD,
ShardingStrategy._HYBRID_SHARD_ZERO2: HandleShardingStrategy._HYBRID_SHARD_ZERO2,
}
-
HYBRID_SHARDING_STRATEGIES = {
ShardingStrategy.HYBRID_SHARD,
ShardingStrategy._HYBRID_SHARD_ZERO2,
}
+NO_RESHARD_AFTER_FORWARD_STRATEGIES = (
+ ShardingStrategy.SHARD_GRAD_OP,
+ ShardingStrategy._HYBRID_SHARD_ZERO2,
+)
# NOTE: Since non-self attributes cannot be type annotated, several attributes
diff --git a/torch/distributed/fsdp/_runtime_utils.py b/torch/distributed/fsdp/_runtime_utils.py
index 511d172..c89bcd4 100644
--- a/torch/distributed/fsdp/_runtime_utils.py
+++ b/torch/distributed/fsdp/_runtime_utils.py
@@ -35,13 +35,10 @@
FlatParamHandle,
HandleShardingStrategy,
HandleTrainingState,
+ RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES,
)
from torch.distributed.utils import _apply_to_tensors, _p_assert, _to_kwargs
-RESHARD_AFTER_FORWARD_STRATEGIES = {
- HandleShardingStrategy.FULL_SHARD,
- HandleShardingStrategy.HYBRID_SHARD,
-}
# Do not include "process_group" to enable hybrid shard and MoE cases
HOMOGENEOUS_ATTR_NAMES = (
@@ -501,7 +498,7 @@
# computation (though this may not be true)
free_unsharded_flat_params = [
not state._is_root
- and handle._sharding_strategy in RESHARD_AFTER_FORWARD_STRATEGIES
+ and handle._sharding_strategy in RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES
for handle in handles
]
_reshard(state, handles, free_unsharded_flat_params)
@@ -830,7 +827,7 @@
# higher throughput.
return (
state._sync_gradients
- or handle._sharding_strategy in RESHARD_AFTER_FORWARD_STRATEGIES
+ or handle._sharding_strategy in RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES
)
diff --git a/torch/distributed/fsdp/flat_param.py b/torch/distributed/fsdp/flat_param.py
index 97f81f6..3cc1a6c 100644
--- a/torch/distributed/fsdp/flat_param.py
+++ b/torch/distributed/fsdp/flat_param.py
@@ -81,6 +81,25 @@
_FLAT_PARAM_PADDING_VALUE = 42
+# TODO: Define this for now to avoid circular imports. See if we can remove.
+class HandleShardingStrategy(Enum):
+ FULL_SHARD = auto()
+ SHARD_GRAD_OP = auto()
+ NO_SHARD = auto()
+ HYBRID_SHARD = auto()
+ _HYBRID_SHARD_ZERO2 = auto()
+
+
+RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES = (
+ HandleShardingStrategy.FULL_SHARD,
+ HandleShardingStrategy.HYBRID_SHARD,
+)
+NO_RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES = (
+ HandleShardingStrategy.SHARD_GRAD_OP,
+ HandleShardingStrategy._HYBRID_SHARD_ZERO2,
+)
+
+
class ParamInfo(NamedTuple):
"""Information for an original parameter."""
@@ -144,17 +163,6 @@
param_offsets: Tuple[Tuple[int, int], ...]
-# TODO (awgu): Prefix these with "Handle" for now to avoid circular imports and
-# inadvertent misuses; coalesce with those in fully_sharded_data_parallel.py
-# later
-class HandleShardingStrategy(Enum):
- FULL_SHARD = auto()
- SHARD_GRAD_OP = auto()
- NO_SHARD = auto()
- HYBRID_SHARD = auto()
- _HYBRID_SHARD_ZERO2 = auto()
-
-
class FlatParameter(nn.Parameter):
"""
This is the flat parameter used by :class:`FullyShardedDataParallel`. It is
@@ -417,6 +425,11 @@
self._training_state = HandleTrainingState.IDLE
self._debug_level = dist.get_debug_level()
self._fully_sharded_module = fully_sharded_module
+ # NOTE: For the code path using this flag, we only skip calling
+ # `_use_sharded_views()` and do not skip switching to the sharded flat
+ # parameter since whether `self.flat_param` uses the sharded or
+ # unsharded flat parameter parameterizes behavior.
+ self._skipped_use_sharded_views = False
# Optimistically assume a valid input `params` and set dtype attributes
# before `_init_flat_param()`, which performs the actual validation
self._orig_param_dtype = params[0].dtype
@@ -1199,6 +1212,11 @@
in_forward = self._training_state == HandleTrainingState.FORWARD
in_pre_backward = self._training_state == HandleTrainingState.BACKWARD_PRE
if self._use_orig_params:
+ if self._skipped_use_sharded_views and in_pre_backward:
+ # This call corresponds to the complementary pre-backward
+ # `_use_unsharded_views()` to the skipped pre-forward
+ # `_use_sharded_views()`, so we should skip this one too.
+ return
# We use `Tensor` views in the forward so that they are tracked by
# autograd. We use them in the pre-backward as well to support
# reentrant activation checkpointing, which needs the views to be
@@ -1511,12 +1529,26 @@
)
flat_param.data = flat_param._local_shard # type: ignore[attr-defined]
if self._use_orig_params:
- self._use_sharded_views()
+ in_forward = self._training_state == HandleTrainingState.FORWARD
+ if (
+ in_forward
+ and self._sharding_strategy
+ in NO_RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES
+ ):
+ self._skipped_use_sharded_views = True
+ else:
+ self._use_sharded_views()
# For the post-forward reshard, we may try to use sharded gradient
# views (or unsharded gradient views if a gradient was accumulated
# in `no_sync()`), but for the post-backward reshard, we delay the
# call to after the reduce-scatter.
- if self._training_state == HandleTrainingState.FORWARD:
+ if (
+ in_forward
+ # Skip using gradient views if skipped using sharded views
+ # since exposing unsharded parameters with sharded gradients
+ # may be confusing to the user
+ and not self._skipped_use_sharded_views
+ ):
# TODO: Change `_unpadded_unsharded_size` if we change the
# gradient to be computed directly with padding.
accumulated_grad_in_no_sync = (
@@ -1755,6 +1787,7 @@
printability. Parameters whose data is present must preserve their
variables to be passable to an optimizer.
"""
+ self._skipped_use_sharded_views = False
if not self.uses_sharded_strategy:
# For `NO_SHARD`, use the *unflattened* unsharded views since we
# have the unsharded parameter
@@ -1869,12 +1902,21 @@
but no longer has the expected flattened shape.
Returns: ``True`` if some writeback happened, and ``False`` otherwise.
"""
- if self.uses_sharded_strategy and not self.is_sharded(self.flat_param):
+ if (
+ self.uses_sharded_strategy
+ and not self.is_sharded(self.flat_param)
+ and not self._skipped_use_sharded_views
+ ):
# For `NO_SHARD`, we may still need to writeback
return False
flat_param = self.flat_param
wroteback = False
- flat_param_data_ptr = flat_param.untyped_storage().data_ptr()
+ if self._skipped_use_sharded_views and self.uses_sharded_strategy:
+ flat_param_data_ptr = (
+ self._get_padded_unsharded_flat_param().untyped_storage().data_ptr()
+ )
+ else:
+ flat_param_data_ptr = flat_param.untyped_storage().data_ptr()
# NOTE: Since this method is called in the pre-unshard, which is only
# called during computation in the pre-forward or pre-backward, the
# sharded gradient should be guaranteed to be in `.grad`, not in
@@ -1908,6 +1950,12 @@
continue
# Check for parameter writeback
+ if self._skipped_use_sharded_views:
+ param = flat_param._tensors[i]
+ _p_assert(
+ param is not None,
+ f"Expects to have saved tensor for {flat_param._fqns[i]}",
+ )
param_changed = getattr(module, param_name) is not param
needs_param_writeback = (
param_changed # changed parameter variable itself
@@ -1915,6 +1963,13 @@
param, flat_param_data_ptr
) # changed `.data`
)
+ if self._skipped_use_sharded_views and (
+ param_changed or needs_param_writeback
+ ):
+ raise AssertionError(
+ "FSDP does not support changing the parameters between "
+ f"forward and backward for {self._sharding_strategy}"
+ )
if param_changed:
# NOTE: The gradient is not preserved after a parameter change.
param = getattr(module, param_name)
@@ -1927,6 +1982,10 @@
wroteback = True
# Check for gradient writeback
+ if self._skipped_use_sharded_views:
+ # Skip the writeback check because we do not expose gradients
+ # when we skipped using sharded views
+ continue
if param.grad is None and flat_param.grad is not None:
expected_shape = torch.Size([numel_in_shard])
self._writeback_tensor(
diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py
index 2e90c84..c110c16 100644
--- a/torch/distributed/fsdp/fully_sharded_data_parallel.py
+++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py
@@ -199,6 +199,16 @@
accessing the original parameters between forward and backward may
raise an illegal memory access.
+ .. warning::
+ For ``use_orig_params=True``, ``ShardingStrategy.SHARD_GRAD_OP``
+ exposes the unsharded parameters, not the sharded parameters, after
+ forward since it does not free the unsharded ones, unlike
+ ``ShardingStrategy.FULL_SHARD``. One caveat is that, since gradients
+ are always sharded or ``None``, ``ShardingStrategy.SHARD_GRAD_OP`` will
+ not expose the sharded gradients with the unsharded parameters after
+ forward. If you want to inspect the gradients, try
+ :meth:`summon_full_params` with ``with_grads=True``.
+
Args:
module (nn.Module):
This is the module to be wrapped with FSDP.
diff --git a/torch/testing/_internal/common_fsdp.py b/torch/testing/_internal/common_fsdp.py
index 75e648f..ae85b49 100644
--- a/torch/testing/_internal/common_fsdp.py
+++ b/torch/testing/_internal/common_fsdp.py
@@ -14,6 +14,7 @@
import torch.nn as nn
from torch.distributed.fsdp import CPUOffload, FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._common_utils import TrainingState
+from torch.distributed.fsdp._init_utils import NO_RESHARD_AFTER_FORWARD_STRATEGIES
from torch.distributed.fsdp.fully_sharded_data_parallel import (
BackwardPrefetch,
MixedPrecision,
@@ -843,7 +844,14 @@
input = tuple(x.half() for x in input)
output = model(*input)
# Post-forward, if CPU offloading model param should be on CPU.
- if cpu_offload_params and isinstance(model, FSDP):
+ if (
+ cpu_offload_params
+ and isinstance(model, FSDP)
+ # If not resharding after forward, the parameters are still
+ # exposed as unsharded views into the GPU flat parameter
+ and model.sharding_strategy
+ not in NO_RESHARD_AFTER_FORWARD_STRATEGIES
+ ):
for p in model.parameters():
# Params should always be on CPU
self.assertEqual(p.device, torch.device("cpu"))