Fix weight_norm export for dim=0 (#31015)
Summary:
Exported weight_norm is incorrectly reducing over axis 0 as well when dim is set to 0.
Previous test case only covers weight with size(0) == 1, which yields the same result whether reduced over or not.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31015
Reviewed By: hl475
Differential Revision: D18900894
Pulled By: houseroad
fbshipit-source-id: 19004f51933b37f848dbe4138e617a7a8e35a9ec
diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py
index 00d48ef..5032846 100644
--- a/test/onnx/test_pytorch_onnx_onnxruntime.py
+++ b/test/onnx/test_pytorch_onnx_onnxruntime.py
@@ -1610,6 +1610,10 @@
x = torch.randn(1, 1, 5, requires_grad=True)
self.run_test(model, x)
+ model = torch.nn.utils.weight_norm(torch.nn.Conv1d(3, 6, 3), name='weight')
+ x = torch.randn(3, 3, 5, requires_grad=True)
+ self.run_test(model, x)
+
def test_weight_norm_nodim(self):
model = torch.nn.utils.weight_norm(torch.nn.Linear(5, 10), dim=None)
x = torch.randn(3, 4, 5, requires_grad=True)
diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py
index bcca3c9..c00e8e7 100644
--- a/torch/onnx/symbolic_opset9.py
+++ b/torch/onnx/symbolic_opset9.py
@@ -2153,7 +2153,7 @@
# This conflicts the logic for negative axes to access dims backwards
# TODO: Might need a fix in torch group_norm module
axes = list(range(rank))
- if dim:
+ if dim is not None:
if dim < -1:
dim += rank
if dim != -1: