Fix weight_norm export for dim=0 (#31015)

Summary:
Exported weight_norm is incorrectly reducing over axis 0 as well when dim is set to 0.
Previous test case only covers weight with size(0) == 1, which yields the same result whether reduced over or not.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31015

Reviewed By: hl475

Differential Revision: D18900894

Pulled By: houseroad

fbshipit-source-id: 19004f51933b37f848dbe4138e617a7a8e35a9ec
diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py
index 00d48ef..5032846 100644
--- a/test/onnx/test_pytorch_onnx_onnxruntime.py
+++ b/test/onnx/test_pytorch_onnx_onnxruntime.py
@@ -1610,6 +1610,10 @@
         x = torch.randn(1, 1, 5, requires_grad=True)
         self.run_test(model, x)
 
+        model = torch.nn.utils.weight_norm(torch.nn.Conv1d(3, 6, 3), name='weight')
+        x = torch.randn(3, 3, 5, requires_grad=True)
+        self.run_test(model, x)
+
     def test_weight_norm_nodim(self):
         model = torch.nn.utils.weight_norm(torch.nn.Linear(5, 10), dim=None)
         x = torch.randn(3, 4, 5, requires_grad=True)
diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py
index bcca3c9..c00e8e7 100644
--- a/torch/onnx/symbolic_opset9.py
+++ b/torch/onnx/symbolic_opset9.py
@@ -2153,7 +2153,7 @@
         # This conflicts the logic for negative axes to access dims backwards
         # TODO: Might need a fix in torch group_norm module
         axes = list(range(rank))
-        if dim:
+        if dim is not None:
             if dim < -1:
                 dim += rank
             if dim != -1: