[pytorch] make is_tracing scriptable (#49853)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49853
fix https://github.com/pytorch/pytorch/issues/47379
Test Plan: buck test mode/dev-nosan //caffe2/test:jit -- 'test_script_is_tracing'
Reviewed By: SplitInfinity
Differential Revision: D25704315
fbshipit-source-id: 33c09c5bc1f1b62ef254f58e18ab1e951dbd1790
diff --git a/test/test_jit.py b/test/test_jit.py
index f31f7a5..9c09c6f 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -9569,9 +9569,18 @@
return x - 1
inp = torch.randn(3, 3)
-
self.checkScript(test_if_tracing, (inp,))
+ def test_script_is_tracing(self):
+ def test_is_tracing(x):
+ if torch.jit.is_tracing():
+ return x + 1
+ else:
+ return x - 1
+
+ inp = torch.randn(3, 3)
+ self.checkScript(test_is_tracing, (inp,))
+
def test_is_scripting(self):
def foo():
return torch.jit.is_scripting()
diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py
index 54e5fa5..c84a1ad 100644
--- a/torch/jit/_trace.py
+++ b/torch/jit/_trace.py
@@ -20,7 +20,7 @@
from torch.jit._state import _python_cu, _enabled
from torch.jit._script import ScriptModule, _CachedForward, script
-from torch._jit_internal import _qualified_name, get_callable_argument_names
+from torch._jit_internal import _qualified_name, is_scripting, get_callable_argument_names
from torch.autograd import function
from torch.nn import Module
@@ -993,6 +993,8 @@
Returns ``True`` in tracing (if a function is called during the tracing of
code with ``torch.jit.trace``) and ``False`` otherwise.
"""
+ if is_scripting():
+ return False
return torch._C._is_tracing()