[optim]fix ut and sgd kernel (#124904)

 - Original `test_grad_scaling_autocast_fused_optimizers` does not work since there is no "fused" in `optim_inputs`
 - We should use different `grad_scaler`, they should not share 1 `scale`, there is no issue exposed here because the default `_growth_interval` is 2000 so it will not growth and there is also no inf is found so it will not reduced. The one in `test_cuda.py` should also have this issue,
 - I set a manual seed to reproduce purpose if there is any numerical failure
 - I use Tensor tracker here because we failed this UT in dynamo case, the cpp generated code are not exactly same with fused/non fused kernel.
 - I make it check both `cuda` and `cpu`.
 - I find some SGD numerical issue with `clang`, and fixed it by using `fmadd` instead of `add/mul` in fused sgd veckernel.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124904
Approved by: https://github.com/jgong5, https://github.com/janeyx99
diff --git a/aten/src/ATen/native/cpu/FusedSGDKernel.cpp b/aten/src/ATen/native/cpu/FusedSGDKernel.cpp
index 3383585..c19aa24 100644
--- a/aten/src/ATen/native/cpu/FusedSGDKernel.cpp
+++ b/aten/src/ATen/native/cpu/FusedSGDKernel.cpp
@@ -52,8 +52,8 @@
       grad_vec2 = grad_vec2 * fVec(opmath_t(-1.0));
     }
     if (weight_decay != 0.0){
-      grad_vec1 += param_vec1 * fVec(scalar_t(weight_decay));
-      grad_vec2 += param_vec2 * fVec(scalar_t(weight_decay));
+      grad_vec1 = vec::fmadd(param_vec1, fVec(scalar_t(weight_decay)), grad_vec1);
+      grad_vec2 = vec::fmadd(param_vec2, fVec(scalar_t(weight_decay)), grad_vec2);
     }
     if (momentum != 0.0) {
       fVec momentum_vec1, momentum_vec2;
@@ -61,17 +61,16 @@
         momentum_vec1 = grad_vec1;
         momentum_vec2 = grad_vec2;
       } else {
-        momentum_vec1 =
-            fVec::loadu(momentum_buf_ptr + d) * fVec(scalar_t(momentum)) +
-            grad_vec1 * fVec(scalar_t(1 - dampening));
-        momentum_vec2 =
-            fVec::loadu(momentum_buf_ptr + d + fVec::size()) * fVec(scalar_t(momentum)) +
-            grad_vec2 * fVec(scalar_t(1 - dampening));
+
+        momentum_vec1 = fVec::loadu(momentum_buf_ptr + d) * fVec(scalar_t(momentum));
+        momentum_vec2 = fVec::loadu(momentum_buf_ptr + d + fVec::size()) * fVec(scalar_t(momentum));
+        momentum_vec1 = vec::fmadd(fVec(scalar_t(1 - dampening)), grad_vec1, momentum_vec1);
+        momentum_vec2 = vec::fmadd(fVec(scalar_t(1 - dampening)), grad_vec2, momentum_vec2);
       }
       vec::convert_from_float<scalar_t>(momentum_vec1, momentum_vec2).store(momentum_buf_ptr + d);;
       if (nesterov) {
-        grad_vec1 += momentum_vec1 * fVec(scalar_t(momentum));
-        grad_vec2 += momentum_vec2 * fVec(scalar_t(momentum));
+        grad_vec1 = vec::fmadd(momentum_vec1, fVec(scalar_t(momentum)), grad_vec1);
+        grad_vec2 = vec::fmadd(momentum_vec2, fVec(scalar_t(momentum)), grad_vec2);
       } else {
         grad_vec1 = momentum_vec1;
         grad_vec2 = momentum_vec2;
@@ -142,7 +141,7 @@
     }
     if (maximize) grad_vec = grad_vec * Vec(scalar_t(-1.0));
     if (weight_decay != 0.0){
-      grad_vec += param_vec * Vec(scalar_t(weight_decay));
+      grad_vec = vec::fmadd(param_vec, Vec(scalar_t(weight_decay)), grad_vec);
     }
     if (momentum != 0.0) {
       Vec momentum_vec;
@@ -150,12 +149,12 @@
         momentum_vec = grad_vec;
       } else {
         momentum_vec =
-            Vec::loadu(momentum_buf_ptr + d) * Vec(scalar_t(momentum)) +
-            grad_vec * Vec(scalar_t(1 - dampening));
+            Vec::loadu(momentum_buf_ptr + d) * Vec(scalar_t(momentum));
+        momentum_vec = vec::fmadd(Vec(scalar_t(1 - dampening)), grad_vec, momentum_vec);
       }
       momentum_vec.store(momentum_buf_ptr + d);
       if (nesterov) {
-        grad_vec += momentum_vec * Vec(scalar_t(momentum));
+        grad_vec =  vec::fmadd(momentum_vec, Vec(scalar_t(momentum)), grad_vec);
       } else {
         grad_vec = momentum_vec;
       }
diff --git a/test/test_cuda.py b/test/test_cuda.py
index 24acfb0..778bdd3 100644
--- a/test/test_cuda.py
+++ b/test/test_cuda.py
@@ -29,7 +29,6 @@
 )
 from torch.testing._internal.autocast_test_lists import AutocastTestLists
 from torch.testing._internal.common_cuda import (
-    _create_scaling_case,
     _get_torch_cuda_version,
     TEST_CUDNN,
     TEST_MULTIGPU,
@@ -1274,109 +1273,6 @@
                 )
                 self.assertTrue(r != 0)
 
-    # Compare non-fused optimizer vs fused one as the fused one unscales gradients
-    # inside its cuda kernel unlike the other.
-    def test_grad_scaling_autocast_fused_optimizers(self):
-        for optimizer_ctor, optimizer_kwargs, separate_unscale in list(
-            product(
-                (torch.optim.Adam, torch.optim.AdamW),
-                ({"fused": True, "amsgrad": False}, {"fused": True, "amsgrad": True}),
-                (False, True),
-            )
-        ) + list(
-            product(
-                (torch.optim.SGD,),
-                [
-                    {
-                        "momentum": 0.0,
-                        "dampening": d,
-                        "weight_decay": w,
-                        "nesterov": n,
-                        "fused": True,
-                    }
-                    for d, w, n in product((0.0, 0.5), (0.0, 0.5), (False,))
-                ]
-                + [
-                    {
-                        "momentum": 0.5,
-                        "dampening": d,
-                        "weight_decay": w,
-                        "nesterov": n,
-                        "fused": True,
-                    }
-                    for d, w, n in product((0.0,), (0.0, 0.5), (True, False))
-                ],
-                (False, True),
-            )
-        ):
-            with self.subTest(
-                optim=optimizer_ctor,
-                kwargs=optimizer_kwargs,
-                separate_unscale=separate_unscale,
-            ):
-                self._grad_scaling_autocast_fused_optimizers(
-                    optimizer_ctor=optimizer_ctor,
-                    optimizer_kwargs=optimizer_kwargs,
-                    separate_unscale=separate_unscale,
-                )
-
-    def _grad_scaling_autocast_fused_optimizers(
-        self, optimizer_ctor, optimizer_kwargs, separate_unscale
-    ):
-        (
-            mod_control,
-            mod_scaling,
-            opt_control,
-            opt_scaling,
-            data,
-            loss_fn,
-            _,
-        ) = _create_scaling_case(
-            optimizer_ctor=optimizer_ctor, optimizer_kwargs=optimizer_kwargs
-        )
-        kwargs = deepcopy(optimizer_kwargs)
-        kwargs["fused"] = False
-        opt_control = optimizer_ctor(mod_control.parameters(), lr=1.0, **kwargs)
-
-        scaler = torch.cuda.amp.GradScaler(init_scale=128.0)
-
-        for input, target in data:
-            opt_control.zero_grad()
-            with torch.autocast("cuda"):
-                output_control = mod_control(input)
-                loss_control = loss_fn(output_control, target)
-            scaler.scale(loss_control).backward()
-            scaler.step(opt_control)
-            scaler.update()
-
-            opt_scaling.zero_grad()
-            with torch.autocast("cuda"):
-                output_scaling = mod_scaling(input)
-                loss_scaling = loss_fn(output_scaling, target)
-            scaler.scale(loss_scaling).backward()
-            if separate_unscale:
-                scaler.unscale_(opt_scaling)
-            scaler.step(opt_scaling)
-            scaler.update()
-
-            self.assertEqual(loss_control, loss_scaling)
-            for param_control, param_scaling in zip(
-                mod_control.parameters(), mod_scaling.parameters()
-            ):
-                self.assertEqual(param_control.grad, param_scaling.grad)
-                self.assertEqual(param_control, param_scaling)
-
-                state_control, state_scaling = (
-                    opt_control.state[param_control],
-                    opt_scaling.state[param_scaling],
-                )
-
-                for k in state_control:
-                    actual = state_scaling[k]
-                    if k == "step":
-                        actual = actual.squeeze()
-                    self.assertEqual(state_control[k], actual)
-
     @unittest.skipIf(TEST_CUDAMALLOCASYNC, "FAIL")
     def test_cublas_multiple_threads_same_device(self):
         # Note, these parameters should be very carefully tuned
diff --git a/test/test_optim.py b/test/test_optim.py
index 031f2aa..709c28f 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -1677,7 +1677,7 @@
                 optimizers.append(optimizer)
         self._compare_between(inpts, models, optimizers)
 
-    @onlyCPU
+    @onlyNativeDeviceTypes
     @optims([optim for optim in optim_db if "fused" in optim.supported_impls], dtypes=[torch.float32])
     def test_grad_scaling_autocast_fused_optimizers(self, device, dtype, optim_info):
         # This ut is from test_cuda.py test_grad_scaling_autocast_fused_optimizers
@@ -1689,11 +1689,13 @@
         optim_cls = optim_info.optim_cls
         for optim_input in optim_inputs:
             kwargs = optim_input.kwargs
+            kwargs["fused"] = True
             for _separate_unscale in (True, False):
                 self._grad_scaling_autocast_fused_optimizers(
-                    optimizer_ctor=optim_cls, optimizer_kwargs=kwargs, separate_unscale=_separate_unscale)
+                    device=device, optimizer_ctor=optim_cls, optimizer_kwargs=kwargs, separate_unscale=_separate_unscale)
 
-    def _grad_scaling_autocast_fused_optimizers(self, optimizer_ctor, optimizer_kwargs, separate_unscale):
+    def _grad_scaling_autocast_fused_optimizers(self, device, optimizer_ctor, optimizer_kwargs, separate_unscale):
+        torch.manual_seed(20)
         (
             mod_control, mod_scaling, opt_control, opt_scaling, data, loss_fn, _,
         ) = _create_scaling_case(optimizer_ctor=optimizer_ctor, optimizer_kwargs=optimizer_kwargs, device='cpu')
@@ -1704,30 +1706,35 @@
             kwargs['lr'] = 1.0
         opt_control = optimizer_ctor(mod_control.parameters(), **kwargs)
 
-        scaler = torch.cpu.amp.GradScaler(init_scale=128.0)
+        scaler_scaling = torch.amp.GradScaler(device, init_scale=128.0)
+        scaler_control = torch.amp.GradScaler(device, init_scale=128.0)
+        tracker = TensorTracker()
         for input, target in data:
             opt_control.zero_grad()
-            with torch.autocast('cpu', dtype=torch.half):
+            with torch.autocast(device_type=device, dtype=torch.half):
                 output_control = mod_control(input)
                 loss_control = loss_fn(output_control, target)
-            scaler.scale(loss_control).backward()
-            scaler.step(opt_control)
-            scaler.update()
+            scaler_control.scale(loss_control).backward()
+            scaler_control.step(opt_control)
+            scaler_control.update()
 
             opt_scaling.zero_grad()
-            with torch.autocast('cpu', dtype=torch.half):
+            with torch.autocast(device_type=device, dtype=torch.half):
                 output_scaling = mod_scaling(input)
                 loss_scaling = loss_fn(output_scaling, target)
-            scaler.scale(loss_scaling).backward()
+            scaler_scaling.scale(loss_scaling).backward()
             if separate_unscale:
-                scaler.unscale_(opt_scaling)
-            scaler.step(opt_scaling)
-            scaler.update()
+                scaler_scaling.unscale_(opt_scaling)
+            scaler_scaling.step(opt_scaling)
+            scaler_scaling.update()
 
-            self.assertEqual(loss_control, loss_scaling,)
+            tracker.add(loss_control)
+            tracker.pop_check_set(loss_scaling, self)
             for param_control, param_scaling in zip(mod_control.parameters(), mod_scaling.parameters()):
-                self.assertEqual(param_control.grad, param_scaling.grad,)
-                self.assertEqual(param_control, param_scaling,)
+                tracker.add(param_control.grad)
+                tracker.pop_check_set(param_scaling.grad, self)
+                tracker.add(param_control)
+                tracker.pop_check_set(param_scaling, self)
 
                 state_control, state_scaling = opt_control.state[param_control], opt_scaling.state[param_scaling]
 
@@ -1735,7 +1742,8 @@
                     actual = state_scaling[k]
                     if k == "step":
                         actual = actual.squeeze()
-                    self.assertEqual(state_control[k], actual,)
+                    tracker.add(state_control[k])
+                    tracker.pop_check_set(actual, self)
 
     @onlyCUDA
     @optims([o for o in optim_db if "foreach" in o.supported_impls], dtypes=[torch.float32])