rename torch.Assert to torch._assert (#47763)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47763
Changing the name due to the discussion in
https://github.com/pytorch/pytorch/pull/47399.
Test Plan:
```
python test/test_utils.py TestAssert.test_assert_true
python test/test_fx.py TestFX.test_symbolic_trace_assert
python test/test_fx_experimental.py
```
Imported from OSS
Reviewed By: ezyang
Differential Revision: D24891767
fbshipit-source-id: 01c7a5acd83bf9c962751552780930c242134dd2
diff --git a/test/test_fx.py b/test/test_fx.py
index 4d3a862..f1425a7 100644
--- a/test/test_fx.py
+++ b/test/test_fx.py
@@ -750,7 +750,7 @@
class AssertsTensorShape(torch.nn.Module):
def forward(self, x):
- torch.Assert(x.shape[1] > 4, message)
+ torch._assert(x.shape[1] > 4, message)
return x
m = AssertsTensorShape()
diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py
index 11a3ea2..dd67908 100644
--- a/test/test_fx_experimental.py
+++ b/test/test_fx_experimental.py
@@ -392,7 +392,7 @@
# Check the IR to make sure there's a call_function node with target == "Assert"
self.assertTrue(
any(
- node.op == "call_function" and node.target == torch.Assert
+ node.op == "call_function" and node.target == torch._assert
for node in traced.graph.nodes
)
)
@@ -420,7 +420,7 @@
# Check the IR to make sure there's a call_function node with target == "Assert"
self.assertTrue(
any(
- node.op == "call_function" and node.target == torch.Assert
+ node.op == "call_function" and node.target == torch._assert
for node in traced.graph.nodes
)
)
@@ -448,7 +448,7 @@
# Check the IR to make sure there's a call_function node with target == "Assert"
self.assertTrue(
any(
- node.op == "call_function" and node.target == torch.Assert
+ node.op == "call_function" and node.target == torch._assert
for node in traced.graph.nodes
)
)
@@ -480,7 +480,7 @@
# Check the IR to make sure there's a call_function node with target == "Assert"
self.assertTrue(
any(
- node.op == "call_function" and node.target == torch.Assert
+ node.op == "call_function" and node.target == torch._assert
for node in traced.graph.nodes
)
)
diff --git a/test/test_utils.py b/test/test_utils.py
index ad0d508..63e7d58 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -637,9 +637,9 @@
class TestAssert(TestCase):
def test_assert_true(self):
# verify assertions work as expected
- torch.Assert(True, "foo")
+ torch._assert(True, "foo")
with self.assertRaisesRegex(AssertionError, "bar"):
- torch.Assert(False, "bar")
+ torch._assert(False, "bar")
if __name__ == '__main__':
diff --git a/torch/__init__.py b/torch/__init__.py
index 4d84b2c..24a770f 100644
--- a/torch/__init__.py
+++ b/torch/__init__.py
@@ -620,11 +620,11 @@
quantized_gru = torch.ops.aten.quantized_gru
-def Assert(condition, message):
+def _assert(condition, message):
r"""A wrapper around Python's assert which is symbolically traceable.
"""
from .overrides import has_torch_function, handle_torch_function
if type(condition) is not torch.Tensor and has_torch_function((condition,)):
- return handle_torch_function(Assert, (condition,), condition, message)
+ return handle_torch_function(_assert, (condition,), condition, message)
assert condition, message
diff --git a/torch/fx/experimental/rewriter.py b/torch/fx/experimental/rewriter.py
index 215abc5..93c10d2 100644
--- a/torch/fx/experimental/rewriter.py
+++ b/torch/fx/experimental/rewriter.py
@@ -47,10 +47,10 @@
def visit_Assert(self, node):
"""
Swap out the Assert node (Python's `assert`) with a callsite to the
- symbolically-traceable torch.Assert function
+ symbolically-traceable torch._assert function
"""
# Create the Call node
- n = ast.parse('torch.Assert()', mode='eval')
+ n = ast.parse('torch._assert()', mode='eval')
assert isinstance(n, ast.Expression)
call_node = n.body
assert isinstance(call_node, ast.Call)
@@ -61,7 +61,7 @@
expr_wrapper = ast.Expr(value=call_node)
# Return the new Call node to signify that we want to use it as
- # a replacement for the original Assert node
+ # a replacement for the original _assert node
return ast.copy_location(expr_wrapper, node)