nvprims native batch norm patch (#88455)
Cherry-picking: https://github.com/csarofeen/pytorch/pull/2104
- [x] Added explicit cast on inputs to nvprims.native_batch_norm. This avoids the explicit cast, which gives us issue on fusion definition.
- [x] add python repro with dynamo
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88455
Approved by: https://github.com/mruberry, https://github.com/IvanYashchuk
diff --git a/test/test_nvfuser_dynamo.py b/test/test_nvfuser_dynamo.py
index b0c2838..749cae8 100644
--- a/test/test_nvfuser_dynamo.py
+++ b/test/test_nvfuser_dynamo.py
@@ -45,6 +45,25 @@
eager_result = func.__wrapped__(input1, input2)
self.assertEqual(eager_result, nvfuser_result)
+ def test_batch_norm_implicit_dtype_promotion(self):
+ input1 = make_tensor((2, 3, 4, 5), device="cuda", dtype=torch.float32)
+ input2 = make_tensor((5, 5), device="cuda", dtype=torch.float32)
+ w = make_tensor((3), device="cuda", dtype=torch.float32)
+ b = make_tensor((3), device="cuda", dtype=torch.float32)
+
+ @torchdynamo.optimize("nvprims_nvfuser")
+ def func(mat1, mat2, w, b):
+ o = torch.matmul(mat1, mat2)
+ return torch.batch_norm(o, w, b, None, None, True, 1e-2, 1e-5, True)
+
+ # No warnings and no errors
+ with torch.cuda.amp.autocast():
+ with warnings.catch_warnings(record=True) as warning:
+ nvfuser_result = func(input1, input2, w, b)
+ self.assertEqual(len(warning), 0)
+ eager_result = func.__wrapped__(input1, input2, w, b)
+ self.assertEqual(eager_result, nvfuser_result)
+
def test_dtype_correctness(self):
input1 = make_tensor((2, 4, 8), device="cuda", dtype=torch.float16)
diff --git a/torch/_prims/nvfuser_prims.py b/torch/_prims/nvfuser_prims.py
index 391a7fe..59a8820 100644
--- a/torch/_prims/nvfuser_prims.py
+++ b/torch/_prims/nvfuser_prims.py
@@ -5,12 +5,13 @@
# can be added in the future for the corresponding higher-level torch/aten
# functions.
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Optional, Tuple
import torch
from torch._prims_common import (
DimsSequenceType,
+ elementwise_dtypes,
ELEMENTWISE_TYPE_PROMOTION_KIND,
getnvFuserDtype,
make_contiguous_strides_for,
@@ -19,6 +20,7 @@
)
from torch._prims_common.wrappers import (
+ _maybe_convert_to_dtype,
backwards_not_supported,
elementwise_type_promotion_wrapper,
)
@@ -373,12 +375,60 @@
)
nvprim_impl.impl(name, _prim_impl)
- nvprim_autograd_impl.impl(
- name, backwards_not_supported(torch.ops.nvprims.native_batch_norm.default)
- )
-
prim_packet = torch.ops.nvprims.native_batch_norm
prim = prim_packet.default
+
+ def _native_batch_norm_ref(
+ input: torch.Tensor,
+ weight: Optional[torch.Tensor],
+ bias: Optional[torch.Tensor],
+ running_mean: Optional[torch.Tensor],
+ running_var: Optional[torch.Tensor],
+ training: bool,
+ momentum: float,
+ eps: float,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ if torch._prims_common.is_complex_dtype(input.dtype):
+ raise NotImplementedError("Complex tensors are not supported")
+
+ # note: BN only promotes input to dtype of weight/bias, but keeps the same output dtype
+ result_dtype = input.dtype
+ computation_dtype, _ = elementwise_dtypes(
+ input,
+ weight,
+ bias,
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
+ )
+
+ input_ = _maybe_convert_to_dtype(input, computation_dtype)
+ output, mean, rstd = prim(
+ input_, weight, bias, running_mean, running_var, training, momentum, eps
+ )
+ output_ = _maybe_convert_to_dtype(output, result_dtype) # type: ignore[arg-type]
+ return (output_, mean, rstd) # type: ignore[return-value]
+
+ def _native_batch_norm_autograd(
+ input: torch.Tensor,
+ weight: Optional[torch.Tensor],
+ bias: Optional[torch.Tensor],
+ running_mean: Optional[torch.Tensor],
+ running_var: Optional[torch.Tensor],
+ training: bool,
+ momentum: float,
+ eps: float,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ # This wrapper is needed to convert prims calls inside
+ # _native_batch_norm_ref to nvprims calls
+ from torch._prims.context import NvfuserPrimsMode
+
+ with NvfuserPrimsMode():
+ return backwards_not_supported(_native_batch_norm_ref)(
+ input, weight, bias, running_mean, running_var, training, momentum, eps
+ )
+
+ nvprim_autograd_impl.impl(name, _native_batch_norm_autograd)
+
for p in (prim_packet, prim):
p.__doc__ = "Computes batch normalization."
p.impl_nvfuser = _nvfuser_impls["native_batch_norm"]