Revert D19859905: [pytorch][PR] Gradient scaling API

Test Plan: revert-hammer

Differential Revision:
D19859905

Original commit changeset: bb8ae6966214

fbshipit-source-id: 28f1c93e8a00e3a4bbe8cc981499b15468f0b970
diff --git a/aten/src/ATen/native/cuda/AmpKernels.cu b/aten/src/ATen/native/cuda/AmpKernels.cu
deleted file mode 100644
index e5f2294..0000000
--- a/aten/src/ATen/native/cuda/AmpKernels.cu
+++ /dev/null
@@ -1,130 +0,0 @@
-#define _USE_MATH_DEFINES
-
-#include <math.h>
-
-#include <ATen/ATen.h>
-#include <ATen/Dispatch.h>
-#include <ATen/native/TensorIterator.h>
-#include <ATen/native/cuda/Loops.cuh>
-
-namespace {
-// Thin wrapper around https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH__SINGLE.html#group__CUDA__MATH__SINGLE_1g57a3c8313f570282a1a7bcc78743b08e,
-// to ensure the Cuda math library's isfinite is actually what gets called in
-// _amp_non_finite_check_and_unscale_cuda_'s gpu_kernel lambda.
-//
-// isfinite_ensure_cuda_math is defined outside at::native because:
-// - A bare call to "isfinite(val)" inside at::native causes nvcc to prefer the unrelated
-//   Tensor at::native::isfinite(const Tensor&), resulting in an error:
-//   "no suitable constructor exists to convert from "float" to "at::Tensor""
-// - Unfortunately, the Cuda math library documentation doesn't say how (or if) you can provide a full namespace path
-//   to ensure that its version of a particular function is invoked.  It only shows bare (not-namespaced)
-//   calls to its routines inside kernel or device functions.
-// - "std::isfinite(val)" in the gpu_kernel lambda causes an "unspecified launch failure" at runtime with cuda 9 on Windows.
-//
-// isfinite_ensure_cuda_math, declared at file scope outside the at::native region, uses isfinite as math library docs
-// suggest and allows disambiguated usage in the lambda within the at::native region.
-// GPU_LAMBDA is defined as __host__ __device__ (see Loops.cuh), so I need the __host__ keyword or else nvcc complains that
-// "calling a __device__ function("isfinite_ensure_cuda_math") from a __host__ __device__ function("operator()") is not allowed."
-static __host__ __device__ __forceinline__ int isfinite_ensure_cuda_math(float val) {
-  return isfinite(val);
-}
-}
-
-namespace at {
-namespace native {
-
-// Multiplies scaled_grad in-place by inv_scale.  If an element of scaled_grad was inf or NaN sets found_inf to 1.0.
-//
-// Args:
-// scaled_grad:  A (scaled) gradient tensor.  May contain infs or NaNs.
-// found_inf:  A single-element float tensor to which 1.0 will be written if any gradients contain infs/nans.
-//             Pre-zeroing found_inf, if appropriate, is the responsibility of the caller.
-// inv_scale:  The inverse of the scale factor by which scaled_grad is currently multiplied.
-//
-// Returns:
-// A tuple with references to scaled_grad, which is now unscaled in place, and found_inf,
-// which is now guaranteed to contain 1.0 if an inf or NaN was found in scaled_grad.
-void _amp_non_finite_check_and_unscale_cuda_(Tensor& scaled_grad,
-                                             Tensor& found_inf,
-                                             const Tensor& inv_scale)
-{
-  TORCH_CHECK(scaled_grad.is_cuda(), "scaled_grad must be a CUDA tensor.");
-  TORCH_CHECK(inv_scale.is_cuda(), "inv_scale must be a CUDA tensor.");
-  TORCH_CHECK(found_inf.is_cuda(), "found_inf must be a CUDA tensor.");
-  TORCH_CHECK(inv_scale.numel() == 1, "inv_scale must be a 1-element tensor.");
-  TORCH_CHECK(found_inf.numel() == 1, "found_inf must be a 1-element tensor.");
-  TORCH_CHECK(inv_scale.scalar_type() == at::ScalarType::Float, "inv_scale must be a float tensor.");
-  TORCH_CHECK(found_inf.scalar_type() == at::ScalarType::Float, "found_inf must be a float tensor.");
-  TORCH_CHECK(scaled_grad.layout() == at::kStrided, "scaled_grad must be a strided (not sparse) Tensor.");
-
-  // Act on scaled_grad in place.
-  auto iter = TensorIterator::unary_op(scaled_grad, scaled_grad);
-
-  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
-    iter.dtype(),
-    "_amp_non_finite_check_and_unscale_cuda",
-    [&iter, &found_inf, &inv_scale] {
-      auto* found_inf_ptr = found_inf.data_ptr<float>();
-      auto* inv_scale_ptr = inv_scale.data_ptr<float>();
-
-      gpu_kernel(iter, [found_inf_ptr, inv_scale_ptr]GPU_LAMBDA(scalar_t val) -> scalar_t {
-          float fval = static_cast<float>(val);
-          // See isfinite_ensure_cuda_math above.
-          if (!isfinite_ensure_cuda_math(fval)) {
-            *found_inf_ptr = 1.f;
-          }
-          const auto inv_scale_val = *inv_scale_ptr; // Every thread accesses inv_scale, but it will hit in cache.
-          return static_cast<scalar_t>(inv_scale_val == 1.f ? fval : fval*inv_scale_val);
-        });
-    });
-}
-
-
-// amp_update_scale_cuda_kernel is launched with a single thread to compute the new scale.
-// The scale factor is maintained and updated on the GPU to avoid synchronization.
-__global__ void amp_update_scale_cuda_kernel(float* current_scale,
-                                             float* found_inf,
-                                             float* new_scale,
-                                             double scale_growth_factor,
-                                             double scale_backoff_factor)
-{
-  *new_scale = (*found_inf) ? (*current_scale)*scale_backoff_factor : (*current_scale)*scale_growth_factor;
-}
-
-
-// _amp_update_scale_cuda asynchronously updates the scale factor.
-//
-// Args:
-// current_scale:  A one-element torch.cuda.FloatTensor containing the current scale value.
-// found_inf:  A one-element torch.cuda.FloatTensor. If > 0, indicates that infs/nans were found by the relevant
-//             prior _amp_non_finite_check_and_unscale_cuda call, and 0 if no infs/nans were found.
-// scale_growth_factor:  Multiplier if no infs/NaNs were found (typically slightly > 1).
-// scale_backoff_factor:  Multiplier if infs/NaNs were found (typically 0.5).
-//
-// Returns:
-// new_scale:  A new one-element torch.cuda.FloatTensor containing the new recommended scale value.
-Tensor _amp_update_scale_cuda(const Tensor& current_scale,
-                              const Tensor& found_inf,
-                              double scale_growth_factor,
-                              double scale_backoff_factor)
-{
-  TORCH_CHECK(current_scale.is_cuda(), "current_scale must be a CUDA tensor.");
-  TORCH_CHECK(found_inf.is_cuda(), "found_inf must be a CUDA tensor.");
-  TORCH_CHECK(current_scale.numel() == 1, "current_scale must be a 1-element tensor.");
-  TORCH_CHECK(found_inf.numel() == 1, "found_inf must be a 1-element tensor.");
-  TORCH_CHECK(current_scale.scalar_type() == at::ScalarType::Float, "current_scale must be a float tensor.");
-  TORCH_CHECK(found_inf.scalar_type() == at::ScalarType::Float, "found_inf must be a float tensor.");
-
-  auto new_scale = at::empty_like(current_scale);
-
-  amp_update_scale_cuda_kernel<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
-    current_scale.data_ptr<float>(),
-    found_inf.data_ptr<float>(),
-    new_scale.data_ptr<float>(),
-    scale_growth_factor,
-    scale_backoff_factor);
-
-  return new_scale;
-}
-
-}} // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 9c65adf..5eadc65 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -5269,16 +5269,6 @@
     CUDA: legacy::cuda::_th_std
   supports_named_tensor: True
 
-- func: _amp_non_finite_check_and_unscale_(Tensor(a!) self, Tensor(b!) found_inf, Tensor inv_scale) -> ()
-  variants: function
-  dispatch:
-    CUDA: _amp_non_finite_check_and_unscale_cuda_
-
-- func: _amp_update_scale(Tensor current_scale, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor) -> Tensor
-  variants: function
-  dispatch:
-    CUDA: _amp_update_scale_cuda
-
 - func: _cat(Tensor[] tensors, int dim=0) -> Tensor
   dispatch:
     CPU: _cat_cpu
diff --git a/docs/source/amp.rst b/docs/source/amp.rst
deleted file mode 100644
index 9ae8783..0000000
--- a/docs/source/amp.rst
+++ /dev/null
@@ -1,37 +0,0 @@
-.. role:: hidden
-    :class: hidden-section
-
-Automatic Mixed Precision package - torch.cuda.amp
-==================================================
-
-.. automodule:: torch.cuda.amp
-.. currentmodule:: torch.cuda.amp
-
-``torch.cuda.amp`` provides convenience methods for running networks with mixed precision,
-where some operations use the ``torch.float32`` (``float``) datatype and other operations
-use ``torch.float16`` (``half``). Some operations, like linear layers and convolutions,
-are much faster in ``float16``. Other operations, like reductions, often require the dynamic
-range of ``float32``. Networks running in mixed precision try to match each operation to its appropriate datatype.
-
-.. contents:: :local:
-
-.. _gradient-scaling:
-
-Gradient Scaling
-^^^^^^^^^^^^^^^^
-
-When training a network with mixed precision, if the forward pass for a particular op has
-``torch.float16`` inputs, the backward pass for that op will produce ``torch.float16`` gradients.
-Gradient values with small magnitudes may not be representable in ``torch.float16``.
-These values will flush to zero ("underflow"), so the update for the corresponding parameters will be lost.
-
-To prevent underflow, "gradient scaling" multiplies the network's loss(es) by a scale factor and
-invokes a backward pass on the scaled loss(es).  Gradients flowing backward through the network are
-then scaled by the same factor.  In other words, gradient values have a larger magnitude,
-so they don't flush to zero.
-
-The parameters' gradients (``.grad`` attributes) should be unscaled before the optimizer uses them
-to update the parameters, so the scale factor does not interfere with the learning rate.
-
-.. autoclass:: GradScaler
-    :members:
diff --git a/docs/source/index.rst b/docs/source/index.rst
index e8f0395..3196086 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -36,7 +36,6 @@
    tensor_attributes
    torch.autograd <autograd>
    cuda
-   torch.cuda.amp <amp>
    torch.distributed <distributed>
    torch.distributions <distributions>
    torch.hub <hub>
diff --git a/docs/source/notes/amp_examples.rst b/docs/source/notes/amp_examples.rst
deleted file mode 100644
index 7823fdc..0000000
--- a/docs/source/notes/amp_examples.rst
+++ /dev/null
@@ -1,210 +0,0 @@
-.. _amp-examples:
-
-Automatic Mixed Precision examples
-==================================
-
-.. currentmodule:: torch.cuda.amp
-
-.. contents:: :local:
-
-.. _gradient-scaling-examples:
-
-Gradient Scaling
-^^^^^^^^^^^^^^^^
-
-Gradient scaling helps prevent gradient underflow when training with mixed precision,
-as explained :ref:`here<gradient-scaling>`.
-
-Instances of :class:`torch.cuda.amp.GradScaler` help perform the steps of
-gradient scaling conveniently, as shown in the following code snippets.
-
-
-Typical Use
------------
-
-::
-
-    # Creates a GradScaler once at the beginning of training.
-    scaler = GradScaler()
-
-    for epoch in epochs:
-        for input, target in data:
-            optimizer.zero_grad()
-            output = model(input)
-            loss = loss_fn(output, target)
-
-            # Scales the loss, and calls backward() on the scaled loss to create scaled gradients.
-            scaler.scale(loss).backward()
-
-            # scaler.step() first unscales the gradients of the optimizer's assigned params.
-            # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
-            # otherwise, optimizer.step() is skipped.
-            scaler.step(optimizer)
-
-            # Updates the scale for next iteration.
-            scaler.update()
-
-.. _working-with-unscaled-gradients:
-
-Working with Unscaled Gradients
--------------------------------
-
-All gradients produced by ``scaler.scale(loss).backward()`` are scaled.  If you wish to modify or inspect
-the parameters' ``.grad`` attributes between ``backward()`` and ``scaler.step(optimizer)``,  you should
-unscale them first.  For example, gradient clipping manipulates a set of gradients such that their global norm
-(see :func:`torch.nn.utils.clip_grad_norm_`) or maximum magnitude (see :func:`torch.nn.utils.clip_grad_value_`)
-is :math:`<=` some user-imposed threshold.  If you attempted to clip *without* unscaling, the gradients' norm/maximum
-magnitude would also be scaled, so your requested threshold (which was meant to be the threshold for *unscaled*
-gradients) would be invalid.
-
-``scaler.unscale_(optimizer)`` unscales gradients held by ``optimizer``'s assigned parameters.
-If your model or models contain other parameters that were assigned to another optimizer
-(say ``optimizer2``), you may call ``scaler.unscale_(optimizer2)`` separately to unscale those
-parameters' gradients as well.
-
-Gradient clipping
-"""""""""""""""""
-
-Calling ``scaler.unscale_(optimizer)`` before clipping enables you to clip unscaled gradients as usual::
-
-    scaler = GradScaler()
-
-    for epoch in epochs:
-        for input, target in data:
-            optimizer.zero_grad()
-            output = model(input)
-            loss = loss_fn(output, target)
-            scaler.scale(loss).backward()
-
-            # Unscales the gradients of optimizer's assigned params in-place
-            scaler.unscale_(optimizer)
-
-            # Since the gradients of optimizer's assigned params are unscaled, clips as usual:
-            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
-
-            # optimizer's gradients are already unscaled, so scaler.step does not unscale them,
-            # although it still skips optimizer.step() if the gradients contain infs or NaNs.
-            scaler.step(optimizer)
-
-            # Updates the scale for next iteration.
-            scaler.update()
-
-``scaler`` records that ``scaler.unscale_(optimizer)`` was already called for this optimizer
-this iteration, so ``scaler.step(optimizer)`` knows not to redundantly unscale gradients before
-(internally) calling ``optimizer.step()``.
-
-.. warning::
-    :meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
-    and only after all gradients for that optimizer's assigned parameters have been accumulated.
-    Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
-
-Working with Scaled Gradients
------------------------------
-
-For some operations, you may need to work with scaled gradients in a setting where
-``scaler.unscale_`` is unsuitable.
-
-Gradient penalty
-""""""""""""""""
-
-A gradient penalty implementation typically creates gradients out-of-place using
-:func:`torch.autograd.grad`, combines them to create the penalty value,
-and adds the penalty value to the loss.
-
-Here's an ordinary example of an L2 penalty without gradient scaling::
-
-    for epoch in epochs:
-        for input, target in data:
-            optimizer.zero_grad()
-            output = model(input)
-            loss = loss_fn(output, target)
-
-            # Creates some gradients out-of-place
-            grad_params = torch.autograd.grad(loss, model.parameters(), create_graph=True)
-
-            # Computes the penalty term and adds it to the loss
-            grad_norm = 0
-            for grad in grad_params:
-                grad_norm += grad.pow(2).sum()
-            grad_norm = grad_norm.sqrt()
-            loss = loss + grad_norm
-
-            loss.backward()
-            optimizer.step()
-
-To implement a gradient penalty *with* gradient scaling, the loss passed to
-:func:`torch.autograd.grad` should be scaled.  The resulting out-of-place gradients
-will therefore be scaled, and should be unscaled before being combined to create the
-penalty value.
-
-Here's how that looks for the same L2 penalty::
-
-    scaler = GradScaler()
-
-    for epoch in epochs:
-        for input, target in data:
-            optimizer.zero_grad()
-            output = model(input)
-            loss = loss_fn(output, target)
-
-            # Scales the loss for the out-of-place backward pass, resulting in scaled grad_params
-            scaled_grad_params = torch.autograd.grad(scaler.scale(loss), model.parameters(), create_graph=True)
-
-            # Unscales grad_params before computing the penalty.  grad_params are not owned
-            # by any optimizer, so ordinary division is used instead of scaler.unscale_:
-            inv_scale = 1./scaler.get_scale()
-            grad_params = [p*inv_scale for p in scaled_grad_params]
-
-            # Computes the penalty term and adds it to the loss
-            grad_norm = 0
-            for grad in grad_params:
-                grad_norm += grad.pow(2).sum()
-            grad_norm = grad_norm.sqrt()
-            loss = loss + grad_norm
-
-            # Applies scaling to the backward call as usual.  Accumulates leaf gradients that are correctly scaled.
-            scaler.scale(loss).backward()
-
-            # step() and update() proceed as usual.
-            scaler.step(optimizer)
-            scaler.update()
-
-
-Working with Multiple Losses and Optimizers
--------------------------------------------
-
-If your network has multiple losses, you must call ``scaler.scale`` on each of them individually.
-If your network has multiple optimizers, you may call ``scaler.unscale_`` on any of them individually,
-and you must call ``scaler.step`` on each of them individually.
-
-However, ``scaler.update()`` should only be called once,
-after all optimizers used this iteration have been stepped::
-
-    scaler = torch.cuda.amp.GradScaler()
-
-    for epoch in epochs:
-        for input, target in data:
-            optimizer0.zero_grad()
-            optimizer1.zero_grad()
-            output0 = model0(input)
-            output1 = model1(input)
-            loss0 = loss_fn(2 * output0 + 3 * output1, target)
-            loss1 = loss_fn(3 * output0 - 5 * output1, target)
-
-            scaler.scale(loss0).backward(retain_graph=True)
-            scaler.scale(loss1).backward()
-
-            # You can choose which optimizers receive explicit unscaling, if you
-            # want to inspect or modify the gradients of the params they own.
-            scaler.unscale_(optimizer0)
-
-            scaler.step(optimizer0)
-            scaler.step(optimizer1)
-
-            scaler.update()
-
-Each optimizer independently checks its gradients for infs/NaNs, and therefore makes an independent decision
-whether or not to skip the step.  This may result in one optimizer skipping the step
-while the other one does not.  Since step skipping occurs rarely (every several hundred iterations)
-this should not impede convergence.  If you observe poor convergence after adding gradient scaling
-to a multiple-optimizer model, please file an issue.
diff --git a/test/test_cuda.py b/test/test_cuda.py
index dbc45be..4d24db0 100644
--- a/test/test_cuda.py
+++ b/test/test_cuda.py
@@ -3,7 +3,7 @@
 import tempfile
 import unittest
 import sys
-from itertools import repeat, chain
+from itertools import repeat
 import os
 from contextlib import contextmanager
 import threading
@@ -1975,359 +1975,6 @@
 t2.start()
 """])
 
-    def test_grad_scaling_builtins(self, device="cuda", dtype=torch.float):
-        inv_scale = torch.tensor([0.25], dtype=dtype, device=device)
-
-        found_inf = torch.tensor([0.0], dtype=dtype, device=device)
-        g = torch.tensor([4.0], dtype=dtype, device=device)
-        torch._amp_non_finite_check_and_unscale_(g, found_inf, inv_scale)
-        self.assertTrue(found_inf.item() == 0.0)
-        self.assertTrue(torch.allclose(g, torch.ones(10, dtype=torch.float32, device="cuda"), atol=1e-7))
-
-        found_inf.zero_()
-        g = torch.tensor([float('inf')], dtype=dtype, device=device)
-        torch._amp_non_finite_check_and_unscale_(g, found_inf, inv_scale)
-        self.assertTrue(found_inf.item() == 1.0)
-
-        found_inf.zero_()
-        g = torch.tensor([float('nan')], dtype=dtype, device=device)
-        torch._amp_non_finite_check_and_unscale_(g, found_inf, inv_scale)
-        self.assertTrue(found_inf.item() == 1.0)
-
-        growth_factor = 4.0
-        backoff_factor = 0.5
-        current_scale = torch.tensor([4.0], dtype=dtype, device=device)
-
-        found_inf.zero_()
-        new_scale = torch._amp_update_scale(current_scale, found_inf, growth_factor, backoff_factor)
-        self.assertTrue(new_scale.item(), 16.0)
-
-        found_inf.fill_(1.0)
-        new_scale = torch._amp_update_scale(current_scale, found_inf, growth_factor, backoff_factor)
-        self.assertTrue(new_scale.item(), 2.0)
-
-    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
-    def test_grad_scaling_device_as_key(self):
-        # Ensure that different instances of "device" objects that point to the same device
-        # are treated as identical keys by dicts.  GradScaler relies on this behavior, and may
-        # error otherwise in a way that's difficult to detect (a silent performance hit).
-        d = {}
-        dev0a = torch.device("cuda:0")
-        dev0b = torch.device("cuda:0")
-        dev1a = torch.device("cuda:1")
-        dev1b = torch.device("cuda:1")
-
-        self.assertTrue(hash(dev0a) == hash(dev0b))
-        self.assertTrue(hash(dev1a) == hash(dev1b))
-
-        d[dev0a] = "0a"
-        d[dev0b] = "0b"
-        self.assertTrue(len(d) == 1)
-        self.assertTrue(d[dev0a] == "0b")
-
-        d[dev1a] = "1a"
-        d[dev1b] = "1b"
-        self.assertTrue(len(d) == 2)
-        self.assertTrue(d[dev1a] == "1b")
-
-    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
-    def test_grad_scaling_scale(self):
-        scaler = torch.cuda.amp.GradScaler(init_scale=2.)
-        t0 = torch.tensor([4.0], dtype=torch.float32, device="cuda:0")
-        t1 = torch.tensor([4.0], dtype=torch.float32, device="cuda:1")
-        # Create some nested iterables of tensors on different devices.
-        outputs = (t1.clone(), (t0.clone(), t1.clone()), [t0.clone(), (t1.clone(), t0.clone())])
-        outputs = scaler.scale(outputs)
-        self.assertTrue(outputs[0] == 8.0 and outputs[1][0] == 8.0 and outputs[1][1] == 8.0 and
-                        outputs[2][0] == 8.0 and outputs[2][1][0] == 8.0 and outputs[2][1][1] == 8.0)
-        self.assertTrue(scaler._scale.device == t1.device)
-
-    def test_grad_scaling_state_dict(self):
-        for lazy_init_scale in True, False:
-            s0 = torch.cuda.amp.GradScaler(init_scale=3., growth_factor=4., backoff_factor=.5)
-            s1 = torch.cuda.amp.GradScaler(init_scale=6., growth_factor=7., backoff_factor=.8)
-
-            if lazy_init_scale:
-                # Dummy scale() call to ensure the scale tensor is lazily initialized.
-                s1.scale(torch.tensor([4.0], dtype=torch.float32, device="cuda:0"))
-                self.assertTrue(isinstance(s1._scale, torch.cuda.FloatTensor))
-
-            s1.load_state_dict(s0.state_dict())
-
-            self.assertTrue(s1.get_scale() == 3.)
-            self.assertTrue(s1.get_growth_factor() == 4.)
-            self.assertTrue(s1.get_backoff_factor() == .5)
-
-    def _create_scaling_models_optimizers(self, device="cuda"):
-        # Create a module+optimizer that will use scaling, and a control module+optimizer
-        # that will not use scaling, against which the scaling-enabled module+optimizer can be compared.
-        mod_control = torch.nn.Sequential(torch.nn.Linear(8, 8), torch.nn.Linear(8, 8)).to(device=device)
-        mod_scaling = torch.nn.Sequential(torch.nn.Linear(8, 8), torch.nn.Linear(8, 8)).to(device=device)
-        for c, s in zip(mod_control.parameters(), mod_scaling.parameters()):
-            s.data.copy_(c.data)
-
-        opt_control = torch.optim.SGD(mod_control.parameters(), lr=1.0)
-        opt_scaling = torch.optim.SGD(mod_scaling.parameters(), lr=1.0)
-
-        return mod_control, mod_scaling, opt_control, opt_scaling
-
-    def _create_scaling_case(self, device="cuda", dtype=torch.float):
-        data = [(torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)),
-                (torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)),
-                (torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)),
-                (torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device))]
-
-        loss_fn = torch.nn.MSELoss().cuda()
-
-        skip_iter = 2
-
-        return self._create_scaling_models_optimizers(device=device) + (data, loss_fn, skip_iter)
-
-    # _run_scaling_case generalizes some single-optimizer test logic to avoid too much copy-pasting below.
-    def _run_scaling_case(self, run, unskipped, skipped):
-        # Ensure scaling can be disabled without changing user control flow.
-        for enabled in True, False:
-            mod_control, mod_scaling, opt_control, opt_scaling, data, loss_fn, skip_iter = self._create_scaling_case()
-
-            # For functionality, test with a modest initial scale, and an unrealistically-large growth factor
-            # so any potential errors with the growth factor handling will be magnified.
-            scaler = torch.cuda.amp.GradScaler(init_scale=128., growth_factor=2.0, enabled=enabled)
-
-            run(data, mod_control, opt_control, scaler, loss_fn, skip_iter, False)
-            run(data, mod_scaling, opt_scaling, scaler, loss_fn, skip_iter, True)
-
-            # If scaling was enabled, the scale factor should have been multiplied by the growth factor
-            # len(data) - skipped times and the backoff factor "skipped" times.
-            if enabled:
-                net_growth = scaler.get_growth_factor()**unskipped if unskipped > 0 else 1.0
-                net_backoff = scaler.get_backoff_factor()**skipped if skipped > 0 else 1.0
-                self.assertTrue(scaler.get_scale() == (128. * net_growth * net_backoff))
-            else:
-                self.assertTrue(scaler.get_scale() == 1.0)
-
-            for c, s in zip(mod_control.parameters(), mod_scaling.parameters()):
-                self.assertTrue(torch.allclose(c, s, atol=1e-7))
-
-    def test_grad_scaling_clipping(self):
-        def run(data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_api):
-            max_norm = 0.2  # A reasonable value that actually has an effect, based on printouts of grads
-            for i, (input, target) in enumerate(data):
-                optimizer.zero_grad()
-                output = model(input)
-                loss = loss_fn(output, target)
-                if try_scaling_api:
-                    scaler.scale(loss).backward()
-                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm * scaler.get_scale())
-                    if i == skip_iter and scaler.is_enabled():
-                        model[1].weight.grad.data.fill_(float('inf'))
-                    scaler.step(optimizer)
-                    scaler.update()
-                else:
-                    loss.backward()
-                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
-                    if (not scaler.is_enabled()) or (i != skip_iter):
-                        optimizer.step()
-
-        self._run_scaling_case(run, unskipped=3, skipped=1)
-
-    def test_grad_scaling_clipping_separate_unscale(self):
-        def run(data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_api):
-            max_norm = 0.2  # A reasonable value that actually has an effect, based on printouts of grads
-            for i, (input, target) in enumerate(data):
-                optimizer.zero_grad()
-                output = model(input)
-                loss = loss_fn(output, target)
-                if try_scaling_api:
-                    scaler.scale(loss).backward()
-                    if i == skip_iter and scaler.is_enabled():
-                        model[1].weight.grad.data.fill_(float('inf'))
-                    scaler.unscale_(optimizer)
-                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
-                    scaler.step(optimizer)
-                    scaler.update()
-                else:
-                    loss.backward()
-                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
-                    if (not scaler.is_enabled()) or (i != skip_iter):
-                        optimizer.step()
-
-        self._run_scaling_case(run, unskipped=3, skipped=1)
-
-    def test_grad_scaling_penalty(self):
-        def run(data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_api):
-            for i, (input, target) in enumerate(data):
-                optimizer.zero_grad()
-                output = model(input)
-                loss = loss_fn(output, target)
-
-                if try_scaling_api:
-                    grad_params = torch.autograd.grad(scaler.scale(loss),
-                                                      model.parameters(), create_graph=True)
-                    inv_scale = 1. / scaler.get_scale()
-                    grad_params = [p * inv_scale for p in grad_params]
-                else:
-                    grad_params = torch.autograd.grad(loss, model.parameters(), create_graph=True)
-
-                grad_norm = 0
-                for grad in grad_params:
-                    grad_norm += grad.pow(2).sum()
-                grad_norm = grad_norm.sqrt()
-                loss = loss + grad_norm
-
-                if try_scaling_api:
-                    scaler.scale(loss).backward()
-                    if i == skip_iter and scaler.is_enabled():
-                        model[1].weight.grad.data.fill_(float('inf'))
-                    scaler.step(optimizer)
-                    scaler.update()
-                else:
-                    loss.backward()
-                    if (not scaler.is_enabled()) or (i != skip_iter):
-                        optimizer.step()
-
-        self._run_scaling_case(run, unskipped=3, skipped=1)
-
-    def test_grad_scaling_accumulation(self):
-        def run(data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_api):
-            iters_to_accumulate = 2
-            for i, (input, target) in enumerate(data):
-                output = model(input)
-                loss = loss_fn(output, target)
-                loss = loss / iters_to_accumulate
-                if try_scaling_api:
-                    scaler.scale(loss).backward()
-                else:
-                    loss.backward()
-                if (i + 1) % iters_to_accumulate == 0:
-                    if try_scaling_api:
-                        scaler.step(optimizer)
-                        scaler.update()
-                        optimizer.zero_grad()
-                    else:
-                        optimizer.step()
-                        optimizer.zero_grad()
-
-        self._run_scaling_case(run, unskipped=2, skipped=0)
-
-    def test_grad_scaling_multiple(self):
-        # Tests gradient scaling with 2 models and 2 optimizers that both receive gradients from 2 losses.
-        # Some of the logic here cannot reuse the generic helper functions created for the 1-optimizer cases.
-        for enabled in True, False:
-            mod_control0, mod_scaling0, opt_control0, opt_scaling0, data, loss_fn, skip_iter = \
-                self._create_scaling_case()
-            mod_control1, mod_scaling1, opt_control1, opt_scaling1 = \
-                self._create_scaling_models_optimizers()
-
-            scaler = torch.cuda.amp.GradScaler(init_scale=128., growth_factor=2.0, enabled=enabled)
-
-            def run(model0, model1, optimizer0, optimizer1, try_scaling_api):
-                for i, (input, target) in enumerate(data):
-                    optimizer0.zero_grad()
-                    optimizer1.zero_grad()
-                    output0 = model0(input)
-                    output1 = model1(input)
-                    loss0 = loss_fn(0.3 * output0 + 0.7 * output1, target)
-                    loss1 = loss_fn(0.6 * output0 - 0.4 * output1, target)
-
-                    if try_scaling_api:
-                        scaler.scale(loss0).backward(retain_graph=True)
-                        scaler.scale(loss1).backward()
-                        if i == skip_iter and scaler.is_enabled():
-                            model1[1].weight.grad.data.fill_(float('inf'))
-
-                        # As an additional stress test, separately unscale for one of the optimizers.
-                        scaler.unscale_(optimizer0)
-
-                        scaler.step(optimizer0)
-                        scaler.step(optimizer1)
-                        scaler.update()
-                    else:
-                        loss0.backward(retain_graph=True)
-                        loss1.backward()
-                        optimizer0.step()
-                        if (not scaler.is_enabled()) or (i != skip_iter):
-                            optimizer1.step()
-
-            run(mod_control0, mod_control1, opt_control0, opt_control1, False)
-            run(mod_scaling0, mod_scaling1, opt_scaling0, opt_scaling1, True)
-
-            # The loss scale should have been multiplied by the growth factor 3 times and the backoff factor once.
-            self.assertTrue(scaler.get_scale() == (128. * scaler.get_growth_factor()**3 *
-                                                   scaler.get_backoff_factor()**1) if enabled else 1.0)
-
-            for c, s in zip(chain(mod_control0.parameters(), mod_control1.parameters()),
-                            chain(mod_scaling0.parameters(), mod_scaling1.parameters())):
-                self.assertTrue(torch.allclose(c, s, atol=1e-7))
-
-    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
-    def test_grad_scaling_multigpu(self):
-        # Same as above, but runs some of the models on device 1.
-        # GradScaler should transparently handle losses and gradients on multiple devices.
-        # This test could be combined with the test above, but I think it makes sense to treat
-        # multi-GPU operations separately.
-        dev0 = torch.device("cuda:0")
-        dev1 = torch.device("cuda:1")
-
-        for enabled in True, False:
-            mod_control0, mod_scaling0, opt_control0, opt_scaling0, data, loss_fn, skip_iter = \
-                self._create_scaling_case()
-            mod_control1, mod_scaling1, opt_control1, opt_scaling1 = \
-                self._create_scaling_models_optimizers(device=dev1)
-
-            scaler = torch.cuda.amp.GradScaler(init_scale=128., growth_factor=2.0, enabled=enabled)
-
-            def run(model0, model1, optimizer0, optimizer1, try_scaling_api):
-                for i, (input, target) in enumerate(data):
-                    optimizer0.zero_grad()
-                    optimizer1.zero_grad()
-                    output0 = model0(input)
-                    output1 = model1(input.to(dev1))
-                    loss0 = loss_fn(0.3 * output0 + 0.7 * output1.to(dev0), target)
-                    loss1 = loss_fn(0.6 * output0.to(dev1) - 0.4 * output1, target.to(dev1))
-
-                    if try_scaling_api:
-                        scaler.scale(loss0).backward(retain_graph=True)
-                        scaler.scale(loss1).backward()
-                        if i == skip_iter and scaler.is_enabled():
-                            model1[1].weight.grad.data.fill_(float('inf'))
-
-                        # As an additional stress test, separately unscale for one of the optimizers.
-                        scaler.unscale_(optimizer0)
-
-                        scaler.step(optimizer0)
-                        scaler.step(optimizer1)
-
-                        # Make sure the found_infs were collected properly across optimizers and devices.
-                        if scaler.is_enabled():
-                            self.assertTrue(len(scaler._found_inf_per_device(optimizer0)) == 1)
-                            self.assertTrue(len(scaler._found_inf_per_device(optimizer1)) == 1)
-                            self.assertTrue(scaler._found_inf_per_device(optimizer0)[dev0].item() == 0.)
-                            self.assertTrue(scaler._found_inf_per_device(optimizer1)[dev1].item() ==
-                                            float(i == skip_iter))
-
-                        scaler.update()
-                    else:
-                        loss0.backward(retain_graph=True)
-                        loss1.backward()
-                        optimizer0.step()
-                        if (not scaler.is_enabled()) or (i != skip_iter):
-                            optimizer1.step()
-
-            run(mod_control0, mod_control1, opt_control0, opt_control1, False)
-            run(mod_scaling0, mod_scaling1, opt_scaling0, opt_scaling1, True)
-
-            # The loss scale should have been multiplied by the growth factor 3 times and the backoff factor once.
-            self.assertTrue(scaler.get_scale() == (128. * scaler.get_growth_factor()**3 *
-                                                   scaler.get_backoff_factor()**1) if enabled else 1.0)
-
-            # Copy mod_control1 and mod_scaling1 back the device 0 for comparison
-            mod_control1.to(dev0)
-            mod_scaling1.to(dev0)
-
-            for c, s in zip(chain(mod_control0.parameters(), mod_control1.parameters()),
-                            chain(mod_scaling0.parameters(), mod_scaling1.parameters())):
-                self.assertTrue(torch.allclose(c, s, atol=1e-7))
-
     @skipIfRocm
     @unittest.skipIf(not PY3, "Barrier is unavailable before Python3")
     def test_cublas_multiple_threads_same_device(self):
diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py
index 06a2d64..3ff253a 100644
--- a/torch/cuda/__init__.py
+++ b/torch/cuda/__init__.py
@@ -543,4 +543,3 @@
 from . import profiler
 from . import nvtx
 from .streams import Stream, Event
-from . import amp
diff --git a/torch/cuda/amp/__init__.py b/torch/cuda/amp/__init__.py
deleted file mode 100644
index 0251eaf..0000000
--- a/torch/cuda/amp/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .grad_scaler import GradScaler  # noqa: F401
diff --git a/torch/cuda/amp/grad_scaler.py b/torch/cuda/amp/grad_scaler.py
deleted file mode 100644
index 797d9b3..0000000
--- a/torch/cuda/amp/grad_scaler.py
+++ /dev/null
@@ -1,419 +0,0 @@
-import torch
-from collections import defaultdict
-from torch._six import container_abcs
-
-
-class _MultiDeviceReplicator(object):
-    """
-    Lazily serves copies of a tensor to requested devices.  Copies are cached per-device.
-    """
-    def __init__(self, master_tensor):
-        assert master_tensor.is_cuda
-        self.master = master_tensor
-        self._per_device_tensors = {}
-
-    def get(self, device):
-        retval = self._per_device_tensors.get(device, None)
-        if retval is None:
-            retval = self.master.to(device=device, non_blocking=True, copy=True)
-            self._per_device_tensors[device] = retval
-        return retval
-
-
-class GradScaler(object):
-    """
-    An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling
-    conveniently.
-
-    * ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor.
-    * ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``.
-    * ``scaler.update()`` updates ``scaler``'s scale factor.
-
-    Typical use::
-
-        # Creates a GradScaler once at the beginning of training.
-        scaler = GradScaler()
-
-        for epoch in epochs:
-            for input, target in data:
-                optimizer.zero_grad()
-                output = model(input)
-                loss = loss_fn(output, target)
-
-                # Scales the loss, and calls backward() on the scaled loss to create scaled gradients.
-                scaler.scale(loss).backward()
-
-                # scaler.step() first unscales the gradients of the optimizer's assigned params.
-                # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
-                # otherwise, optimizer.step() is skipped.
-                scaler.step(optimizer)
-
-                # Updates the scale for next iteration.
-                scaler.update()
-
-    See the :ref:`Gradient Scaling Examples<gradient-scaling-examples>` for usage in more complex cases like
-    gradient clipping, gradient penalty, and multiple losses/optimizers.
-
-    ``scaler`` dynamically estimates the scale factor each iteration.  To minimize gradient underflow,
-    a large scale factor should be used.  However, ``torch.float16`` values can "overflow" (become inf or NaN) if
-    the scale factor is too large.  Therefore, the optimal scale factor is the largest factor that can be used
-    without incurring inf or NaN gradient values.
-    ``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every
-    ``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`).
-    If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual and
-    ``scaler.update()`` multiplies the scale factor by ``growth_factor``.  If infs/NaNs are found,
-    ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params themselves remain uncorrupted)
-    and multiplies the scale factor by ``backoff_factor``.
-
-    The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its
-    value calibrates.  ``scaler.step`` will skip the underlying ``optimizer.step()`` for these
-    iterations.  After that, step skipping should occur rarely (once every few hundred iterations).
-
-    Arguments:
-        init_scale (float, optional, default=2.**24):  Initial scale factor.
-        growth_factor (float, optional, default=1.001):  Factor by which the scale is multiplied during
-            :meth:`update` if no inf/NaN gradients were found this iteration.  The default value is recommended.
-        backoff_factor (float, optional, default=0.5):  Factor by which the scale is multiplied during
-            :meth:`update` if inf/NaN gradients were found this iteration.  The default value is recommended.
-        enabled (bool,optional, default=True):  If ``False``, disables gradient scaling. :meth:`step` simply
-            invokes the underlying ``optimizer.step()``, and other methods become no-ops.
-    """
-    # Python 2 doesn't support enums.
-    READY = 0
-    UNSCALED = 1
-    STEPPED = 2
-
-    def __init__(self,
-                 init_scale=2.**24,
-                 growth_factor=1.001,
-                 backoff_factor=0.5,
-                 enabled=True):
-        self._enabled = enabled
-        if enabled:
-            assert growth_factor > 1.0, "The growth factor must be > 1.0.  Using the default value is recommended."
-            assert backoff_factor < 1.0, "The backoff factor must be < 1.0.  Using the default value is recommended."
-
-            self._init_scale = init_scale
-            # self._scale will be lazily initialized during the first call to scaler.scale(loss or outputs)
-            self._scale = None
-            self._growth_factor = growth_factor
-            self._backoff_factor = backoff_factor
-            self._per_optimizer_states = defaultdict(lambda: {"stage": self.READY, "found_inf_per_device": {}})
-
-    @staticmethod
-    def _scale_not_initialized_error(funcname):
-        return "Attempted to call {} but the scale tensor is None. This may indicate your ".format(funcname) + \
-               "script did not use scaler.scale(loss or outputs) earlier in the iteration."
-
-    def scale(self, outputs):
-        """
-        Multiplies ('scales') a tensor or list of tensors by the scale factor.
-
-        Arguments:
-            outputs (Tensor or iterable of Tensors):  Outputs to scale.
-
-        Returns:
-            Scaled outputs.  If this instance of :class:`GradScaler` is not enabled, outputs are returned unmodified.
-        """
-        if not self._enabled:
-            return outputs
-
-        # Short-circuit for the common case.
-        if isinstance(outputs, torch.Tensor):
-            assert outputs.is_cuda
-            if self._scale is None:
-                self._scale = torch.full((1,), self._init_scale, dtype=torch.float32, device=outputs.device)
-            return outputs * self._scale.to(device=outputs.device, non_blocking=True)
-
-        # Invoke the more complex machinery only if we're treating multiple outputs.
-        stash = [None]  # trick to hold a reference that can be overwritten at any level of the recursion below.
-
-        def apply_scale(val):
-            if isinstance(val, torch.Tensor):
-                assert val.is_cuda
-                if self._scale is None:
-                    self._scale = torch.full((1,), self._init_scale, dtype=torch.float32, device=val.device)
-                if stash[0] is None:
-                    stash[0] = _MultiDeviceReplicator(self._scale)
-                return val * stash[0].get(val.device)
-            elif isinstance(val, container_abcs.Iterable):
-                return type(val)(apply_scale(v) for v in val)
-            else:
-                raise ValueError("outputs must be a Tensor or an iterable of Tensors")
-
-        return apply_scale(outputs)
-
-    def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16):
-        per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
-        per_device_found_inf = _MultiDeviceReplicator(found_inf)
-
-        for group in optimizer.param_groups:
-            for param in group["params"]:
-                if param.grad is not None:
-                    if (not allow_fp16) and param.grad.dtype == torch.float16:
-                        raise ValueError("Attempting to unscale FP16 gradients.")
-                    else:
-                        torch._amp_non_finite_check_and_unscale_(param.grad,
-                                                                 per_device_found_inf.get(param.grad.device),
-                                                                 per_device_inv_scale.get(param.grad.device))
-
-        return per_device_found_inf._per_device_tensors
-
-    def unscale_(self, optimizer):
-        """
-        Divides ("unscales") the optimizer's gradient tensors by the scale factor.
-
-        :meth:`unscale_` is optional, serving cases where you need to
-        :ref:`modify or inspect gradients<working-with-unscaled-gradients>`
-        between the backward pass(es) and :meth:`step`.
-        If :meth:`unscale_` is not called explicitly,  gradients will be unscaled  automatically during :meth:`step`.
-
-        Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
-
-            ...
-            scaler.scale(loss).backward()
-            scaler.unscale_(optimizer)
-            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
-            scaler.step(optimizer)
-            scaler.update()
-
-        Arguments:
-            optimizer (torch.optim.Optimizer):  Optimizer that owns the gradients to be unscaled.
-
-        .. note::
-            :meth:`unscale_` does not incur a CPU-GPU sync.
-
-        .. warning::
-            :meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
-            and only after all gradients for that optimizer's assigned parameters have been accumulated.
-            Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
-        """
-        if not self._enabled:
-            return
-
-        assert self._scale is not None, self._scale_not_initialized_error("unscale")
-
-        optimizer_state = self._per_optimizer_states[id(optimizer)]
-
-        if optimizer_state["stage"] == self.UNSCALED:
-            raise RuntimeError("unscale_() has already been called on this optimizer since the last update().")
-        elif optimizer_state["stage"] == self.STEPPED:
-            raise RuntimeError("unscale_() is being called after step().")
-
-        # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
-        inv_scale = self._scale.double().reciprocal().float()
-        found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=self._scale.device)
-
-        optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False)
-        optimizer_state["stage"] = self.UNSCALED
-
-    def step(self, optimizer, *args, **kwargs):
-        """
-        :meth:`step` carries out the following two operations:
-
-        1.  Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer``
-            earlier in the iteration).  As part of the :meth:`unscale_`, gradients are checked for infs/NaNs.
-        2.  If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled
-            gradients.  Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params.
-
-        ``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``.
-
-        Arguments:
-            optimizer (torch.optim.Optimizer):  Optimizer that applies the gradients.
-            args:  Any arguments.
-            kwargs:  Any keyword arguments.
-
-        Returns:
-            The return value of ``optimizer.step(*args, **kwargs)``.
-
-        .. warning::
-            Closure use is not currently supported.
-        """
-        if (not self._enabled):
-            return optimizer.step(*args, **kwargs)
-
-        if "closure" in kwargs:
-            raise RuntimeError("Closure use is not currently supported if GradScaler is enabled.")
-
-        assert self._scale is not None, self._scale_not_initialized_error("step")
-
-        optimizer_state = self._per_optimizer_states[id(optimizer)]
-
-        if optimizer_state["stage"] == self.STEPPED:
-            raise RuntimeError("step() has already been called since the last update().")
-
-        retval = None
-
-        if (hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling):
-            # This optimizer has customized scale-handling logic, so we can call optimizer.step() directly.
-            # The contract with custom optimizers is that their step() should accept an additional,
-            # optional grad_scaler kwarg.  We append self to the kwargs so the custom optimizer has full information:
-            # it can query its own state, invoke unscale_ on itself, etc
-            retval = optimizer.step(*args, **dict(kwargs, grad_scaler=self))
-            optimizer_state["stage"] == self.STEPPED
-            return retval
-
-        if optimizer_state["stage"] == self.READY:
-            self.unscale_(optimizer)
-
-        assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."
-
-        if not sum(v.item() for v in optimizer_state["found_inf_per_device"].values()):
-            retval = optimizer.step(*args, **kwargs)
-
-        optimizer_state["stage"] == self.STEPPED
-
-        return retval
-
-    def update(self, new_scale=None):
-        """
-        Updates the scale factor.
-
-        If any optimizer steps were skipped the scale factor is multipled by
-        ``backoff_factor`` to reduce it. If all optimizer steps were taken
-        it is multiplied by ``growth_factor`` to increase it.
-
-        Passing ``new_scale`` sets the scale factor directly.
-
-        Arguments:
-            new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None):  New scale factor.
-
-        .. warning::
-            :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
-            been invoked for all optimizers used this iteration.
-        """
-        if not self._enabled:
-            return
-
-        assert self._scale is not None, self._scale_not_initialized_error("update")
-
-        if new_scale is not None:
-            # Accept a new user-defined scale.
-            if isinstance(new_scale, float):
-                self._scale = torch.full((1,), new_scale, dtype=torch.float32, device=self._scale.device)
-            else:
-                reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
-                assert isinstance(new_scale, torch.cuda.FloatTensor), reason
-                assert new_scale.numel() == 1, reason
-                assert new_scale.requires_grad is False, reason
-                self._scale = new_scale
-        else:
-            # Consume shared inf/nan data collected from optimizers to update the scale.
-            # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
-            found_infs = [found_inf.to(device=self._scale.device, non_blocking=True)
-                          for state in self._per_optimizer_states.values()
-                          for found_inf in state["found_inf_per_device"].values()]
-
-            assert len(found_infs) > 0, "No inf checks were recorded prior to update."
-
-            found_inf_combined = found_infs[0]
-            if len(found_infs) > 1:
-                for i in range(1, len(found_infs)):
-                    found_inf_combined += found_infs[i]
-
-            self._scale = torch._amp_update_scale(self._scale,
-                                                  found_inf_combined,
-                                                  self._growth_factor,
-                                                  self._backoff_factor)
-
-        # To prepare for next iteration, clear the data collected from optimizers this iteration.
-        self._per_optimizer_states = defaultdict(lambda: {"stage": self.READY, "found_inf_per_device": {}})
-
-    def _get_scale_async(self):
-        return self._scale
-
-    def get_scale(self):
-        """
-        Returns:
-            A Python float containing the current scale, or 1.0 if scaling is disabled.
-
-        .. warning::
-            :meth:`get_scale` incurs a CPU-GPU sync.
-        """
-        if self._enabled:
-            return self._init_scale if self._scale is None else self._get_scale_async().item()
-        else:
-            return 1.0
-
-    def get_growth_factor(self):
-        r"""
-        Returns:
-            A Python float containing the scale growth factor.
-        """
-        return self._growth_factor
-
-    def set_growth_factor(self, new_factor):
-        r"""
-        Arguments:
-            new_scale (float):  Value to use as the new scale growth factor.
-        """
-        self._growth_factor = new_factor
-
-    def get_backoff_factor(self):
-        r"""
-        Returns:
-            A Python float containing the scale backoff factor.
-        """
-        return self._backoff_factor
-
-    def set_backoff_factor(self, new_factor):
-        r"""
-        Arguments:
-            new_scale (float):  Value to use as the new scale backoff factor.
-        """
-        self._backoff_factor = new_factor
-
-    def is_enabled(self):
-        r"""
-        Returns:
-            A bool indicating whether this instance is enabled.
-        """
-        return self._enabled
-
-    def state_dict(self):
-        r"""
-        Returns the state of the scaler as a :class:`dict`.  It contains three entries:
-
-        * ``"scale"`` - a Python float containing the current scale
-        * ``"growth_factor"`` - a Python float containing the current growth factor
-        * ``"backoff_factor"`` - a Python float containing the current backoff factor
-
-        If this instance is not enabled, returns an empty dict.
-        """
-        return {"scale": self.get_scale(),
-                "growth_factor": self._growth_factor,
-                "backoff_factor": self._backoff_factor} if self._enabled else {}
-
-    def load_state_dict(self, state_dict):
-        r"""
-        Loads the scaler state.  If this instance is disabled, :meth:`load_state_dict` is a no-op.
-
-        Arguments:
-           state_dict(dict): scaler state.  Should be an object returned from a call to :meth:`state_dict`.
-        """
-        if not self._enabled:
-            return
-
-        if len(state_dict) == 0:
-            raise RuntimeError("The source state dict is empty, possibly because it was saved "
-                               "from a disabled instance of GradScaler.")
-
-        self._init_scale = state_dict["scale"]
-        if self._scale is not None:
-            self._scale.fill_(state_dict["scale"])
-        self._growth_factor = state_dict["growth_factor"]
-        self._backoff_factor = state_dict["backoff_factor"]
-
-    def _check_inf_per_device(self, optimizer):
-        assert self._scale is not None, self._scale_not_initialized_error("_check_inf_per_device")
-
-        dummy_inv_scale = torch.full((1,), 1.0, dtype=torch.float32, device=self._scale.device)
-        found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=self._scale.device)
-
-        self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \
-            self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True)
-
-        return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
-
-    def _found_inf_per_device(self, optimizer):
-        return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]