Revert "Batch Norm Consolidation (#116092)"

This reverts commit 5680f565d5b7d4aa412a3988d3d91ca4c5679303.

Reverted https://github.com/pytorch/pytorch/pull/116092 on behalf of https://github.com/jeffdaily due to broke ROCm, PR signal was clean but trunk was not, the merge should have been blocked but wasn't ([comment](https://github.com/pytorch/pytorch/pull/116092#issuecomment-1981373237))
diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp
index 64de958..2b00567 100644
--- a/aten/src/ATen/native/Normalization.cpp
+++ b/aten/src/ATen/native/Normalization.cpp
@@ -29,11 +29,6 @@
 #include <ATen/ops/_native_batch_norm_legit_native.h>
 #include <ATen/ops/_native_batch_norm_legit_no_training.h>
 #include <ATen/ops/_native_batch_norm_legit_no_training_native.h>
-#include <ATen/ops/_batch_norm_with_update.h>
-#include <ATen/ops/_batch_norm_with_update_native.h>
-#include <ATen/ops/_batch_norm_no_update.h>
-#include <ATen/ops/_batch_norm_no_update_native.h>
-#include <ATen/ops/batch_norm_backward_native.h>
 #include <ATen/ops/alias.h>
 #include <ATen/ops/batch_norm.h>
 #include <ATen/ops/batch_norm_native.h>
@@ -484,58 +479,10 @@
   return std::make_tuple(grad_input, grad_weight, grad_bias);
 }
 
-BatchNormBackend _select_batch_norm_backend(
-    const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean,
-    const Tensor& running_var, bool training, double eps) {
-
-  auto& ctx = at::globalContext();
-  bool cudnn_enabled = ctx.userEnabledCuDNN();
-
-  if (
-      input.is_cuda()
-      && input.scalar_type() != at::kBFloat16 && weight.scalar_type() != at::kBFloat16
-      && (input.scalar_type() != at::kHalf
-        || weight.scalar_type() == at::kFloat)
-      && weight.defined() && bias.defined()
-      && ((running_mean.defined() && running_var.defined())
-        || (!running_mean.defined() && !running_var.defined() && training))
-      && (input.dim() >= 3)
-      && ((input.sym_size(0) <= 880801 && training) // spatial, training
-          ||(input.sym_size(0) <= 65535 && !training)) //spatial, eval
-      && detail::getCUDAHooks().compiledWithCuDNN()
-      && eps >= detail::getCUDAHooks().batchnormMinEpsilonCuDNN()
-      && cudnn_enabled && detail::getCUDAHooks().versionCuDNN() >= 5110L
-      && input.sym_numel() < std::numeric_limits<std::int32_t>::max() // some cuDNN kernels have 32-bit indexing limitations
-  ) {
-    return BatchNormBackend::Cudnn;
-  }
-
-  if (
-      input.is_cuda()
-      && input.dim() <= MIOPEN_DIM_MAX
-      && input.scalar_type() != at::kDouble
-      && input.scalar_type() != at::kBFloat16
-      && (weight.scalar_type() != at::kHalf)
-      && weight.defined() && bias.defined()
-      && ((running_mean.defined() && running_var.defined())
-        || (!running_mean.defined() && !running_var.defined() && training))
-      && detail::getCUDAHooks().compiledWithMIOpen()
-      && cudnn_enabled
-      && input.suggest_memory_format() != MemoryFormat::ChannelsLast
-      && input.suggest_memory_format() != MemoryFormat::ChannelsLast3d
-  ) {
-    return BatchNormBackend::Miopen;
-  }
-
-  return BatchNormBackend::Native;
-}
-
-
 // _batch_norm_impl_index(_backward) are used in the JIT be able to keep the run-time selection
 // of backends, while enabling it to keep the information about the used backend, so that it can
 // use its corresponding backward implementation.
 // XXX: The indices of backends need to be kept synchronized between this function and its _backward.
-// TODO: remove cudnn_enabled arg
 std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
     const Tensor& input, const c10::optional<Tensor>& weight_opt /* optional */, const c10::optional<Tensor>& bias_opt /* optional */, const c10::optional<Tensor>& running_mean_opt /* optional */, const c10::optional<Tensor>& running_var_opt /* optional */,
     bool training, double momentum, double eps, bool cudnn_enabled) {
@@ -580,9 +527,24 @@
     check_dims_match_num_input_features("bias", std::move(num_features), bias.sym_numel());
   }
 
-  BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, training, eps);
+  const bool use_cudnn = (
+      input.is_cuda()
+      && input.scalar_type() != at::kBFloat16 && weight.scalar_type() != at::kBFloat16
+      && (input.scalar_type() != at::kHalf
+        || weight.scalar_type() == at::kFloat)
+      && weight.defined() && bias.defined()
+      && ((running_mean.defined() && running_var.defined())
+        || (!running_mean.defined() && !running_var.defined() && training))
+      && (input.dim() >= 3)
+      && ((input.sym_size(0) <= 880801 && training) // spatial, training
+          ||(input.sym_size(0) <= 65535 && !training)) //spatial, eval
+      && detail::getCUDAHooks().compiledWithCuDNN()
+      && eps >= detail::getCUDAHooks().batchnormMinEpsilonCuDNN()
+      && cudnn_enabled && detail::getCUDAHooks().versionCuDNN() >= 5110L
+      && input.sym_numel() < std::numeric_limits<std::int32_t>::max() // some cuDNN kernels have 32-bit indexing limitations
+      );
 
-  if (backend == BatchNormBackend::Cudnn) {
+  if (use_cudnn) {
     auto input_c = input.contiguous(input.suggest_memory_format());
     auto weight_c = weight.contiguous();
     auto bias_c = bias.contiguous();
@@ -599,7 +561,19 @@
 
   Tensor reserve = at::empty({0}, input.options().dtype(kByte));
 
-  if (backend == BatchNormBackend::Miopen) {
+  bool use_miopen = (input.is_cuda()
+               && input.dim() <= MIOPEN_DIM_MAX
+               && input.scalar_type() != at::kDouble
+               && input.scalar_type() != at::kBFloat16
+               && (weight.scalar_type() != at::kHalf)
+               && weight.defined() && bias.defined()
+               && ((running_mean.defined() && running_var.defined())
+                 || (!running_mean.defined() && !running_var.defined() && training))
+               && detail::getCUDAHooks().compiledWithMIOpen()
+               && cudnn_enabled
+               );
+
+  if (use_miopen && input.suggest_memory_format() != MemoryFormat::ChannelsLast && input.suggest_memory_format() != MemoryFormat::ChannelsLast3d) {
     return std::tuple_cat(
              at::miopen_batch_norm(
                input.contiguous(), weight.contiguous(), bias.contiguous(),
@@ -663,7 +637,6 @@
   TORCH_INTERNAL_ASSERT(false, "Unsupported impl_index in _batch_norm_impl_index_backward: ", impl_index);
 }
 
-// TODO: remove cudnn_enabled arg
 Tensor batch_norm(
     const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
     const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
@@ -674,30 +647,6 @@
   const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
   return std::get<0>(at::_batch_norm_impl_index(input, weight, bias, running_mean, running_var,
                                                 training, momentum, eps, cudnn_enabled));
-  // TODO: switch to the new stack after the 2 week FC window
-  // if (training) {
-  //   BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, training, eps);
-  //   if (backend == BatchNormBackend::Cudnn || backend == BatchNormBackend::Miopen) {
-  //     auto input_c = input;
-  //     if (backend == BatchNormBackend::Cudnn) {
-  //         input_c = input.contiguous(input.suggest_memory_format());
-  //     } else {
-  //         input_c = input.contiguous();
-  //     }
-  //     auto weight_c = weight.contiguous();
-  //     auto bias_c = bias.contiguous();
-  //     auto rmean_c = running_mean.defined() ? running_mean.contiguous() : running_mean;
-  //     auto rvar_c = running_var.defined() ? running_var.contiguous() : running_var;
-  //     return std::get<0>(at::_batch_norm_with_update(input_c, weight_c, bias_c, const_cast<Tensor&>(rmean_c),
-  //                                                   const_cast<Tensor&>(rvar_c), momentum, eps));
-  //   } else {
-  //     return std::get<0>(at::_batch_norm_with_update(input, weight, bias, const_cast<Tensor&>(running_mean),
-  //                                                   const_cast<Tensor&>(running_var), momentum, eps));
-  //   }
-  // } else {
-  //   return std::get<0>(at::_batch_norm_no_update(input, weight, bias, running_mean, running_var,
-  //                                               momentum, eps));
-  // }
 }
 
 Tensor instance_norm(
@@ -849,38 +798,6 @@
   return batch_norm_cpu_out(self, weight_opt, bias_opt, running_mean_opt, running_var_opt, train, momentum, eps, output, save_mean, save_var);
 }
 
-std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_cpu(
-    const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
-    Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
-  Tensor output, save_mean, save_var;
-  std::tie(output, save_mean, save_var) =
-    batch_norm_cpu(input, weight_opt, bias_opt, running_mean, running_var, /*update*/true, momentum, eps);
-  Tensor reserve = at::empty({0}, input.options().dtype(kByte));
-  return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
-}
-
-std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> _batch_norm_with_update_cpu_out(
-    const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
-    Tensor& running_mean, Tensor& running_var, double momentum, double eps,
-    Tensor& out, Tensor& save_mean, Tensor& save_var, Tensor& reserve) {
-  std::tie(out, save_mean, save_var) =
-    batch_norm_cpu_out(input, weight_opt, bias_opt, running_mean, running_var, /*update*/true, momentum, eps, out, save_mean, save_var);
-  return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>(out, save_mean, save_var, reserve);
-}
-
-
-std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_no_update(
-    const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
-    const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
-    double momentum, double eps) {
-  const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
-  const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
-  Tensor output, save_mean, save_var;
-  std::tie(output, save_mean, save_var) =
-    batch_norm_cpu(input, weight_opt, bias_opt, const_cast<Tensor&>(running_mean), const_cast<Tensor&>(running_var), /*update*/false, momentum, eps);
-  Tensor reserve = at::empty({0}, input.options().dtype(kByte));
-  return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
-}
 
 std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_cpu(
     const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
@@ -909,13 +826,6 @@
   return batch_norm_cpu_out(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, eps, out, save_mean, save_var);
 }
 
-std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_cpu(
-    const Tensor& grad_output, const Tensor& input, const Tensor& weight,
-    const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
-    const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_var_opt,
-    bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
-  return batch_norm_backward_cpu(grad_output, input, weight, running_mean_opt, running_var_opt, save_mean_opt, save_var_opt, update, eps, grad_input_mask);
-}
 
 std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu(const Tensor& grad_out, const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_invstd_opt,
                                                            bool train, double eps, std::array<bool,3> grad_input_mask) {
diff --git a/aten/src/ATen/native/Normalization.h b/aten/src/ATen/native/Normalization.h
index 1ba99e7..6cd4dcd 100644
--- a/aten/src/ATen/native/Normalization.h
+++ b/aten/src/ATen/native/Normalization.h
@@ -8,12 +8,4 @@
 using renorm_scale_factor_fn = void (*) (TensorIteratorBase& iter, double maxnorm);
 DECLARE_DISPATCH(renorm_scale_factor_fn, renorm_scale_factor_stub);
 
-enum class BatchNormBackend {
-  Native,
-  Cudnn,
-  Miopen,
-};
-
-TORCH_API BatchNormBackend _select_batch_norm_backend(const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean, const Tensor& running_var, bool training, double eps);
-
 }  // namespace at::native
diff --git a/aten/src/ATen/native/cuda/Normalization.cu b/aten/src/ATen/native/cuda/Normalization.cu
index 1b4c159..655d32b 100644
--- a/aten/src/ATen/native/cuda/Normalization.cu
+++ b/aten/src/ATen/native/cuda/Normalization.cu
@@ -1,7 +1,5 @@
 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
 #include <ATen/cuda/detail/IndexUtils.cuh>
-#include <ATen/detail/CUDAHooksInterface.h>
-#include <ATen/native/Normalization.h>
 #include <ATen/native/TensorIterator.h>
 #include <ATen/native/ReduceOps.h>
 #include <ATen/native/Resize.h>
@@ -14,8 +12,6 @@
 #include <ATen/Functions.h>
 #include <ATen/NativeFunctions.h>
 #else
-#include <ATen/ops/_batch_norm_with_update_native.h>
-#include <ATen/ops/batch_norm_backward_native.h>
 #include <ATen/ops/batch_norm_backward_elemt_native.h>
 #include <ATen/ops/batch_norm_backward_reduce_native.h>
 #include <ATen/ops/batch_norm_elemt_native.h>
@@ -23,12 +19,8 @@
 #include <ATen/ops/batch_norm_gather_stats_with_counts_native.h>
 #include <ATen/ops/batch_norm_stats_native.h>
 #include <ATen/ops/batch_norm_update_stats_native.h>
-#include <ATen/ops/cudnn_batch_norm.h>
-#include <ATen/ops/cudnn_batch_norm_backward.h>
 #include <ATen/ops/empty_like.h>
 #include <ATen/ops/from_blob.h>
-#include <ATen/ops/miopen_batch_norm.h>
-#include <ATen/ops/miopen_batch_norm_backward.h>
 #include <ATen/ops/native_batch_norm_backward_native.h>
 #include <ATen/ops/native_batch_norm_native.h>
 #include <ATen/ops/scalar_tensor.h>
@@ -481,54 +473,6 @@
   return std::make_tuple(output, save_mean, save_invstd);
 }
 
-std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_cuda(
-    const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
-    Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
-  // See [Note: hacky wrapper removal for optional tensor]
-  c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
-  const Tensor& weight = *weight_maybe_owned;
-  const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
-  Tensor output, save_mean, save_var, reserve;
-
-  BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, /*training*/true, eps);
-  if (backend == BatchNormBackend::Cudnn) {
-    std::tie(output, save_mean, save_var, reserve) =
-        at::cudnn_batch_norm(input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
-  } else if (backend == BatchNormBackend::Miopen) {
-    reserve = at::empty({0}, input.options().dtype(kByte));
-    std::tie(output, save_mean, save_var) =
-        at::miopen_batch_norm(input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
-  } else {
-    reserve = at::empty({0}, input.options().dtype(kByte));
-    std::tie(output, save_mean, save_var) =
-        batch_norm_cuda(input, weight_opt, bias_opt, running_mean, running_var, /*training*/true, momentum, eps);
-  }
-  return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
-}
-
-std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> _batch_norm_with_update_cuda_out(
-    const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
-    Tensor& running_mean, Tensor& running_var, double momentum, double eps,
-    Tensor& out, Tensor& save_mean, Tensor& save_var, Tensor& reserve) {
-  // See [Note: hacky wrapper removal for optional tensor]
-  c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
-  const Tensor& weight = *weight_maybe_owned;
-  const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
-
-  BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, /*training*/true, eps);
-  if (backend == BatchNormBackend::Cudnn) {
-    std::tie(out, save_mean, save_var, reserve) =
-        at::cudnn_batch_norm_out(out, save_mean, save_var, reserve, input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
-  } else if (backend == BatchNormBackend::Miopen) {
-    std::tie(out, save_mean, save_var) =
-        at::miopen_batch_norm_out(out, save_mean, save_var, input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
-  } else {
-    std::tie(out, save_mean, save_var) =
-      batch_norm_cuda_out(input, weight_opt, bias_opt, running_mean, running_var, /*update*/true, momentum, eps, out, save_mean, save_var);
-  }
-  return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>(out, save_mean, save_var, reserve);
-}
-
 std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_cuda(const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt, Tensor& running_mean, Tensor& running_var, bool train, double momentum, double epsilon) {
   return batch_norm_cuda(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, epsilon);
 }
@@ -545,28 +489,6 @@
   return batch_norm_cuda_out(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, epsilon, output, save_mean, save_invstd);
 }
 
-std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_cuda(
-    const Tensor& grad_output, const Tensor& input, const Tensor& weight,
-    const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
-    const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_var_opt,
-    bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
-  const Tensor& dummy_bias = at::empty(1);
-  const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
-  const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
-  const Tensor& save_mean = c10::value_or_else(save_mean_opt, [] {return Tensor();});
-  const Tensor& save_var = c10::value_or_else(save_var_opt, [] {return Tensor();});
-
-  BatchNormBackend backend = _select_batch_norm_backend(input, weight, dummy_bias, running_mean, running_var, /*training*/true, eps);
-
-  if (backend == BatchNormBackend::Cudnn) {
-    return at::cudnn_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var, eps, reserve);
-  } else if (backend == BatchNormBackend::Miopen) {
-    return at::miopen_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var, eps);
-  } else {
-    return batch_norm_backward_cuda(grad_output, input, weight, running_mean, running_var, save_mean, save_var, update, eps, grad_input_mask);
-  }
-}
-
 std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda(const Tensor& grad_out, const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_invstd_opt, bool train, double epsilon, std::array<bool,3> grad_input_mask) {
   // See [Note: hacky wrapper removal for optional tensor]
   c10::MaybeOwned<Tensor> weight = at::borrow_from_optional_tensor(weight_opt);
diff --git a/aten/src/ATen/native/cudnn/BatchNorm.cpp b/aten/src/ATen/native/cudnn/BatchNorm.cpp
index 2553139..2efe7a7 100644
--- a/aten/src/ATen/native/cudnn/BatchNorm.cpp
+++ b/aten/src/ATen/native/cudnn/BatchNorm.cpp
@@ -2,7 +2,6 @@
 #include <ATen/Config.h>
 #include <ATen/core/Tensor.h>
 #include <ATen/cuda/CUDAConfig.h>
-#include <ATen/native/cudnn/BatchNorm.h>
 
 #if !AT_CUDNN_ENABLED()
 
@@ -36,24 +35,18 @@
   AT_ERROR("cudnn_batch_norm_backward: ATen not compiled with cuDNN support");
 }
 
-size_t _get_cudnn_batch_norm_reserve_space_size(
-    const Tensor& input_t,
-    bool training) {
-  AT_ERROR(
-      "_get_cudnn_batch_norm_reserve_space_size: ATen not compiled with cuDNN support");
-}
-
 } // namespace native
 } // namespace at
 
 #else // AT_CUDNN_ENABLED
 
-#include <ATen/TensorUtils.h>
 #include <ATen/cuda/Exceptions.h>
 #include <ATen/cudnn/Descriptors.h>
 #include <ATen/cudnn/Types.h>
 #include <ATen/cudnn/Utils.h>
 
+#include <ATen/TensorUtils.h>
+
 #ifndef AT_PER_OPERATOR_HEADERS
 #include <ATen/Functions.h>
 #include <ATen/NativeFunctions.h>
@@ -105,21 +98,6 @@
 
 } // namespace
 
-size_t _get_cudnn_batch_norm_reserve_space_size(
-    const Tensor& input_t,
-    bool training) {
-  size_t reserve_size;
-  TensorArg input{input_t, "input", 1};
-  TensorDescriptor idesc{*input, 4};
-  auto handle = getCudnnHandle();
-  cudnnBatchNormMode_t mode = getCudnnBatchNormMode(
-      training, input->suggest_memory_format(), input->dim());
-  auto op = CUDNN_BATCHNORM_OPS_BN;
-  AT_CUDNN_CHECK(cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
-      handle, mode, op, nullptr, idesc.desc(), &reserve_size));
-  return reserve_size;
-}
-
 std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
     const Tensor& input_t,
     const Tensor& weight_t,
@@ -208,8 +186,9 @@
     Tensor workspace = at::empty(workspace_size, input->options().dtype(kByte));
 
     // get the reserved size and allocate as tensor
-    size_t reserve_size =
-        _get_cudnn_batch_norm_reserve_space_size(input_t, true /* training */);
+    size_t reserve_size;
+    AT_CUDNN_CHECK(cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
+        handle, mode, op, nullptr, idesc.desc(), &reserve_size));
     reserve = at::empty(reserve_size, input->options().dtype(kByte));
 
     AT_CUDNN_CHECK(cudnnBatchNormalizationForwardTrainingEx(
diff --git a/aten/src/ATen/native/cudnn/BatchNorm.h b/aten/src/ATen/native/cudnn/BatchNorm.h
deleted file mode 100644
index 3da76c0..0000000
--- a/aten/src/ATen/native/cudnn/BatchNorm.h
+++ /dev/null
@@ -1,6 +0,0 @@
-namespace at::native {
-
-TORCH_API size_t
-_get_cudnn_batch_norm_reserve_space_size(const Tensor& input_t, bool training);
-
-} // namespace at::native
diff --git a/aten/src/ATen/native/mkldnn/Normalization.cpp b/aten/src/ATen/native/mkldnn/Normalization.cpp
index 0aced61..108ce35 100644
--- a/aten/src/ATen/native/mkldnn/Normalization.cpp
+++ b/aten/src/ATen/native/mkldnn/Normalization.cpp
@@ -6,8 +6,6 @@
 #ifndef AT_PER_OPERATOR_HEADERS
 #include <ATen/NativeFunctions.h>
 #else
-#include <ATen/ops/_batch_norm_with_update_native.h>
-#include <ATen/ops/batch_norm_backward_native.h>
 #include <ATen/ops/_native_batch_norm_legit_native.h>
 #include <ATen/ops/_to_dense_native.h>
 #include <ATen/ops/empty_native.h>
@@ -61,20 +59,6 @@
   TORCH_CHECK(false, "_mkldnn_batch_norm_legit_no_stats: ATen not compiled with MKLDNN support");
 }
 
-std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_mkldnn(
-    const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
-    Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
-  TORCH_CHECK(false, "_batch_norm_with_update_mkldnn: ATen not compiled with MKLDNN support");
-}
-
-std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_mkldnn(
-    const Tensor& grad_output, const Tensor& input, const Tensor& weight,
-    const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
-    const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_var_opt,
-    bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
-  TORCH_CHECK(false, "_new_batch_norm_backward_mkldnn: ATen not compiled with MKLDNN support");
-}
-
 } // namespace native
 } // namespace at
 
@@ -208,17 +192,6 @@
 }
 
 
-std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_mkldnn(
-    const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
-    Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
-  Tensor output, save_mean, save_var;
-  std::tie(output, save_mean, save_var) =
-    mkldnn_batch_norm(input, weight_opt, bias_opt, running_mean, running_var, /*train*/true, momentum, eps);
-  Tensor reserve = empty_mkldnn({0}, input.scalar_type());
-  return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
-}
-
-
 std::tuple<Tensor, Tensor, Tensor> _mkldnn_batch_norm_legit(
     const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt, Tensor& running_mean, Tensor& running_var,
     bool train,
@@ -237,15 +210,6 @@
 }
 
 
-std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_mkldnn(
-    const Tensor& grad_output, const Tensor& input, const Tensor& weight,
-    const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
-    const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_var_opt,
-    bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
-  return mkldnn_batch_norm_backward(grad_output, input, weight, running_mean_opt, running_var_opt, save_mean_opt, save_var_opt, update, eps, grad_input_mask);
-}
-
-
 std::tuple<Tensor, Tensor, Tensor> mkldnn_batch_norm_backward(const Tensor& grad_output,
     const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_invstd_opt,
     bool train,
diff --git a/aten/src/ATen/native/mps/operations/Normalization.mm b/aten/src/ATen/native/mps/operations/Normalization.mm
index bdca3b0..eb754ae 100644
--- a/aten/src/ATen/native/mps/operations/Normalization.mm
+++ b/aten/src/ATen/native/mps/operations/Normalization.mm
@@ -10,9 +10,7 @@
 #include <ATen/Functions.h>
 #include <ATen/NativeFunctions.h>
 #else
-#include <ATen/ops/_batch_norm_with_update_native.h>
 #include <ATen/ops/_native_batch_norm_legit_native.h>
-#include <ATen/ops/batch_norm_backward_native.h>
 #include <ATen/ops/native_batch_norm.h>
 #include <ATen/ops/native_batch_norm_backward_native.h>
 #include <ATen/ops/native_batch_norm_native.h>
@@ -408,36 +406,6 @@
   return std::make_tuple(output, save_mean, save_var);
 }
 
-std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_mps(const Tensor& input,
-                                                                       const c10::optional<Tensor>& weight_opt,
-                                                                       const c10::optional<Tensor>& bias_opt,
-                                                                       Tensor& running_mean,
-                                                                       Tensor& running_var,
-                                                                       double momentum,
-                                                                       double eps) {
-  Tensor output, save_mean, save_var;
-  std::tie(output, save_mean, save_var) =
-      batch_norm_mps(input, weight_opt, bias_opt, running_mean, running_var, /*train*/ true, momentum, eps);
-  Tensor reserve = at::empty({0}, input.options().dtype(kByte));
-  return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
-}
-
-std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> _batch_norm_with_update_mps_out(const Tensor& input,
-                                                                               const c10::optional<Tensor>& weight_opt,
-                                                                               const c10::optional<Tensor>& bias_opt,
-                                                                               Tensor& running_mean,
-                                                                               Tensor& running_var,
-                                                                               double momentum,
-                                                                               double eps,
-                                                                               Tensor& out,
-                                                                               Tensor& save_mean,
-                                                                               Tensor& save_var,
-                                                                               Tensor& reserve) {
-  std::tie(out, save_mean, save_var) = batch_norm_mps_out(
-      input, weight_opt, bias_opt, running_mean, running_var, /*update*/ true, momentum, eps, out, save_mean, save_var);
-  return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>(out, save_mean, save_var, reserve);
-}
-
 std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_mps(const Tensor& self,
                                                          const c10::optional<Tensor>& weight_opt,
                                                          const c10::optional<Tensor>& bias_opt,
@@ -503,29 +471,6 @@
 }
 
 // Batch norm backward
-std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_mps(const Tensor& grad_output,
-                                                                const Tensor& input,
-                                                                const Tensor& weight,
-                                                                const c10::optional<Tensor>& running_mean_opt,
-                                                                const c10::optional<Tensor>& running_var_opt,
-                                                                const c10::optional<Tensor>& save_mean_opt,
-                                                                const c10::optional<Tensor>& save_var_opt,
-                                                                bool update,
-                                                                double eps,
-                                                                std::array<bool, 3> grad_input_mask,
-                                                                const Tensor& reserve) {
-  return batch_norm_backward_mps(grad_output,
-                                 input,
-                                 weight,
-                                 running_mean_opt,
-                                 running_var_opt,
-                                 save_mean_opt,
-                                 save_var_opt,
-                                 update,
-                                 eps,
-                                 grad_input_mask);
-}
-
 std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_mps(const Tensor& grad_out,
                                                            const Tensor& input,
                                                            const c10::optional<Tensor>& weight_opt,
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 809e9e9..2becda9 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -6486,32 +6486,6 @@
     SparseCPU, SparseCUDA: norm_sparse
   autogen: native_norm.ScalarOpt_dim_dtype_out
 
-- func: _batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)
-  dispatch:
-    CPU: _batch_norm_with_update_cpu
-    CUDA: _batch_norm_with_update_cuda
-    MPS: _batch_norm_with_update_mps
-    MkldnnCPU: _batch_norm_with_update_mkldnn
-  autogen: _batch_norm_with_update_functional
-
-- func: _batch_norm_with_update.out(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps, *, Tensor(d!) out, Tensor(e!) save_mean, Tensor(f!) save_invstd, Tensor(g!) reserve) -> (Tensor(d!), Tensor(e!), Tensor(f!), Tensor(g!))
-  dispatch:
-    CPU: _batch_norm_with_update_cpu_out
-    CUDA: _batch_norm_with_update_cuda_out
-    MPS: _batch_norm_with_update_mps_out
-
-- func: _batch_norm_no_update(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)
-  dispatch:
-    CompositeExplicitAutograd: _batch_norm_no_update
-  autogen: _batch_norm_no_update.out
-
-- func: batch_norm_backward(Tensor grad_out, Tensor input, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, bool update, float eps, bool[3] output_mask, Tensor reserve) -> (Tensor, Tensor, Tensor)
-  dispatch:
-    CPU: _new_batch_norm_backward_cpu
-    CUDA: _new_batch_norm_backward_cuda
-    MPS: _new_batch_norm_backward_mps
-    MkldnnCPU: _new_batch_norm_backward_mkldnn
-
 # TODO: reduce signatures down to one when optional args is available
 - func: _sparse_sum(Tensor self) -> Tensor
 
diff --git a/test/distributed/_tensor/test_dtensor_ops.py b/test/distributed/_tensor/test_dtensor_ops.py
index 66b78f0..3e88ba4 100644
--- a/test/distributed/_tensor/test_dtensor_ops.py
+++ b/test/distributed/_tensor/test_dtensor_ops.py
@@ -113,7 +113,6 @@
     xfail("as_strided", "partial_views"),
     xfail("as_strided_scatter"),
     xfail("bernoulli"),
-    xfail("_batch_norm_with_update"),
     xfail("block_diag"),
     xfail("broadcast_shapes"),
     xfail("cauchy"),
diff --git a/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_exhaustive_batch_norm_with_update_cpu_float32 b/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_exhaustive_batch_norm_with_update_cpu_float32
deleted file mode 100644
index e69de29..0000000
--- a/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_exhaustive_batch_norm_with_update_cpu_float32
+++ /dev/null
diff --git a/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_fake_exhaustive_batch_norm_with_update_cpu_float32 b/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_fake_exhaustive_batch_norm_with_update_cpu_float32
deleted file mode 100644
index e69de29..0000000
--- a/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_fake_exhaustive_batch_norm_with_update_cpu_float32
+++ /dev/null
diff --git a/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_batch_norm_with_update_cpu_float32 b/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_batch_norm_with_update_cpu_float32
deleted file mode 100644
index e69de29..0000000
--- a/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_batch_norm_with_update_cpu_float32
+++ /dev/null
diff --git a/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_inplace_batch_norm_with_update_cpu_float32 b/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_inplace_batch_norm_with_update_cpu_float32
deleted file mode 100644
index e69de29..0000000
--- a/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_inplace_batch_norm_with_update_cpu_float32
+++ /dev/null
diff --git a/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_out_batch_norm_with_update_cpu_float32 b/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_out_batch_norm_with_update_cpu_float32
deleted file mode 100644
index e69de29..0000000
--- a/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_out_batch_norm_with_update_cpu_float32
+++ /dev/null
diff --git a/test/expect/HasDecompTest.test_aten_core_operators.expect b/test/expect/HasDecompTest.test_aten_core_operators.expect
index f23021e..c9f379e 100644
--- a/test/expect/HasDecompTest.test_aten_core_operators.expect
+++ b/test/expect/HasDecompTest.test_aten_core_operators.expect
@@ -6,9 +6,6 @@
 aten::_adaptive_avg_pool2d.out
 aten::_addmm_activation
 aten::_addmm_activation.out
-aten::_batch_norm_no_update
-aten::_batch_norm_with_update
-aten::_batch_norm_with_update_functional
 aten::_euclidean_dist.out
 aten::_fused_dropout
 aten::_fused_dropout.out
@@ -79,7 +76,6 @@
 aten::atanh.out
 aten::atanh_
 aten::baddbmm_
-aten::batch_norm_backward
 aten::bitwise_and.Scalar
 aten::bitwise_and.Scalar_Tensor
 aten::bitwise_and.Scalar_Tensor_out
diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect
index 917bcad..3b540d3 100644
--- a/test/expect/HasDecompTest.test_has_decomposition.expect
+++ b/test/expect/HasDecompTest.test_has_decomposition.expect
@@ -30,8 +30,6 @@
 aten::_amp_update_scale_
 aten::_assert_async
 aten::_assert_async.msg
-aten::_batch_norm_no_update.out
-aten::_batch_norm_with_update.out
 aten::_cdist_backward
 aten::_cdist_backward.out
 aten::_cdist_forward
diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py
index 138fafad..15d5ff0 100644
--- a/test/functorch/test_ops.py
+++ b/test/functorch/test_ops.py
@@ -368,10 +368,6 @@
     # 'tensor_split' not composite compliant, see vjp_fail
 }
 
-skip_noncontig = {
-    '_batch_norm_with_update',
-}
-
 
 @unittest.skipIf(TEST_WITH_ASAN, "tests time out with asan, are probably redundant")
 @unMarkDynamoStrictTest
@@ -437,10 +433,9 @@
             args = [sample.input] + list(sample.args)
             kwargs = sample.kwargs
 
-            if op.name not in skip_noncontig:
-                noncontig_sample = sample.noncontiguous()
-                noncontig_args = [noncontig_sample.input] + list(noncontig_sample.args)
-                noncontig_kwargs = noncontig_sample.kwargs
+            noncontig_sample = sample.noncontiguous()
+            noncontig_args = [noncontig_sample.input] + list(noncontig_sample.args)
+            noncontig_kwargs = noncontig_sample.kwargs
 
             diff_argnums = tuple(i for i, arg in enumerate(args) if diff_arg(arg))
             assert len(diff_argnums) > 0
@@ -463,12 +458,11 @@
                 return result
 
             result = grad(wrapped_fn, diff_argnums)(*args, **kwargs)
+            result_noncontig = grad(wrapped_fn, diff_argnums)(*noncontig_args, **noncontig_kwargs)
             expected = _autograd_grad(_as_tuple(wrapped_fn(*args, **kwargs)), diff_args)
-            self.assertEqual(result, expected)
 
-            if op.name not in skip_noncontig:
-                result_noncontig = grad(wrapped_fn, diff_argnums)(*noncontig_args, **noncontig_kwargs)
-                self.assertEqual(result_noncontig, expected)
+            self.assertEqual(result, expected)
+            self.assertEqual(result_noncontig, expected)
 
     @with_tf32_off  # https://github.com/pytorch/pytorch/issues/86798
     @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
@@ -482,8 +476,7 @@
         skip('nn.functional.max_unpool2d'),  # fails everywhere except on windows
         skip('nn.functional.max_unpool3d'),  # fails everywhere except on mac
         xfail("native_batch_norm"),          # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents
-        xfail("_native_batch_norm_legit"),   # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents
-        xfail("_batch_norm_with_update"),     # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents
+        xfail("_native_batch_norm_legit"),    # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents
 
         xfail('nn.functional.scaled_dot_product_attention'),
         xfail('torch.ops.aten._flash_attention_forward'),
@@ -552,17 +545,15 @@
                 self.jvp_opinfo_test(outplace_variant, sample,
                                      sample.output_process_fn_grad,
                                      clone_inputs=False,
-                                     fixme_ref_jvp_local=fixme_ref_jvp_local,
-                                     test_noncontig=op.name not in skip_noncontig)
+                                     fixme_ref_jvp_local=fixme_ref_jvp_local)
             if is_valid_inplace_sample_input(sample, op, inplace_variant):
                 self.jvp_opinfo_test(inplace_variant, sample,
                                      sample.output_process_fn_grad,
                                      clone_inputs=True,
-                                     fixme_ref_jvp_local=fixme_ref_jvp_local,
-                                     test_noncontig=op.name not in skip_noncontig)
+                                     fixme_ref_jvp_local=fixme_ref_jvp_local)
 
     def jvp_opinfo_test(self, fn, sample, output_process_fn,
-                        clone_inputs, fixme_ref_jvp_local, test_noncontig):
+                        clone_inputs, fixme_ref_jvp_local):
         # NB: we used requires_grad=True to determine where the primals are,
         # but don't need that information otherwise
         args = (sample.input,) + sample.args
@@ -572,6 +563,15 @@
         orig_primals = tree_map(lambda x: x.detach(), primals)
         orig_tangents = tree_map(lambda x: torch.randn_like(x), primals)
 
+        noncontig_sample = sample.noncontiguous()
+        noncontig_args = (noncontig_sample.input,) + noncontig_sample.args
+        noncontig_kwargs = sample.kwargs
+        noncontig_fn, primals = normalize_op_input_output2(
+            fn, noncontig_args, noncontig_kwargs,
+            output_process_fn, requires_grad=True)
+        noncontig_primals = tree_map(lambda x: x.detach(), primals)
+        noncontig_tangents = tree_map(lambda x: noncontiguous_like(x), orig_tangents)
+
         def maybe_clone_inputs():
             if clone_inputs:
                 primals = tree_map(torch.clone, orig_primals)
@@ -586,24 +586,15 @@
         primals, tangents = maybe_clone_inputs()
         primal_outs, tangent_outs = jvp(contig_fn, primals, tangents)
 
+        noncontig_primal_outs, noncontig_tangent_outs = jvp(noncontig_fn,
+                                                            noncontig_primals,
+                                                            noncontig_tangents)
+
         self.assertEqual(primal_outs, expected_primal_outs)
         self.assertEqual(tangent_outs, expected_tangent_outs)
 
-        if test_noncontig:
-            noncontig_sample = sample.noncontiguous()
-            noncontig_args = (noncontig_sample.input,) + noncontig_sample.args
-            noncontig_kwargs = sample.kwargs
-            noncontig_fn, primals = normalize_op_input_output2(
-                fn, noncontig_args, noncontig_kwargs,
-                output_process_fn, requires_grad=True)
-            noncontig_primals = tree_map(lambda x: x.detach(), primals)
-            noncontig_tangents = tree_map(lambda x: noncontiguous_like(x), orig_tangents)
-            noncontig_primal_outs, noncontig_tangent_outs = jvp(noncontig_fn,
-                                                                noncontig_primals,
-                                                                noncontig_tangents)
-
-            self.assertEqual(noncontig_primal_outs, expected_primal_outs)
-            self.assertEqual(noncontig_tangent_outs, expected_tangent_outs)
+        self.assertEqual(noncontig_primal_outs, expected_primal_outs)
+        self.assertEqual(noncontig_tangent_outs, expected_tangent_outs)
 
     @with_tf32_off  # https://github.com/pytorch/pytorch/issues/86798
     @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
@@ -664,22 +655,22 @@
                 result = fn(*primals)
                 cotangents = tree_map(lambda x: torch.randn_like(x), result)
 
+                noncontig_fn, noncontig_primals = normalize_op_input_output(_op, sample.noncontiguous())
+                noncontig_cotangents = tree_map(lambda x: noncontiguous_like(x), cotangents)
+
                 out, vjp_fn = vjp(fn, *primals)
                 self.assertEqual(out, result)
                 result_vjps = vjp_fn(cotangents)
 
+                out_noncontig, vjp_fn = vjp(noncontig_fn, *noncontig_primals)
+                self.assertEqual(out_noncontig, result)
+                noncontig_result_vjps = vjp_fn(noncontig_cotangents)
+
                 _, vjp_fn = ref_vjp(fn, *primals)
                 expected_vjps = vjp_fn(cotangents)
 
                 self.assertEqual(result_vjps, expected_vjps)
-
-                if op.name not in skip_noncontig:
-                    noncontig_fn, noncontig_primals = normalize_op_input_output(_op, sample.noncontiguous())
-                    noncontig_cotangents = tree_map(lambda x: noncontiguous_like(x), cotangents)
-                    out_noncontig, vjp_fn = vjp(noncontig_fn, *noncontig_primals)
-                    self.assertEqual(out_noncontig, result)
-                    noncontig_result_vjps = vjp_fn(noncontig_cotangents)
-                    self.assertEqual(noncontig_result_vjps, expected_vjps)
+                self.assertEqual(noncontig_result_vjps, expected_vjps)
 
         _test(op)
         for a_op in op.aliases:
@@ -839,8 +830,6 @@
         xfail("to_sparse"),
         xfail("native_batch_norm"),
         xfail("_native_batch_norm_legit"),
-        # TODO: implement batching rule
-        xfail("_batch_norm_with_update"),
     }))
     @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
     @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
@@ -932,8 +921,6 @@
         skip('linalg.svdvals'),  # # really annoying thing where it passes correctness check but not has_batch_rule
         skip("native_batch_norm"),
         skip("_native_batch_norm_legit"),
-        # TODO: implement batching rule
-        skip("_batch_norm_with_update"),
         xfail('__getitem__', ''),  # dynamic error
         xfail('nanquantile', device_type='cpu'),  # checks q via a .item() call
         xfail('nn.functional.gaussian_nll_loss'),  # checks var for if any value < 0
@@ -1058,8 +1045,6 @@
         xfail('nn.functional.batch_norm', 'without_cudnn'),
         xfail("native_batch_norm"),
         xfail("_native_batch_norm_legit"),
-        # TODO: implement batching rule
-        xfail("_batch_norm_with_update"),
 
         # https://github.com/pytorch/pytorch/issues/96560
         # ROCm: NotImplementedError
@@ -1245,8 +1230,6 @@
         xfail('sparse.mm', 'reduce'),
         xfail("native_batch_norm"),
         xfail("_native_batch_norm_legit"),
-        # TODO: implement batching rule
-        xfail("_batch_norm_with_update"),
         xfail("native_dropout_backward"),
         xfail("index_fill"),  # aten::_unique hit the vmap fallback which is currently disabled
     }))
@@ -1323,8 +1306,6 @@
         xfail('sparse.mm', 'reduce'),
         xfail("native_batch_norm"),
         xfail("_native_batch_norm_legit"),
-        # TODO: implement batching rule
-        xfail("_batch_norm_with_update"),
         xfail('as_strided', 'partial_views'),
     }))
     def test_vjpvmap(self, device, dtype, op):
@@ -1583,8 +1564,6 @@
         # place, were not batched.
         xfail("native_batch_norm"),
         xfail("_native_batch_norm_legit"),
-        # TODO: implement batching rule
-        xfail("_batch_norm_with_update"),
         xfail('native_dropout_backward'),
     }))
     @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py
index 1b5723f..df46c95 100644
--- a/test/functorch/test_vmap.py
+++ b/test/functorch/test_vmap.py
@@ -3625,8 +3625,6 @@
         # which will be updated in place, were not batched.
         xfail('native_batch_norm'),
         xfail('_native_batch_norm_legit'),
-        # TODO: implement batching rule
-        xfail('_batch_norm_with_update'),
         xfail('tril'),  # Exception not raised on error input
         xfail('triu'),  # Exception not raised on error input
         xfail('as_strided', 'partial_views'),
@@ -3666,8 +3664,6 @@
         # which will be updated in place, were not batched.
         xfail('native_batch_norm'),
         xfail('_native_batch_norm_legit'),
-        # TODO: implement batching rule
-        xfail('_batch_norm_with_update'),
         xfail('histogram'),
         xfail('scatter_reduce', 'sum'),
         xfail('scatter_reduce', 'mean'),
diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py
index c51ec4b..8a35bc6 100644
--- a/test/inductor/test_torchinductor_opinfo.py
+++ b/test/inductor/test_torchinductor_opinfo.py
@@ -192,7 +192,6 @@
     "nn.functional.cosine_embedding_loss": {b8},
     "native_batch_norm": {f16, f32, f64},
     "_native_batch_norm_legit": {f16, f32, f64},
-    "_batch_norm_with_update": {f16, f32, f64},
 }
 
 if not SM80OrLater:
diff --git a/test/onnx/test_fx_op_consistency.py b/test/onnx/test_fx_op_consistency.py
index b6a4176..4d38e70 100644
--- a/test/onnx/test_fx_op_consistency.py
+++ b/test/onnx/test_fx_op_consistency.py
@@ -157,11 +157,6 @@
         dtypes=(torch.float16,),
         reason="fixme: Assertion error: result mismatch and type error",
     ),
-    skip(
-        "_batch_norm_with_update",
-        dtypes=(torch.float16,),
-        reason="fixme: Assertion error: result mismatch and type error",
-    ),
     xfail(
         "_softmax_backward_data",
         reason=onnx_test_common.reason_dynamo_does_not_support("assert all(isinstance(a, KNOWN_TYPES) for a in flat_args)")
@@ -1357,20 +1352,6 @@
         model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
         reason="https://github.com/pytorch/pytorch/issues/115106",
     ),
-    skip(
-        "_batch_norm_with_update",
-        model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
-        reason="https://github.com/pytorch/pytorch/issues/115106",
-    ),
-    # TODO: This test currently fails only for certain inputs, e.g. shape([3, 1]).
-    # Numerically the ONNX program is correct, but the output shapes for `save_mean`
-    # and `save_var` were tensor(-2.1268) instead of the correct tensor([-2.1268])
-    # for example.
-    skip(
-        "_batch_norm_with_update",
-        model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
-        reason="not supported yet",
-    ),
     xfail(
         "addmm",  # xfail can't only use dtypes to catch all cases
         matcher=lambda sample: sample.input.dtype
diff --git a/test/test_meta.py b/test/test_meta.py
index 6a092b3..65e17ce 100644
--- a/test/test_meta.py
+++ b/test/test_meta.py
@@ -708,11 +708,8 @@
 meta_function_device_skips = defaultdict(dict)
 
 meta_function_device_expected_failures['cpu'] = {
-    # TODO: The decomps for these batch norm ops return different dtypes depending
-    # on the device. We should make this work better with meta tensors.
     torch.native_batch_norm: {bf16, f16},
     torch._native_batch_norm_legit: {bf16, f16},
-    torch.ops.aten._batch_norm_with_update: {bf16, f16},
     torch.native_layer_norm: {bf16, f16},
 }
 
@@ -727,11 +724,8 @@
 }
 
 meta_function_device_skips['cpu'] = {
-    # TODO: The decomps for these batch norm ops return different dtypes depending
-    # on the device. We should make this work better with meta tensors.
     torch.native_batch_norm: {f32, f64},
     torch._native_batch_norm_legit: {f32, f64},
-    torch.ops.aten._batch_norm_with_update: {f32, f64},
 }
 
 meta_function_device_skips['cuda'] = {
@@ -856,13 +850,9 @@
 meta_dispatch_device_skips = defaultdict(dict)
 
 meta_dispatch_device_expected_failures['cpu'] = {
-    # TODO: The decomps for these batch norm ops return different dtypes depending
-    # on the device. We should make this work better with meta tensors.
     aten.native_batch_norm.default: {bf16, f16},
     aten._native_batch_norm_legit.default: {bf16, f16},
     aten._native_batch_norm_legit.no_stats: {bf16, f16},
-    aten._batch_norm_with_update.default: {bf16, f16},
-
     aten.native_layer_norm.default: {bf16, f16},
     aten.histc.default: {f16},
     aten.histc.out: {f16},
@@ -887,13 +877,9 @@
 
 meta_dispatch_device_skips['cpu'] = {
     aten._embedding_bag_forward_only.default: {bf16, f16, f32, f64},
-
-    # TODO: The decomps for these batch norm ops return different dtypes depending
-    # on the device. We should make this work better with meta tensors.
     aten.native_batch_norm.default: {f32, f64},
     aten._native_batch_norm_legit.default: {f32, f64},
     aten._native_batch_norm_legit.no_stats: {f32, f64},
-    aten._batch_norm_with_update.default: {f32, f64},
 
     # If the computation dtype is different from the input
     # dtype this will fail. CPU execution may also have a
diff --git a/test/test_mps.py b/test/test_mps.py
index af26cbb..b5f54a6 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -11384,7 +11384,6 @@
         'nn.functional.gelu',
         'nn.functional.glu',
         '_native_batch_norm_legit',
-        '_batch_norm_with_update',
         'native_batch_norm',
         'softmax',
         '_softmax_backward_data',
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 08947ba..3338344 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -1927,15 +1927,6 @@
 }
 
 out_symbolic_tensor_failures = {
-    # Cast error details: Unable to cast (...) to Tensor
-    #
-    # This happens because the test is set up to call the out variant using the `out` kwarg:
-    #   torch._some_op(arg1, arg2, out=(out1, out2, out3))
-    #
-    # However, this only works on torch ops, not aten ops. For `_batch_norm_with_update`,
-    # this fails because the op has no python bindings, so it doesn't support the `out` kwarg
-    # way of calling its out variant.
-    xfail('_batch_norm_with_update', ''),
     xfail('_native_batch_norm_legit', ''),
     xfail('angle', ''),
     xfail('argmax', ''),
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index f85af41..e692ae9 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -1250,20 +1250,6 @@
   self: grad.neg()
   result: auto_element_wise
 
-- name: _batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)
-  input, weight, bias: "grad.defined() ? batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, /*update*/true, eps, grad_input_mask, retain_variables ? result3.clone() : result3) : std::tuple<Tensor, Tensor, Tensor>()"
-  result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, true, eps)
-
-- name: _batch_norm_no_update(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)
-  input, weight, bias: "grad.defined() ? batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, /*update*/false, eps, grad_input_mask, retain_variables ? result3.clone() : result3) : std::tuple<Tensor, Tensor, Tensor>()"
-  result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, false, eps)
-
-- name: batch_norm_backward(Tensor grad_out, Tensor input, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, bool update, float eps, bool[3] output_mask, Tensor reserve) -> (Tensor, Tensor, Tensor)
-  input, weight, grad_out: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, running_mean, running_var, update, eps, save_mean, save_var, grad_input_mask)
-  save_mean: not_implemented("batch_norm_backward save_mean")
-  save_var: not_implemented("batch_norm_backward save_var")
-  reserve: not_implemented("batch_norm_backward reserve")
-
 - name: nextafter(Tensor self, Tensor other) -> Tensor
   self: not_implemented("nextafter")
   other: not_implemented("nextafter")
diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py
index 29ccf12..438754d 100644
--- a/tools/autograd/gen_python_functions.py
+++ b/tools/autograd/gen_python_functions.py
@@ -158,12 +158,9 @@
     "fill.Tensor",  # only used by the functionalization pass
     "fill.Scalar",  # only used by the functionalization pass
     "lift.*",
-    "normal_functional",  # only used by the functionalization pass
+    "normal_functional",  # only used by the functionalization pas
     "nbytes",
     "itemsize",
-    "_batch_norm_with_update",
-    "_batch_norm_with_update_out",
-    "_batch_norm_no_update",
 ]
 
 SKIP_PYTHON_BINDINGS = [
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index 0027881..516072b 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -1132,7 +1132,6 @@
 def _stash_obj_in_tls(key: str, arg: Any) -> None: ...
 def _get_obj_in_tls(key: str) -> Any: ...
 def _is_key_in_tls(key: str) -> _bool: ...
-def _select_batch_norm_backend(*args, **kwargs) -> BatchNormBackend: ...
 def _select_conv_backend(*args, **kwargs) -> ConvBackend: ...
 def _conv_determine_backend_memory_format(
     input: Tensor,
@@ -1198,8 +1197,6 @@
     Cusolver: _LinalgBackend
     Magma: _LinalgBackend
 
-class BatchNormBackend(Enum): ...
-
 class ConvBackend(Enum): ...
 
 class Tag(Enum):
diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py
index 5c66749..bcd6652 100644
--- a/torch/_decomp/decompositions.py
+++ b/torch/_decomp/decompositions.py
@@ -1843,114 +1843,6 @@
     return output, save_mean, save_rstd, new_running_mean, new_running_var
 
 
-def _get_batch_norm_reserve_tensor(
-    input: Tensor,
-    weight: Optional[Tensor],
-    bias: Optional[Tensor],
-    running_mean: Tensor,
-    running_var: Tensor,
-    eps: float,
-    training: bool,
-) -> Tensor:
-    """
-    Return a reserve tensor for batch norm, used only by cudnn to pass forward state to the
-    backward pass. This is needed for `_batch_norm_with_update` and `_batch_norm_no_update`,
-    which support a variety of backends including cudnn. We create this tensor here to get
-    the correct shape in the traced graph if we detect that will call the cudnn kernel,
-    and rely on DCE to avoid materializing this tensor.
-    """
-    backend = torch._C._select_batch_norm_backend(  # type: ignore[attr-defined]
-        input, weight, bias, running_mean, running_var, True, eps
-    )
-    reserve_size = 0
-    if backend == torch._C._BatchNormBackend.Cudnn:  # type: ignore[attr-defined]
-        reserve_size = torch._C._get_cudnn_batch_norm_reserve_space_size(input, training)  # type: ignore[attr-defined]
-    return torch.empty(
-        reserve_size, dtype=torch.uint8, layout=input.layout, device=input.device
-    )
-
-
-@register_decomposition(aten._batch_norm_with_update.default)
-def _batch_norm_with_update(
-    input: Tensor,
-    weight: Optional[Tensor],
-    bias: Optional[Tensor],
-    running_mean: Tensor,
-    running_var: Tensor,
-    momentum: float,
-    eps: float,
-) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
-    output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
-        input,
-        weight,
-        bias,
-        running_mean,
-        running_var,
-        True,  # training
-        momentum,
-        eps,
-        False,  # functional
-    )
-    reserve = _get_batch_norm_reserve_tensor(
-        input, weight, bias, running_mean, running_var, eps, training=True
-    )
-    return output, save_mean, save_rstd, reserve
-
-
-@register_decomposition(aten._batch_norm_with_update_functional.default)
-def _batch_norm_with_update_functional(
-    input: Tensor,
-    weight: Optional[Tensor],
-    bias: Optional[Tensor],
-    running_mean: Tensor,
-    running_var: Tensor,
-    momentum: float,
-    eps: float,
-) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
-    (
-        output,
-        save_mean,
-        save_rstd,
-        new_rm,
-        new_rv,
-    ) = native_batch_norm_helper(
-        input, weight, bias, running_mean, running_var, True, momentum, eps, True
-    )
-    reserve = _get_batch_norm_reserve_tensor(
-        input, weight, bias, running_mean, running_var, eps, training=True
-    )
-    assert new_rm is not None, "new_running_mean should not be None"
-    assert new_rv is not None, "new_running_var should not be None"
-    return (output, save_mean, save_rstd, reserve, new_rm, new_rv)
-
-
-@register_decomposition(aten._batch_norm_no_update.default)
-def _batch_norm_no_update(
-    input: Tensor,
-    weight: Optional[Tensor],
-    bias: Optional[Tensor],
-    running_mean: Tensor,
-    running_var: Tensor,
-    momentum: float,
-    eps: float,
-) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
-    output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
-        input,
-        weight,
-        bias,
-        running_mean,
-        running_var,
-        False,  # training
-        momentum,
-        eps,
-        False,  # functional
-    )
-    reserve = _get_batch_norm_reserve_tensor(
-        input, weight, bias, running_mean, running_var, eps, training=False
-    )
-    return output, save_mean, save_rstd, reserve
-
-
 @register_decomposition(aten._fused_dropout)
 @out_wrapper("out0", "out1")
 @pw_cast_for_opmath
@@ -2055,34 +1947,6 @@
     return x
 
 
-@register_decomposition(aten.batch_norm_backward.default)
-def batch_norm_backward(
-    grad_out: Tensor,
-    input: Tensor,
-    weight: Optional[Tensor],
-    running_mean: Optional[Tensor],
-    running_var: Optional[Tensor],
-    save_mean: Optional[Tensor],
-    save_invstd: Optional[Tensor],
-    train: bool,
-    eps: float,
-    output_mask: List[bool],
-    reserve: Tensor,
-) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
-    return native_batch_norm_backward(
-        grad_out,
-        input,
-        weight,
-        running_mean,
-        running_var,
-        save_mean,
-        save_invstd,
-        train,
-        eps,
-        output_mask,
-    )
-
-
 @register_decomposition(aten.native_batch_norm_backward.default)
 def native_batch_norm_backward(
     grad_out: Tensor,
diff --git a/torch/_decomp/decompositions_for_jvp.py b/torch/_decomp/decompositions_for_jvp.py
index 81946c3..19dfaed 100644
--- a/torch/_decomp/decompositions_for_jvp.py
+++ b/torch/_decomp/decompositions_for_jvp.py
@@ -291,34 +291,6 @@
     return (grad_input, grad_weight, grad_bias)
 
 
-@register_decomposition_for_jvp(aten.batch_norm_backward)
-def batch_norm_backward(
-    grad_out: Tensor,
-    input: Tensor,
-    weight: Tensor,
-    running_mean: Optional[Tensor],
-    running_var: Optional[Tensor],
-    save_mean: Optional[Tensor],
-    save_var: Optional[Tensor],
-    update: bool,
-    eps: float,
-    output_mask: List[bool],
-    reserve: Tensor,
-) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
-    return native_batch_norm_backward(
-        grad_out,
-        input,
-        weight,
-        running_mean,
-        running_var,
-        save_mean,
-        save_var,
-        update,
-        eps,
-        output_mask,
-    )
-
-
 _register_jit_decomposition_for_jvp(torch.ops.aten.trace.default, use_python=True)
 _register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss_backward.default)
 _register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss2d_backward.default)
@@ -328,4 +300,3 @@
 _register_jit_decomposition_for_jvp(torch.ops.aten.native_layer_norm_backward.default)
 _register_jit_decomposition_for_jvp(torch.ops.aten.native_batch_norm_backward.default)
 _register_jit_decomposition_for_jvp(torch.ops.aten.cudnn_batch_norm_backward.default)
-_register_jit_decomposition_for_jvp(torch.ops.aten.batch_norm_backward.default)
diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py
index c8c97f9..9d9ec2c 100644
--- a/torch/_dynamo/trace_rules.py
+++ b/torch/_dynamo/trace_rules.py
@@ -999,7 +999,6 @@
         "torch._C._scatter_out",
         "torch._C._scatter",
         "torch._C._select_conv_backend",
-        "torch._C._select_batch_norm_backend",
         "torch._C._set_autograd_fallback_mode",
         "torch._C._set_backcompat_broadcast_warn",
         "torch._C._set_backcompat_keepdim_warn",
diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py
index 18d0ed1..dd8f860 100644
--- a/torch/_functorch/partitioners.py
+++ b/torch/_functorch/partitioners.py
@@ -753,7 +753,7 @@
     recomputable_ops = set(recomputable_ops) if recomputable_ops is not None else set(default_recomputable_ops)
 
     random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like]
-    compute_intensive_ops = [aten.mm, aten.convolution, aten.convolution_backward, aten.bmm, aten.addmm, aten.upsample_bilinear2d, aten._softmax, aten._softmax_backward_data, aten.native_layer_norm, aten.native_layer_norm_backward, aten.native_batch_norm, aten.native_batch_norm_backward, aten._native_batch_norm_legit, aten._batch_norm_with_update, aten.batch_norm_backward]  # noqa: E501,B950
+    compute_intensive_ops = [aten.mm, aten.convolution, aten.convolution_backward, aten.bmm, aten.addmm, aten.upsample_bilinear2d, aten._softmax, aten._softmax_backward_data, aten.native_layer_norm, aten.native_layer_norm_backward, aten.native_batch_norm, aten.native_batch_norm_backward, aten._native_batch_norm_legit]  # noqa: E501,B950
 
     fusible_ops = recomputable_ops | set(random_ops)
     if AOT_PARTITIONER_DEBUG:
diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py
index 49d4003..80d3d1e 100644
--- a/torch/_inductor/decomposition.py
+++ b/torch/_inductor/decomposition.py
@@ -54,10 +54,6 @@
         aten._native_batch_norm_legit,
         aten._native_batch_norm_legit_functional,
         aten._native_batch_norm_legit_no_training,
-        aten._batch_norm_with_update,
-        aten._batch_norm_with_update_functional,
-        aten._batch_norm_no_update,
-        aten.batch_norm_backward,
         aten.native_batch_norm,
         aten.native_group_norm,
         aten.native_layer_norm,
diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp
index a0d6b6e..21af884 100644
--- a/torch/csrc/Module.cpp
+++ b/torch/csrc/Module.cpp
@@ -18,7 +18,6 @@
 #include <ATen/dlpack.h>
 #include <ATen/native/ConvUtils.h>
 #include <ATen/native/ForeachUtils.h>
-#include <ATen/native/Normalization.h>
 #include <c10/core/DispatchKeySet.h>
 #include <c10/util/AbortHandler.h>
 #include <c10/util/Backtrace.h>
@@ -92,10 +91,7 @@
 #include <ATen/native/transformers/sdp_utils_cpp.h>
 #include <torch/csrc/profiler/combined_traceback.h>
 #include <sstream>
-
 #ifdef USE_CUDA
-#include <ATen/cuda/CUDAConfig.h>
-#include <ATen/native/cudnn/BatchNorm.h>
 #include <ATen/native/transformers/cuda/sdp_utils.h>
 #endif
 
@@ -2099,44 +2095,6 @@
       },
       "Checks if a tensor's data pointer is COW");
 
-  py_module.def(
-      "_get_cudnn_batch_norm_reserve_space_size",
-      [](const at::Tensor& input, bool training) {
-#ifdef USE_CUDA
-        return at::native::_get_cudnn_batch_norm_reserve_space_size(
-            input, training);
-#else
-        TORCH_CHECK(false, "PyTorch was not built with cuda");
-#endif
-      },
-      py::arg("input"),
-      py::arg("training"));
-
-  py::enum_<at::native::BatchNormBackend>(py_module, "_BatchNormBackend")
-      .value("Native", at::native::BatchNormBackend::Native)
-      .value("Cudnn", at::native::BatchNormBackend::Cudnn)
-      .value("Miopen", at::native::BatchNormBackend::Miopen);
-
-  py_module.def(
-      "_select_batch_norm_backend",
-      [](const at::Tensor& input,
-         const at::Tensor& weight,
-         const at::Tensor& bias,
-         const at::Tensor& running_mean,
-         const at::Tensor& running_var,
-         bool training,
-         double eps) {
-        return at::native::_select_batch_norm_backend(
-            input, weight, bias, running_mean, running_var, training, eps);
-      },
-      py::arg("input"),
-      py::arg("weight"),
-      py::arg("bias"),
-      py::arg("running_mean"),
-      py::arg("running_var"),
-      py::arg("training"),
-      py::arg("eps"));
-
   const auto& defaultGenerator = at::detail::getDefaultCPUGenerator();
   THPDefaultCPUGenerator =
       (THPGenerator*)THPGenerator_initDefaultGenerator(defaultGenerator);
diff --git a/torch/csrc/jit/runtime/serialized_shape_function_registry.cpp b/torch/csrc/jit/runtime/serialized_shape_function_registry.cpp
index 231aba3..943d43f 100644
--- a/torch/csrc/jit/runtime/serialized_shape_function_registry.cpp
+++ b/torch/csrc/jit/runtime/serialized_shape_function_registry.cpp
@@ -3312,7 +3312,6 @@
     {"aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "native_batch_norm"},
     {"aten::_native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "native_batch_norm"},
     {"aten::_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "native_batch_norm"},
-    {"aten::_batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)", "native_batch_norm"},
     {"aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor", "cross_entropy_loss"},
     {"aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor", "broadcast_three"},
     {"aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor", "broadcast_one_three"},
diff --git a/torch/jit/_shape_functions.py b/torch/jit/_shape_functions.py
index 9c1da98..5151503 100644
--- a/torch/jit/_shape_functions.py
+++ b/torch/jit/_shape_functions.py
@@ -1431,11 +1431,6 @@
     native_batch_norm,
 )
 add_shape_compute_mapping(
-    "_batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)",
-    native_batch_norm,
-)
-
-add_shape_compute_mapping(
     "aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor",
     cross_entropy_loss,
 )
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 09ff394..ca3f461 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -509,19 +509,6 @@
         else:
             yield SampleInput(sample.input, args=(args[2], args[3], training, momentum, eps))
 
-def sample_inputs__batch_norm_with_update(op_info, device, dtype, requires_grad, **kwargs):
-    samples = sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs)
-    for sample in samples:
-        # torch.native_batch_norm does not support 0 numel tensors
-        # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
-        if sample.input.numel() == 0:
-            continue
-        args = sample.args
-        momentum = sample.kwargs.get('momentum', 0.5)
-        eps = sample.kwargs.get('eps', 1e-5)
-        if any(args[i] is None for i in range(4)):
-            continue
-        yield SampleInput(sample.input, args=(args[2], args[3], args[0], args[1], momentum, eps))
 
 def sample_inputs_nn_activation_relu(op_info, device, dtype, requires_grad, **kwargs):
     make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
@@ -13002,42 +12989,6 @@
                             "TestCompositeCompliance", "test_forward_ad"),
            )
            ),
-    OpInfo('_batch_norm_with_update',
-           op=torch.ops.aten._batch_norm_with_update,
-           aten_name='_batch_norm_with_update',
-           dtypes=floating_types_and(torch.float16, torch.bfloat16),
-           supports_forward_ad=True,
-           supports_fwgrad_bwgrad=True,
-           assert_jit_shape_analysis=True,
-           # TODO: Avoid COW materialize
-           supports_cow_input_no_materialize=False,
-           sample_inputs_func=sample_inputs__batch_norm_with_update,
-           skips=(
-               # NotImplementedError: Could not run
-               # 'aten::native_batch_norm.out' with arguments from the 'CPU' backend.
-               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type="cpu"),
-               # RuntimeError: out_invstd.dim() == 1 && out_invstd.is_contiguous() && out_invstd.sizes()[0]
-               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type="cuda"),
-               # Problem with _get_numerical_jacobian
-               # IndexError: tuple index out of range
-               DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
-               # RuntimeError: deepEquals(input.iValue, deepCopiedInput) INTERNAL ASSERT FAILED
-               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
-               # https://github.com/pytorch/pytorch/issues/85960
-               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'),
-               DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-5, rtol=5e-5)}),
-                            "TestCompositeCompliance", "test_forward_ad"),
-               # _batch_norm_with_update expects contiguous inputs for cudnn and miopen
-               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type="cuda"),
-               DecorateInfo(unittest.expectedFailure,
-                            'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides', device_type="cuda"),
-               # _batch_norm_with_update does not have python bindings
-               DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
-               # aten out variants do not accept out= kwarg, only python out variants
-               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
-               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
-           )
-           ),
     OpInfo('nn.functional.cosine_similarity',
            aten_name="cosine_similarity",
            dtypes=floating_types_and(torch.half, torch.bfloat16),