[pytorch] use correct warning type for tracer warnings (#53460)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53460
We have code to ignore this category of warnings and found this one is incorrect.
Use `stacklevel=2`, otherwise the warning is always filtered by TracerWarning.ignore_lib_warnings()
Test Plan: sandcastle
Reviewed By: wanchaol
Differential Revision: D26867290
fbshipit-source-id: cda1bc74a28d5965d52387d5ea2c4dcd1a2b1e86
diff --git a/test/jit/test_tracer.py b/test/jit/test_tracer.py
index 9c26ceb..1ad99c8 100644
--- a/test/jit/test_tracer.py
+++ b/test/jit/test_tracer.py
@@ -515,6 +515,8 @@
with warnings.catch_warnings(record=True) as warns:
traced_fn = torch.jit.trace(fn, torch.tensor([1]))
+ for warn in warns:
+ self.assertIs(warn.category, torch.jit.TracerWarning)
warns = [str(w.message) for w in warns]
self.assertIn('a Python integer', warns[0])
self.assertIn('a Python boolean', warns[1])
diff --git a/torch/tensor.py b/torch/tensor.py
index e2c9e0d..8243795 100644
--- a/torch/tensor.py
+++ b/torch/tensor.py
@@ -593,7 +593,7 @@
warnings.warn('Iterating over a tensor might cause the trace to be incorrect. '
'Passing a tensor of different shape won\'t change the number of '
'iterations executed (and might lead to errors or silently give '
- 'incorrect results).', category=RuntimeWarning)
+ 'incorrect results).', category=torch.jit.TracerWarning, stacklevel=2)
return iter(self.unbind(0))
def __hash__(self):