Revert "Batch Norm Consolidation (#116092)"
This reverts commit 5680f565d5b7d4aa412a3988d3d91ca4c5679303.
Reverted https://github.com/pytorch/pytorch/pull/116092 on behalf of https://github.com/jeffdaily due to broke ROCm, PR signal was clean but trunk was not, the merge should have been blocked but wasn't ([comment](https://github.com/pytorch/pytorch/pull/116092#issuecomment-1981373237))
diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp
index 64de958..2b00567 100644
--- a/aten/src/ATen/native/Normalization.cpp
+++ b/aten/src/ATen/native/Normalization.cpp
@@ -29,11 +29,6 @@
#include <ATen/ops/_native_batch_norm_legit_native.h>
#include <ATen/ops/_native_batch_norm_legit_no_training.h>
#include <ATen/ops/_native_batch_norm_legit_no_training_native.h>
-#include <ATen/ops/_batch_norm_with_update.h>
-#include <ATen/ops/_batch_norm_with_update_native.h>
-#include <ATen/ops/_batch_norm_no_update.h>
-#include <ATen/ops/_batch_norm_no_update_native.h>
-#include <ATen/ops/batch_norm_backward_native.h>
#include <ATen/ops/alias.h>
#include <ATen/ops/batch_norm.h>
#include <ATen/ops/batch_norm_native.h>
@@ -484,58 +479,10 @@
return std::make_tuple(grad_input, grad_weight, grad_bias);
}
-BatchNormBackend _select_batch_norm_backend(
- const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean,
- const Tensor& running_var, bool training, double eps) {
-
- auto& ctx = at::globalContext();
- bool cudnn_enabled = ctx.userEnabledCuDNN();
-
- if (
- input.is_cuda()
- && input.scalar_type() != at::kBFloat16 && weight.scalar_type() != at::kBFloat16
- && (input.scalar_type() != at::kHalf
- || weight.scalar_type() == at::kFloat)
- && weight.defined() && bias.defined()
- && ((running_mean.defined() && running_var.defined())
- || (!running_mean.defined() && !running_var.defined() && training))
- && (input.dim() >= 3)
- && ((input.sym_size(0) <= 880801 && training) // spatial, training
- ||(input.sym_size(0) <= 65535 && !training)) //spatial, eval
- && detail::getCUDAHooks().compiledWithCuDNN()
- && eps >= detail::getCUDAHooks().batchnormMinEpsilonCuDNN()
- && cudnn_enabled && detail::getCUDAHooks().versionCuDNN() >= 5110L
- && input.sym_numel() < std::numeric_limits<std::int32_t>::max() // some cuDNN kernels have 32-bit indexing limitations
- ) {
- return BatchNormBackend::Cudnn;
- }
-
- if (
- input.is_cuda()
- && input.dim() <= MIOPEN_DIM_MAX
- && input.scalar_type() != at::kDouble
- && input.scalar_type() != at::kBFloat16
- && (weight.scalar_type() != at::kHalf)
- && weight.defined() && bias.defined()
- && ((running_mean.defined() && running_var.defined())
- || (!running_mean.defined() && !running_var.defined() && training))
- && detail::getCUDAHooks().compiledWithMIOpen()
- && cudnn_enabled
- && input.suggest_memory_format() != MemoryFormat::ChannelsLast
- && input.suggest_memory_format() != MemoryFormat::ChannelsLast3d
- ) {
- return BatchNormBackend::Miopen;
- }
-
- return BatchNormBackend::Native;
-}
-
-
// _batch_norm_impl_index(_backward) are used in the JIT be able to keep the run-time selection
// of backends, while enabling it to keep the information about the used backend, so that it can
// use its corresponding backward implementation.
// XXX: The indices of backends need to be kept synchronized between this function and its _backward.
-// TODO: remove cudnn_enabled arg
std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
const Tensor& input, const c10::optional<Tensor>& weight_opt /* optional */, const c10::optional<Tensor>& bias_opt /* optional */, const c10::optional<Tensor>& running_mean_opt /* optional */, const c10::optional<Tensor>& running_var_opt /* optional */,
bool training, double momentum, double eps, bool cudnn_enabled) {
@@ -580,9 +527,24 @@
check_dims_match_num_input_features("bias", std::move(num_features), bias.sym_numel());
}
- BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, training, eps);
+ const bool use_cudnn = (
+ input.is_cuda()
+ && input.scalar_type() != at::kBFloat16 && weight.scalar_type() != at::kBFloat16
+ && (input.scalar_type() != at::kHalf
+ || weight.scalar_type() == at::kFloat)
+ && weight.defined() && bias.defined()
+ && ((running_mean.defined() && running_var.defined())
+ || (!running_mean.defined() && !running_var.defined() && training))
+ && (input.dim() >= 3)
+ && ((input.sym_size(0) <= 880801 && training) // spatial, training
+ ||(input.sym_size(0) <= 65535 && !training)) //spatial, eval
+ && detail::getCUDAHooks().compiledWithCuDNN()
+ && eps >= detail::getCUDAHooks().batchnormMinEpsilonCuDNN()
+ && cudnn_enabled && detail::getCUDAHooks().versionCuDNN() >= 5110L
+ && input.sym_numel() < std::numeric_limits<std::int32_t>::max() // some cuDNN kernels have 32-bit indexing limitations
+ );
- if (backend == BatchNormBackend::Cudnn) {
+ if (use_cudnn) {
auto input_c = input.contiguous(input.suggest_memory_format());
auto weight_c = weight.contiguous();
auto bias_c = bias.contiguous();
@@ -599,7 +561,19 @@
Tensor reserve = at::empty({0}, input.options().dtype(kByte));
- if (backend == BatchNormBackend::Miopen) {
+ bool use_miopen = (input.is_cuda()
+ && input.dim() <= MIOPEN_DIM_MAX
+ && input.scalar_type() != at::kDouble
+ && input.scalar_type() != at::kBFloat16
+ && (weight.scalar_type() != at::kHalf)
+ && weight.defined() && bias.defined()
+ && ((running_mean.defined() && running_var.defined())
+ || (!running_mean.defined() && !running_var.defined() && training))
+ && detail::getCUDAHooks().compiledWithMIOpen()
+ && cudnn_enabled
+ );
+
+ if (use_miopen && input.suggest_memory_format() != MemoryFormat::ChannelsLast && input.suggest_memory_format() != MemoryFormat::ChannelsLast3d) {
return std::tuple_cat(
at::miopen_batch_norm(
input.contiguous(), weight.contiguous(), bias.contiguous(),
@@ -663,7 +637,6 @@
TORCH_INTERNAL_ASSERT(false, "Unsupported impl_index in _batch_norm_impl_index_backward: ", impl_index);
}
-// TODO: remove cudnn_enabled arg
Tensor batch_norm(
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
@@ -674,30 +647,6 @@
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
return std::get<0>(at::_batch_norm_impl_index(input, weight, bias, running_mean, running_var,
training, momentum, eps, cudnn_enabled));
- // TODO: switch to the new stack after the 2 week FC window
- // if (training) {
- // BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, training, eps);
- // if (backend == BatchNormBackend::Cudnn || backend == BatchNormBackend::Miopen) {
- // auto input_c = input;
- // if (backend == BatchNormBackend::Cudnn) {
- // input_c = input.contiguous(input.suggest_memory_format());
- // } else {
- // input_c = input.contiguous();
- // }
- // auto weight_c = weight.contiguous();
- // auto bias_c = bias.contiguous();
- // auto rmean_c = running_mean.defined() ? running_mean.contiguous() : running_mean;
- // auto rvar_c = running_var.defined() ? running_var.contiguous() : running_var;
- // return std::get<0>(at::_batch_norm_with_update(input_c, weight_c, bias_c, const_cast<Tensor&>(rmean_c),
- // const_cast<Tensor&>(rvar_c), momentum, eps));
- // } else {
- // return std::get<0>(at::_batch_norm_with_update(input, weight, bias, const_cast<Tensor&>(running_mean),
- // const_cast<Tensor&>(running_var), momentum, eps));
- // }
- // } else {
- // return std::get<0>(at::_batch_norm_no_update(input, weight, bias, running_mean, running_var,
- // momentum, eps));
- // }
}
Tensor instance_norm(
@@ -849,38 +798,6 @@
return batch_norm_cpu_out(self, weight_opt, bias_opt, running_mean_opt, running_var_opt, train, momentum, eps, output, save_mean, save_var);
}
-std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_cpu(
- const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
- Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
- Tensor output, save_mean, save_var;
- std::tie(output, save_mean, save_var) =
- batch_norm_cpu(input, weight_opt, bias_opt, running_mean, running_var, /*update*/true, momentum, eps);
- Tensor reserve = at::empty({0}, input.options().dtype(kByte));
- return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
-}
-
-std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> _batch_norm_with_update_cpu_out(
- const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
- Tensor& running_mean, Tensor& running_var, double momentum, double eps,
- Tensor& out, Tensor& save_mean, Tensor& save_var, Tensor& reserve) {
- std::tie(out, save_mean, save_var) =
- batch_norm_cpu_out(input, weight_opt, bias_opt, running_mean, running_var, /*update*/true, momentum, eps, out, save_mean, save_var);
- return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>(out, save_mean, save_var, reserve);
-}
-
-
-std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_no_update(
- const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
- const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
- double momentum, double eps) {
- const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
- const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
- Tensor output, save_mean, save_var;
- std::tie(output, save_mean, save_var) =
- batch_norm_cpu(input, weight_opt, bias_opt, const_cast<Tensor&>(running_mean), const_cast<Tensor&>(running_var), /*update*/false, momentum, eps);
- Tensor reserve = at::empty({0}, input.options().dtype(kByte));
- return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
-}
std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_cpu(
const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
@@ -909,13 +826,6 @@
return batch_norm_cpu_out(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, eps, out, save_mean, save_var);
}
-std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_cpu(
- const Tensor& grad_output, const Tensor& input, const Tensor& weight,
- const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
- const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_var_opt,
- bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
- return batch_norm_backward_cpu(grad_output, input, weight, running_mean_opt, running_var_opt, save_mean_opt, save_var_opt, update, eps, grad_input_mask);
-}
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu(const Tensor& grad_out, const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_invstd_opt,
bool train, double eps, std::array<bool,3> grad_input_mask) {
diff --git a/aten/src/ATen/native/Normalization.h b/aten/src/ATen/native/Normalization.h
index 1ba99e7..6cd4dcd 100644
--- a/aten/src/ATen/native/Normalization.h
+++ b/aten/src/ATen/native/Normalization.h
@@ -8,12 +8,4 @@
using renorm_scale_factor_fn = void (*) (TensorIteratorBase& iter, double maxnorm);
DECLARE_DISPATCH(renorm_scale_factor_fn, renorm_scale_factor_stub);
-enum class BatchNormBackend {
- Native,
- Cudnn,
- Miopen,
-};
-
-TORCH_API BatchNormBackend _select_batch_norm_backend(const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean, const Tensor& running_var, bool training, double eps);
-
} // namespace at::native
diff --git a/aten/src/ATen/native/cuda/Normalization.cu b/aten/src/ATen/native/cuda/Normalization.cu
index 1b4c159..655d32b 100644
--- a/aten/src/ATen/native/cuda/Normalization.cu
+++ b/aten/src/ATen/native/cuda/Normalization.cu
@@ -1,7 +1,5 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/cuda/detail/IndexUtils.cuh>
-#include <ATen/detail/CUDAHooksInterface.h>
-#include <ATen/native/Normalization.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/ReduceOps.h>
#include <ATen/native/Resize.h>
@@ -14,8 +12,6 @@
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
-#include <ATen/ops/_batch_norm_with_update_native.h>
-#include <ATen/ops/batch_norm_backward_native.h>
#include <ATen/ops/batch_norm_backward_elemt_native.h>
#include <ATen/ops/batch_norm_backward_reduce_native.h>
#include <ATen/ops/batch_norm_elemt_native.h>
@@ -23,12 +19,8 @@
#include <ATen/ops/batch_norm_gather_stats_with_counts_native.h>
#include <ATen/ops/batch_norm_stats_native.h>
#include <ATen/ops/batch_norm_update_stats_native.h>
-#include <ATen/ops/cudnn_batch_norm.h>
-#include <ATen/ops/cudnn_batch_norm_backward.h>
#include <ATen/ops/empty_like.h>
#include <ATen/ops/from_blob.h>
-#include <ATen/ops/miopen_batch_norm.h>
-#include <ATen/ops/miopen_batch_norm_backward.h>
#include <ATen/ops/native_batch_norm_backward_native.h>
#include <ATen/ops/native_batch_norm_native.h>
#include <ATen/ops/scalar_tensor.h>
@@ -481,54 +473,6 @@
return std::make_tuple(output, save_mean, save_invstd);
}
-std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_cuda(
- const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
- Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
- // See [Note: hacky wrapper removal for optional tensor]
- c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
- const Tensor& weight = *weight_maybe_owned;
- const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
- Tensor output, save_mean, save_var, reserve;
-
- BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, /*training*/true, eps);
- if (backend == BatchNormBackend::Cudnn) {
- std::tie(output, save_mean, save_var, reserve) =
- at::cudnn_batch_norm(input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
- } else if (backend == BatchNormBackend::Miopen) {
- reserve = at::empty({0}, input.options().dtype(kByte));
- std::tie(output, save_mean, save_var) =
- at::miopen_batch_norm(input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
- } else {
- reserve = at::empty({0}, input.options().dtype(kByte));
- std::tie(output, save_mean, save_var) =
- batch_norm_cuda(input, weight_opt, bias_opt, running_mean, running_var, /*training*/true, momentum, eps);
- }
- return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
-}
-
-std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> _batch_norm_with_update_cuda_out(
- const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
- Tensor& running_mean, Tensor& running_var, double momentum, double eps,
- Tensor& out, Tensor& save_mean, Tensor& save_var, Tensor& reserve) {
- // See [Note: hacky wrapper removal for optional tensor]
- c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
- const Tensor& weight = *weight_maybe_owned;
- const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
-
- BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, /*training*/true, eps);
- if (backend == BatchNormBackend::Cudnn) {
- std::tie(out, save_mean, save_var, reserve) =
- at::cudnn_batch_norm_out(out, save_mean, save_var, reserve, input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
- } else if (backend == BatchNormBackend::Miopen) {
- std::tie(out, save_mean, save_var) =
- at::miopen_batch_norm_out(out, save_mean, save_var, input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
- } else {
- std::tie(out, save_mean, save_var) =
- batch_norm_cuda_out(input, weight_opt, bias_opt, running_mean, running_var, /*update*/true, momentum, eps, out, save_mean, save_var);
- }
- return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>(out, save_mean, save_var, reserve);
-}
-
std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_cuda(const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt, Tensor& running_mean, Tensor& running_var, bool train, double momentum, double epsilon) {
return batch_norm_cuda(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, epsilon);
}
@@ -545,28 +489,6 @@
return batch_norm_cuda_out(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, epsilon, output, save_mean, save_invstd);
}
-std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_cuda(
- const Tensor& grad_output, const Tensor& input, const Tensor& weight,
- const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
- const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_var_opt,
- bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
- const Tensor& dummy_bias = at::empty(1);
- const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
- const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
- const Tensor& save_mean = c10::value_or_else(save_mean_opt, [] {return Tensor();});
- const Tensor& save_var = c10::value_or_else(save_var_opt, [] {return Tensor();});
-
- BatchNormBackend backend = _select_batch_norm_backend(input, weight, dummy_bias, running_mean, running_var, /*training*/true, eps);
-
- if (backend == BatchNormBackend::Cudnn) {
- return at::cudnn_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var, eps, reserve);
- } else if (backend == BatchNormBackend::Miopen) {
- return at::miopen_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var, eps);
- } else {
- return batch_norm_backward_cuda(grad_output, input, weight, running_mean, running_var, save_mean, save_var, update, eps, grad_input_mask);
- }
-}
-
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda(const Tensor& grad_out, const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_invstd_opt, bool train, double epsilon, std::array<bool,3> grad_input_mask) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight = at::borrow_from_optional_tensor(weight_opt);
diff --git a/aten/src/ATen/native/cudnn/BatchNorm.cpp b/aten/src/ATen/native/cudnn/BatchNorm.cpp
index 2553139..2efe7a7 100644
--- a/aten/src/ATen/native/cudnn/BatchNorm.cpp
+++ b/aten/src/ATen/native/cudnn/BatchNorm.cpp
@@ -2,7 +2,6 @@
#include <ATen/Config.h>
#include <ATen/core/Tensor.h>
#include <ATen/cuda/CUDAConfig.h>
-#include <ATen/native/cudnn/BatchNorm.h>
#if !AT_CUDNN_ENABLED()
@@ -36,24 +35,18 @@
AT_ERROR("cudnn_batch_norm_backward: ATen not compiled with cuDNN support");
}
-size_t _get_cudnn_batch_norm_reserve_space_size(
- const Tensor& input_t,
- bool training) {
- AT_ERROR(
- "_get_cudnn_batch_norm_reserve_space_size: ATen not compiled with cuDNN support");
-}
-
} // namespace native
} // namespace at
#else // AT_CUDNN_ENABLED
-#include <ATen/TensorUtils.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/cudnn/Descriptors.h>
#include <ATen/cudnn/Types.h>
#include <ATen/cudnn/Utils.h>
+#include <ATen/TensorUtils.h>
+
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
@@ -105,21 +98,6 @@
} // namespace
-size_t _get_cudnn_batch_norm_reserve_space_size(
- const Tensor& input_t,
- bool training) {
- size_t reserve_size;
- TensorArg input{input_t, "input", 1};
- TensorDescriptor idesc{*input, 4};
- auto handle = getCudnnHandle();
- cudnnBatchNormMode_t mode = getCudnnBatchNormMode(
- training, input->suggest_memory_format(), input->dim());
- auto op = CUDNN_BATCHNORM_OPS_BN;
- AT_CUDNN_CHECK(cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
- handle, mode, op, nullptr, idesc.desc(), &reserve_size));
- return reserve_size;
-}
-
std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
const Tensor& input_t,
const Tensor& weight_t,
@@ -208,8 +186,9 @@
Tensor workspace = at::empty(workspace_size, input->options().dtype(kByte));
// get the reserved size and allocate as tensor
- size_t reserve_size =
- _get_cudnn_batch_norm_reserve_space_size(input_t, true /* training */);
+ size_t reserve_size;
+ AT_CUDNN_CHECK(cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
+ handle, mode, op, nullptr, idesc.desc(), &reserve_size));
reserve = at::empty(reserve_size, input->options().dtype(kByte));
AT_CUDNN_CHECK(cudnnBatchNormalizationForwardTrainingEx(
diff --git a/aten/src/ATen/native/cudnn/BatchNorm.h b/aten/src/ATen/native/cudnn/BatchNorm.h
deleted file mode 100644
index 3da76c0..0000000
--- a/aten/src/ATen/native/cudnn/BatchNorm.h
+++ /dev/null
@@ -1,6 +0,0 @@
-namespace at::native {
-
-TORCH_API size_t
-_get_cudnn_batch_norm_reserve_space_size(const Tensor& input_t, bool training);
-
-} // namespace at::native
diff --git a/aten/src/ATen/native/mkldnn/Normalization.cpp b/aten/src/ATen/native/mkldnn/Normalization.cpp
index 0aced61..108ce35 100644
--- a/aten/src/ATen/native/mkldnn/Normalization.cpp
+++ b/aten/src/ATen/native/mkldnn/Normalization.cpp
@@ -6,8 +6,6 @@
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/NativeFunctions.h>
#else
-#include <ATen/ops/_batch_norm_with_update_native.h>
-#include <ATen/ops/batch_norm_backward_native.h>
#include <ATen/ops/_native_batch_norm_legit_native.h>
#include <ATen/ops/_to_dense_native.h>
#include <ATen/ops/empty_native.h>
@@ -61,20 +59,6 @@
TORCH_CHECK(false, "_mkldnn_batch_norm_legit_no_stats: ATen not compiled with MKLDNN support");
}
-std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_mkldnn(
- const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
- Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
- TORCH_CHECK(false, "_batch_norm_with_update_mkldnn: ATen not compiled with MKLDNN support");
-}
-
-std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_mkldnn(
- const Tensor& grad_output, const Tensor& input, const Tensor& weight,
- const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
- const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_var_opt,
- bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
- TORCH_CHECK(false, "_new_batch_norm_backward_mkldnn: ATen not compiled with MKLDNN support");
-}
-
} // namespace native
} // namespace at
@@ -208,17 +192,6 @@
}
-std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_mkldnn(
- const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
- Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
- Tensor output, save_mean, save_var;
- std::tie(output, save_mean, save_var) =
- mkldnn_batch_norm(input, weight_opt, bias_opt, running_mean, running_var, /*train*/true, momentum, eps);
- Tensor reserve = empty_mkldnn({0}, input.scalar_type());
- return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
-}
-
-
std::tuple<Tensor, Tensor, Tensor> _mkldnn_batch_norm_legit(
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt, Tensor& running_mean, Tensor& running_var,
bool train,
@@ -237,15 +210,6 @@
}
-std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_mkldnn(
- const Tensor& grad_output, const Tensor& input, const Tensor& weight,
- const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
- const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_var_opt,
- bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
- return mkldnn_batch_norm_backward(grad_output, input, weight, running_mean_opt, running_var_opt, save_mean_opt, save_var_opt, update, eps, grad_input_mask);
-}
-
-
std::tuple<Tensor, Tensor, Tensor> mkldnn_batch_norm_backward(const Tensor& grad_output,
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_invstd_opt,
bool train,
diff --git a/aten/src/ATen/native/mps/operations/Normalization.mm b/aten/src/ATen/native/mps/operations/Normalization.mm
index bdca3b0..eb754ae 100644
--- a/aten/src/ATen/native/mps/operations/Normalization.mm
+++ b/aten/src/ATen/native/mps/operations/Normalization.mm
@@ -10,9 +10,7 @@
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
-#include <ATen/ops/_batch_norm_with_update_native.h>
#include <ATen/ops/_native_batch_norm_legit_native.h>
-#include <ATen/ops/batch_norm_backward_native.h>
#include <ATen/ops/native_batch_norm.h>
#include <ATen/ops/native_batch_norm_backward_native.h>
#include <ATen/ops/native_batch_norm_native.h>
@@ -408,36 +406,6 @@
return std::make_tuple(output, save_mean, save_var);
}
-std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_mps(const Tensor& input,
- const c10::optional<Tensor>& weight_opt,
- const c10::optional<Tensor>& bias_opt,
- Tensor& running_mean,
- Tensor& running_var,
- double momentum,
- double eps) {
- Tensor output, save_mean, save_var;
- std::tie(output, save_mean, save_var) =
- batch_norm_mps(input, weight_opt, bias_opt, running_mean, running_var, /*train*/ true, momentum, eps);
- Tensor reserve = at::empty({0}, input.options().dtype(kByte));
- return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
-}
-
-std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> _batch_norm_with_update_mps_out(const Tensor& input,
- const c10::optional<Tensor>& weight_opt,
- const c10::optional<Tensor>& bias_opt,
- Tensor& running_mean,
- Tensor& running_var,
- double momentum,
- double eps,
- Tensor& out,
- Tensor& save_mean,
- Tensor& save_var,
- Tensor& reserve) {
- std::tie(out, save_mean, save_var) = batch_norm_mps_out(
- input, weight_opt, bias_opt, running_mean, running_var, /*update*/ true, momentum, eps, out, save_mean, save_var);
- return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>(out, save_mean, save_var, reserve);
-}
-
std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_mps(const Tensor& self,
const c10::optional<Tensor>& weight_opt,
const c10::optional<Tensor>& bias_opt,
@@ -503,29 +471,6 @@
}
// Batch norm backward
-std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_mps(const Tensor& grad_output,
- const Tensor& input,
- const Tensor& weight,
- const c10::optional<Tensor>& running_mean_opt,
- const c10::optional<Tensor>& running_var_opt,
- const c10::optional<Tensor>& save_mean_opt,
- const c10::optional<Tensor>& save_var_opt,
- bool update,
- double eps,
- std::array<bool, 3> grad_input_mask,
- const Tensor& reserve) {
- return batch_norm_backward_mps(grad_output,
- input,
- weight,
- running_mean_opt,
- running_var_opt,
- save_mean_opt,
- save_var_opt,
- update,
- eps,
- grad_input_mask);
-}
-
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_mps(const Tensor& grad_out,
const Tensor& input,
const c10::optional<Tensor>& weight_opt,
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 809e9e9..2becda9 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -6486,32 +6486,6 @@
SparseCPU, SparseCUDA: norm_sparse
autogen: native_norm.ScalarOpt_dim_dtype_out
-- func: _batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)
- dispatch:
- CPU: _batch_norm_with_update_cpu
- CUDA: _batch_norm_with_update_cuda
- MPS: _batch_norm_with_update_mps
- MkldnnCPU: _batch_norm_with_update_mkldnn
- autogen: _batch_norm_with_update_functional
-
-- func: _batch_norm_with_update.out(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps, *, Tensor(d!) out, Tensor(e!) save_mean, Tensor(f!) save_invstd, Tensor(g!) reserve) -> (Tensor(d!), Tensor(e!), Tensor(f!), Tensor(g!))
- dispatch:
- CPU: _batch_norm_with_update_cpu_out
- CUDA: _batch_norm_with_update_cuda_out
- MPS: _batch_norm_with_update_mps_out
-
-- func: _batch_norm_no_update(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)
- dispatch:
- CompositeExplicitAutograd: _batch_norm_no_update
- autogen: _batch_norm_no_update.out
-
-- func: batch_norm_backward(Tensor grad_out, Tensor input, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, bool update, float eps, bool[3] output_mask, Tensor reserve) -> (Tensor, Tensor, Tensor)
- dispatch:
- CPU: _new_batch_norm_backward_cpu
- CUDA: _new_batch_norm_backward_cuda
- MPS: _new_batch_norm_backward_mps
- MkldnnCPU: _new_batch_norm_backward_mkldnn
-
# TODO: reduce signatures down to one when optional args is available
- func: _sparse_sum(Tensor self) -> Tensor
diff --git a/test/distributed/_tensor/test_dtensor_ops.py b/test/distributed/_tensor/test_dtensor_ops.py
index 66b78f0..3e88ba4 100644
--- a/test/distributed/_tensor/test_dtensor_ops.py
+++ b/test/distributed/_tensor/test_dtensor_ops.py
@@ -113,7 +113,6 @@
xfail("as_strided", "partial_views"),
xfail("as_strided_scatter"),
xfail("bernoulli"),
- xfail("_batch_norm_with_update"),
xfail("block_diag"),
xfail("broadcast_shapes"),
xfail("cauchy"),
diff --git a/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_exhaustive_batch_norm_with_update_cpu_float32 b/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_exhaustive_batch_norm_with_update_cpu_float32
deleted file mode 100644
index e69de29..0000000
--- a/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_exhaustive_batch_norm_with_update_cpu_float32
+++ /dev/null
diff --git a/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_fake_exhaustive_batch_norm_with_update_cpu_float32 b/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_fake_exhaustive_batch_norm_with_update_cpu_float32
deleted file mode 100644
index e69de29..0000000
--- a/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_fake_exhaustive_batch_norm_with_update_cpu_float32
+++ /dev/null
diff --git a/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_batch_norm_with_update_cpu_float32 b/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_batch_norm_with_update_cpu_float32
deleted file mode 100644
index e69de29..0000000
--- a/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_batch_norm_with_update_cpu_float32
+++ /dev/null
diff --git a/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_inplace_batch_norm_with_update_cpu_float32 b/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_inplace_batch_norm_with_update_cpu_float32
deleted file mode 100644
index e69de29..0000000
--- a/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_inplace_batch_norm_with_update_cpu_float32
+++ /dev/null
diff --git a/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_out_batch_norm_with_update_cpu_float32 b/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_out_batch_norm_with_update_cpu_float32
deleted file mode 100644
index e69de29..0000000
--- a/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_out_batch_norm_with_update_cpu_float32
+++ /dev/null
diff --git a/test/expect/HasDecompTest.test_aten_core_operators.expect b/test/expect/HasDecompTest.test_aten_core_operators.expect
index f23021e..c9f379e 100644
--- a/test/expect/HasDecompTest.test_aten_core_operators.expect
+++ b/test/expect/HasDecompTest.test_aten_core_operators.expect
@@ -6,9 +6,6 @@
aten::_adaptive_avg_pool2d.out
aten::_addmm_activation
aten::_addmm_activation.out
-aten::_batch_norm_no_update
-aten::_batch_norm_with_update
-aten::_batch_norm_with_update_functional
aten::_euclidean_dist.out
aten::_fused_dropout
aten::_fused_dropout.out
@@ -79,7 +76,6 @@
aten::atanh.out
aten::atanh_
aten::baddbmm_
-aten::batch_norm_backward
aten::bitwise_and.Scalar
aten::bitwise_and.Scalar_Tensor
aten::bitwise_and.Scalar_Tensor_out
diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect
index 917bcad..3b540d3 100644
--- a/test/expect/HasDecompTest.test_has_decomposition.expect
+++ b/test/expect/HasDecompTest.test_has_decomposition.expect
@@ -30,8 +30,6 @@
aten::_amp_update_scale_
aten::_assert_async
aten::_assert_async.msg
-aten::_batch_norm_no_update.out
-aten::_batch_norm_with_update.out
aten::_cdist_backward
aten::_cdist_backward.out
aten::_cdist_forward
diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py
index 138fafad..15d5ff0 100644
--- a/test/functorch/test_ops.py
+++ b/test/functorch/test_ops.py
@@ -368,10 +368,6 @@
# 'tensor_split' not composite compliant, see vjp_fail
}
-skip_noncontig = {
- '_batch_norm_with_update',
-}
-
@unittest.skipIf(TEST_WITH_ASAN, "tests time out with asan, are probably redundant")
@unMarkDynamoStrictTest
@@ -437,10 +433,9 @@
args = [sample.input] + list(sample.args)
kwargs = sample.kwargs
- if op.name not in skip_noncontig:
- noncontig_sample = sample.noncontiguous()
- noncontig_args = [noncontig_sample.input] + list(noncontig_sample.args)
- noncontig_kwargs = noncontig_sample.kwargs
+ noncontig_sample = sample.noncontiguous()
+ noncontig_args = [noncontig_sample.input] + list(noncontig_sample.args)
+ noncontig_kwargs = noncontig_sample.kwargs
diff_argnums = tuple(i for i, arg in enumerate(args) if diff_arg(arg))
assert len(diff_argnums) > 0
@@ -463,12 +458,11 @@
return result
result = grad(wrapped_fn, diff_argnums)(*args, **kwargs)
+ result_noncontig = grad(wrapped_fn, diff_argnums)(*noncontig_args, **noncontig_kwargs)
expected = _autograd_grad(_as_tuple(wrapped_fn(*args, **kwargs)), diff_args)
- self.assertEqual(result, expected)
- if op.name not in skip_noncontig:
- result_noncontig = grad(wrapped_fn, diff_argnums)(*noncontig_args, **noncontig_kwargs)
- self.assertEqual(result_noncontig, expected)
+ self.assertEqual(result, expected)
+ self.assertEqual(result_noncontig, expected)
@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
@@ -482,8 +476,7 @@
skip('nn.functional.max_unpool2d'), # fails everywhere except on windows
skip('nn.functional.max_unpool3d'), # fails everywhere except on mac
xfail("native_batch_norm"), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents
- xfail("_native_batch_norm_legit"), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents
- xfail("_batch_norm_with_update"), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents
+ xfail("_native_batch_norm_legit"), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents
xfail('nn.functional.scaled_dot_product_attention'),
xfail('torch.ops.aten._flash_attention_forward'),
@@ -552,17 +545,15 @@
self.jvp_opinfo_test(outplace_variant, sample,
sample.output_process_fn_grad,
clone_inputs=False,
- fixme_ref_jvp_local=fixme_ref_jvp_local,
- test_noncontig=op.name not in skip_noncontig)
+ fixme_ref_jvp_local=fixme_ref_jvp_local)
if is_valid_inplace_sample_input(sample, op, inplace_variant):
self.jvp_opinfo_test(inplace_variant, sample,
sample.output_process_fn_grad,
clone_inputs=True,
- fixme_ref_jvp_local=fixme_ref_jvp_local,
- test_noncontig=op.name not in skip_noncontig)
+ fixme_ref_jvp_local=fixme_ref_jvp_local)
def jvp_opinfo_test(self, fn, sample, output_process_fn,
- clone_inputs, fixme_ref_jvp_local, test_noncontig):
+ clone_inputs, fixme_ref_jvp_local):
# NB: we used requires_grad=True to determine where the primals are,
# but don't need that information otherwise
args = (sample.input,) + sample.args
@@ -572,6 +563,15 @@
orig_primals = tree_map(lambda x: x.detach(), primals)
orig_tangents = tree_map(lambda x: torch.randn_like(x), primals)
+ noncontig_sample = sample.noncontiguous()
+ noncontig_args = (noncontig_sample.input,) + noncontig_sample.args
+ noncontig_kwargs = sample.kwargs
+ noncontig_fn, primals = normalize_op_input_output2(
+ fn, noncontig_args, noncontig_kwargs,
+ output_process_fn, requires_grad=True)
+ noncontig_primals = tree_map(lambda x: x.detach(), primals)
+ noncontig_tangents = tree_map(lambda x: noncontiguous_like(x), orig_tangents)
+
def maybe_clone_inputs():
if clone_inputs:
primals = tree_map(torch.clone, orig_primals)
@@ -586,24 +586,15 @@
primals, tangents = maybe_clone_inputs()
primal_outs, tangent_outs = jvp(contig_fn, primals, tangents)
+ noncontig_primal_outs, noncontig_tangent_outs = jvp(noncontig_fn,
+ noncontig_primals,
+ noncontig_tangents)
+
self.assertEqual(primal_outs, expected_primal_outs)
self.assertEqual(tangent_outs, expected_tangent_outs)
- if test_noncontig:
- noncontig_sample = sample.noncontiguous()
- noncontig_args = (noncontig_sample.input,) + noncontig_sample.args
- noncontig_kwargs = sample.kwargs
- noncontig_fn, primals = normalize_op_input_output2(
- fn, noncontig_args, noncontig_kwargs,
- output_process_fn, requires_grad=True)
- noncontig_primals = tree_map(lambda x: x.detach(), primals)
- noncontig_tangents = tree_map(lambda x: noncontiguous_like(x), orig_tangents)
- noncontig_primal_outs, noncontig_tangent_outs = jvp(noncontig_fn,
- noncontig_primals,
- noncontig_tangents)
-
- self.assertEqual(noncontig_primal_outs, expected_primal_outs)
- self.assertEqual(noncontig_tangent_outs, expected_tangent_outs)
+ self.assertEqual(noncontig_primal_outs, expected_primal_outs)
+ self.assertEqual(noncontig_tangent_outs, expected_tangent_outs)
@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
@@ -664,22 +655,22 @@
result = fn(*primals)
cotangents = tree_map(lambda x: torch.randn_like(x), result)
+ noncontig_fn, noncontig_primals = normalize_op_input_output(_op, sample.noncontiguous())
+ noncontig_cotangents = tree_map(lambda x: noncontiguous_like(x), cotangents)
+
out, vjp_fn = vjp(fn, *primals)
self.assertEqual(out, result)
result_vjps = vjp_fn(cotangents)
+ out_noncontig, vjp_fn = vjp(noncontig_fn, *noncontig_primals)
+ self.assertEqual(out_noncontig, result)
+ noncontig_result_vjps = vjp_fn(noncontig_cotangents)
+
_, vjp_fn = ref_vjp(fn, *primals)
expected_vjps = vjp_fn(cotangents)
self.assertEqual(result_vjps, expected_vjps)
-
- if op.name not in skip_noncontig:
- noncontig_fn, noncontig_primals = normalize_op_input_output(_op, sample.noncontiguous())
- noncontig_cotangents = tree_map(lambda x: noncontiguous_like(x), cotangents)
- out_noncontig, vjp_fn = vjp(noncontig_fn, *noncontig_primals)
- self.assertEqual(out_noncontig, result)
- noncontig_result_vjps = vjp_fn(noncontig_cotangents)
- self.assertEqual(noncontig_result_vjps, expected_vjps)
+ self.assertEqual(noncontig_result_vjps, expected_vjps)
_test(op)
for a_op in op.aliases:
@@ -839,8 +830,6 @@
xfail("to_sparse"),
xfail("native_batch_norm"),
xfail("_native_batch_norm_legit"),
- # TODO: implement batching rule
- xfail("_batch_norm_with_update"),
}))
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
@@ -932,8 +921,6 @@
skip('linalg.svdvals'), # # really annoying thing where it passes correctness check but not has_batch_rule
skip("native_batch_norm"),
skip("_native_batch_norm_legit"),
- # TODO: implement batching rule
- skip("_batch_norm_with_update"),
xfail('__getitem__', ''), # dynamic error
xfail('nanquantile', device_type='cpu'), # checks q via a .item() call
xfail('nn.functional.gaussian_nll_loss'), # checks var for if any value < 0
@@ -1058,8 +1045,6 @@
xfail('nn.functional.batch_norm', 'without_cudnn'),
xfail("native_batch_norm"),
xfail("_native_batch_norm_legit"),
- # TODO: implement batching rule
- xfail("_batch_norm_with_update"),
# https://github.com/pytorch/pytorch/issues/96560
# ROCm: NotImplementedError
@@ -1245,8 +1230,6 @@
xfail('sparse.mm', 'reduce'),
xfail("native_batch_norm"),
xfail("_native_batch_norm_legit"),
- # TODO: implement batching rule
- xfail("_batch_norm_with_update"),
xfail("native_dropout_backward"),
xfail("index_fill"), # aten::_unique hit the vmap fallback which is currently disabled
}))
@@ -1323,8 +1306,6 @@
xfail('sparse.mm', 'reduce'),
xfail("native_batch_norm"),
xfail("_native_batch_norm_legit"),
- # TODO: implement batching rule
- xfail("_batch_norm_with_update"),
xfail('as_strided', 'partial_views'),
}))
def test_vjpvmap(self, device, dtype, op):
@@ -1583,8 +1564,6 @@
# place, were not batched.
xfail("native_batch_norm"),
xfail("_native_batch_norm_legit"),
- # TODO: implement batching rule
- xfail("_batch_norm_with_update"),
xfail('native_dropout_backward'),
}))
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py
index 1b5723f..df46c95 100644
--- a/test/functorch/test_vmap.py
+++ b/test/functorch/test_vmap.py
@@ -3625,8 +3625,6 @@
# which will be updated in place, were not batched.
xfail('native_batch_norm'),
xfail('_native_batch_norm_legit'),
- # TODO: implement batching rule
- xfail('_batch_norm_with_update'),
xfail('tril'), # Exception not raised on error input
xfail('triu'), # Exception not raised on error input
xfail('as_strided', 'partial_views'),
@@ -3666,8 +3664,6 @@
# which will be updated in place, were not batched.
xfail('native_batch_norm'),
xfail('_native_batch_norm_legit'),
- # TODO: implement batching rule
- xfail('_batch_norm_with_update'),
xfail('histogram'),
xfail('scatter_reduce', 'sum'),
xfail('scatter_reduce', 'mean'),
diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py
index c51ec4b..8a35bc6 100644
--- a/test/inductor/test_torchinductor_opinfo.py
+++ b/test/inductor/test_torchinductor_opinfo.py
@@ -192,7 +192,6 @@
"nn.functional.cosine_embedding_loss": {b8},
"native_batch_norm": {f16, f32, f64},
"_native_batch_norm_legit": {f16, f32, f64},
- "_batch_norm_with_update": {f16, f32, f64},
}
if not SM80OrLater:
diff --git a/test/onnx/test_fx_op_consistency.py b/test/onnx/test_fx_op_consistency.py
index b6a4176..4d38e70 100644
--- a/test/onnx/test_fx_op_consistency.py
+++ b/test/onnx/test_fx_op_consistency.py
@@ -157,11 +157,6 @@
dtypes=(torch.float16,),
reason="fixme: Assertion error: result mismatch and type error",
),
- skip(
- "_batch_norm_with_update",
- dtypes=(torch.float16,),
- reason="fixme: Assertion error: result mismatch and type error",
- ),
xfail(
"_softmax_backward_data",
reason=onnx_test_common.reason_dynamo_does_not_support("assert all(isinstance(a, KNOWN_TYPES) for a in flat_args)")
@@ -1357,20 +1352,6 @@
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
reason="https://github.com/pytorch/pytorch/issues/115106",
),
- skip(
- "_batch_norm_with_update",
- model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
- reason="https://github.com/pytorch/pytorch/issues/115106",
- ),
- # TODO: This test currently fails only for certain inputs, e.g. shape([3, 1]).
- # Numerically the ONNX program is correct, but the output shapes for `save_mean`
- # and `save_var` were tensor(-2.1268) instead of the correct tensor([-2.1268])
- # for example.
- skip(
- "_batch_norm_with_update",
- model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
- reason="not supported yet",
- ),
xfail(
"addmm", # xfail can't only use dtypes to catch all cases
matcher=lambda sample: sample.input.dtype
diff --git a/test/test_meta.py b/test/test_meta.py
index 6a092b3..65e17ce 100644
--- a/test/test_meta.py
+++ b/test/test_meta.py
@@ -708,11 +708,8 @@
meta_function_device_skips = defaultdict(dict)
meta_function_device_expected_failures['cpu'] = {
- # TODO: The decomps for these batch norm ops return different dtypes depending
- # on the device. We should make this work better with meta tensors.
torch.native_batch_norm: {bf16, f16},
torch._native_batch_norm_legit: {bf16, f16},
- torch.ops.aten._batch_norm_with_update: {bf16, f16},
torch.native_layer_norm: {bf16, f16},
}
@@ -727,11 +724,8 @@
}
meta_function_device_skips['cpu'] = {
- # TODO: The decomps for these batch norm ops return different dtypes depending
- # on the device. We should make this work better with meta tensors.
torch.native_batch_norm: {f32, f64},
torch._native_batch_norm_legit: {f32, f64},
- torch.ops.aten._batch_norm_with_update: {f32, f64},
}
meta_function_device_skips['cuda'] = {
@@ -856,13 +850,9 @@
meta_dispatch_device_skips = defaultdict(dict)
meta_dispatch_device_expected_failures['cpu'] = {
- # TODO: The decomps for these batch norm ops return different dtypes depending
- # on the device. We should make this work better with meta tensors.
aten.native_batch_norm.default: {bf16, f16},
aten._native_batch_norm_legit.default: {bf16, f16},
aten._native_batch_norm_legit.no_stats: {bf16, f16},
- aten._batch_norm_with_update.default: {bf16, f16},
-
aten.native_layer_norm.default: {bf16, f16},
aten.histc.default: {f16},
aten.histc.out: {f16},
@@ -887,13 +877,9 @@
meta_dispatch_device_skips['cpu'] = {
aten._embedding_bag_forward_only.default: {bf16, f16, f32, f64},
-
- # TODO: The decomps for these batch norm ops return different dtypes depending
- # on the device. We should make this work better with meta tensors.
aten.native_batch_norm.default: {f32, f64},
aten._native_batch_norm_legit.default: {f32, f64},
aten._native_batch_norm_legit.no_stats: {f32, f64},
- aten._batch_norm_with_update.default: {f32, f64},
# If the computation dtype is different from the input
# dtype this will fail. CPU execution may also have a
diff --git a/test/test_mps.py b/test/test_mps.py
index af26cbb..b5f54a6 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -11384,7 +11384,6 @@
'nn.functional.gelu',
'nn.functional.glu',
'_native_batch_norm_legit',
- '_batch_norm_with_update',
'native_batch_norm',
'softmax',
'_softmax_backward_data',
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 08947ba..3338344 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -1927,15 +1927,6 @@
}
out_symbolic_tensor_failures = {
- # Cast error details: Unable to cast (...) to Tensor
- #
- # This happens because the test is set up to call the out variant using the `out` kwarg:
- # torch._some_op(arg1, arg2, out=(out1, out2, out3))
- #
- # However, this only works on torch ops, not aten ops. For `_batch_norm_with_update`,
- # this fails because the op has no python bindings, so it doesn't support the `out` kwarg
- # way of calling its out variant.
- xfail('_batch_norm_with_update', ''),
xfail('_native_batch_norm_legit', ''),
xfail('angle', ''),
xfail('argmax', ''),
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index f85af41..e692ae9 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -1250,20 +1250,6 @@
self: grad.neg()
result: auto_element_wise
-- name: _batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)
- input, weight, bias: "grad.defined() ? batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, /*update*/true, eps, grad_input_mask, retain_variables ? result3.clone() : result3) : std::tuple<Tensor, Tensor, Tensor>()"
- result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, true, eps)
-
-- name: _batch_norm_no_update(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)
- input, weight, bias: "grad.defined() ? batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, /*update*/false, eps, grad_input_mask, retain_variables ? result3.clone() : result3) : std::tuple<Tensor, Tensor, Tensor>()"
- result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, false, eps)
-
-- name: batch_norm_backward(Tensor grad_out, Tensor input, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, bool update, float eps, bool[3] output_mask, Tensor reserve) -> (Tensor, Tensor, Tensor)
- input, weight, grad_out: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, running_mean, running_var, update, eps, save_mean, save_var, grad_input_mask)
- save_mean: not_implemented("batch_norm_backward save_mean")
- save_var: not_implemented("batch_norm_backward save_var")
- reserve: not_implemented("batch_norm_backward reserve")
-
- name: nextafter(Tensor self, Tensor other) -> Tensor
self: not_implemented("nextafter")
other: not_implemented("nextafter")
diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py
index 29ccf12..438754d 100644
--- a/tools/autograd/gen_python_functions.py
+++ b/tools/autograd/gen_python_functions.py
@@ -158,12 +158,9 @@
"fill.Tensor", # only used by the functionalization pass
"fill.Scalar", # only used by the functionalization pass
"lift.*",
- "normal_functional", # only used by the functionalization pass
+ "normal_functional", # only used by the functionalization pas
"nbytes",
"itemsize",
- "_batch_norm_with_update",
- "_batch_norm_with_update_out",
- "_batch_norm_no_update",
]
SKIP_PYTHON_BINDINGS = [
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index 0027881..516072b 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -1132,7 +1132,6 @@
def _stash_obj_in_tls(key: str, arg: Any) -> None: ...
def _get_obj_in_tls(key: str) -> Any: ...
def _is_key_in_tls(key: str) -> _bool: ...
-def _select_batch_norm_backend(*args, **kwargs) -> BatchNormBackend: ...
def _select_conv_backend(*args, **kwargs) -> ConvBackend: ...
def _conv_determine_backend_memory_format(
input: Tensor,
@@ -1198,8 +1197,6 @@
Cusolver: _LinalgBackend
Magma: _LinalgBackend
-class BatchNormBackend(Enum): ...
-
class ConvBackend(Enum): ...
class Tag(Enum):
diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py
index 5c66749..bcd6652 100644
--- a/torch/_decomp/decompositions.py
+++ b/torch/_decomp/decompositions.py
@@ -1843,114 +1843,6 @@
return output, save_mean, save_rstd, new_running_mean, new_running_var
-def _get_batch_norm_reserve_tensor(
- input: Tensor,
- weight: Optional[Tensor],
- bias: Optional[Tensor],
- running_mean: Tensor,
- running_var: Tensor,
- eps: float,
- training: bool,
-) -> Tensor:
- """
- Return a reserve tensor for batch norm, used only by cudnn to pass forward state to the
- backward pass. This is needed for `_batch_norm_with_update` and `_batch_norm_no_update`,
- which support a variety of backends including cudnn. We create this tensor here to get
- the correct shape in the traced graph if we detect that will call the cudnn kernel,
- and rely on DCE to avoid materializing this tensor.
- """
- backend = torch._C._select_batch_norm_backend( # type: ignore[attr-defined]
- input, weight, bias, running_mean, running_var, True, eps
- )
- reserve_size = 0
- if backend == torch._C._BatchNormBackend.Cudnn: # type: ignore[attr-defined]
- reserve_size = torch._C._get_cudnn_batch_norm_reserve_space_size(input, training) # type: ignore[attr-defined]
- return torch.empty(
- reserve_size, dtype=torch.uint8, layout=input.layout, device=input.device
- )
-
-
-@register_decomposition(aten._batch_norm_with_update.default)
-def _batch_norm_with_update(
- input: Tensor,
- weight: Optional[Tensor],
- bias: Optional[Tensor],
- running_mean: Tensor,
- running_var: Tensor,
- momentum: float,
- eps: float,
-) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
- output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
- input,
- weight,
- bias,
- running_mean,
- running_var,
- True, # training
- momentum,
- eps,
- False, # functional
- )
- reserve = _get_batch_norm_reserve_tensor(
- input, weight, bias, running_mean, running_var, eps, training=True
- )
- return output, save_mean, save_rstd, reserve
-
-
-@register_decomposition(aten._batch_norm_with_update_functional.default)
-def _batch_norm_with_update_functional(
- input: Tensor,
- weight: Optional[Tensor],
- bias: Optional[Tensor],
- running_mean: Tensor,
- running_var: Tensor,
- momentum: float,
- eps: float,
-) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
- (
- output,
- save_mean,
- save_rstd,
- new_rm,
- new_rv,
- ) = native_batch_norm_helper(
- input, weight, bias, running_mean, running_var, True, momentum, eps, True
- )
- reserve = _get_batch_norm_reserve_tensor(
- input, weight, bias, running_mean, running_var, eps, training=True
- )
- assert new_rm is not None, "new_running_mean should not be None"
- assert new_rv is not None, "new_running_var should not be None"
- return (output, save_mean, save_rstd, reserve, new_rm, new_rv)
-
-
-@register_decomposition(aten._batch_norm_no_update.default)
-def _batch_norm_no_update(
- input: Tensor,
- weight: Optional[Tensor],
- bias: Optional[Tensor],
- running_mean: Tensor,
- running_var: Tensor,
- momentum: float,
- eps: float,
-) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
- output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
- input,
- weight,
- bias,
- running_mean,
- running_var,
- False, # training
- momentum,
- eps,
- False, # functional
- )
- reserve = _get_batch_norm_reserve_tensor(
- input, weight, bias, running_mean, running_var, eps, training=False
- )
- return output, save_mean, save_rstd, reserve
-
-
@register_decomposition(aten._fused_dropout)
@out_wrapper("out0", "out1")
@pw_cast_for_opmath
@@ -2055,34 +1947,6 @@
return x
-@register_decomposition(aten.batch_norm_backward.default)
-def batch_norm_backward(
- grad_out: Tensor,
- input: Tensor,
- weight: Optional[Tensor],
- running_mean: Optional[Tensor],
- running_var: Optional[Tensor],
- save_mean: Optional[Tensor],
- save_invstd: Optional[Tensor],
- train: bool,
- eps: float,
- output_mask: List[bool],
- reserve: Tensor,
-) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
- return native_batch_norm_backward(
- grad_out,
- input,
- weight,
- running_mean,
- running_var,
- save_mean,
- save_invstd,
- train,
- eps,
- output_mask,
- )
-
-
@register_decomposition(aten.native_batch_norm_backward.default)
def native_batch_norm_backward(
grad_out: Tensor,
diff --git a/torch/_decomp/decompositions_for_jvp.py b/torch/_decomp/decompositions_for_jvp.py
index 81946c3..19dfaed 100644
--- a/torch/_decomp/decompositions_for_jvp.py
+++ b/torch/_decomp/decompositions_for_jvp.py
@@ -291,34 +291,6 @@
return (grad_input, grad_weight, grad_bias)
-@register_decomposition_for_jvp(aten.batch_norm_backward)
-def batch_norm_backward(
- grad_out: Tensor,
- input: Tensor,
- weight: Tensor,
- running_mean: Optional[Tensor],
- running_var: Optional[Tensor],
- save_mean: Optional[Tensor],
- save_var: Optional[Tensor],
- update: bool,
- eps: float,
- output_mask: List[bool],
- reserve: Tensor,
-) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
- return native_batch_norm_backward(
- grad_out,
- input,
- weight,
- running_mean,
- running_var,
- save_mean,
- save_var,
- update,
- eps,
- output_mask,
- )
-
-
_register_jit_decomposition_for_jvp(torch.ops.aten.trace.default, use_python=True)
_register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss_backward.default)
_register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss2d_backward.default)
@@ -328,4 +300,3 @@
_register_jit_decomposition_for_jvp(torch.ops.aten.native_layer_norm_backward.default)
_register_jit_decomposition_for_jvp(torch.ops.aten.native_batch_norm_backward.default)
_register_jit_decomposition_for_jvp(torch.ops.aten.cudnn_batch_norm_backward.default)
-_register_jit_decomposition_for_jvp(torch.ops.aten.batch_norm_backward.default)
diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py
index c8c97f9..9d9ec2c 100644
--- a/torch/_dynamo/trace_rules.py
+++ b/torch/_dynamo/trace_rules.py
@@ -999,7 +999,6 @@
"torch._C._scatter_out",
"torch._C._scatter",
"torch._C._select_conv_backend",
- "torch._C._select_batch_norm_backend",
"torch._C._set_autograd_fallback_mode",
"torch._C._set_backcompat_broadcast_warn",
"torch._C._set_backcompat_keepdim_warn",
diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py
index 18d0ed1..dd8f860 100644
--- a/torch/_functorch/partitioners.py
+++ b/torch/_functorch/partitioners.py
@@ -753,7 +753,7 @@
recomputable_ops = set(recomputable_ops) if recomputable_ops is not None else set(default_recomputable_ops)
random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like]
- compute_intensive_ops = [aten.mm, aten.convolution, aten.convolution_backward, aten.bmm, aten.addmm, aten.upsample_bilinear2d, aten._softmax, aten._softmax_backward_data, aten.native_layer_norm, aten.native_layer_norm_backward, aten.native_batch_norm, aten.native_batch_norm_backward, aten._native_batch_norm_legit, aten._batch_norm_with_update, aten.batch_norm_backward] # noqa: E501,B950
+ compute_intensive_ops = [aten.mm, aten.convolution, aten.convolution_backward, aten.bmm, aten.addmm, aten.upsample_bilinear2d, aten._softmax, aten._softmax_backward_data, aten.native_layer_norm, aten.native_layer_norm_backward, aten.native_batch_norm, aten.native_batch_norm_backward, aten._native_batch_norm_legit] # noqa: E501,B950
fusible_ops = recomputable_ops | set(random_ops)
if AOT_PARTITIONER_DEBUG:
diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py
index 49d4003..80d3d1e 100644
--- a/torch/_inductor/decomposition.py
+++ b/torch/_inductor/decomposition.py
@@ -54,10 +54,6 @@
aten._native_batch_norm_legit,
aten._native_batch_norm_legit_functional,
aten._native_batch_norm_legit_no_training,
- aten._batch_norm_with_update,
- aten._batch_norm_with_update_functional,
- aten._batch_norm_no_update,
- aten.batch_norm_backward,
aten.native_batch_norm,
aten.native_group_norm,
aten.native_layer_norm,
diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp
index a0d6b6e..21af884 100644
--- a/torch/csrc/Module.cpp
+++ b/torch/csrc/Module.cpp
@@ -18,7 +18,6 @@
#include <ATen/dlpack.h>
#include <ATen/native/ConvUtils.h>
#include <ATen/native/ForeachUtils.h>
-#include <ATen/native/Normalization.h>
#include <c10/core/DispatchKeySet.h>
#include <c10/util/AbortHandler.h>
#include <c10/util/Backtrace.h>
@@ -92,10 +91,7 @@
#include <ATen/native/transformers/sdp_utils_cpp.h>
#include <torch/csrc/profiler/combined_traceback.h>
#include <sstream>
-
#ifdef USE_CUDA
-#include <ATen/cuda/CUDAConfig.h>
-#include <ATen/native/cudnn/BatchNorm.h>
#include <ATen/native/transformers/cuda/sdp_utils.h>
#endif
@@ -2099,44 +2095,6 @@
},
"Checks if a tensor's data pointer is COW");
- py_module.def(
- "_get_cudnn_batch_norm_reserve_space_size",
- [](const at::Tensor& input, bool training) {
-#ifdef USE_CUDA
- return at::native::_get_cudnn_batch_norm_reserve_space_size(
- input, training);
-#else
- TORCH_CHECK(false, "PyTorch was not built with cuda");
-#endif
- },
- py::arg("input"),
- py::arg("training"));
-
- py::enum_<at::native::BatchNormBackend>(py_module, "_BatchNormBackend")
- .value("Native", at::native::BatchNormBackend::Native)
- .value("Cudnn", at::native::BatchNormBackend::Cudnn)
- .value("Miopen", at::native::BatchNormBackend::Miopen);
-
- py_module.def(
- "_select_batch_norm_backend",
- [](const at::Tensor& input,
- const at::Tensor& weight,
- const at::Tensor& bias,
- const at::Tensor& running_mean,
- const at::Tensor& running_var,
- bool training,
- double eps) {
- return at::native::_select_batch_norm_backend(
- input, weight, bias, running_mean, running_var, training, eps);
- },
- py::arg("input"),
- py::arg("weight"),
- py::arg("bias"),
- py::arg("running_mean"),
- py::arg("running_var"),
- py::arg("training"),
- py::arg("eps"));
-
const auto& defaultGenerator = at::detail::getDefaultCPUGenerator();
THPDefaultCPUGenerator =
(THPGenerator*)THPGenerator_initDefaultGenerator(defaultGenerator);
diff --git a/torch/csrc/jit/runtime/serialized_shape_function_registry.cpp b/torch/csrc/jit/runtime/serialized_shape_function_registry.cpp
index 231aba3..943d43f 100644
--- a/torch/csrc/jit/runtime/serialized_shape_function_registry.cpp
+++ b/torch/csrc/jit/runtime/serialized_shape_function_registry.cpp
@@ -3312,7 +3312,6 @@
{"aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "native_batch_norm"},
{"aten::_native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "native_batch_norm"},
{"aten::_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "native_batch_norm"},
- {"aten::_batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)", "native_batch_norm"},
{"aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor", "cross_entropy_loss"},
{"aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor", "broadcast_three"},
{"aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor", "broadcast_one_three"},
diff --git a/torch/jit/_shape_functions.py b/torch/jit/_shape_functions.py
index 9c1da98..5151503 100644
--- a/torch/jit/_shape_functions.py
+++ b/torch/jit/_shape_functions.py
@@ -1431,11 +1431,6 @@
native_batch_norm,
)
add_shape_compute_mapping(
- "_batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)",
- native_batch_norm,
-)
-
-add_shape_compute_mapping(
"aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor",
cross_entropy_loss,
)
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 09ff394..ca3f461 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -509,19 +509,6 @@
else:
yield SampleInput(sample.input, args=(args[2], args[3], training, momentum, eps))
-def sample_inputs__batch_norm_with_update(op_info, device, dtype, requires_grad, **kwargs):
- samples = sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs)
- for sample in samples:
- # torch.native_batch_norm does not support 0 numel tensors
- # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
- if sample.input.numel() == 0:
- continue
- args = sample.args
- momentum = sample.kwargs.get('momentum', 0.5)
- eps = sample.kwargs.get('eps', 1e-5)
- if any(args[i] is None for i in range(4)):
- continue
- yield SampleInput(sample.input, args=(args[2], args[3], args[0], args[1], momentum, eps))
def sample_inputs_nn_activation_relu(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
@@ -13002,42 +12989,6 @@
"TestCompositeCompliance", "test_forward_ad"),
)
),
- OpInfo('_batch_norm_with_update',
- op=torch.ops.aten._batch_norm_with_update,
- aten_name='_batch_norm_with_update',
- dtypes=floating_types_and(torch.float16, torch.bfloat16),
- supports_forward_ad=True,
- supports_fwgrad_bwgrad=True,
- assert_jit_shape_analysis=True,
- # TODO: Avoid COW materialize
- supports_cow_input_no_materialize=False,
- sample_inputs_func=sample_inputs__batch_norm_with_update,
- skips=(
- # NotImplementedError: Could not run
- # 'aten::native_batch_norm.out' with arguments from the 'CPU' backend.
- DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type="cpu"),
- # RuntimeError: out_invstd.dim() == 1 && out_invstd.is_contiguous() && out_invstd.sizes()[0]
- DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type="cuda"),
- # Problem with _get_numerical_jacobian
- # IndexError: tuple index out of range
- DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
- # RuntimeError: deepEquals(input.iValue, deepCopiedInput) INTERNAL ASSERT FAILED
- DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
- # https://github.com/pytorch/pytorch/issues/85960
- DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'),
- DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-5, rtol=5e-5)}),
- "TestCompositeCompliance", "test_forward_ad"),
- # _batch_norm_with_update expects contiguous inputs for cudnn and miopen
- DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type="cuda"),
- DecorateInfo(unittest.expectedFailure,
- 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides', device_type="cuda"),
- # _batch_norm_with_update does not have python bindings
- DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
- # aten out variants do not accept out= kwarg, only python out variants
- DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
- DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
- )
- ),
OpInfo('nn.functional.cosine_similarity',
aten_name="cosine_similarity",
dtypes=floating_types_and(torch.half, torch.bfloat16),