make_fx can now SymIntify int inputs (#113452)
This PR also contains a basket of fixes that were turned up by now testing more arguments with SymInt. I fixed as many of the easy ones as I could easily get earlier in this stack and a bunch here, but there are some more annoying ones I xfailed.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113452
Approved by: https://github.com/Chillee
ghstack dependencies: #113877, #113911
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index b152175..5c9b981 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -923,6 +923,16 @@
lambda: interp.run(torch.randn(3, 3))
)
+ def test_int_input(self):
+ def f(x, y):
+ return x.view(y)
+
+ r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(3, 4), 12).code).strip()
+ self.assertExpectedInline(r, """\
+def forward(self, x_1, y_1):
+ view = torch.ops.aten.view.default(x_1, [y_1]); x_1 = y_1 = None
+ return view""")
+
def test_resize_from_zero(self):
def f(x, y):
x.resize_(y.size(0))
@@ -1705,11 +1715,20 @@
xfail('nn.functional.interpolate', 'trilinear'), # aten.upsample_trilinear3d.vec - couldn't find symbolic meta functi...
xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta function/deco...
xfail('quantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend.
- xfail('resize_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition
xfail('resize_as_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition
xfail('unique_consecutive', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition
xfail('unique', ''), # aten._unique2.default - couldn't find symbolic meta function/decomposition
+ # AssertionError: False != True - https://github.com/pytorch/pytorch/issues/113905
+ xfail('dist', ''),
+ xfail('norm', ''),
+ xfail('linalg.vector_norm', ''),
+ xfail('linalg.norm', 'subgradients_at_zero'),
+ xfail('renorm', ''),
+
+ xfail('max_pool2d_with_indices_backward', ''), # Expected a value of type 'List[int]' for argument 'kernel_size' but...
+ xfail('randint_like', ''), # when unpacking SymInt, expected int but got s0
+
# many complex operators incorrect striding, metadata
xfail('fft.fft', ''),
xfail('fft.hfft2', ''),
@@ -1736,6 +1755,11 @@
outplace_symbolic_tensor_failures = {
xfail('i0', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition
+
+ xfail('linalg.norm', ''),
+ xfail('round', 'decimals_0'), # Cannot call numel() on tensor with symbolic sizes/strides
+ xfail('round', 'decimals_3'), # Cannot call numel() on tensor with symbolic sizes/strides
+ xfail('round', 'decimals_neg_3'), # Cannot call numel() on tensor with symbolic sizes/strides
}
inplace_symbolic_tensor_failures = {
@@ -1799,6 +1823,11 @@
xfail('topk', ''),
xfail('triangular_solve', ''),
xfail('view_copy', ''),
+
+ # SymIntArrayRef expected to contain only concrete
+ xfail('ones', ''),
+ xfail('randn', ''),
+ xfail('zeros', ''),
}
out_symbolic_tensor_segfaults = {
diff --git a/torch/__init__.py b/torch/__init__.py
index 92d5425..fe699ad 100644
--- a/torch/__init__.py
+++ b/torch/__init__.py
@@ -287,6 +287,9 @@
def __sym_float__(self):
raise AssertionError("type stub not overridden")
+ def __neg__(self):
+ raise AssertionError("type stub not overridden")
+
def __repr__(self):
return str(self.node)
diff --git a/torch/csrc/jit/python/pybind_utils.cpp b/torch/csrc/jit/python/pybind_utils.cpp
index c2e2509..1565d73 100644
--- a/torch/csrc/jit/python/pybind_utils.cpp
+++ b/torch/csrc/jit/python/pybind_utils.cpp
@@ -111,6 +111,9 @@
case TypeKind::StorageType:
return py::cast<at::Storage>(obj);
case TypeKind::FloatType:
+ if (torch::is_symfloat(py::handle(obj))) {
+ return py::cast<c10::SymFloat>(obj).guard_float(__FILE__, __LINE__);
+ }
return py::cast<double>(obj);
case TypeKind::ComplexType: {
auto c_obj = py::cast<std::complex<double>>(obj.ptr());
@@ -139,6 +142,9 @@
auto memory_format = reinterpret_cast<THPMemoryFormat*>(obj.ptr());
return static_cast<int8_t>(memory_format->memory_format);
}
+ if (torch::is_symint(py::handle(obj))) {
+ return py::cast<c10::SymInt>(obj).guard_int(__FILE__, __LINE__);
+ }
return py::cast<int64_t>(obj);
case TypeKind::LayoutType: {
if (THPLayout_Check(obj.ptr())) {
@@ -186,6 +192,9 @@
}
return {};
case TypeKind::BoolType:
+ if (torch::is_symbool(obj.ptr())) {
+ return py::cast<c10::SymBool>(obj).guard_bool(__FILE__, __LINE__);
+ }
return py::cast<bool>(obj);
case TypeKind::TupleType: {
py::tuple tuple = py::cast<py::tuple>(obj);
diff --git a/torch/functional.py b/torch/functional.py
index 19117c7..a6c1241 100644
--- a/torch/functional.py
+++ b/torch/functional.py
@@ -1171,7 +1171,7 @@
if has_torch_function_variadic(a, b):
return handle_torch_function(tensordot, (a, b), a, b, dims=dims, out=out)
- if not isinstance(dims, (tuple, list, torch.Tensor, int)):
+ if not isinstance(dims, (tuple, list, torch.Tensor, int, torch.SymInt)):
raise RuntimeError("tensordot expects dims to be int or "
+ "Tuple[List[int], List[int]] or "
+ "List[List[int]] containing two lists, but got "
@@ -1196,7 +1196,7 @@
dims_a = list(range(-dims_val, 0))
dims_b = list(range(dims_val))
- if isinstance(dims, int):
+ if isinstance(dims, (int, torch.SymInt)):
if dims < 0:
raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}")
if dims > min(a.dim(), b.dim()):
@@ -1597,7 +1597,7 @@
if input.layout == torch.strided and input.device.type in \
("cpu", "cuda", "meta", torch.utils.backend_registration._privateuse1_backend_name):
if dim is not None:
- if isinstance(dim, int):
+ if isinstance(dim, (int, torch.SymInt)):
_dim = [dim]
else:
_dim = dim
@@ -1605,7 +1605,7 @@
_dim = None # type: ignore[assignment]
if isinstance(p, str):
- if p == "fro" and (dim is None or isinstance(dim, int) or len(dim) <= 2):
+ if p == "fro" and (dim is None or isinstance(dim, (int, torch.SymInt)) or len(dim) <= 2):
if out is None:
return torch.linalg.vector_norm(input, 2, _dim, keepdim, dtype=dtype)
else:
@@ -1642,7 +1642,7 @@
# remove the overloads where dim is an int and replace with BraodcastingList1
# and remove next four lines, replace _dim with dim
if dim is not None:
- if isinstance(dim, int):
+ if isinstance(dim, (int, torch.SymInt)):
_dim = [dim]
else:
_dim = dim
@@ -1657,22 +1657,22 @@
if _dim is None:
_dim = list(range(ndim))
if out is None:
- return _VF.frobenius_norm(input, _dim, keepdim=keepdim)
+ return _VF.frobenius_norm(input, _dim, keepdim=keepdim) # type: ignore[arg-type]
else:
- return _VF.frobenius_norm(input, _dim, keepdim=keepdim, out=out)
+ return _VF.frobenius_norm(input, _dim, keepdim=keepdim, out=out) # type: ignore[arg-type]
elif p == "nuc":
if dtype is not None:
raise ValueError("dtype argument is not supported in nuclear norm")
if _dim is None:
if out is None:
- return _VF.nuclear_norm(input, keepdim=keepdim)
+ return _VF.nuclear_norm(input, keepdim=keepdim) # type: ignore[arg-type]
else:
- return _VF.nuclear_norm(input, keepdim=keepdim, out=out)
+ return _VF.nuclear_norm(input, keepdim=keepdim, out=out) # type: ignore[arg-type]
else:
if out is None:
- return _VF.nuclear_norm(input, _dim, keepdim=keepdim)
+ return _VF.nuclear_norm(input, _dim, keepdim=keepdim) # type: ignore[arg-type]
else:
- return _VF.nuclear_norm(input, _dim, keepdim=keepdim, out=out)
+ return _VF.nuclear_norm(input, _dim, keepdim=keepdim, out=out) # type: ignore[arg-type]
raise RuntimeError(f"only valid string values are 'fro' and 'nuc', found {p}")
else:
if _dim is None:
@@ -1750,15 +1750,15 @@
lambda: f"expected 'indices' to be integer dtype, but got {indices.dtype}")
torch._check_type(
- isinstance(shape, (int, Sequence)),
+ isinstance(shape, (int, torch.SymInt, Sequence)),
lambda: f"expected 'shape' to be int or sequence of ints, but got {type(shape)}")
- if isinstance(shape, int):
+ if isinstance(shape, (int, torch.SymInt)):
shape = torch.Size([shape])
else:
for dim in shape:
torch._check_type(
- isinstance(dim, int),
+ isinstance(dim, (int, torch.SymInt)),
lambda: f"expected 'shape' sequence to only contain ints, but got {type(dim)}")
shape = torch.Size(shape)
diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py
index 19d7089..04219b3 100644
--- a/torch/fx/experimental/proxy_tensor.py
+++ b/torch/fx/experimental/proxy_tensor.py
@@ -824,15 +824,18 @@
def wrap_fake(x):
nonlocal arg_count
+ # TODO: it would be nice to line these up with the names
+ # FX will choose for the placeholders, but we don't
+ # actually know what the names will be at this point yet
+ # NB: the Source here is actually meaningless
+ from torch._dynamo.source import ConstantSource
+ source = ConstantSource(f"input{arg_count}")
if isinstance(x, torch.Tensor):
- # TODO: it would be nice to line these up with the names
- # FX will choose for the placeholders, but we don't
- # actually know what the names will be at this point yet
- # NB: the Source here is actually meaningless
- from torch._dynamo.source import ConstantSource
- source = ConstantSource(f"input{arg_count}")
arg_count += 1
return fake_tensor_mode.from_tensor(x, source=source) # type: ignore[attr-defined]
+ # NB: don't match on bools
+ elif type(x) is int and tracing_mode == "symbolic":
+ return shape_env.create_symintnode(shape_env.create_symbol(x, source, positive=None), hint=x, source=source)
return x
diff --git a/torch/jit/annotations.py b/torch/jit/annotations.py
index b1d89fc..fffb650 100644
--- a/torch/jit/annotations.py
+++ b/torch/jit/annotations.py
@@ -465,7 +465,7 @@
return FloatType.get()
if ann is complex:
return ComplexType.get()
- if ann is int:
+ if ann is int or ann is torch.SymInt:
return IntType.get()
if ann is str:
return StringType.get()
diff --git a/torch/masked/_ops.py b/torch/masked/_ops.py
index 93a5b5e..083af67 100644
--- a/torch/masked/_ops.py
+++ b/torch/masked/_ops.py
@@ -463,7 +463,7 @@
if dim is None:
return tuple(range(ndim))
ndim = max(ndim, 1)
- dim_ = (dim,) if isinstance(dim, int) else dim
+ dim_ = (dim,) if isinstance(dim, (int, torch.SymInt)) else dim
for d in dim_:
if d in dims:
raise RuntimeError(f"dim={d} appears multiple times in the list of dims")