[dynamo] add a handler for itertools_chain_from_iterable and test (#116849)
1. add a handler for itertools_chain_from_iterable
2. a test for itertools_chain_from_iterable
Fixes #116463
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116849
Approved by: https://github.com/ezyang
diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py
index 2b7d8fa..ee4ab5a 100644
--- a/test/dynamo/test_functions.py
+++ b/test/dynamo/test_functions.py
@@ -135,6 +135,13 @@
return v
@make_test
+ def test_itertools_chain_from_iterable(a, b):
+ v = a
+ for x in itertools.chain.from_iterable([[a, b], [1, 2]]):
+ v = v + x
+ return v
+
+ @make_test
def test_itertools_combinations(a, b):
combs = []
for size in itertools.combinations((1, 2, 3, 4), 2):
diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py
index a544208..8cc5f21 100644
--- a/torch/_dynamo/variables/builtin.py
+++ b/torch/_dynamo/variables/builtin.py
@@ -676,6 +676,15 @@
) -> "VariableTracker":
if self.fn == dict and name == "fromkeys":
return BuiltinVariable.call_custom_dict_fromkeys(tx, dict, *args, **kwargs)
+ if self.fn == itertools.chain and name == "from_iterable":
+ assert len(args) == 1
+ assert len(kwargs) == 0
+ obj = args[0]
+ items = []
+ for item in obj.unpack_var_sequence(tx):
+ items.extend(item.unpack_var_sequence(tx))
+ return variables.TupleVariable(items)
+
return super().call_method(tx, name, args, kwargs)
def _call_min_max(self, tx, *args):