Rectify `native_batch_norm` schema by splitting it into two legit schemas (#88697)
Using the same repro from the issue (but with BatchNorm2D)
Rectifies native_batch_norm schema by splitting the schema into 2:
1. one will have NON-optional alias-able running_mean and running_var inputs
2. the other will just not have those parameters at all (no_stats variation)
**Calling for name suggestions!**
## test plan
I've added tests in test_functionalization.py as well as an entry in common_method_invocations.py for `native_batch_norm_legit`
CI should pass.
## next steps
Because of bc/fc reasons, we reroute native_batch_norm to call our new schemas ONLY through the python dispatcher, but in 2 weeks or so, we should make `native_batch_norm_legit` the official batch_norm.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88697
Approved by: https://github.com/albanD
diff --git a/.gitignore b/.gitignore
index 5ca1885..597ae39 100644
--- a/.gitignore
+++ b/.gitignore
@@ -46,6 +46,7 @@
log
usage_log.txt
test-reports/
+test/*.bak
test/.coverage
test/.hypothesis/
test/cpp/api/mnist
diff --git a/aten/src/ATen/functorch/BatchRulesNorm.cpp b/aten/src/ATen/functorch/BatchRulesNorm.cpp
index 5e6f855..d53d4f6 100644
--- a/aten/src/ATen/functorch/BatchRulesNorm.cpp
+++ b/aten/src/ATen/functorch/BatchRulesNorm.cpp
@@ -875,10 +875,28 @@
return at::miopen_batch_norm_backward(input, grad_out, weight_opt, running_mean_opt, running_var_opt, save_mean_opt, save_rstd_opt, eps);
}
+// NB: This is NOT good. In the ideal world, we do NOT want to convert the new legit op back into native_batch_norm
+// as native_batch_norm has a problematic schema--it promises it is functional when it is not. However, vmap doesn't
+// work with dynamo anyway so we gain some buffer room to do wrong things here. The (reasonable) hope is that we will
+// make native_batch_norm composite implicit within a few weeks and we can fix this before vmap works with dynamo.
+std::tuple<at::Tensor,at::Tensor,at::Tensor> _native_batch_norm_legit_batch(
+ const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
+ Tensor& running_mean, Tensor& running_var, bool train, double momentum, double eps) {
+ return at::native_batch_norm(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, eps);
+}
+
+std::tuple<at::Tensor,at::Tensor,at::Tensor> _native_batch_norm_legit_no_stats_batch(
+ const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
+ bool train, double momentum, double eps) {
+ return at::native_batch_norm(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, eps);
+}
+
TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
VMAP_SUPPORT(native_batch_norm, NATIVE_BATCH_NORM_BATCH_RULE(native_batch_norm));
VMAP_SUPPORT(cudnn_batch_norm, CUDNN_BATCH_NORM_BATCH_RULE(cudnn_batch_norm));
VMAP_SUPPORT(miopen_batch_norm, MIOPEN_BATCH_NORM_BATCH_RULE(miopen_batch_norm));
+ m.impl("_native_batch_norm_legit", _native_batch_norm_legit_batch);
+ m.impl("_native_batch_norm_legit.no_stats", _native_batch_norm_legit_no_stats_batch);
m.impl("native_batch_norm_backward", NATIVE_BATCH_NORM_BACKWARD_BATCH_RULE(native_batch_norm_backward));
m.impl("cudnn_batch_norm_backward", CUDNN_BATCH_NORM_BACKWARD_BATCH_RULE(at::functorch::cudnn_batch_norm_backward_wrapper));
m.impl("miopen_batch_norm_backward", MIOPEN_BATCH_NORM_BACKWARD_BATCH_RULE(at::functorch::miopen_batch_norm_backward_wrapper));
diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp
index 69196a3..ab9094d 100644
--- a/aten/src/ATen/native/Normalization.cpp
+++ b/aten/src/ATen/native/Normalization.cpp
@@ -787,6 +787,30 @@
return batch_norm_cpu_out(self, weight_opt, bias_opt, running_mean_opt, running_var_opt, train, momentum, eps, output, save_mean, save_var);
}
+
+std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_cpu(
+ const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
+ Tensor& running_mean, Tensor& running_var, bool train, double momentum, double eps) {
+ return batch_norm_cpu(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, eps);
+}
+
+std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_no_stats_cpu(
+ const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
+ bool train, double momentum, double eps) {
+ return batch_norm_cpu(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, eps);
+}
+
+
+std::tuple<Tensor&, Tensor&, Tensor&> _batch_norm_legit_cpu_out(const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt, Tensor& running_mean, Tensor& running_var, bool train, double momentum, double eps, Tensor& out, Tensor& save_mean, Tensor& save_var) {
+ return batch_norm_cpu_out(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, eps, out, save_mean, save_var);
+}
+
+
+std::tuple<Tensor&, Tensor&, Tensor&> _batch_norm_legit_no_stats_cpu_out(const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt, bool train, double momentum, double eps, Tensor& out, Tensor& save_mean, Tensor& save_var) {
+ return batch_norm_cpu_out(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, eps, out, save_mean, save_var);
+}
+
+
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu(const Tensor& grad_out, const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_invstd_opt,
bool train, double eps, std::array<bool,3> grad_input_mask) {
// See [Note: hacky wrapper removal for optional tensor]
diff --git a/aten/src/ATen/native/cuda/Normalization.cu b/aten/src/ATen/native/cuda/Normalization.cu
index df46044..a8eff15 100644
--- a/aten/src/ATen/native/cuda/Normalization.cu
+++ b/aten/src/ATen/native/cuda/Normalization.cu
@@ -473,6 +473,22 @@
return std::make_tuple(output, save_mean, save_invstd);
}
+std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_cuda(const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt, Tensor& running_mean, Tensor& running_var, bool train, double momentum, double epsilon) {
+ return batch_norm_cuda(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, epsilon);
+}
+
+std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_no_stats_cuda(const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt, bool train, double momentum, double epsilon) {
+ return batch_norm_cuda(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, epsilon);
+}
+
+std::tuple<Tensor&, Tensor&, Tensor&> _batch_norm_legit_cuda_out(const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt, Tensor& running_mean, Tensor& running_var, bool train, double momentum, double epsilon, Tensor& output, Tensor& save_mean, Tensor& save_invstd) {
+ return batch_norm_cuda_out(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, epsilon, output, save_mean, save_invstd);
+}
+
+std::tuple<Tensor&, Tensor&, Tensor&> _batch_norm_legit_no_stats_cuda_out(const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt, bool train, double momentum, double epsilon, Tensor& output, Tensor& save_mean, Tensor& save_invstd) {
+ return batch_norm_cuda_out(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, epsilon, output, save_mean, save_invstd);
+}
+
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda(const Tensor& grad_out, const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_invstd_opt, bool train, double epsilon, std::array<bool,3> grad_input_mask) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight = at::borrow_from_optional_tensor(weight_opt);
diff --git a/aten/src/ATen/native/mkldnn/Normalization.cpp b/aten/src/ATen/native/mkldnn/Normalization.cpp
index 1be6224..d017186 100644
--- a/aten/src/ATen/native/mkldnn/Normalization.cpp
+++ b/aten/src/ATen/native/mkldnn/Normalization.cpp
@@ -41,6 +41,23 @@
TORCH_CHECK(false, "mkldnn_layer_norm_last_index_weight_bias_f32: ATen not compiled with MKLDNN support");
}
+std::tuple<Tensor, Tensor, Tensor> _mkldnn_batch_norm_legit(
+ const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt, Tensor& running_mean, Tensor& running_var,
+ bool train,
+ double momentum,
+ double eps) {
+ TORCH_CHECK(false, "_mkldnn_batch_norm_legit: ATen not compiled with MKLDNN support");
+}
+
+
+std::tuple<Tensor, Tensor, Tensor> _mkldnn_batch_norm_legit_no_stats(
+ const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
+ bool train,
+ double momentum,
+ double eps) {
+ TORCH_CHECK(false, "_mkldnn_batch_norm_legit_no_stats: ATen not compiled with MKLDNN support");
+}
+
} // namespace native
} // namespace at
@@ -173,6 +190,25 @@
}
}
+
+std::tuple<Tensor, Tensor, Tensor> _mkldnn_batch_norm_legit(
+ const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt, Tensor& running_mean, Tensor& running_var,
+ bool train,
+ double momentum,
+ double eps) {
+ return mkldnn_batch_norm(input, weight_opt, bias_opt, running_mean, running_var, train, momentum, eps);
+}
+
+
+std::tuple<Tensor, Tensor, Tensor> _mkldnn_batch_norm_legit_no_stats(
+ const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
+ bool train,
+ double momentum,
+ double eps) {
+ return mkldnn_batch_norm(input, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, eps);
+}
+
+
std::tuple<Tensor, Tensor, Tensor> mkldnn_batch_norm_backward(const Tensor& grad_output,
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_invstd_opt,
bool train,
diff --git a/aten/src/ATen/native/mps/operations/Normalization.mm b/aten/src/ATen/native/mps/operations/Normalization.mm
index 5384ee6..49f1e05 100644
--- a/aten/src/ATen/native/mps/operations/Normalization.mm
+++ b/aten/src/ATen/native/mps/operations/Normalization.mm
@@ -411,6 +411,54 @@
return std::make_tuple(output, save_mean, save_var);
}
+std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_mps
+ (const Tensor& self,
+ const c10::optional<Tensor>& weight_opt,
+ const c10::optional<Tensor>& bias_opt,
+ Tensor& running_mean,
+ Tensor& running_var,
+ bool train,
+ double momentum,
+ double epsilon) {
+
+ return batch_norm_mps(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, epsilon);
+}
+
+std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_no_stats_mps
+ (const Tensor& self,
+ const c10::optional<Tensor>& weight_opt,
+ const c10::optional<Tensor>& bias_opt,
+ bool train,
+ double momentum,
+ double epsilon) {
+
+ return batch_norm_mps(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, epsilon);
+}
+
+std::tuple<Tensor&, Tensor&, Tensor&> _batch_norm_legit_mps_out
+ (const Tensor& self,
+ const c10::optional<Tensor>& weight_opt,
+ const c10::optional<Tensor>& bias_opt,
+ Tensor& running_mean,
+ Tensor& running_var,
+ bool train, double momentum, double epsilon,
+ Tensor& output,
+ Tensor& save_mean,
+ Tensor& save_var) {
+ return batch_norm_mps_out(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, epsilon, output, save_mean, save_var);
+}
+
+std::tuple<Tensor&, Tensor&, Tensor&> _batch_norm_legit_no_stats_mps_out
+ (const Tensor& self,
+ const c10::optional<Tensor>& weight_opt,
+ const c10::optional<Tensor>& bias_opt,
+ bool train, double momentum, double epsilon,
+ Tensor& output,
+ Tensor& save_mean,
+ Tensor& save_var) {
+ return batch_norm_mps_out(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, epsilon, output, save_mean, save_var);
+}
+
string get_mem_string(c10::MemoryFormat memory_format) {
string mem_format_key;
switch(memory_format) {
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 3917be0..9aa3a2c 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -3811,6 +3811,35 @@
MPS: batch_norm_mps_out
CPU: batch_norm_cpu_out
+# TODO: In 2 weeks, we should make native_batch_norm composite implicit so that this correct schema percolates correctly through our dispatching
+- func: _native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)
+ dispatch:
+ CPU: _batch_norm_legit_cpu
+ CUDA: _batch_norm_legit_cuda
+ MPS: _batch_norm_legit_mps
+ MkldnnCPU: _mkldnn_batch_norm_legit
+ autogen: _native_batch_norm_legit_functional
+
+- func: _native_batch_norm_legit.out(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps, *, Tensor(d!) out, Tensor(e!) save_mean, Tensor(f!) save_invstd) -> (Tensor(d!), Tensor(e!), Tensor(f!))
+ dispatch:
+ CPU: _batch_norm_legit_cpu_out
+ CUDA: _batch_norm_legit_cuda_out
+ MPS: _batch_norm_legit_mps_out
+
+- func: _native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)
+ dispatch:
+ CPU: _batch_norm_legit_no_stats_cpu
+ CUDA: _batch_norm_legit_no_stats_cuda
+ MPS: _batch_norm_legit_no_stats_mps
+ MkldnnCPU: _mkldnn_batch_norm_legit_no_stats
+ tags: canonical
+
+- func: _native_batch_norm_legit.no_stats_out(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+ dispatch:
+ CPU: _batch_norm_legit_no_stats_cpu_out
+ CUDA: _batch_norm_legit_no_stats_cuda_out
+ MPS: _batch_norm_legit_no_stats_mps_out
+
- func: batch_norm_stats(Tensor input, float eps) -> (Tensor, Tensor)
dispatch:
CUDA: batch_norm_stats_cuda
diff --git a/functorch/_src/partitioners.py b/functorch/_src/partitioners.py
index e12840f..712c9a0 100644
--- a/functorch/_src/partitioners.py
+++ b/functorch/_src/partitioners.py
@@ -349,7 +349,7 @@
recomputable_ops = set(recomputable_ops) if recomputable_ops is not None else set(default_recomputable_ops)
random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like]
- compute_intensive_ops = [aten.mm, aten.convolution, aten.convolution_backward, aten.bmm, aten.addmm, aten.upsample_bilinear2d, aten._softmax, aten._softmax_backward_data, aten.native_layer_norm, aten.native_layer_norm_backward, aten.native_batch_norm, aten.native_batch_norm_backward] # noqa: E501
+ compute_intensive_ops = [aten.mm, aten.convolution, aten.convolution_backward, aten.bmm, aten.addmm, aten.upsample_bilinear2d, aten._softmax, aten._softmax_backward_data, aten.native_layer_norm, aten.native_layer_norm_backward, aten.native_batch_norm, aten.native_batch_norm_backward, aten._native_batch_norm_legit] # noqa: E501
unrecomputable_ops = random_ops + compute_intensive_ops
diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py
index e9451b5..c0ae683 100644
--- a/test/functorch/test_ops.py
+++ b/test/functorch/test_ops.py
@@ -400,7 +400,9 @@
skip('nn.functional.max_unpool1d'), # fails everywhere except on mac
skip('nn.functional.max_unpool2d'), # fails everywhere except on windows
skip('nn.functional.max_unpool3d'), # fails everywhere except on mac
- xfail("native_batch_norm"),
+ xfail("native_batch_norm"), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents
+ xfail("_native_batch_norm_legit"), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents
+
xfail('nn.functional._scaled_dot_product_attention', device_type='cuda'),
xfail('nn.functional.rrelu'), # in-place test errors out with no formula implemented
@@ -689,6 +691,7 @@
# view doesn't work on sparse
xfail("to_sparse"),
xfail("native_batch_norm"),
+ xfail("_native_batch_norm_legit"),
}))
@ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
@@ -773,6 +776,7 @@
# All of the following are bugs and need to be fixed
skip('linalg.svdvals'), # # really annoying thing where it passes correctness check but not has_batch_rule
skip("native_batch_norm"),
+ skip("_native_batch_norm_legit"),
xfail('__getitem__', ''), # dynamic error
xfail('linalg.eig'), # Uses aten::allclose
xfail('nanquantile', device_type='cpu'), # checks q via a .item() call
@@ -888,6 +892,7 @@
xfail('nn.functional.batch_norm'),
xfail('nn.functional.batch_norm', 'without_cudnn'),
xfail("native_batch_norm"),
+ xfail("_native_batch_norm_legit"),
# ----------------------------------------------------------------------
}
@@ -1090,6 +1095,7 @@
xfail('segment_reduce', 'lengths'),
xfail('sparse.sampled_addmm', ''),
xfail("native_batch_norm"),
+ xfail("_native_batch_norm_legit"),
xfail("native_dropout_backward"),
}))
def test_vmapvjp_has_batch_rule(self, device, dtype, op):
@@ -1162,6 +1168,7 @@
xfail('as_strided_scatter', ''),
xfail('sparse.sampled_addmm', ''),
xfail("native_batch_norm"),
+ xfail("_native_batch_norm_legit"),
}))
def test_vjpvmap(self, device, dtype, op):
# NB: there is no vjpvmap_has_batch_rule test because that is almost
@@ -1419,6 +1426,7 @@
# input while the running_mean or running_var, which will be updated in
# place, were not batched.
xfail("native_batch_norm"),
+ xfail("_native_batch_norm_legit"),
xfail('native_dropout_backward',)
}))
@ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py
index 9b3293a..4b46056 100644
--- a/test/functorch/test_vmap.py
+++ b/test/functorch/test_vmap.py
@@ -3294,7 +3294,10 @@
))
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
@skipOps('TestVmapOperatorsOpInfo', 'test_vmap_exhaustive', vmap_fail.union({
+ # RuntimeError: Batch norm got a batched tensor as input while the running_mean or running_var,
+ # which will be updated in place, were not batched.
xfail('native_batch_norm'),
+ xfail('_native_batch_norm_legit'),
xfail('tril'), # Exception not raised on error input
xfail('triu'), # Exception not raised on error input
# The error inputs are vectors, that pass when batched as they are treated as a matrix
@@ -3317,7 +3320,10 @@
skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format
xfail('complex'),
xfail('copysign'),
+ # Batch norm got a batched tensor as input while the running_mean or running_var,
+ # which will be updated in place, were not batched.
xfail('native_batch_norm'),
+ xfail('_native_batch_norm_legit'),
xfail('histogram'),
xfail('index_fill'),
xfail('nansum'),
diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py
index 2a79107..c9a9147 100644
--- a/test/inductor/test_torchinductor_opinfo.py
+++ b/test/inductor/test_torchinductor_opinfo.py
@@ -135,6 +135,7 @@
"jiterator_unary": {b8, f16, f32, f64, i32, i64},
# flaky
"native_batch_norm": {f16, f32, f64},
+ "_native_batch_norm_legit": {f16, f32, f64},
}
inductor_expected_failures_single_sample = defaultdict(dict)
diff --git a/test/lazy/test_reuse_ir.py b/test/lazy/test_reuse_ir.py
index 2d19fe1..f7024e9 100644
--- a/test/lazy/test_reuse_ir.py
+++ b/test/lazy/test_reuse_ir.py
@@ -111,6 +111,7 @@
# BatchNorm2d does extra checks on dimensions which SymInts don't support yet
# so we call `torch.ops.aten.native_batch_norm` to bypass the checks.
z, _, _ = torch.ops.aten.native_batch_norm(x, weight, bias, None, None, True, 0.1, 1e-5)
+ z_legit, _, _ = torch.ops.aten._native_batch_norm_legit(x, weight, bias, True, 0.1, 1e-5)
device = "lazy"
x_lazy = x.detach().clone().to(device=device)
@@ -118,12 +119,15 @@
bias_lazy = bias.detach().clone().to(device=device)
for i in range(10):
z_lazy, _, _ = torch.ops.aten.native_batch_norm(x_lazy, weight_lazy, bias_lazy, None, None, True, 0.1, 1e-5)
+ z_legit_lazy, _, _ = torch.ops.aten._native_batch_norm_legit(x_lazy, weight_lazy, bias_lazy, True, 0.1, 1e-5)
torch._lazy.mark_step()
torch.testing.assert_close(z.cpu(), z_lazy.cpu())
+ torch.testing.assert_close(z_legit.cpu(), z_legit_lazy.cpu())
assert metrics.counter_value("IrNodeReused_torch::lazy::NativeBatchNorm") >= 7
metrics.reset()
torch._lazy.ir_cache.reset()
+
if __name__ == '__main__':
run_tests()
diff --git a/test/test_decomp.py b/test/test_decomp.py
index d69d727..73f8c7a 100644
--- a/test/test_decomp.py
+++ b/test/test_decomp.py
@@ -159,6 +159,10 @@
(torch.bfloat16, torch.ops.aten.native_layer_norm_backward.default): 2e-2,
(torch.bfloat16, torch.ops.aten.native_batch_norm.default): 1e-5,
(torch.float16, torch.ops.aten.native_batch_norm.default): 1e-5,
+ (torch.bfloat16, torch.ops.aten._native_batch_norm_legit.default): 1e-5,
+ (torch.bfloat16, torch.ops.aten._native_batch_norm_legit.no_stats): 1e-5,
+ (torch.float16, torch.ops.aten._native_batch_norm_legit.default): 1e-5,
+ (torch.float16, torch.ops.aten._native_batch_norm_legit.no_stats): 1e-5,
(torch.bfloat16, torch.ops.aten.linalg_vector_norm.default): 1e-5,
(torch.float16, torch.ops.aten.linalg_vector_norm.default): 1e-5,
(torch.float16, torch.ops.aten.nll_loss_forward.default): 1e-2,
@@ -306,6 +310,8 @@
# _softmax_backward_data's CPU kernel for bfloat16 always return the grad_input as float32
("cpu", torch.bfloat16, "_softmax_backward_data"),
(None, None, "norm"),
+ # native_batch_norm is only implicit when python dispatcher is on (and noncomposite otherwise)
+ (None, None, "native_batch_norm"),
}
CROSS_REF_BACKWARD_EXCLUDE_SET = {
diff --git a/test/test_functionalization.py b/test/test_functionalization.py
index 0731cae..d699c03 100644
--- a/test/test_functionalization.py
+++ b/test/test_functionalization.py
@@ -3,14 +3,14 @@
import torch
from contextlib import nullcontext
from torch.testing._internal.common_utils import (
- TestCase, run_tests, skipIfTorchDynamo, TEST_WITH_TORCHDYNAMO,
+ TestCase, run_tests, skipIfTorchDynamo, TEST_WITH_TORCHDYNAMO, IS_WINDOWS,
xfail_inherited_tests
)
from torch.testing._internal.logging_tensor import LoggingTensor, capture_logs
from torch.utils._pytree import tree_map
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.passes.reinplace import reinplace
-from torch._dispatch.python import enable_crossref_functionalize
+from torch._dispatch.python import enable_crossref_functionalize, enable_python_dispatcher
import unittest
@@ -1228,6 +1228,205 @@
return zeros
""")
+
+ def test_instance_norm(self):
+ def f(x):
+ with enable_python_dispatcher():
+ return torch.instance_norm(x, None, None, running_mean=torch.zeros(100), running_var=torch.ones(100),
+ use_input_stats=True, momentum=0.1, eps=1e-5, cudnn_enabled=False)
+ self.assert_functionalization(f, torch.randn(20, 100, 35, 45))
+ # On Windows, for instance_norm, the alias_copy's are reordered to come right before they need to be used
+ # whereas on other platforms, the alias_copy's are before the view_copy's.
+ # e.g., the alias_copy after the getitem_4 assignment would be moved to be right before the copy assignment.
+ if not IS_WINDOWS:
+ logs = self.get_logs(f, torch.randn(20, 100, 35, 45))
+ self.assertExpectedInline(logs, """\
+
+
+
+def forward(self, a_1):
+ zeros = torch.ops.aten.zeros.default([100], device = device(type='cpu'), pin_memory = False)
+ ones = torch.ops.aten.ones.default([100], device = device(type='cpu'), pin_memory = False)
+ repeat = torch.ops.aten.repeat.default(zeros, [20])
+ repeat_1 = torch.ops.aten.repeat.default(ones, [20])
+ view_copy = torch.ops.aten.view_copy.default(a_1, [1, 2000, 35, 45]); a_1 = None
+ empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu'))
+ _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(view_copy, None, None, repeat, repeat_1, True, 0.1, 1e-05); view_copy = repeat = repeat_1 = None
+ getitem = _native_batch_norm_legit_functional[0]
+ getitem_1 = _native_batch_norm_legit_functional[1]
+ getitem_2 = _native_batch_norm_legit_functional[2]
+ getitem_3 = _native_batch_norm_legit_functional[3]
+ getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
+ alias_copy = torch.ops.aten.alias_copy.default(zeros); zeros = None
+ view_copy_1 = torch.ops.aten.view_copy.default(getitem_3, [20, 100])
+ view_copy_2 = torch.ops.aten.view_copy.default(getitem_3, [20, 100]); getitem_3 = None
+ mean = torch.ops.aten.mean.dim(view_copy_2, [0]); view_copy_2 = None
+ copy = torch.ops.aten.copy.default(alias_copy, mean); alias_copy = mean = None
+ alias_copy_1 = torch.ops.aten.alias_copy.default(ones); ones = None
+ view_copy_3 = torch.ops.aten.view_copy.default(getitem_4, [20, 100])
+ view_copy_4 = torch.ops.aten.view_copy.default(getitem_4, [20, 100]); getitem_4 = None
+ mean_1 = torch.ops.aten.mean.dim(view_copy_4, [0]); view_copy_4 = None
+ copy_1 = torch.ops.aten.copy.default(alias_copy_1, mean_1); alias_copy_1 = mean_1 = None
+ view_copy_5 = torch.ops.aten.view_copy.default(getitem, [20, 100, 35, 45]); getitem = None
+ return view_copy_5
+ """) # noqa: B950
+
+ reinplaced_logs = self.get_logs(f, torch.randn(20, 100, 35, 45), reapply_views=True, run_reinplace=True)
+ self.assertExpectedInline(reinplaced_logs, """\
+
+
+
+def forward(self, a_1):
+ zeros = torch.ops.aten.zeros.default([100], device = device(type='cpu'), pin_memory = False)
+ ones = torch.ops.aten.ones.default([100], device = device(type='cpu'), pin_memory = False)
+ repeat = torch.ops.aten.repeat.default(zeros, [20])
+ repeat_1 = torch.ops.aten.repeat.default(ones, [20])
+ view = torch.ops.aten.view.default(a_1, [1, 2000, 35, 45]); a_1 = None
+ empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu'))
+ _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(view, None, None, repeat, repeat_1, True, 0.1, 1e-05); view = repeat = repeat_1 = None
+ getitem = _native_batch_norm_legit_functional[0]
+ getitem_1 = _native_batch_norm_legit_functional[1]
+ getitem_2 = _native_batch_norm_legit_functional[2]
+ getitem_3 = _native_batch_norm_legit_functional[3]
+ getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
+ alias = torch.ops.aten.alias.default(zeros); zeros = None
+ view_1 = torch.ops.aten.view.default(getitem_3, [20, 100])
+ view_2 = torch.ops.aten.view.default(getitem_3, [20, 100]); getitem_3 = None
+ mean = torch.ops.aten.mean.dim(view_2, [0]); view_2 = None
+ copy = torch.ops.aten.copy_.default(alias, mean); alias = mean = None
+ alias_1 = torch.ops.aten.alias.default(ones); ones = None
+ view_3 = torch.ops.aten.view.default(getitem_4, [20, 100])
+ view_4 = torch.ops.aten.view.default(getitem_4, [20, 100]); getitem_4 = None
+ mean_1 = torch.ops.aten.mean.dim(view_4, [0]); view_4 = None
+ copy_1 = torch.ops.aten.copy_.default(alias_1, mean_1); alias_1 = mean_1 = None
+ view_5 = torch.ops.aten.view.default(getitem, [20, 100, 35, 45]); getitem = None
+ return view_5
+ """) # noqa: B950
+
+
+ def test_instance_norm_running_mean_is_x(self):
+ def f(x):
+ with enable_python_dispatcher():
+ return torch.instance_norm(torch.randn(20, 100, 35, 45), None, None, running_mean=x, running_var=torch.ones(100),
+ use_input_stats=True, momentum=0.1, eps=1e-5, cudnn_enabled=False)
+ # TODO: uncomment following line after functionalization can handle input mutations
+ # self.assert_functionalization(f, torch.zeros(100))
+ logs = self.get_logs(f, torch.zeros(100))
+ # On Windows, for instance_norm, the alias_copy's are reordered to come right before they need to be used
+ # whereas on other platforms, the alias_copy's are before the view_copy's.
+ # e.g., the alias_copy after the getitem_4 assignment would be moved to be right before the copy assignment.
+ if not IS_WINDOWS:
+ self.assertExpectedInline(logs, """\
+
+
+
+def forward(self, a_1):
+ randn = torch.ops.aten.randn.default([20, 100, 35, 45], device = device(type='cpu'), pin_memory = False)
+ ones = torch.ops.aten.ones.default([100], device = device(type='cpu'), pin_memory = False)
+ repeat = torch.ops.aten.repeat.default(a_1, [20])
+ repeat_1 = torch.ops.aten.repeat.default(ones, [20])
+ view_copy = torch.ops.aten.view_copy.default(randn, [1, 2000, 35, 45]); randn = None
+ empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu'))
+ _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(view_copy, None, None, repeat, repeat_1, True, 0.1, 1e-05); view_copy = repeat = repeat_1 = None
+ getitem = _native_batch_norm_legit_functional[0]
+ getitem_1 = _native_batch_norm_legit_functional[1]
+ getitem_2 = _native_batch_norm_legit_functional[2]
+ getitem_3 = _native_batch_norm_legit_functional[3]
+ getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
+ alias_copy = torch.ops.aten.alias_copy.default(a_1)
+ view_copy_1 = torch.ops.aten.view_copy.default(getitem_3, [20, 100])
+ view_copy_2 = torch.ops.aten.view_copy.default(getitem_3, [20, 100]); getitem_3 = None
+ mean = torch.ops.aten.mean.dim(view_copy_2, [0]); view_copy_2 = None
+ copy = torch.ops.aten.copy.default(alias_copy, mean); alias_copy = mean = None
+ alias_copy_1 = torch.ops.aten.alias_copy.default(ones); ones = None
+ view_copy_3 = torch.ops.aten.view_copy.default(getitem_4, [20, 100])
+ view_copy_4 = torch.ops.aten.view_copy.default(getitem_4, [20, 100]); getitem_4 = None
+ mean_1 = torch.ops.aten.mean.dim(view_copy_4, [0]); view_copy_4 = None
+ copy_1 = torch.ops.aten.copy.default(alias_copy_1, mean_1); alias_copy_1 = mean_1 = None
+ view_copy_5 = torch.ops.aten.view_copy.default(getitem, [20, 100, 35, 45]); getitem = None
+ alias_copy_2 = torch.ops.aten.alias_copy.default(copy); copy = None
+ copy_ = torch.ops.aten.copy_.default(a_1, alias_copy_2); a_1 = alias_copy_2 = None
+ return view_copy_5
+ """) # noqa: B950
+
+ reinplaced_logs = self.get_logs(f, torch.zeros(100), reapply_views=True, run_reinplace=True)
+ self.assertExpectedInline(reinplaced_logs, """\
+
+
+
+def forward(self, a_1):
+ randn = torch.ops.aten.randn.default([20, 100, 35, 45], device = device(type='cpu'), pin_memory = False)
+ ones = torch.ops.aten.ones.default([100], device = device(type='cpu'), pin_memory = False)
+ repeat = torch.ops.aten.repeat.default(a_1, [20])
+ repeat_1 = torch.ops.aten.repeat.default(ones, [20])
+ view = torch.ops.aten.view.default(randn, [1, 2000, 35, 45]); randn = None
+ empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu'))
+ _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(view, None, None, repeat, repeat_1, True, 0.1, 1e-05); view = repeat = repeat_1 = None
+ getitem = _native_batch_norm_legit_functional[0]
+ getitem_1 = _native_batch_norm_legit_functional[1]
+ getitem_2 = _native_batch_norm_legit_functional[2]
+ getitem_3 = _native_batch_norm_legit_functional[3]
+ getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
+ alias = torch.ops.aten.alias.default(a_1)
+ view_1 = torch.ops.aten.view.default(getitem_3, [20, 100])
+ view_2 = torch.ops.aten.view.default(getitem_3, [20, 100]); getitem_3 = None
+ mean = torch.ops.aten.mean.dim(view_2, [0]); view_2 = None
+ copy = torch.ops.aten.copy.default(alias, mean); alias = mean = None
+ alias_1 = torch.ops.aten.alias.default(ones); ones = None
+ view_3 = torch.ops.aten.view.default(getitem_4, [20, 100])
+ view_4 = torch.ops.aten.view.default(getitem_4, [20, 100]); getitem_4 = None
+ mean_1 = torch.ops.aten.mean.dim(view_4, [0]); view_4 = None
+ copy_1 = torch.ops.aten.copy_.default(alias_1, mean_1); alias_1 = mean_1 = None
+ view_5 = torch.ops.aten.view.default(getitem, [20, 100, 35, 45]); getitem = None
+ alias_2 = torch.ops.aten.alias.default(copy); copy = None
+ copy_ = torch.ops.aten.copy_.default(a_1, alias_2); a_1 = alias_2 = None
+ return view_5
+ """) # noqa: B950
+
+
+ def test_batch_norm(self):
+ def f(x):
+ with enable_python_dispatcher():
+ return torch.batch_norm(x, None, None, torch.zeros(100), torch.ones(100), False, 0.1, 1e-5, False)
+
+ self.assert_functionalization(f, torch.randn(20, 100, 35, 45))
+ logs = self.get_logs(f, torch.randn(20, 100, 35, 45))
+ self.assertExpectedInline(logs, """\
+
+
+
+def forward(self, a_1):
+ zeros = torch.ops.aten.zeros.default([100], device = device(type='cpu'), pin_memory = False)
+ ones = torch.ops.aten.ones.default([100], device = device(type='cpu'), pin_memory = False)
+ empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu'))
+ _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(a_1, None, None, zeros, ones, False, 0.1, 1e-05); a_1 = zeros = ones = None
+ getitem = _native_batch_norm_legit_functional[0]
+ getitem_1 = _native_batch_norm_legit_functional[1]
+ getitem_2 = _native_batch_norm_legit_functional[2]
+ getitem_3 = _native_batch_norm_legit_functional[3]
+ getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
+ return getitem
+ """) # noqa: B950
+
+ reinplaced_logs = self.get_logs(f, torch.randn(20, 100, 35, 45), reapply_views=True, run_reinplace=True)
+ self.assertExpectedInline(reinplaced_logs, """\
+
+
+
+def forward(self, a_1):
+ zeros = torch.ops.aten.zeros.default([100], device = device(type='cpu'), pin_memory = False)
+ ones = torch.ops.aten.ones.default([100], device = device(type='cpu'), pin_memory = False)
+ empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu'))
+ _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(a_1, None, None, zeros, ones, False, 0.1, 1e-05); a_1 = zeros = ones = None
+ getitem = _native_batch_norm_legit_functional[0]
+ getitem_1 = _native_batch_norm_legit_functional[1]
+ getitem_2 = _native_batch_norm_legit_functional[2]
+ getitem_3 = _native_batch_norm_legit_functional[3]
+ getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
+ return getitem
+ """) # noqa: B950
+
+
@xfail_inherited_tests([
"test_as_strided",
"test_copy_",
diff --git a/test/test_jit_cuda_fuser.py b/test/test_jit_cuda_fuser.py
index 7de6f17..0a13fdb 100644
--- a/test/test_jit_cuda_fuser.py
+++ b/test/test_jit_cuda_fuser.py
@@ -143,7 +143,7 @@
disabled_ops = ("aten::batch_norm",
"aten::_batch_norm_impl_index",
"aten::_batch_norm_impl_index_backward",
- "aten::native_batch_norm_backward")
+ "aten::native_batch_norm_backward",)
for op in disabled_ops:
disabled_flag = torch._C._jit_set_nvfuser_skip_node_kind(op, False)
if disabled_flag:
diff --git a/test/test_meta.py b/test/test_meta.py
index 0e3cfb6..af81d14 100644
--- a/test/test_meta.py
+++ b/test/test_meta.py
@@ -710,6 +710,7 @@
meta_function_device_expected_failures['cpu'] = {
torch.native_batch_norm: {bf16},
+ torch._native_batch_norm_legit: {bf16},
torch.native_layer_norm: {bf16},
}
@@ -744,6 +745,7 @@
meta_function_device_skips['cpu'] = {
torch.native_batch_norm: {f32, f64},
+ torch._native_batch_norm_legit: {f32, f64},
}
meta_function_device_skips['cuda'] = {
@@ -927,6 +929,8 @@
meta_dispatch_device_expected_failures['cpu'] = {
aten.native_batch_norm.default: {bf16},
+ aten._native_batch_norm_legit.default: {bf16},
+ aten._native_batch_norm_legit.no_stats: {bf16},
aten.native_layer_norm.default: {bf16},
}
@@ -972,6 +976,8 @@
meta_dispatch_device_skips['cpu'] = {
aten._embedding_bag_forward_only.default: {f16, f32, f64},
aten.native_batch_norm.default: {f32, f64},
+ aten._native_batch_norm_legit.default: {f32, f64},
+ aten._native_batch_norm_legit.no_stats: {f32, f64},
}
meta_dispatch_device_skips['cuda'] = {
diff --git a/test/test_ops.py b/test/test_ops.py
index 7e0a995..62d4403 100644
--- a/test/test_ops.py
+++ b/test/test_ops.py
@@ -1574,7 +1574,7 @@
def check_inplace_view(func, input, rs, input_size, input_strides):
if func is None:
return
- # TODO: extend this test to test ops with multiple outputs and ops like native_batch_norm.out
+ # TODO: extend this test to test ops with multiple outputs and ops like native_batch_norm(_legit).out
# which mutate not necessarily the first input.
if isinstance(rs, torch.Tensor) and rs is input:
unequal_size = rs.size() != input_size
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 6e1b456..3500bd2 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -1144,6 +1144,14 @@
input, weight, bias: "grad.defined() ? native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, eps, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, eps)
+- name: _native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)
+ input, weight, bias: "grad.defined() ? native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, eps, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
+ result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, eps)
+
+- name: _native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)
+ input, weight, bias: "grad.defined() ? native_batch_norm_backward(grad, input, weight, Tensor(), Tensor(), result1, result2, training, eps, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
+ result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, Tensor(), Tensor(), result1, result2, training, eps)
+
- name: native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
input, weight, grad_out: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, running_mean, running_var, train, eps, save_mean, save_invstd, grad_input_mask)
save_mean: not_implemented("native_batch_norm_backward save_mean")
diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py
index 57e068d..b9c9225 100644
--- a/torch/_decomp/decompositions.py
+++ b/torch/_decomp/decompositions.py
@@ -1312,8 +1312,7 @@
)
-@register_decomposition(aten.native_batch_norm)
-def native_batch_norm(
+def native_batch_norm_helper(
input: Tensor,
weight: Optional[Tensor],
bias: Optional[Tensor],
@@ -1322,16 +1321,21 @@
training: bool,
momentum: float,
eps: float,
-) -> Tuple[Tensor, Tensor, Tensor]:
+ functional: bool,
+) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
reduction_dims = [0] + list(range(2, input.dim()))
computation_dtype = utils.get_computation_dtype(input.dtype)
+ new_running_mean = running_mean
+ new_running_var = running_var
if training:
output, mean, rstd = normalize(input, reduction_dims, eps)
save_mean = _squeeze_multiple(mean, reduction_dims)
save_rstd = _squeeze_multiple(rstd, reduction_dims)
if running_mean is not None:
- running_mean.copy_(momentum * save_mean + (1 - momentum) * running_mean)
+ new_running_mean = momentum * save_mean + (1 - momentum) * running_mean
+ if not functional:
+ running_mean.copy_(new_running_mean)
if running_var is not None:
n = input.numel() / input.shape[1]
# This doesn't strictly match eager's numerics, which accumulates var sum and then directly applies the correction
@@ -1340,11 +1344,15 @@
unbiased_var = torch.var(input, reduction_dims, unbiased=False) * (
n / (n - 1)
)
- running_var.copy_(momentum * unbiased_var + (1 - momentum) * running_var)
+ new_running_var = momentum * unbiased_var + (1 - momentum) * running_var
+ if not functional:
+ running_var.copy_(new_running_var)
else:
assert running_mean is not None and running_var is not None
running_mean = running_mean.to(dtype=computation_dtype, copy=True)
+ new_running_mean = running_mean
running_var = running_var.to(dtype=computation_dtype, copy=True)
+ new_running_var = running_var
mean = running_mean
invstd = 1 / (torch.sqrt(running_var + eps))
# Very annoying inconsistency where CPU and CUDA give different shapes
@@ -1370,7 +1378,127 @@
if input.device.type == "cpu":
save_mean = save_mean.to(dtype=input.dtype)
save_rstd = save_rstd.to(dtype=input.dtype)
- return output.to(dtype=input.dtype), save_mean, save_rstd
+ return (
+ output.to(dtype=input.dtype),
+ save_mean,
+ save_rstd,
+ new_running_mean,
+ new_running_var,
+ )
+
+
+@register_decomposition(aten.native_batch_norm)
+def native_batch_norm(
+ input: Tensor,
+ weight: Optional[Tensor],
+ bias: Optional[Tensor],
+ running_mean: Optional[Tensor],
+ running_var: Optional[Tensor],
+ training: bool,
+ momentum: float,
+ eps: float,
+) -> Tuple[Tensor, Tensor, Tensor]:
+ output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
+ input, weight, bias, running_mean, running_var, training, momentum, eps, False
+ )
+ return output, save_mean, save_rstd
+
+
+# TODO: this decomposition is NOT here to stay. We would much prefer replacing native_batch_norm
+# with our new correctly schema'd _native_batch_norm_legit and its variants, but
+# we cannot do that immediately in the C++ because it would be forwards incompatible
+# with some mobile use cases.
+#
+# Since this change is most impactful for aot autograd/functionalization, we simply
+# register this decomposition on the Autograd key for the python dispatcher (which is
+# currently only used by aot autograd/functionalization and no one else, really).
+# In two weeks or so, we should remove this decomposition and phase out the current native_batch_norm
+# to be _native_batch_norm_legit and have the right schema (stating that there are input mutations).
+@torch.ops.aten.native_batch_norm.default.py_impl(DispatchKey.Autograd)
+def native_batch_norm_decomposition(
+ input: Tensor,
+ weight: Optional[Tensor],
+ bias: Optional[Tensor],
+ running_mean: Optional[Tensor],
+ running_var: Optional[Tensor],
+ training: bool,
+ momentum: float,
+ eps: float,
+) -> Tuple[Tensor, Tensor, Tensor]:
+ if running_mean is None and running_var is None:
+ return aten._native_batch_norm_legit(
+ input, weight, bias, training, momentum, eps
+ )
+ if running_mean is None:
+ raise RuntimeError(
+ "running_mean is None, but running_var is provided. "
+ "They should both be None or both be provided."
+ )
+ if running_var is None:
+ raise RuntimeError(
+ "running_var is None, but running_mean is provided. "
+ "They should both be None or both be provided."
+ )
+ return aten._native_batch_norm_legit(
+ input, weight, bias, running_mean, running_var, training, momentum, eps
+ )
+
+
+@register_decomposition(aten._native_batch_norm_legit.default)
+def _native_batch_norm_legit(
+ input: Tensor,
+ weight: Optional[Tensor],
+ bias: Optional[Tensor],
+ running_mean: Tensor,
+ running_var: Tensor,
+ training: bool,
+ momentum: float,
+ eps: float,
+) -> Tuple[Tensor, Tensor, Tensor]:
+ output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
+ input, weight, bias, running_mean, running_var, training, momentum, eps, False
+ )
+ return output, save_mean, save_rstd
+
+
+@register_decomposition(aten._native_batch_norm_legit.no_stats)
+def _native_batch_norm_legit_no_stats(
+ input: Tensor,
+ weight: Optional[Tensor],
+ bias: Optional[Tensor],
+ training: bool,
+ momentum: float,
+ eps: float,
+) -> Tuple[Tensor, Tensor, Tensor]:
+ output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
+ input, weight, bias, None, None, training, momentum, eps, False
+ )
+ return output, save_mean, save_rstd
+
+
+@register_decomposition(aten._native_batch_norm_legit_functional.default)
+def _native_batch_norm_legit_functional(
+ input: Tensor,
+ weight: Optional[Tensor],
+ bias: Optional[Tensor],
+ running_mean: Tensor,
+ running_var: Tensor,
+ training: bool,
+ momentum: float,
+ eps: float,
+) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
+ (
+ output,
+ save_mean,
+ save_rstd,
+ new_running_mean,
+ new_running_var,
+ ) = native_batch_norm_helper(
+ input, weight, bias, running_mean, running_var, training, momentum, eps, True
+ )
+ assert new_running_mean is not None, "new_running_mean should not be None"
+ assert new_running_var is not None, "new_running_var should not be None"
+ return output, save_mean, save_rstd, new_running_mean, new_running_var
@register_decomposition(aten._fused_dropout)
diff --git a/torch/jit/_shape_functions.py b/torch/jit/_shape_functions.py
index ffb4463..5eace95 100644
--- a/torch/jit/_shape_functions.py
+++ b/torch/jit/_shape_functions.py
@@ -1091,6 +1091,8 @@
add_shape_compute_mapping("aten::nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> (Tensor output, Tensor total_weight)", nll_loss_forward)
add_shape_compute_mapping("aten::native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)", native_layer_norm)
add_shape_compute_mapping("aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", native_batch_norm)
+add_shape_compute_mapping("aten::_native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", native_batch_norm)
+add_shape_compute_mapping("aten::_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", native_batch_norm)
# add_shape_compute_mapping("aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor", index_Tensor)
# TODO: migrate over all of symbolic_shape_registry_util.cpp
diff --git a/torch/overrides.py b/torch/overrides.py
index cb44022..21cfe24 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -716,6 +716,7 @@
torch.narrow_copy: lambda input, dim, start, length: -1,
torch.nan_to_num: lambda input, nan=0.0, posinf=None, neginf=None, out=None: -1,
torch.native_batch_norm: lambda input, weight, bias, running_mean, running_var, training, momentum, eps: -1,
+ torch._native_batch_norm_legit: lambda input, weight, bias, training, momentum, eps: -1,
torch.native_dropout: lambda input, p, train: -1,
torch.native_layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
torch.native_group_norm: lambda input, weight, bias, N, C, HxW, group, eps: -1,
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 177dc66..f20f3a6 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -439,6 +439,23 @@
yield SampleInput(sample.input, args=(args[2], args[3], args[0], args[1], training, momentum, eps))
+def sample_inputs__native_batch_norm_legit(op_info, device, dtype, requires_grad, **kwargs):
+ samples = sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs)
+ for sample in samples:
+ # torch.native_batch_norm does not support 0 numel tensors
+ # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
+ if sample.input.numel() == 0:
+ continue
+ args = sample.args
+ training = sample.kwargs.get('training', True)
+ momentum = sample.kwargs.get('momentum', 0.5)
+ eps = sample.kwargs.get('eps', 1e-5)
+ if args[0] is not None and args[1] is not None:
+ yield SampleInput(sample.input, args=(args[2], args[3], args[0], args[1], training, momentum, eps))
+ else:
+ yield SampleInput(sample.input, args=(args[2], args[3], training, momentum, eps))
+
+
def sample_inputs_nn_activation_relu(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
@@ -10779,6 +10796,34 @@
DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo', 'test_nvfuser_extremal_values'),
)
),
+ OpInfo('_native_batch_norm_legit',
+ aten_name='_native_batch_norm_legit',
+ dtypes=floating_types_and(torch.bfloat16),
+ dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+ supports_forward_ad=True,
+ supports_fwgrad_bwgrad=True,
+ assert_jit_shape_analysis=True,
+ sample_inputs_func=sample_inputs__native_batch_norm_legit,
+ skips=(
+ # NotImplementedError: Could not run
+ # 'aten::native_batch_norm.out' with arguments from the 'CPU' backend.
+ DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type="cpu"),
+ # RuntimeError: out_invstd.dim() == 1 && out_invstd.is_contiguous() && out_invstd.sizes()[0]
+ DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type="cuda"),
+ # Problem with _get_numerical_jacobian
+ # IndexError: tuple index out of range
+ DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
+ # RuntimeError: deepEquals(input.iValue, deepCopiedInput) INTERNAL ASSERT FAILED
+ DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+ # https://github.com/pytorch/pytorch/issues/85960
+ DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'),
+ DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-5, rtol=5e-5)}),
+ "TestCompositeCompliance", "test_forward_ad"),
+ # Extremal value issue on aten::native_batch_norm, which returns 'nan' for mean on 'inf' inputs
+ # possibly because of the welford implementation.
+ DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo', 'test_nvfuser_extremal_values'),
+ )
+ ),
OpInfo('nn.functional.cosine_similarity',
aten_name="cosine_similarity",
dtypes=floating_types_and(torch.bfloat16),