[optim]fix ut and sgd kernel (#124904)
- Original `test_grad_scaling_autocast_fused_optimizers` does not work since there is no "fused" in `optim_inputs`
- We should use different `grad_scaler`, they should not share 1 `scale`, there is no issue exposed here because the default `_growth_interval` is 2000 so it will not growth and there is also no inf is found so it will not reduced. The one in `test_cuda.py` should also have this issue,
- I set a manual seed to reproduce purpose if there is any numerical failure
- I use Tensor tracker here because we failed this UT in dynamo case, the cpp generated code are not exactly same with fused/non fused kernel.
- I make it check both `cuda` and `cpu`.
- I find some SGD numerical issue with `clang`, and fixed it by using `fmadd` instead of `add/mul` in fused sgd veckernel.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124904
Approved by: https://github.com/jgong5, https://github.com/janeyx99
diff --git a/aten/src/ATen/native/cpu/FusedSGDKernel.cpp b/aten/src/ATen/native/cpu/FusedSGDKernel.cpp
index 3383585..c19aa24 100644
--- a/aten/src/ATen/native/cpu/FusedSGDKernel.cpp
+++ b/aten/src/ATen/native/cpu/FusedSGDKernel.cpp
@@ -52,8 +52,8 @@
grad_vec2 = grad_vec2 * fVec(opmath_t(-1.0));
}
if (weight_decay != 0.0){
- grad_vec1 += param_vec1 * fVec(scalar_t(weight_decay));
- grad_vec2 += param_vec2 * fVec(scalar_t(weight_decay));
+ grad_vec1 = vec::fmadd(param_vec1, fVec(scalar_t(weight_decay)), grad_vec1);
+ grad_vec2 = vec::fmadd(param_vec2, fVec(scalar_t(weight_decay)), grad_vec2);
}
if (momentum != 0.0) {
fVec momentum_vec1, momentum_vec2;
@@ -61,17 +61,16 @@
momentum_vec1 = grad_vec1;
momentum_vec2 = grad_vec2;
} else {
- momentum_vec1 =
- fVec::loadu(momentum_buf_ptr + d) * fVec(scalar_t(momentum)) +
- grad_vec1 * fVec(scalar_t(1 - dampening));
- momentum_vec2 =
- fVec::loadu(momentum_buf_ptr + d + fVec::size()) * fVec(scalar_t(momentum)) +
- grad_vec2 * fVec(scalar_t(1 - dampening));
+
+ momentum_vec1 = fVec::loadu(momentum_buf_ptr + d) * fVec(scalar_t(momentum));
+ momentum_vec2 = fVec::loadu(momentum_buf_ptr + d + fVec::size()) * fVec(scalar_t(momentum));
+ momentum_vec1 = vec::fmadd(fVec(scalar_t(1 - dampening)), grad_vec1, momentum_vec1);
+ momentum_vec2 = vec::fmadd(fVec(scalar_t(1 - dampening)), grad_vec2, momentum_vec2);
}
vec::convert_from_float<scalar_t>(momentum_vec1, momentum_vec2).store(momentum_buf_ptr + d);;
if (nesterov) {
- grad_vec1 += momentum_vec1 * fVec(scalar_t(momentum));
- grad_vec2 += momentum_vec2 * fVec(scalar_t(momentum));
+ grad_vec1 = vec::fmadd(momentum_vec1, fVec(scalar_t(momentum)), grad_vec1);
+ grad_vec2 = vec::fmadd(momentum_vec2, fVec(scalar_t(momentum)), grad_vec2);
} else {
grad_vec1 = momentum_vec1;
grad_vec2 = momentum_vec2;
@@ -142,7 +141,7 @@
}
if (maximize) grad_vec = grad_vec * Vec(scalar_t(-1.0));
if (weight_decay != 0.0){
- grad_vec += param_vec * Vec(scalar_t(weight_decay));
+ grad_vec = vec::fmadd(param_vec, Vec(scalar_t(weight_decay)), grad_vec);
}
if (momentum != 0.0) {
Vec momentum_vec;
@@ -150,12 +149,12 @@
momentum_vec = grad_vec;
} else {
momentum_vec =
- Vec::loadu(momentum_buf_ptr + d) * Vec(scalar_t(momentum)) +
- grad_vec * Vec(scalar_t(1 - dampening));
+ Vec::loadu(momentum_buf_ptr + d) * Vec(scalar_t(momentum));
+ momentum_vec = vec::fmadd(Vec(scalar_t(1 - dampening)), grad_vec, momentum_vec);
}
momentum_vec.store(momentum_buf_ptr + d);
if (nesterov) {
- grad_vec += momentum_vec * Vec(scalar_t(momentum));
+ grad_vec = vec::fmadd(momentum_vec, Vec(scalar_t(momentum)), grad_vec);
} else {
grad_vec = momentum_vec;
}
diff --git a/test/test_cuda.py b/test/test_cuda.py
index 24acfb0..778bdd3 100644
--- a/test/test_cuda.py
+++ b/test/test_cuda.py
@@ -29,7 +29,6 @@
)
from torch.testing._internal.autocast_test_lists import AutocastTestLists
from torch.testing._internal.common_cuda import (
- _create_scaling_case,
_get_torch_cuda_version,
TEST_CUDNN,
TEST_MULTIGPU,
@@ -1274,109 +1273,6 @@
)
self.assertTrue(r != 0)
- # Compare non-fused optimizer vs fused one as the fused one unscales gradients
- # inside its cuda kernel unlike the other.
- def test_grad_scaling_autocast_fused_optimizers(self):
- for optimizer_ctor, optimizer_kwargs, separate_unscale in list(
- product(
- (torch.optim.Adam, torch.optim.AdamW),
- ({"fused": True, "amsgrad": False}, {"fused": True, "amsgrad": True}),
- (False, True),
- )
- ) + list(
- product(
- (torch.optim.SGD,),
- [
- {
- "momentum": 0.0,
- "dampening": d,
- "weight_decay": w,
- "nesterov": n,
- "fused": True,
- }
- for d, w, n in product((0.0, 0.5), (0.0, 0.5), (False,))
- ]
- + [
- {
- "momentum": 0.5,
- "dampening": d,
- "weight_decay": w,
- "nesterov": n,
- "fused": True,
- }
- for d, w, n in product((0.0,), (0.0, 0.5), (True, False))
- ],
- (False, True),
- )
- ):
- with self.subTest(
- optim=optimizer_ctor,
- kwargs=optimizer_kwargs,
- separate_unscale=separate_unscale,
- ):
- self._grad_scaling_autocast_fused_optimizers(
- optimizer_ctor=optimizer_ctor,
- optimizer_kwargs=optimizer_kwargs,
- separate_unscale=separate_unscale,
- )
-
- def _grad_scaling_autocast_fused_optimizers(
- self, optimizer_ctor, optimizer_kwargs, separate_unscale
- ):
- (
- mod_control,
- mod_scaling,
- opt_control,
- opt_scaling,
- data,
- loss_fn,
- _,
- ) = _create_scaling_case(
- optimizer_ctor=optimizer_ctor, optimizer_kwargs=optimizer_kwargs
- )
- kwargs = deepcopy(optimizer_kwargs)
- kwargs["fused"] = False
- opt_control = optimizer_ctor(mod_control.parameters(), lr=1.0, **kwargs)
-
- scaler = torch.cuda.amp.GradScaler(init_scale=128.0)
-
- for input, target in data:
- opt_control.zero_grad()
- with torch.autocast("cuda"):
- output_control = mod_control(input)
- loss_control = loss_fn(output_control, target)
- scaler.scale(loss_control).backward()
- scaler.step(opt_control)
- scaler.update()
-
- opt_scaling.zero_grad()
- with torch.autocast("cuda"):
- output_scaling = mod_scaling(input)
- loss_scaling = loss_fn(output_scaling, target)
- scaler.scale(loss_scaling).backward()
- if separate_unscale:
- scaler.unscale_(opt_scaling)
- scaler.step(opt_scaling)
- scaler.update()
-
- self.assertEqual(loss_control, loss_scaling)
- for param_control, param_scaling in zip(
- mod_control.parameters(), mod_scaling.parameters()
- ):
- self.assertEqual(param_control.grad, param_scaling.grad)
- self.assertEqual(param_control, param_scaling)
-
- state_control, state_scaling = (
- opt_control.state[param_control],
- opt_scaling.state[param_scaling],
- )
-
- for k in state_control:
- actual = state_scaling[k]
- if k == "step":
- actual = actual.squeeze()
- self.assertEqual(state_control[k], actual)
-
@unittest.skipIf(TEST_CUDAMALLOCASYNC, "FAIL")
def test_cublas_multiple_threads_same_device(self):
# Note, these parameters should be very carefully tuned
diff --git a/test/test_optim.py b/test/test_optim.py
index 031f2aa..709c28f 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -1677,7 +1677,7 @@
optimizers.append(optimizer)
self._compare_between(inpts, models, optimizers)
- @onlyCPU
+ @onlyNativeDeviceTypes
@optims([optim for optim in optim_db if "fused" in optim.supported_impls], dtypes=[torch.float32])
def test_grad_scaling_autocast_fused_optimizers(self, device, dtype, optim_info):
# This ut is from test_cuda.py test_grad_scaling_autocast_fused_optimizers
@@ -1689,11 +1689,13 @@
optim_cls = optim_info.optim_cls
for optim_input in optim_inputs:
kwargs = optim_input.kwargs
+ kwargs["fused"] = True
for _separate_unscale in (True, False):
self._grad_scaling_autocast_fused_optimizers(
- optimizer_ctor=optim_cls, optimizer_kwargs=kwargs, separate_unscale=_separate_unscale)
+ device=device, optimizer_ctor=optim_cls, optimizer_kwargs=kwargs, separate_unscale=_separate_unscale)
- def _grad_scaling_autocast_fused_optimizers(self, optimizer_ctor, optimizer_kwargs, separate_unscale):
+ def _grad_scaling_autocast_fused_optimizers(self, device, optimizer_ctor, optimizer_kwargs, separate_unscale):
+ torch.manual_seed(20)
(
mod_control, mod_scaling, opt_control, opt_scaling, data, loss_fn, _,
) = _create_scaling_case(optimizer_ctor=optimizer_ctor, optimizer_kwargs=optimizer_kwargs, device='cpu')
@@ -1704,30 +1706,35 @@
kwargs['lr'] = 1.0
opt_control = optimizer_ctor(mod_control.parameters(), **kwargs)
- scaler = torch.cpu.amp.GradScaler(init_scale=128.0)
+ scaler_scaling = torch.amp.GradScaler(device, init_scale=128.0)
+ scaler_control = torch.amp.GradScaler(device, init_scale=128.0)
+ tracker = TensorTracker()
for input, target in data:
opt_control.zero_grad()
- with torch.autocast('cpu', dtype=torch.half):
+ with torch.autocast(device_type=device, dtype=torch.half):
output_control = mod_control(input)
loss_control = loss_fn(output_control, target)
- scaler.scale(loss_control).backward()
- scaler.step(opt_control)
- scaler.update()
+ scaler_control.scale(loss_control).backward()
+ scaler_control.step(opt_control)
+ scaler_control.update()
opt_scaling.zero_grad()
- with torch.autocast('cpu', dtype=torch.half):
+ with torch.autocast(device_type=device, dtype=torch.half):
output_scaling = mod_scaling(input)
loss_scaling = loss_fn(output_scaling, target)
- scaler.scale(loss_scaling).backward()
+ scaler_scaling.scale(loss_scaling).backward()
if separate_unscale:
- scaler.unscale_(opt_scaling)
- scaler.step(opt_scaling)
- scaler.update()
+ scaler_scaling.unscale_(opt_scaling)
+ scaler_scaling.step(opt_scaling)
+ scaler_scaling.update()
- self.assertEqual(loss_control, loss_scaling,)
+ tracker.add(loss_control)
+ tracker.pop_check_set(loss_scaling, self)
for param_control, param_scaling in zip(mod_control.parameters(), mod_scaling.parameters()):
- self.assertEqual(param_control.grad, param_scaling.grad,)
- self.assertEqual(param_control, param_scaling,)
+ tracker.add(param_control.grad)
+ tracker.pop_check_set(param_scaling.grad, self)
+ tracker.add(param_control)
+ tracker.pop_check_set(param_scaling, self)
state_control, state_scaling = opt_control.state[param_control], opt_scaling.state[param_scaling]
@@ -1735,7 +1742,8 @@
actual = state_scaling[k]
if k == "step":
actual = actual.squeeze()
- self.assertEqual(state_control[k], actual,)
+ tracker.add(state_control[k])
+ tracker.pop_check_set(actual, self)
@onlyCUDA
@optims([o for o in optim_db if "foreach" in o.supported_impls], dtypes=[torch.float32])