[FSDP] Remove unneeded disable of tf32 (#100179)
I recall needing to disable tf32, but I cannot repro the issue anymore. Nowhere else in our unit tests do we disable tf32, so we can try to get rid of this disabling.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100179
Approved by: https://github.com/rohan-varma
diff --git a/test/distributed/fsdp/test_fsdp_grad_acc.py b/test/distributed/fsdp/test_fsdp_grad_acc.py
index 42d1b7d..f5eb0df 100644
--- a/test/distributed/fsdp/test_fsdp_grad_acc.py
+++ b/test/distributed/fsdp/test_fsdp_grad_acc.py
@@ -123,109 +123,100 @@
not config.use_no_sync for config in configs
):
return
- old_allow_tf32 = torch.backends.cuda.matmul.allow_tf32
- try:
- # Disable TF32 to prevent floating point drift
- torch.backends.cuda.matmul.allow_tf32 = False
+ # Initialize the FSDP model and optimizer
+ fsdp_kwargs = {
+ "cpu_offload": cpu_offload,
+ "backward_prefetch": backward_prefetch,
+ "sharding_strategy": sharding_strategy,
+ "use_orig_params": use_orig_params,
+ }
+ fsdp_model: FSDP = TransformerWithSharedParams.init(
+ self.process_group,
+ FSDPInitMode.RECURSIVE,
+ CUDAInitMode.CUDA_BEFORE,
+ fsdp_kwargs,
+ deterministic=True,
+ add_bn=False, # disable BN since the test uses varying batch sizes
+ )
+ device = torch.device("cuda")
+ optim = torch.optim.SGD(
+ fsdp_model.parameters(),
+ lr=0.01,
+ momentum=0.9,
+ )
- # Initialize the FSDP model and optimizer
- fsdp_kwargs = {
- "cpu_offload": cpu_offload,
- "backward_prefetch": backward_prefetch,
- "sharding_strategy": sharding_strategy,
- "use_orig_params": use_orig_params,
- }
- fsdp_model: FSDP = TransformerWithSharedParams.init(
- self.process_group,
- FSDPInitMode.RECURSIVE,
- CUDAInitMode.CUDA_BEFORE,
- fsdp_kwargs,
- deterministic=True,
- add_bn=False, # disable BN since the test uses varying batch sizes
+ # Generate the sequence of batches, each containing the same data
+ # but permuted
+ def permute_tensor(x: torch.Tensor):
+ return x.view(-1)[torch.randperm(x.numel())].view_as(x)
+
+ batch: Tuple[torch.Tensor, ...] = fsdp_model.module.get_input(device)
+ batches: List[Tuple[torch.Tensor, ...]] = [batch]
+ num_iters_to_acc = sum(config.num_iters for config in configs)
+ for _ in range(num_iters_to_acc - 1):
+ batches.append(tuple(permute_tensor(t) for t in batch))
+ for batch1, batch2 in itertools.combinations(batches, r=2):
+ for t1, t2 in zip(batch1, batch2):
+ assert not torch.all(
+ t1 == t2
+ ), "Check the test to make sure that batches are distinct"
+
+ # Concatenate the batches along the given batch dimension
+ concat_batch: Tuple[torch.Tensor, ...] = tuple(
+ torch.cat(ts, dim=batch_dim) for ts in zip(*batches)
+ )
+
+ # Establish reference gradients using the concatenated batch
+ fsdp_model.zero_grad()
+ output = fsdp_model(*concat_batch)
+ ref_loss = fsdp_model.module.get_loss(concat_batch, output)
+ ref_loss.backward()
+ ref_grads = [
+ p.grad.detach().clone()
+ for p in fsdp_model.parameters()
+ if p.grad is not None
+ ]
+
+ # Compute and accumulate the gradients
+ fsdp_model.zero_grad()
+ losses = []
+ batch_idx = 0
+ for config in configs:
+ sync_context = (
+ fsdp_model.no_sync() if config.use_no_sync else contextlib.suppress()
)
- device = torch.device("cuda")
- optim = torch.optim.SGD(
- fsdp_model.parameters(),
- lr=0.01,
- momentum=0.9,
- )
+ with sync_context:
+ for _ in range(config.num_iters):
+ if batch_idx == num_iters_to_acc - 1:
+ break # always sync on the last iteration
+ batch = batches[batch_idx]
+ batch_idx += 1
+ output = fsdp_model(*batch)
+ loss = fsdp_model.module.get_loss(batch, output)
+ loss.backward()
+ losses.append(loss)
+ output = fsdp_model(*batches[-1])
+ loss = fsdp_model.module.get_loss(batches[-1], output)
+ loss.backward()
+ losses.append(loss)
+ acc_loss = sum(losses)
+ acc_grads = [
+ p.grad.detach().clone()
+ for p in fsdp_model.parameters()
+ if p.grad is not None
+ ]
- # Generate the sequence of batches, each containing the same data
- # but permuted
- def permute_tensor(x: torch.Tensor):
- return x.view(-1)[torch.randperm(x.numel())].view_as(x)
+ # Compare the losses and gradients
+ torch.testing.assert_close(ref_loss, acc_loss)
+ self.assertEqual(len(ref_grads), len(acc_grads))
+ for ref_grad, acc_grad in zip(ref_grads, acc_grads):
+ self.assertEqual(ref_grad.device, acc_grad.device)
+ self.assertEqual(ref_grad.size(), acc_grad.size())
+ self.assertEqual(ref_grad.dtype, acc_grad.dtype)
+ torch.testing.assert_close(ref_grad, acc_grad)
- batch: Tuple[torch.Tensor, ...] = fsdp_model.module.get_input(device)
- batches: List[Tuple[torch.Tensor, ...]] = [batch]
- num_iters_to_acc = sum(config.num_iters for config in configs)
- for _ in range(num_iters_to_acc - 1):
- batches.append(tuple(permute_tensor(t) for t in batch))
- for batch1, batch2 in itertools.combinations(batches, r=2):
- for t1, t2 in zip(batch1, batch2):
- assert not torch.all(
- t1 == t2
- ), "Check the test to make sure that batches are distinct"
-
- # Concatenate the batches along the given batch dimension
- concat_batch: Tuple[torch.Tensor, ...] = tuple(
- torch.cat(ts, dim=batch_dim) for ts in zip(*batches)
- )
-
- # Establish reference gradients using the concatenated batch
- fsdp_model.zero_grad()
- output = fsdp_model(*concat_batch)
- ref_loss = fsdp_model.module.get_loss(concat_batch, output)
- ref_loss.backward()
- ref_grads = [
- p.grad.detach().clone()
- for p in fsdp_model.parameters()
- if p.grad is not None
- ]
-
- # Compute and accumulate the gradients
- fsdp_model.zero_grad()
- losses = []
- batch_idx = 0
- for config in configs:
- sync_context = (
- fsdp_model.no_sync()
- if config.use_no_sync
- else contextlib.suppress()
- )
- with sync_context:
- for _ in range(config.num_iters):
- if batch_idx == num_iters_to_acc - 1:
- break # always sync on the last iteration
- batch = batches[batch_idx]
- batch_idx += 1
- output = fsdp_model(*batch)
- loss = fsdp_model.module.get_loss(batch, output)
- loss.backward()
- losses.append(loss)
- output = fsdp_model(*batches[-1])
- loss = fsdp_model.module.get_loss(batches[-1], output)
- loss.backward()
- losses.append(loss)
- acc_loss = sum(losses)
- acc_grads = [
- p.grad.detach().clone()
- for p in fsdp_model.parameters()
- if p.grad is not None
- ]
-
- # Compare the losses and gradients
- torch.testing.assert_close(ref_loss, acc_loss)
- self.assertEqual(len(ref_grads), len(acc_grads))
- for ref_grad, acc_grad in zip(ref_grads, acc_grads):
- self.assertEqual(ref_grad.device, acc_grad.device)
- self.assertEqual(ref_grad.size(), acc_grad.size())
- self.assertEqual(ref_grad.dtype, acc_grad.dtype)
- torch.testing.assert_close(ref_grad, acc_grad)
-
- # Check that the optimizer step does not error
- optim.step()
- finally:
- torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32
+ # Check that the optimizer step does not error
+ optim.step()
def _get_subtest_config(self) -> Dict[str, List[Any]]:
"""Returns a subtest configuration that subtests prefetching."""