[BE]: Use `iterable.chain.from_iterable` where possible (#116376)
This is more readable and more efficient when dealing with lots of sequences to chain together.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116376
Approved by: https://github.com/albanD
diff --git a/test/test_linalg.py b/test/test_linalg.py
index 8d60d66..e20d11f 100644
--- a/test/test_linalg.py
+++ b/test/test_linalg.py
@@ -3907,7 +3907,7 @@
self._check_einsum(equation, *operands, np_args=(equation, *np_operands))
# test sublist format
- args = [*itertools.chain(*zip(operands, sublists))]
+ args = list(itertools.chain.from_iterable(zip(operands, sublists)))
self._check_einsum(*args, np_args=(equation, *np_operands))
# generate an explicit output
diff --git a/test/torch_np/numpy_tests/core/test_numeric.py b/test/torch_np/numpy_tests/core/test_numeric.py
index 131ef7b..e94c3f2 100644
--- a/test/torch_np/numpy_tests/core/test_numeric.py
+++ b/test/torch_np/numpy_tests/core/test_numeric.py
@@ -2132,7 +2132,7 @@
# Test ones, zeros, empty and full.
def setUp(self):
- # dtypes = {np.dtype(tp) for tp in itertools.chain(*np.sctypes.values())}
+ # dtypes = {np.dtype(tp) for tp in itertools.chain.from_iterable(np.sctypes.values())}
dtypes = {np.dtype(tp) for tp in "efdFDBbhil?"}
self.dtypes = dtypes
self.orders = {
diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py
index 0eca115..854a6b1 100644
--- a/torch/_dynamo/variables/iter.py
+++ b/torch/_dynamo/variables/iter.py
@@ -45,9 +45,7 @@
and all(arg.has_unpack_var_sequence(tx) for arg in args)
):
seqs = [arg.unpack_var_sequence(tx) for arg in args]
- items = []
- for item in itertools.chain(*seqs):
- items.append(item)
+ items = list(itertools.chain.from_iterable(seqs))
return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
elif self.value is itertools.accumulate:
from .builtin import BuiltinVariable
diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py
index 1adede8..8f67877 100644
--- a/torch/_inductor/codegen/triton.py
+++ b/torch/_inductor/codegen/triton.py
@@ -1193,7 +1193,7 @@
new_ranges, return_getters_groups = self._split_iteration_ranges(
groups, lengths
)
- itervars = list(itertools.chain(*self.set_ranges(*new_ranges)))
+ itervars = list(itertools.chain.from_iterable(self.set_ranges(*new_ranges)))
return [[fn(itervars) for fn in fns] for fns in return_getters_groups]
def is_indirect_indexing(self, index: sympy.Expr):
@@ -3033,11 +3033,11 @@
dep_sources = [rw.reads, rw.writes]
assert all(
isinstance(dep, (MemoryDep, StarDep))
- for dep in itertools.chain(*dep_sources)
+ for dep in itertools.chain.from_iterable(dep_sources)
)
deps = [
dep
- for dep in itertools.chain(*dep_sources)
+ for dep in itertools.chain.from_iterable(dep_sources)
if dep.name not in V.graph.removed_buffers and isinstance(dep, MemoryDep)
]
write_names = {dep.name for dep in rw.writes}
diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py
index 5a5982c..7f2e9b8 100644
--- a/torch/_inductor/dependencies.py
+++ b/torch/_inductor/dependencies.py
@@ -363,7 +363,7 @@
if normalize:
range_vars = [] # Number of vars could differ due to normalization
else:
- range_vars = [*itertools.chain(*args)]
+ range_vars = list(itertools.chain.from_iterable(args))
inner = rw.parent_handler.parent_handler
return ReadWrites(
diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py
index 10e9b60..893aa67 100644
--- a/torch/_inductor/fx_passes/post_grad.py
+++ b/torch/_inductor/fx_passes/post_grad.py
@@ -381,7 +381,8 @@
assert new_size is not None
dtype = functools.reduce(
- torch.promote_types, [x.get_dtype() for x in itertools.chain(*inputs)]
+ torch.promote_types,
+ [x.get_dtype() for x in itertools.chain.from_iterable(inputs)],
)
device = inputs[0][0].get_device()
kernel = ir.ConcatKernel(
diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py
index 59667d5..ee50839 100644
--- a/torch/_inductor/ir.py
+++ b/torch/_inductor/ir.py
@@ -6546,7 +6546,7 @@
return self.indexing[name]
def __call__(self, *indices):
- index = list(itertools.chain(*indices))
+ index = list(itertools.chain.from_iterable(indices))
assert len(index) == len(self.var_ranges), (index, self.var_ranges)
assert all(v not in self.var_ranges for v in index)
replacements = dict(zip(self.var_ranges.keys(), index))
diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py
index 0a75147..82263ff 100644
--- a/torch/_inductor/scheduler.py
+++ b/torch/_inductor/scheduler.py
@@ -1116,7 +1116,7 @@
def get_nodes(self):
"""Returns all nodes contained in this kernel, unpacking fused nodes into their constituent scheduler nodes."""
- return list(itertools.chain(*[x.get_nodes() for x in self.snodes]))
+ return list(itertools.chain.from_iterable(x.get_nodes() for x in self.snodes))
def get_first_name(self):
return self.snodes[0].get_first_name()
diff --git a/torch/_numpy/_funcs_impl.py b/torch/_numpy/_funcs_impl.py
index 76a7fe7..2d9b61e 100644
--- a/torch/_numpy/_funcs_impl.py
+++ b/torch/_numpy/_funcs_impl.py
@@ -1353,7 +1353,7 @@
has_sublistout = len(operands) % 2 == 1
if has_sublistout:
sublistout = operands[-1]
- operands = list(itertools.chain(*zip(tensors, sublists)))
+ operands = list(itertools.chain.from_iterable(zip(tensors, sublists)))
if has_sublistout:
operands.append(sublistout)
diff --git a/torch/ao/quantization/pt2e/graph_utils.py b/torch/ao/quantization/pt2e/graph_utils.py
index 2390baf..bacb4d8 100644
--- a/torch/ao/quantization/pt2e/graph_utils.py
+++ b/torch/ao/quantization/pt2e/graph_utils.py
@@ -98,7 +98,7 @@
for partition_type in partition_types:
types_to_match = _get_matching_types(partition_type)
partitions = get_source_partitions(gm.graph, types_to_match, filter_fn)
- typed_partitions[partition_type] = list(itertools.chain(*partitions.values()))
+ typed_partitions[partition_type] = list(itertools.chain.from_iterable(partitions.values()))
typed_partitions_list = list(typed_partitions.values())
fusion_candidates = itertools.product(*typed_partitions_list)
diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py
index da8fd97..4f0aa7f 100644
--- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py
+++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py
@@ -760,7 +760,7 @@
conv_partitions = get_source_partitions(
gm.graph, [torch.nn.Conv2d, torch.nn.functional.conv2d]
)
- conv_partitions = list(itertools.chain(*conv_partitions.values()))
+ conv_partitions = list(itertools.chain.from_iterable(conv_partitions.values()))
for conv_partition in conv_partitions:
if len(conv_partition.output_nodes) > 1:
raise ValueError("conv partition has more than one output node")
@@ -935,7 +935,9 @@
linear_partitions = get_source_partitions(
gm.graph, [torch.nn.Linear, torch.nn.functional.linear]
)
- linear_partitions = list(itertools.chain(*linear_partitions.values()))
+ linear_partitions = list(
+ itertools.chain.from_iterable(linear_partitions.values())
+ )
for partition in linear_partitions:
if len(partition.output_nodes) > 1:
raise ValueError(
diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py
index 8ad9505..9763cb4 100644
--- a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py
+++ b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py
@@ -471,7 +471,7 @@
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Optional[List[List[Node]]]:
gru_partitions = get_source_partitions(gm.graph, [torch.nn.GRU], filter_fn)
- gru_partitions = list(itertools.chain(*gru_partitions.values()))
+ gru_partitions = list(itertools.chain.from_iterable(gru_partitions.values()))
annotated_partitions = []
for gru_partition in gru_partitions:
annotated_partitions.append(gru_partition.nodes)
@@ -525,7 +525,7 @@
module_partitions = get_source_partitions(
gm.graph, [torch.nn.MaxPool2d, torch.nn.functional.max_pool2d], filter_fn
)
- maxpool_partitions = list(itertools.chain(*module_partitions.values()))
+ maxpool_partitions = list(itertools.chain.from_iterable(module_partitions.values()))
annotated_partitions = []
for maxpool_partition in maxpool_partitions:
annotated_partitions.append(maxpool_partition.nodes)
@@ -577,7 +577,7 @@
module_partitions = get_source_partitions(
gm.graph, [torch.nn.AdaptiveAvgPool2d, F.adaptive_avg_pool2d], filter_fn
)
- partitions = list(itertools.chain(*module_partitions.values()))
+ partitions = list(itertools.chain.from_iterable(module_partitions.values()))
annotated_partitions = []
for partition in partitions:
pool_node = partition.output_nodes[0]
@@ -701,7 +701,7 @@
add_partitions = get_source_partitions(
gm.graph, [operator.add, torch.add, operator.iadd], filter_fn
)
- add_partitions = list(itertools.chain(*add_partitions.values()))
+ add_partitions = list(itertools.chain.from_iterable(add_partitions.values()))
annotated_partitions = []
for add_partition in add_partitions:
annotated_partitions.append(add_partition.nodes)
@@ -800,7 +800,7 @@
mul_partitions = get_source_partitions(
gm.graph, ["mul", "mul_", operator.mul, torch.mul, operator.imul], filter_fn
)
- mul_partitions = list(itertools.chain(*mul_partitions.values()))
+ mul_partitions = list(itertools.chain.from_iterable(mul_partitions.values()))
annotated_partitions = []
for mul_partition in mul_partitions:
annotated_partitions.append(mul_partition.nodes)
@@ -844,7 +844,7 @@
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Optional[List[List[Node]]]:
cat_partitions = get_source_partitions(gm.graph, [torch.cat], filter_fn)
- cat_partitions = list(itertools.chain(*cat_partitions.values()))
+ cat_partitions = list(itertools.chain.from_iterable(cat_partitions.values()))
annotated_partitions = []
for cat_partition in cat_partitions:
cat_node = cat_partition.output_nodes[0]
diff --git a/torch/autograd/profiler_legacy.py b/torch/autograd/profiler_legacy.py
index e389653..32700ff 100644
--- a/torch/autograd/profiler_legacy.py
+++ b/torch/autograd/profiler_legacy.py
@@ -184,7 +184,7 @@
record_stack = []
# '__start_profile' is not guaranteed to be first, so we must find it here
- for record in itertools.chain(*thread_records):
+ for record in itertools.chain.from_iterable(thread_records):
name = record.name()
if start_record is None and name == "__start_profile":
start_record = record
diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py
index a5fd1c9..b61a453 100644
--- a/torch/distributed/distributed_c10d.py
+++ b/torch/distributed/distributed_c10d.py
@@ -868,7 +868,7 @@
def _ensure_all_tensors_same_dtype(*tensors) -> None:
last_dtype = None
- for tensor in itertools.chain(*map(_as_iterable, tensors)):
+ for tensor in itertools.chain.from_iterable(map(_as_iterable, tensors)):
tensor_dtype = tensor.dtype
# Mixing complex and its element type is allowed
if tensor_dtype.is_complex:
diff --git a/torch/distributed/rpc/server_process_global_profiler.py b/torch/distributed/rpc/server_process_global_profiler.py
index ad26492..dc3f4c1 100644
--- a/torch/distributed/rpc/server_process_global_profiler.py
+++ b/torch/distributed/rpc/server_process_global_profiler.py
@@ -163,7 +163,7 @@
process_global_function_events.append(thread_local_function_events)
flattened_function_events = list(
- itertools.chain(*process_global_function_events)
+ itertools.chain.from_iterable(process_global_function_events)
)
self.function_events = torch.autograd.profiler_util.EventList(
flattened_function_events,
diff --git a/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py b/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py
index 153a840..d866117 100644
--- a/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py
+++ b/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py
@@ -1017,7 +1017,7 @@
"""
dims, counter = gen_lists_of_dims(4, i, counter)
[d1, d2, d3, d4] = dims
- nat_dims_i = gen_nat_constraints(list(itertools.chain(*dims)))
+ nat_dims_i = gen_nat_constraints(list(itertools.chain.from_iterable(dims)))
initialize_tensors_constraints = create_equality_constraints_for_broadcasting(e1, e2, e11, e12,
d1, d2, d3, d4)
diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py
index 1ede684..1b14fa3 100644
--- a/torch/nn/parallel/distributed.py
+++ b/torch/nn/parallel/distributed.py
@@ -150,12 +150,12 @@
if isinstance(obj, torch.Tensor):
return [obj]
if isinstance(obj, (list, tuple)):
- return itertools.chain(*map(_find_tensors, obj))
+ return itertools.chain.from_iterable(map(_find_tensors, obj))
if isinstance(obj, dict):
- return itertools.chain(*map(_find_tensors, obj.values()))
+ return itertools.chain.from_iterable(map(_find_tensors, obj.values()))
if is_dataclass(obj):
- return itertools.chain(
- *map(_find_tensors, (getattr(obj, f.name) for f in fields(obj)))
+ return itertools.chain.from_iterable(
+ map(_find_tensors, (getattr(obj, f.name) for f in fields(obj)))
)
return []