make_fx can now SymIntify int inputs (#113452)

This PR also contains a basket of fixes that were turned up by now testing more arguments with SymInt. I fixed as many of the easy ones as I could easily get earlier in this stack and a bunch here, but there are some more annoying ones I xfailed.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113452
Approved by: https://github.com/Chillee
ghstack dependencies: #113877, #113911
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index b152175..5c9b981 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -923,6 +923,16 @@
             lambda: interp.run(torch.randn(3, 3))
         )
 
+    def test_int_input(self):
+        def f(x, y):
+            return x.view(y)
+
+        r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(3, 4), 12).code).strip()
+        self.assertExpectedInline(r, """\
+def forward(self, x_1, y_1):
+    view = torch.ops.aten.view.default(x_1, [y_1]);  x_1 = y_1 = None
+    return view""")
+
     def test_resize_from_zero(self):
         def f(x, y):
             x.resize_(y.size(0))
@@ -1705,11 +1715,20 @@
     xfail('nn.functional.interpolate', 'trilinear'),  # aten.upsample_trilinear3d.vec - couldn't find symbolic meta functi...
     xfail('nn.functional.pixel_unshuffle', ''),  # aten.pixel_unshuffle.default - couldn't find symbolic meta function/deco...
     xfail('quantile', ''),  # Could not run 'aten::equal' with arguments from the 'Meta' backend.
-    xfail('resize_', ''),  # aten.clone.default - couldn't find symbolic meta function/decomposition
     xfail('resize_as_', ''),  # aten.clone.default - couldn't find symbolic meta function/decomposition
     xfail('unique_consecutive', ''),  # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition
     xfail('unique', ''),  # aten._unique2.default - couldn't find symbolic meta function/decomposition
 
+    # AssertionError: False != True - https://github.com/pytorch/pytorch/issues/113905
+    xfail('dist', ''),
+    xfail('norm', ''),
+    xfail('linalg.vector_norm', ''),
+    xfail('linalg.norm', 'subgradients_at_zero'),
+    xfail('renorm', ''),
+
+    xfail('max_pool2d_with_indices_backward', ''),  # Expected a value of type 'List[int]' for argument 'kernel_size' but...
+    xfail('randint_like', ''),  # when unpacking SymInt, expected int but got s0
+
     # many complex operators incorrect striding, metadata
     xfail('fft.fft', ''),
     xfail('fft.hfft2', ''),
@@ -1736,6 +1755,11 @@
 
 outplace_symbolic_tensor_failures = {
     xfail('i0', ''),  # aten.i0.default - couldn't find symbolic meta function/decomposition
+
+    xfail('linalg.norm', ''),
+    xfail('round', 'decimals_0'),  # Cannot call numel() on tensor with symbolic sizes/strides
+    xfail('round', 'decimals_3'),  # Cannot call numel() on tensor with symbolic sizes/strides
+    xfail('round', 'decimals_neg_3'),  # Cannot call numel() on tensor with symbolic sizes/strides
 }
 
 inplace_symbolic_tensor_failures = {
@@ -1799,6 +1823,11 @@
     xfail('topk', ''),
     xfail('triangular_solve', ''),
     xfail('view_copy', ''),
+
+    # SymIntArrayRef expected to contain only concrete
+    xfail('ones', ''),
+    xfail('randn', ''),
+    xfail('zeros', ''),
 }
 
 out_symbolic_tensor_segfaults = {
diff --git a/torch/__init__.py b/torch/__init__.py
index 92d5425..fe699ad 100644
--- a/torch/__init__.py
+++ b/torch/__init__.py
@@ -287,6 +287,9 @@
     def __sym_float__(self):
         raise AssertionError("type stub not overridden")
 
+    def __neg__(self):
+        raise AssertionError("type stub not overridden")
+
     def __repr__(self):
         return str(self.node)
 
diff --git a/torch/csrc/jit/python/pybind_utils.cpp b/torch/csrc/jit/python/pybind_utils.cpp
index c2e2509..1565d73 100644
--- a/torch/csrc/jit/python/pybind_utils.cpp
+++ b/torch/csrc/jit/python/pybind_utils.cpp
@@ -111,6 +111,9 @@
     case TypeKind::StorageType:
       return py::cast<at::Storage>(obj);
     case TypeKind::FloatType:
+      if (torch::is_symfloat(py::handle(obj))) {
+        return py::cast<c10::SymFloat>(obj).guard_float(__FILE__, __LINE__);
+      }
       return py::cast<double>(obj);
     case TypeKind::ComplexType: {
       auto c_obj = py::cast<std::complex<double>>(obj.ptr());
@@ -139,6 +142,9 @@
         auto memory_format = reinterpret_cast<THPMemoryFormat*>(obj.ptr());
         return static_cast<int8_t>(memory_format->memory_format);
       }
+      if (torch::is_symint(py::handle(obj))) {
+        return py::cast<c10::SymInt>(obj).guard_int(__FILE__, __LINE__);
+      }
       return py::cast<int64_t>(obj);
     case TypeKind::LayoutType: {
       if (THPLayout_Check(obj.ptr())) {
@@ -186,6 +192,9 @@
       }
       return {};
     case TypeKind::BoolType:
+      if (torch::is_symbool(obj.ptr())) {
+        return py::cast<c10::SymBool>(obj).guard_bool(__FILE__, __LINE__);
+      }
       return py::cast<bool>(obj);
     case TypeKind::TupleType: {
       py::tuple tuple = py::cast<py::tuple>(obj);
diff --git a/torch/functional.py b/torch/functional.py
index 19117c7..a6c1241 100644
--- a/torch/functional.py
+++ b/torch/functional.py
@@ -1171,7 +1171,7 @@
     if has_torch_function_variadic(a, b):
         return handle_torch_function(tensordot, (a, b), a, b, dims=dims, out=out)
 
-    if not isinstance(dims, (tuple, list, torch.Tensor, int)):
+    if not isinstance(dims, (tuple, list, torch.Tensor, int, torch.SymInt)):
         raise RuntimeError("tensordot expects dims to be int or "
                            + "Tuple[List[int], List[int]] or "
                            + "List[List[int]] containing two lists, but got "
@@ -1196,7 +1196,7 @@
             dims_a = list(range(-dims_val, 0))
             dims_b = list(range(dims_val))
 
-    if isinstance(dims, int):
+    if isinstance(dims, (int, torch.SymInt)):
         if dims < 0:
             raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}")
         if dims > min(a.dim(), b.dim()):
@@ -1597,7 +1597,7 @@
     if input.layout == torch.strided and input.device.type in \
             ("cpu", "cuda", "meta", torch.utils.backend_registration._privateuse1_backend_name):
         if dim is not None:
-            if isinstance(dim, int):
+            if isinstance(dim, (int, torch.SymInt)):
                 _dim = [dim]
             else:
                 _dim = dim
@@ -1605,7 +1605,7 @@
             _dim = None  # type: ignore[assignment]
 
         if isinstance(p, str):
-            if p == "fro" and (dim is None or isinstance(dim, int) or len(dim) <= 2):
+            if p == "fro" and (dim is None or isinstance(dim, (int, torch.SymInt)) or len(dim) <= 2):
                 if out is None:
                     return torch.linalg.vector_norm(input, 2, _dim, keepdim, dtype=dtype)
                 else:
@@ -1642,7 +1642,7 @@
     # remove the overloads where dim is an int and replace with BraodcastingList1
     # and remove next four lines, replace _dim with dim
     if dim is not None:
-        if isinstance(dim, int):
+        if isinstance(dim, (int, torch.SymInt)):
             _dim = [dim]
         else:
             _dim = dim
@@ -1657,22 +1657,22 @@
             if _dim is None:
                 _dim = list(range(ndim))
             if out is None:
-                return _VF.frobenius_norm(input, _dim, keepdim=keepdim)
+                return _VF.frobenius_norm(input, _dim, keepdim=keepdim)  # type: ignore[arg-type]
             else:
-                return _VF.frobenius_norm(input, _dim, keepdim=keepdim, out=out)
+                return _VF.frobenius_norm(input, _dim, keepdim=keepdim, out=out)  # type: ignore[arg-type]
         elif p == "nuc":
             if dtype is not None:
                 raise ValueError("dtype argument is not supported in nuclear norm")
             if _dim is None:
                 if out is None:
-                    return _VF.nuclear_norm(input, keepdim=keepdim)
+                    return _VF.nuclear_norm(input, keepdim=keepdim)  # type: ignore[arg-type]
                 else:
-                    return _VF.nuclear_norm(input, keepdim=keepdim, out=out)
+                    return _VF.nuclear_norm(input, keepdim=keepdim, out=out)  # type: ignore[arg-type]
             else:
                 if out is None:
-                    return _VF.nuclear_norm(input, _dim, keepdim=keepdim)
+                    return _VF.nuclear_norm(input, _dim, keepdim=keepdim)  # type: ignore[arg-type]
                 else:
-                    return _VF.nuclear_norm(input, _dim, keepdim=keepdim, out=out)
+                    return _VF.nuclear_norm(input, _dim, keepdim=keepdim, out=out)  # type: ignore[arg-type]
         raise RuntimeError(f"only valid string values are 'fro' and 'nuc', found {p}")
     else:
         if _dim is None:
@@ -1750,15 +1750,15 @@
         lambda: f"expected 'indices' to be integer dtype, but got {indices.dtype}")
 
     torch._check_type(
-        isinstance(shape, (int, Sequence)),
+        isinstance(shape, (int, torch.SymInt, Sequence)),
         lambda: f"expected 'shape' to be int or sequence of ints, but got {type(shape)}")
 
-    if isinstance(shape, int):
+    if isinstance(shape, (int, torch.SymInt)):
         shape = torch.Size([shape])
     else:
         for dim in shape:
             torch._check_type(
-                isinstance(dim, int),
+                isinstance(dim, (int, torch.SymInt)),
                 lambda: f"expected 'shape' sequence to only contain ints, but got {type(dim)}")
         shape = torch.Size(shape)
 
diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py
index 19d7089..04219b3 100644
--- a/torch/fx/experimental/proxy_tensor.py
+++ b/torch/fx/experimental/proxy_tensor.py
@@ -824,15 +824,18 @@
 
         def wrap_fake(x):
             nonlocal arg_count
+            # TODO: it would be nice to line these up with the names
+            # FX will choose for the placeholders, but we don't
+            # actually know what the names will be at this point yet
+            # NB: the Source here is actually meaningless
+            from torch._dynamo.source import ConstantSource
+            source = ConstantSource(f"input{arg_count}")
             if isinstance(x, torch.Tensor):
-                # TODO: it would be nice to line these up with the names
-                # FX will choose for the placeholders, but we don't
-                # actually know what the names will be at this point yet
-                # NB: the Source here is actually meaningless
-                from torch._dynamo.source import ConstantSource
-                source = ConstantSource(f"input{arg_count}")
                 arg_count += 1
                 return fake_tensor_mode.from_tensor(x, source=source)  # type: ignore[attr-defined]
+            # NB: don't match on bools
+            elif type(x) is int and tracing_mode == "symbolic":
+                return shape_env.create_symintnode(shape_env.create_symbol(x, source, positive=None), hint=x, source=source)
 
             return x
 
diff --git a/torch/jit/annotations.py b/torch/jit/annotations.py
index b1d89fc..fffb650 100644
--- a/torch/jit/annotations.py
+++ b/torch/jit/annotations.py
@@ -465,7 +465,7 @@
         return FloatType.get()
     if ann is complex:
         return ComplexType.get()
-    if ann is int:
+    if ann is int or ann is torch.SymInt:
         return IntType.get()
     if ann is str:
         return StringType.get()
diff --git a/torch/masked/_ops.py b/torch/masked/_ops.py
index 93a5b5e..083af67 100644
--- a/torch/masked/_ops.py
+++ b/torch/masked/_ops.py
@@ -463,7 +463,7 @@
     if dim is None:
         return tuple(range(ndim))
     ndim = max(ndim, 1)
-    dim_ = (dim,) if isinstance(dim, int) else dim
+    dim_ = (dim,) if isinstance(dim, (int, torch.SymInt)) else dim
     for d in dim_:
         if d in dims:
             raise RuntimeError(f"dim={d} appears multiple times in the list of dims")