nvprims native batch norm patch (#88455)

Cherry-picking: https://github.com/csarofeen/pytorch/pull/2104

- [x] Added explicit cast on inputs to nvprims.native_batch_norm. This avoids the explicit cast, which gives us issue on fusion definition.
- [x] add python repro with dynamo

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88455
Approved by: https://github.com/mruberry, https://github.com/IvanYashchuk
diff --git a/test/test_nvfuser_dynamo.py b/test/test_nvfuser_dynamo.py
index b0c2838..749cae8 100644
--- a/test/test_nvfuser_dynamo.py
+++ b/test/test_nvfuser_dynamo.py
@@ -45,6 +45,25 @@
         eager_result = func.__wrapped__(input1, input2)
         self.assertEqual(eager_result, nvfuser_result)
 
+    def test_batch_norm_implicit_dtype_promotion(self):
+        input1 = make_tensor((2, 3, 4, 5), device="cuda", dtype=torch.float32)
+        input2 = make_tensor((5, 5), device="cuda", dtype=torch.float32)
+        w = make_tensor((3), device="cuda", dtype=torch.float32)
+        b = make_tensor((3), device="cuda", dtype=torch.float32)
+
+        @torchdynamo.optimize("nvprims_nvfuser")
+        def func(mat1, mat2, w, b):
+            o = torch.matmul(mat1, mat2)
+            return torch.batch_norm(o, w, b, None, None, True, 1e-2, 1e-5, True)
+
+        # No warnings and no errors
+        with torch.cuda.amp.autocast():
+            with warnings.catch_warnings(record=True) as warning:
+                nvfuser_result = func(input1, input2, w, b)
+                self.assertEqual(len(warning), 0)
+            eager_result = func.__wrapped__(input1, input2, w, b)
+            self.assertEqual(eager_result, nvfuser_result)
+
     def test_dtype_correctness(self):
         input1 = make_tensor((2, 4, 8), device="cuda", dtype=torch.float16)
 
diff --git a/torch/_prims/nvfuser_prims.py b/torch/_prims/nvfuser_prims.py
index 391a7fe..59a8820 100644
--- a/torch/_prims/nvfuser_prims.py
+++ b/torch/_prims/nvfuser_prims.py
@@ -5,12 +5,13 @@
 # can be added in the future for the corresponding higher-level torch/aten
 # functions.
 
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Optional, Tuple
 
 import torch
 
 from torch._prims_common import (
     DimsSequenceType,
+    elementwise_dtypes,
     ELEMENTWISE_TYPE_PROMOTION_KIND,
     getnvFuserDtype,
     make_contiguous_strides_for,
@@ -19,6 +20,7 @@
 )
 
 from torch._prims_common.wrappers import (
+    _maybe_convert_to_dtype,
     backwards_not_supported,
     elementwise_type_promotion_wrapper,
 )
@@ -373,12 +375,60 @@
         )
 
     nvprim_impl.impl(name, _prim_impl)
-    nvprim_autograd_impl.impl(
-        name, backwards_not_supported(torch.ops.nvprims.native_batch_norm.default)
-    )
-
     prim_packet = torch.ops.nvprims.native_batch_norm
     prim = prim_packet.default
+
+    def _native_batch_norm_ref(
+        input: torch.Tensor,
+        weight: Optional[torch.Tensor],
+        bias: Optional[torch.Tensor],
+        running_mean: Optional[torch.Tensor],
+        running_var: Optional[torch.Tensor],
+        training: bool,
+        momentum: float,
+        eps: float,
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+        if torch._prims_common.is_complex_dtype(input.dtype):
+            raise NotImplementedError("Complex tensors are not supported")
+
+        # note: BN only promotes input to dtype of weight/bias, but keeps the same output dtype
+        result_dtype = input.dtype
+        computation_dtype, _ = elementwise_dtypes(
+            input,
+            weight,
+            bias,
+            type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
+        )
+
+        input_ = _maybe_convert_to_dtype(input, computation_dtype)
+        output, mean, rstd = prim(
+            input_, weight, bias, running_mean, running_var, training, momentum, eps
+        )
+        output_ = _maybe_convert_to_dtype(output, result_dtype)  # type: ignore[arg-type]
+        return (output_, mean, rstd)  # type: ignore[return-value]
+
+    def _native_batch_norm_autograd(
+        input: torch.Tensor,
+        weight: Optional[torch.Tensor],
+        bias: Optional[torch.Tensor],
+        running_mean: Optional[torch.Tensor],
+        running_var: Optional[torch.Tensor],
+        training: bool,
+        momentum: float,
+        eps: float,
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        # This wrapper is needed to convert prims calls inside
+        # _native_batch_norm_ref to nvprims calls
+        from torch._prims.context import NvfuserPrimsMode
+
+        with NvfuserPrimsMode():
+            return backwards_not_supported(_native_batch_norm_ref)(
+                input, weight, bias, running_mean, running_var, training, momentum, eps
+            )
+
+    nvprim_autograd_impl.impl(name, _native_batch_norm_autograd)
+
     for p in (prim_packet, prim):
         p.__doc__ = "Computes batch normalization."
         p.impl_nvfuser = _nvfuser_impls["native_batch_norm"]