[FSDP] Remove unneeded disable of tf32 (#100179)

I recall needing to disable tf32, but I cannot repro the issue anymore. Nowhere else in our unit tests do we disable tf32, so we can try to get rid of this disabling.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100179
Approved by: https://github.com/rohan-varma
diff --git a/test/distributed/fsdp/test_fsdp_grad_acc.py b/test/distributed/fsdp/test_fsdp_grad_acc.py
index 42d1b7d..f5eb0df 100644
--- a/test/distributed/fsdp/test_fsdp_grad_acc.py
+++ b/test/distributed/fsdp/test_fsdp_grad_acc.py
@@ -123,109 +123,100 @@
             not config.use_no_sync for config in configs
         ):
             return
-        old_allow_tf32 = torch.backends.cuda.matmul.allow_tf32
-        try:
-            # Disable TF32 to prevent floating point drift
-            torch.backends.cuda.matmul.allow_tf32 = False
+        # Initialize the FSDP model and optimizer
+        fsdp_kwargs = {
+            "cpu_offload": cpu_offload,
+            "backward_prefetch": backward_prefetch,
+            "sharding_strategy": sharding_strategy,
+            "use_orig_params": use_orig_params,
+        }
+        fsdp_model: FSDP = TransformerWithSharedParams.init(
+            self.process_group,
+            FSDPInitMode.RECURSIVE,
+            CUDAInitMode.CUDA_BEFORE,
+            fsdp_kwargs,
+            deterministic=True,
+            add_bn=False,  # disable BN since the test uses varying batch sizes
+        )
+        device = torch.device("cuda")
+        optim = torch.optim.SGD(
+            fsdp_model.parameters(),
+            lr=0.01,
+            momentum=0.9,
+        )
 
-            # Initialize the FSDP model and optimizer
-            fsdp_kwargs = {
-                "cpu_offload": cpu_offload,
-                "backward_prefetch": backward_prefetch,
-                "sharding_strategy": sharding_strategy,
-                "use_orig_params": use_orig_params,
-            }
-            fsdp_model: FSDP = TransformerWithSharedParams.init(
-                self.process_group,
-                FSDPInitMode.RECURSIVE,
-                CUDAInitMode.CUDA_BEFORE,
-                fsdp_kwargs,
-                deterministic=True,
-                add_bn=False,  # disable BN since the test uses varying batch sizes
+        # Generate the sequence of batches, each containing the same data
+        # but permuted
+        def permute_tensor(x: torch.Tensor):
+            return x.view(-1)[torch.randperm(x.numel())].view_as(x)
+
+        batch: Tuple[torch.Tensor, ...] = fsdp_model.module.get_input(device)
+        batches: List[Tuple[torch.Tensor, ...]] = [batch]
+        num_iters_to_acc = sum(config.num_iters for config in configs)
+        for _ in range(num_iters_to_acc - 1):
+            batches.append(tuple(permute_tensor(t) for t in batch))
+        for batch1, batch2 in itertools.combinations(batches, r=2):
+            for t1, t2 in zip(batch1, batch2):
+                assert not torch.all(
+                    t1 == t2
+                ), "Check the test to make sure that batches are distinct"
+
+        # Concatenate the batches along the given batch dimension
+        concat_batch: Tuple[torch.Tensor, ...] = tuple(
+            torch.cat(ts, dim=batch_dim) for ts in zip(*batches)
+        )
+
+        # Establish reference gradients using the concatenated batch
+        fsdp_model.zero_grad()
+        output = fsdp_model(*concat_batch)
+        ref_loss = fsdp_model.module.get_loss(concat_batch, output)
+        ref_loss.backward()
+        ref_grads = [
+            p.grad.detach().clone()
+            for p in fsdp_model.parameters()
+            if p.grad is not None
+        ]
+
+        # Compute and accumulate the gradients
+        fsdp_model.zero_grad()
+        losses = []
+        batch_idx = 0
+        for config in configs:
+            sync_context = (
+                fsdp_model.no_sync() if config.use_no_sync else contextlib.suppress()
             )
-            device = torch.device("cuda")
-            optim = torch.optim.SGD(
-                fsdp_model.parameters(),
-                lr=0.01,
-                momentum=0.9,
-            )
+            with sync_context:
+                for _ in range(config.num_iters):
+                    if batch_idx == num_iters_to_acc - 1:
+                        break  # always sync on the last iteration
+                    batch = batches[batch_idx]
+                    batch_idx += 1
+                    output = fsdp_model(*batch)
+                    loss = fsdp_model.module.get_loss(batch, output)
+                    loss.backward()
+                    losses.append(loss)
+        output = fsdp_model(*batches[-1])
+        loss = fsdp_model.module.get_loss(batches[-1], output)
+        loss.backward()
+        losses.append(loss)
+        acc_loss = sum(losses)
+        acc_grads = [
+            p.grad.detach().clone()
+            for p in fsdp_model.parameters()
+            if p.grad is not None
+        ]
 
-            # Generate the sequence of batches, each containing the same data
-            # but permuted
-            def permute_tensor(x: torch.Tensor):
-                return x.view(-1)[torch.randperm(x.numel())].view_as(x)
+        # Compare the losses and gradients
+        torch.testing.assert_close(ref_loss, acc_loss)
+        self.assertEqual(len(ref_grads), len(acc_grads))
+        for ref_grad, acc_grad in zip(ref_grads, acc_grads):
+            self.assertEqual(ref_grad.device, acc_grad.device)
+            self.assertEqual(ref_grad.size(), acc_grad.size())
+            self.assertEqual(ref_grad.dtype, acc_grad.dtype)
+            torch.testing.assert_close(ref_grad, acc_grad)
 
-            batch: Tuple[torch.Tensor, ...] = fsdp_model.module.get_input(device)
-            batches: List[Tuple[torch.Tensor, ...]] = [batch]
-            num_iters_to_acc = sum(config.num_iters for config in configs)
-            for _ in range(num_iters_to_acc - 1):
-                batches.append(tuple(permute_tensor(t) for t in batch))
-            for batch1, batch2 in itertools.combinations(batches, r=2):
-                for t1, t2 in zip(batch1, batch2):
-                    assert not torch.all(
-                        t1 == t2
-                    ), "Check the test to make sure that batches are distinct"
-
-            # Concatenate the batches along the given batch dimension
-            concat_batch: Tuple[torch.Tensor, ...] = tuple(
-                torch.cat(ts, dim=batch_dim) for ts in zip(*batches)
-            )
-
-            # Establish reference gradients using the concatenated batch
-            fsdp_model.zero_grad()
-            output = fsdp_model(*concat_batch)
-            ref_loss = fsdp_model.module.get_loss(concat_batch, output)
-            ref_loss.backward()
-            ref_grads = [
-                p.grad.detach().clone()
-                for p in fsdp_model.parameters()
-                if p.grad is not None
-            ]
-
-            # Compute and accumulate the gradients
-            fsdp_model.zero_grad()
-            losses = []
-            batch_idx = 0
-            for config in configs:
-                sync_context = (
-                    fsdp_model.no_sync()
-                    if config.use_no_sync
-                    else contextlib.suppress()
-                )
-                with sync_context:
-                    for _ in range(config.num_iters):
-                        if batch_idx == num_iters_to_acc - 1:
-                            break  # always sync on the last iteration
-                        batch = batches[batch_idx]
-                        batch_idx += 1
-                        output = fsdp_model(*batch)
-                        loss = fsdp_model.module.get_loss(batch, output)
-                        loss.backward()
-                        losses.append(loss)
-            output = fsdp_model(*batches[-1])
-            loss = fsdp_model.module.get_loss(batches[-1], output)
-            loss.backward()
-            losses.append(loss)
-            acc_loss = sum(losses)
-            acc_grads = [
-                p.grad.detach().clone()
-                for p in fsdp_model.parameters()
-                if p.grad is not None
-            ]
-
-            # Compare the losses and gradients
-            torch.testing.assert_close(ref_loss, acc_loss)
-            self.assertEqual(len(ref_grads), len(acc_grads))
-            for ref_grad, acc_grad in zip(ref_grads, acc_grads):
-                self.assertEqual(ref_grad.device, acc_grad.device)
-                self.assertEqual(ref_grad.size(), acc_grad.size())
-                self.assertEqual(ref_grad.dtype, acc_grad.dtype)
-                torch.testing.assert_close(ref_grad, acc_grad)
-
-            # Check that the optimizer step does not error
-            optim.step()
-        finally:
-            torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32
+        # Check that the optimizer step does not error
+        optim.step()
 
     def _get_subtest_config(self) -> Dict[str, List[Any]]:
         """Returns a subtest configuration that subtests prefetching."""