[dynamo] support FakeTensor for SYM_INT/SYM_INT_LIST/INT_LIST param in python-to-cpp argument parsing (#103448)

before the PR, when compiling a function with signature symint/symintlist/intlist, we have runtime error like ```argument 'shifts' must be tuple of ints, not FakeTensor```. see newly added unit test in test/dynamo/test_misc.py for repro

after the PR, for FakeTensor with empty size and numel()=1, we will try
to convert it into symint/symintlist. we will likely see expected
exception
```torch._subclasses.fake_tensor.DataDependentOutputException / aten._local_scalar_dense.default``` during conversion

reference PR:
* we handle FakeTensor for symintlist as 1st varags: https://github.com/pytorch/pytorch/pull/97508
* we handle FakeTensor for intlist in a similar way:
https://github.com/pytorch/pytorch/pull/85759/files
* call local_scalar_dense on a FakeTensor:
https://github.com/pytorch/pytorch/commit/f7365eca901b0ed5c5edc0d2ae92834b7b75c0d2

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103448
Approved by: https://github.com/yanboliang
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index 094358f..6297dba 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -5648,6 +5648,70 @@
         opt_out = opt_model(x)
         self.assertTrue(same(orig_out, opt_out))
 
+    def test_scalar_tensor_is_equivalent_to_symint_argument(self):
+        class GumbelTopKSampler(torch.nn.Module):
+            def __init__(self, T, k):
+                super(GumbelTopKSampler, self).__init__()
+                self.T = torch.nn.Parameter(
+                    torch.tensor(T, dtype=torch.float32), requires_grad=False
+                )
+                self.k = torch.nn.Parameter(
+                    torch.tensor(k, dtype=torch.int32), requires_grad=False
+                )
+
+            def sample_discrete(self, logits):
+                threshold = torch.topk(logits, self.k, sorted=True)[0][..., -1]
+                samples = torch.ge(logits.squeeze(1), threshold).float()
+                return samples
+
+            def forward(self, logits):
+                dsamples = self.sample_discrete(logits)
+                return dsamples
+
+        x = torch.rand([4, 4, 4, 4])
+        m = GumbelTopKSampler(T=4, k=4)
+        orig_out = m(x)
+        opt_m = torch.compile(backend="eager")(m)
+        opt_out = opt_m(x)
+        self.assertTrue(same(orig_out, opt_out))
+
+    def test_scalar_tensor_is_equivalent_to_symint_list_argument(self):
+        class Jitter(torch.nn.Module):
+            def __init__(self, jitter_val):
+                super(Jitter, self).__init__()
+                self.jitter_val = jitter_val
+
+            def roll_tensor(self, input):
+                h_shift = np.int_(self.jitter_val - 1)
+                w_shift = np.int_(self.jitter_val + 1)
+                return torch.roll(
+                    torch.roll(input, shifts=h_shift, dims=2), shifts=w_shift, dims=3
+                )
+
+            def forward(self, input):
+                return self.roll_tensor(input)
+
+        x = torch.rand([4, 4, 4, 4])
+        m = Jitter(jitter_val=4)
+        orig_out = m(x)
+        opt_m = torch.compile(backend="eager")(m)
+        opt_out = opt_m(x)
+        self.assertTrue(same(orig_out, opt_out))
+
+    def test_scalar_tensor_is_equivalent_to_int_list_argument(self):
+        class MyModel(torch.nn.Module):
+            def forward(self, input):
+                permute = torch.tensor([0, 2, 1])
+                x = input.permute(*permute)
+                return x
+
+        x = torch.randn(2, 3, 4)
+        m = MyModel()
+        orig_out = m(x)
+        opt_m = torch.compile(backend="eager")(m)
+        opt_out = opt_m(x)
+        self.assertTrue(same(orig_out, opt_out))
+
     def test_torch_variable_hasattr(self):
         def fn(x):
             if hasattr(torch.nn, "Module"):
diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp
index 1a42c5b..6a15cf7 100644
--- a/torch/csrc/utils/python_arg_parser.cpp
+++ b/torch/csrc/utils/python_arg_parser.cpp
@@ -704,6 +704,20 @@
       return true;
     }
 
+    // in dynamo, FakeTensor is qualified for INT_LIST
+    if (is_dynamo_compiling && THPVariable_Check(item.ptr())) {
+      auto& var = THPVariable_Unpack(item.ptr());
+      if (var.numel() != 1 || !var.sizes().empty() ||
+          !at::isIntegralType(
+              var.dtype().toScalarType(), /*include_bool*/ true)) {
+        if (failed_idx != nullptr) {
+          *failed_idx = 0;
+        }
+        return false;
+      }
+      return true;
+    }
+
     // NOTE: JIT tracer allows arbitrary scalar tensors to act as ints
     // in an intlist argument. Even float or complex scalar tensors.
     bool r =
@@ -724,8 +738,23 @@
   // which may have side effects if obj is a symint node
   // so we do `is_symint` check first
   // TODO: maybe we should be using checkLong here?
+  if (torch::is_symint(py::handle(obj))) {
+    return true;
+  }
 
-  return torch::is_symint(py::handle(obj)) || THPUtils_checkIndex(obj);
+  if (THPUtils_checkIndex(obj)) {
+    return true;
+  }
+
+  // FakeTensor(..., size=()) is qualified for SymInt param
+  if (is_dynamo_compiling && THPVariable_Check(obj)) {
+    auto& var = THPVariable_Unpack(obj);
+    if (var.numel() == 1 && var.sizes().empty() &&
+        at::isIntegralType(var.dtype().toScalarType(), /*include_bool*/ true)) {
+      return true;
+    }
+  }
+  return false;
 }
 
 static bool is_int_or_symint_list(
@@ -742,12 +771,6 @@
       return true;
     }
 
-    // NOTE: In dynamo, allow fake tensor as int
-    if (is_dynamo_compiling && THPVariable_Check(item.ptr()) &&
-        THPVariable_Unpack(item.ptr()).sizes().empty()) {
-      return true;
-    }
-
     // NOTE: JIT tracer allows arbitrary scalar tensors to act as ints
     // in an intlist argument. Even float or complex scalar tensors.
     bool r =
diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h
index 8cead32..42788f4 100644
--- a/torch/csrc/utils/python_arg_parser.h
+++ b/torch/csrc/utils/python_arg_parser.h
@@ -536,6 +536,16 @@
     return std::vector<c10::SymInt>(size1, si);
   }
 
+  if (is_dynamo_compiling && size1 > 0 && THPVariable_Check(args[i])) {
+    auto& var = THPVariable_Unpack(args[i]);
+    if (size1 == 1 && var.numel() == 1 && var.sizes().empty() &&
+        at::isIntegralType(var.dtype().toScalarType(), /*include_bool*/ true)) {
+      auto scalar = var.item();
+      TORCH_CHECK(scalar.isIntegral(/*include bool*/ false));
+      return std::vector<c10::SymInt>(size1, scalar.toSymInt());
+    }
+  }
+
   PyObject* arg = args[i];
   auto tuple = PyTuple_Check(arg);
   // NOLINTNEXTLINE(bugprone-branch-clone)
@@ -923,6 +933,7 @@
 }
 
 inline c10::SymInt PythonArgs::toSymInt(int i) {
+  PyObject* obj = args[i];
   if (!args[i]) {
     return c10::SymInt(signature.params[i].default_int);
   }
@@ -933,6 +944,28 @@
         signature.params[i].name, idx, var, c10::IntType::get());
   }
 
+  // convert FakeTensor to SymInt
+  // expect empty sizes, numel = 1
+  // and ScalarType::Int
+  if (is_dynamo_compiling && THPVariable_Check(obj)) {
+    auto& var = THPVariable_Unpack(obj);
+
+    if (var.numel() != 1 || !var.sizes().empty() ||
+        !at::isIntegralType(
+            var.dtype().toScalarType(), /*include_bool*/ true)) {
+      throw TypeError(
+          "%s(): argument '%s' must be %s, failed to convert %s with sizes.empty()=%d",
+          signature.name.c_str(),
+          signature.params[i].name.c_str(),
+          signature.params[i].type_name().c_str(),
+          Py_TYPE(obj)->tp_name,
+          var.sizes().empty());
+    }
+    auto scalar = var.item();
+    TORCH_CHECK(scalar.isIntegral(/*include bool*/ false));
+    return scalar.toSymInt();
+  }
+
   return py::cast<c10::SymInt>(py::handle(args[i]));
 }