Implicit conversion from null tensor to NoneType (#55823)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55823
Test Plan: Imported from OSS
Reviewed By: gmagogsfm
Differential Revision: D27717324
Pulled By: tugsbayasgalan
fbshipit-source-id: a071b90bcea9e8f2b5da633a8dadd11772fb5101
diff --git a/test/test_jit.py b/test/test_jit.py
index d8b5791..e7fb08f 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -10309,6 +10309,54 @@
self.checkScript(t, (torch.zeros(3, 2, 3),))
+ def test_slice_dynamic_index(self):
+ def t(x):
+ slice1 = x[0:1]
+ zero = 0
+ one = zero + 1
+ slice2 = x[zero:one]
+ return slice1 + slice2
+
+ self.checkScript(t, (torch.zeros(3, 2, 3),))
+
+ def test_torch_ignore_conversion_to_none(self):
+ class A(torch.nn.Module):
+ def __init__(self):
+ super(A, self).__init__()
+
+ @torch.jit.ignore
+ def ignored(self, a: int) -> None:
+ l: int = len([2 for i in range(a) if i > 2])
+ return
+
+ def forward(self) -> int:
+ a: int = 4
+ b: int = 5
+ self.ignored(a)
+ return a + b
+
+ class B(torch.nn.Module):
+ def __init__(self):
+ super(B, self).__init__()
+
+ @torch.jit.ignore
+ def ignored(self, a: int):
+ l: int = len([2 for i in range(a) if i > 2])
+ return
+
+ def forward(self) -> int:
+ a: int = 4
+ b: int = 5
+ self.ignored(a)
+ return a + b
+
+ modelA = torch.jit.script(A())
+ self.assertEqual(modelA(), 9)
+
+ with self.assertRaisesRegexWithHighlight(RuntimeError, "expected value of type Tensor", "self.ignored"):
+ modelB = torch.jit.script(B())
+ modelB()
+
def test_addmm_grad(self):
""" This test checks several things:
1. An expand node was inserted before the addmm operating on the
diff --git a/torch/jit/annotations.py b/torch/jit/annotations.py
index a693d0c..457f194 100644
--- a/torch/jit/annotations.py
+++ b/torch/jit/annotations.py
@@ -250,13 +250,9 @@
if all(ann is sig.empty for ann in all_annots):
return None
- def as_ann(ann):
- # sig.empty is really annoying so convert it to None
- return ann if ann is not sig.empty else None
-
- arg_types = [ann_to_type(as_ann(p.annotation), loc)
+ arg_types = [ann_to_type(p.annotation, loc)
for p in sig.parameters.values()]
- return_type = ann_to_type(as_ann(sig.return_annotation), loc)
+ return_type = ann_to_type(sig.return_annotation, loc)
return arg_types, return_type
@@ -294,8 +290,10 @@
def try_ann_to_type(ann, loc):
- if ann is None:
+ if ann is inspect.Signature.empty:
return TensorType.getInferred()
+ if ann is None:
+ return NoneType.get()
if inspect.isclass(ann) and is_tensor(ann):
return TensorType.get()
if is_tuple(ann):