[dynamo] support FakeTensor for SYM_INT/SYM_INT_LIST/INT_LIST param in python-to-cpp argument parsing (#103448)
before the PR, when compiling a function with signature symint/symintlist/intlist, we have runtime error like ```argument 'shifts' must be tuple of ints, not FakeTensor```. see newly added unit test in test/dynamo/test_misc.py for repro
after the PR, for FakeTensor with empty size and numel()=1, we will try
to convert it into symint/symintlist. we will likely see expected
exception
```torch._subclasses.fake_tensor.DataDependentOutputException / aten._local_scalar_dense.default``` during conversion
reference PR:
* we handle FakeTensor for symintlist as 1st varags: https://github.com/pytorch/pytorch/pull/97508
* we handle FakeTensor for intlist in a similar way:
https://github.com/pytorch/pytorch/pull/85759/files
* call local_scalar_dense on a FakeTensor:
https://github.com/pytorch/pytorch/commit/f7365eca901b0ed5c5edc0d2ae92834b7b75c0d2
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103448
Approved by: https://github.com/yanboliang
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index 094358f..6297dba 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -5648,6 +5648,70 @@
opt_out = opt_model(x)
self.assertTrue(same(orig_out, opt_out))
+ def test_scalar_tensor_is_equivalent_to_symint_argument(self):
+ class GumbelTopKSampler(torch.nn.Module):
+ def __init__(self, T, k):
+ super(GumbelTopKSampler, self).__init__()
+ self.T = torch.nn.Parameter(
+ torch.tensor(T, dtype=torch.float32), requires_grad=False
+ )
+ self.k = torch.nn.Parameter(
+ torch.tensor(k, dtype=torch.int32), requires_grad=False
+ )
+
+ def sample_discrete(self, logits):
+ threshold = torch.topk(logits, self.k, sorted=True)[0][..., -1]
+ samples = torch.ge(logits.squeeze(1), threshold).float()
+ return samples
+
+ def forward(self, logits):
+ dsamples = self.sample_discrete(logits)
+ return dsamples
+
+ x = torch.rand([4, 4, 4, 4])
+ m = GumbelTopKSampler(T=4, k=4)
+ orig_out = m(x)
+ opt_m = torch.compile(backend="eager")(m)
+ opt_out = opt_m(x)
+ self.assertTrue(same(orig_out, opt_out))
+
+ def test_scalar_tensor_is_equivalent_to_symint_list_argument(self):
+ class Jitter(torch.nn.Module):
+ def __init__(self, jitter_val):
+ super(Jitter, self).__init__()
+ self.jitter_val = jitter_val
+
+ def roll_tensor(self, input):
+ h_shift = np.int_(self.jitter_val - 1)
+ w_shift = np.int_(self.jitter_val + 1)
+ return torch.roll(
+ torch.roll(input, shifts=h_shift, dims=2), shifts=w_shift, dims=3
+ )
+
+ def forward(self, input):
+ return self.roll_tensor(input)
+
+ x = torch.rand([4, 4, 4, 4])
+ m = Jitter(jitter_val=4)
+ orig_out = m(x)
+ opt_m = torch.compile(backend="eager")(m)
+ opt_out = opt_m(x)
+ self.assertTrue(same(orig_out, opt_out))
+
+ def test_scalar_tensor_is_equivalent_to_int_list_argument(self):
+ class MyModel(torch.nn.Module):
+ def forward(self, input):
+ permute = torch.tensor([0, 2, 1])
+ x = input.permute(*permute)
+ return x
+
+ x = torch.randn(2, 3, 4)
+ m = MyModel()
+ orig_out = m(x)
+ opt_m = torch.compile(backend="eager")(m)
+ opt_out = opt_m(x)
+ self.assertTrue(same(orig_out, opt_out))
+
def test_torch_variable_hasattr(self):
def fn(x):
if hasattr(torch.nn, "Module"):
diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp
index 1a42c5b..6a15cf7 100644
--- a/torch/csrc/utils/python_arg_parser.cpp
+++ b/torch/csrc/utils/python_arg_parser.cpp
@@ -704,6 +704,20 @@
return true;
}
+ // in dynamo, FakeTensor is qualified for INT_LIST
+ if (is_dynamo_compiling && THPVariable_Check(item.ptr())) {
+ auto& var = THPVariable_Unpack(item.ptr());
+ if (var.numel() != 1 || !var.sizes().empty() ||
+ !at::isIntegralType(
+ var.dtype().toScalarType(), /*include_bool*/ true)) {
+ if (failed_idx != nullptr) {
+ *failed_idx = 0;
+ }
+ return false;
+ }
+ return true;
+ }
+
// NOTE: JIT tracer allows arbitrary scalar tensors to act as ints
// in an intlist argument. Even float or complex scalar tensors.
bool r =
@@ -724,8 +738,23 @@
// which may have side effects if obj is a symint node
// so we do `is_symint` check first
// TODO: maybe we should be using checkLong here?
+ if (torch::is_symint(py::handle(obj))) {
+ return true;
+ }
- return torch::is_symint(py::handle(obj)) || THPUtils_checkIndex(obj);
+ if (THPUtils_checkIndex(obj)) {
+ return true;
+ }
+
+ // FakeTensor(..., size=()) is qualified for SymInt param
+ if (is_dynamo_compiling && THPVariable_Check(obj)) {
+ auto& var = THPVariable_Unpack(obj);
+ if (var.numel() == 1 && var.sizes().empty() &&
+ at::isIntegralType(var.dtype().toScalarType(), /*include_bool*/ true)) {
+ return true;
+ }
+ }
+ return false;
}
static bool is_int_or_symint_list(
@@ -742,12 +771,6 @@
return true;
}
- // NOTE: In dynamo, allow fake tensor as int
- if (is_dynamo_compiling && THPVariable_Check(item.ptr()) &&
- THPVariable_Unpack(item.ptr()).sizes().empty()) {
- return true;
- }
-
// NOTE: JIT tracer allows arbitrary scalar tensors to act as ints
// in an intlist argument. Even float or complex scalar tensors.
bool r =
diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h
index 8cead32..42788f4 100644
--- a/torch/csrc/utils/python_arg_parser.h
+++ b/torch/csrc/utils/python_arg_parser.h
@@ -536,6 +536,16 @@
return std::vector<c10::SymInt>(size1, si);
}
+ if (is_dynamo_compiling && size1 > 0 && THPVariable_Check(args[i])) {
+ auto& var = THPVariable_Unpack(args[i]);
+ if (size1 == 1 && var.numel() == 1 && var.sizes().empty() &&
+ at::isIntegralType(var.dtype().toScalarType(), /*include_bool*/ true)) {
+ auto scalar = var.item();
+ TORCH_CHECK(scalar.isIntegral(/*include bool*/ false));
+ return std::vector<c10::SymInt>(size1, scalar.toSymInt());
+ }
+ }
+
PyObject* arg = args[i];
auto tuple = PyTuple_Check(arg);
// NOLINTNEXTLINE(bugprone-branch-clone)
@@ -923,6 +933,7 @@
}
inline c10::SymInt PythonArgs::toSymInt(int i) {
+ PyObject* obj = args[i];
if (!args[i]) {
return c10::SymInt(signature.params[i].default_int);
}
@@ -933,6 +944,28 @@
signature.params[i].name, idx, var, c10::IntType::get());
}
+ // convert FakeTensor to SymInt
+ // expect empty sizes, numel = 1
+ // and ScalarType::Int
+ if (is_dynamo_compiling && THPVariable_Check(obj)) {
+ auto& var = THPVariable_Unpack(obj);
+
+ if (var.numel() != 1 || !var.sizes().empty() ||
+ !at::isIntegralType(
+ var.dtype().toScalarType(), /*include_bool*/ true)) {
+ throw TypeError(
+ "%s(): argument '%s' must be %s, failed to convert %s with sizes.empty()=%d",
+ signature.name.c_str(),
+ signature.params[i].name.c_str(),
+ signature.params[i].type_name().c_str(),
+ Py_TYPE(obj)->tp_name,
+ var.sizes().empty());
+ }
+ auto scalar = var.item();
+ TORCH_CHECK(scalar.isIntegral(/*include bool*/ false));
+ return scalar.toSymInt();
+ }
+
return py::cast<c10::SymInt>(py::handle(args[i]));
}