[BE]: Use `iterable.chain.from_iterable` where possible (#116376)

This is more readable and more efficient when dealing with lots of sequences to chain together.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116376
Approved by: https://github.com/albanD
diff --git a/test/test_linalg.py b/test/test_linalg.py
index 8d60d66..e20d11f 100644
--- a/test/test_linalg.py
+++ b/test/test_linalg.py
@@ -3907,7 +3907,7 @@
                 self._check_einsum(equation, *operands, np_args=(equation, *np_operands))
 
                 # test sublist format
-                args = [*itertools.chain(*zip(operands, sublists))]
+                args = list(itertools.chain.from_iterable(zip(operands, sublists)))
                 self._check_einsum(*args, np_args=(equation, *np_operands))
 
                 # generate an explicit output
diff --git a/test/torch_np/numpy_tests/core/test_numeric.py b/test/torch_np/numpy_tests/core/test_numeric.py
index 131ef7b..e94c3f2 100644
--- a/test/torch_np/numpy_tests/core/test_numeric.py
+++ b/test/torch_np/numpy_tests/core/test_numeric.py
@@ -2132,7 +2132,7 @@
     # Test ones, zeros, empty and full.
 
     def setUp(self):
-        # dtypes = {np.dtype(tp) for tp in itertools.chain(*np.sctypes.values())}
+        # dtypes = {np.dtype(tp) for tp in itertools.chain.from_iterable(np.sctypes.values())}
         dtypes = {np.dtype(tp) for tp in "efdFDBbhil?"}
         self.dtypes = dtypes
         self.orders = {
diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py
index 0eca115..854a6b1 100644
--- a/torch/_dynamo/variables/iter.py
+++ b/torch/_dynamo/variables/iter.py
@@ -45,9 +45,7 @@
             and all(arg.has_unpack_var_sequence(tx) for arg in args)
         ):
             seqs = [arg.unpack_var_sequence(tx) for arg in args]
-            items = []
-            for item in itertools.chain(*seqs):
-                items.append(item)
+            items = list(itertools.chain.from_iterable(seqs))
             return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
         elif self.value is itertools.accumulate:
             from .builtin import BuiltinVariable
diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py
index 1adede8..8f67877 100644
--- a/torch/_inductor/codegen/triton.py
+++ b/torch/_inductor/codegen/triton.py
@@ -1193,7 +1193,7 @@
         new_ranges, return_getters_groups = self._split_iteration_ranges(
             groups, lengths
         )
-        itervars = list(itertools.chain(*self.set_ranges(*new_ranges)))
+        itervars = list(itertools.chain.from_iterable(self.set_ranges(*new_ranges)))
         return [[fn(itervars) for fn in fns] for fns in return_getters_groups]
 
     def is_indirect_indexing(self, index: sympy.Expr):
@@ -3033,11 +3033,11 @@
         dep_sources = [rw.reads, rw.writes]
         assert all(
             isinstance(dep, (MemoryDep, StarDep))
-            for dep in itertools.chain(*dep_sources)
+            for dep in itertools.chain.from_iterable(dep_sources)
         )
         deps = [
             dep
-            for dep in itertools.chain(*dep_sources)
+            for dep in itertools.chain.from_iterable(dep_sources)
             if dep.name not in V.graph.removed_buffers and isinstance(dep, MemoryDep)
         ]
         write_names = {dep.name for dep in rw.writes}
diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py
index 5a5982c..7f2e9b8 100644
--- a/torch/_inductor/dependencies.py
+++ b/torch/_inductor/dependencies.py
@@ -363,7 +363,7 @@
     if normalize:
         range_vars = []  # Number of vars could differ due to normalization
     else:
-        range_vars = [*itertools.chain(*args)]
+        range_vars = list(itertools.chain.from_iterable(args))
 
     inner = rw.parent_handler.parent_handler
     return ReadWrites(
diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py
index 10e9b60..893aa67 100644
--- a/torch/_inductor/fx_passes/post_grad.py
+++ b/torch/_inductor/fx_passes/post_grad.py
@@ -381,7 +381,8 @@
 
     assert new_size is not None
     dtype = functools.reduce(
-        torch.promote_types, [x.get_dtype() for x in itertools.chain(*inputs)]
+        torch.promote_types,
+        [x.get_dtype() for x in itertools.chain.from_iterable(inputs)],
     )
     device = inputs[0][0].get_device()
     kernel = ir.ConcatKernel(
diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py
index 59667d5..ee50839 100644
--- a/torch/_inductor/ir.py
+++ b/torch/_inductor/ir.py
@@ -6546,7 +6546,7 @@
         return self.indexing[name]
 
     def __call__(self, *indices):
-        index = list(itertools.chain(*indices))
+        index = list(itertools.chain.from_iterable(indices))
         assert len(index) == len(self.var_ranges), (index, self.var_ranges)
         assert all(v not in self.var_ranges for v in index)
         replacements = dict(zip(self.var_ranges.keys(), index))
diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py
index 0a75147..82263ff 100644
--- a/torch/_inductor/scheduler.py
+++ b/torch/_inductor/scheduler.py
@@ -1116,7 +1116,7 @@
 
     def get_nodes(self):
         """Returns all nodes contained in this kernel, unpacking fused nodes into their constituent scheduler nodes."""
-        return list(itertools.chain(*[x.get_nodes() for x in self.snodes]))
+        return list(itertools.chain.from_iterable(x.get_nodes() for x in self.snodes))
 
     def get_first_name(self):
         return self.snodes[0].get_first_name()
diff --git a/torch/_numpy/_funcs_impl.py b/torch/_numpy/_funcs_impl.py
index 76a7fe7..2d9b61e 100644
--- a/torch/_numpy/_funcs_impl.py
+++ b/torch/_numpy/_funcs_impl.py
@@ -1353,7 +1353,7 @@
             has_sublistout = len(operands) % 2 == 1
             if has_sublistout:
                 sublistout = operands[-1]
-            operands = list(itertools.chain(*zip(tensors, sublists)))
+            operands = list(itertools.chain.from_iterable(zip(tensors, sublists)))
             if has_sublistout:
                 operands.append(sublistout)
 
diff --git a/torch/ao/quantization/pt2e/graph_utils.py b/torch/ao/quantization/pt2e/graph_utils.py
index 2390baf..bacb4d8 100644
--- a/torch/ao/quantization/pt2e/graph_utils.py
+++ b/torch/ao/quantization/pt2e/graph_utils.py
@@ -98,7 +98,7 @@
     for partition_type in partition_types:
         types_to_match = _get_matching_types(partition_type)
         partitions = get_source_partitions(gm.graph, types_to_match, filter_fn)
-        typed_partitions[partition_type] = list(itertools.chain(*partitions.values()))
+        typed_partitions[partition_type] = list(itertools.chain.from_iterable(partitions.values()))
 
     typed_partitions_list = list(typed_partitions.values())
     fusion_candidates = itertools.product(*typed_partitions_list)
diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py
index da8fd97..4f0aa7f 100644
--- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py
+++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py
@@ -760,7 +760,7 @@
         conv_partitions = get_source_partitions(
             gm.graph, [torch.nn.Conv2d, torch.nn.functional.conv2d]
         )
-        conv_partitions = list(itertools.chain(*conv_partitions.values()))
+        conv_partitions = list(itertools.chain.from_iterable(conv_partitions.values()))
         for conv_partition in conv_partitions:
             if len(conv_partition.output_nodes) > 1:
                 raise ValueError("conv partition has more than one output node")
@@ -935,7 +935,9 @@
         linear_partitions = get_source_partitions(
             gm.graph, [torch.nn.Linear, torch.nn.functional.linear]
         )
-        linear_partitions = list(itertools.chain(*linear_partitions.values()))
+        linear_partitions = list(
+            itertools.chain.from_iterable(linear_partitions.values())
+        )
         for partition in linear_partitions:
             if len(partition.output_nodes) > 1:
                 raise ValueError(
diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py
index 8ad9505..9763cb4 100644
--- a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py
+++ b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py
@@ -471,7 +471,7 @@
     filter_fn: Optional[Callable[[Node], bool]] = None,
 ) -> Optional[List[List[Node]]]:
     gru_partitions = get_source_partitions(gm.graph, [torch.nn.GRU], filter_fn)
-    gru_partitions = list(itertools.chain(*gru_partitions.values()))
+    gru_partitions = list(itertools.chain.from_iterable(gru_partitions.values()))
     annotated_partitions = []
     for gru_partition in gru_partitions:
         annotated_partitions.append(gru_partition.nodes)
@@ -525,7 +525,7 @@
     module_partitions = get_source_partitions(
         gm.graph, [torch.nn.MaxPool2d, torch.nn.functional.max_pool2d], filter_fn
     )
-    maxpool_partitions = list(itertools.chain(*module_partitions.values()))
+    maxpool_partitions = list(itertools.chain.from_iterable(module_partitions.values()))
     annotated_partitions = []
     for maxpool_partition in maxpool_partitions:
         annotated_partitions.append(maxpool_partition.nodes)
@@ -577,7 +577,7 @@
     module_partitions = get_source_partitions(
         gm.graph, [torch.nn.AdaptiveAvgPool2d, F.adaptive_avg_pool2d], filter_fn
     )
-    partitions = list(itertools.chain(*module_partitions.values()))
+    partitions = list(itertools.chain.from_iterable(module_partitions.values()))
     annotated_partitions = []
     for partition in partitions:
         pool_node = partition.output_nodes[0]
@@ -701,7 +701,7 @@
     add_partitions = get_source_partitions(
         gm.graph, [operator.add, torch.add, operator.iadd], filter_fn
     )
-    add_partitions = list(itertools.chain(*add_partitions.values()))
+    add_partitions = list(itertools.chain.from_iterable(add_partitions.values()))
     annotated_partitions = []
     for add_partition in add_partitions:
         annotated_partitions.append(add_partition.nodes)
@@ -800,7 +800,7 @@
     mul_partitions = get_source_partitions(
         gm.graph, ["mul", "mul_", operator.mul, torch.mul, operator.imul], filter_fn
     )
-    mul_partitions = list(itertools.chain(*mul_partitions.values()))
+    mul_partitions = list(itertools.chain.from_iterable(mul_partitions.values()))
     annotated_partitions = []
     for mul_partition in mul_partitions:
         annotated_partitions.append(mul_partition.nodes)
@@ -844,7 +844,7 @@
     filter_fn: Optional[Callable[[Node], bool]] = None,
 ) -> Optional[List[List[Node]]]:
     cat_partitions = get_source_partitions(gm.graph, [torch.cat], filter_fn)
-    cat_partitions = list(itertools.chain(*cat_partitions.values()))
+    cat_partitions = list(itertools.chain.from_iterable(cat_partitions.values()))
     annotated_partitions = []
     for cat_partition in cat_partitions:
         cat_node = cat_partition.output_nodes[0]
diff --git a/torch/autograd/profiler_legacy.py b/torch/autograd/profiler_legacy.py
index e389653..32700ff 100644
--- a/torch/autograd/profiler_legacy.py
+++ b/torch/autograd/profiler_legacy.py
@@ -184,7 +184,7 @@
     record_stack = []
 
     # '__start_profile' is not guaranteed to be first, so we must find it here
-    for record in itertools.chain(*thread_records):
+    for record in itertools.chain.from_iterable(thread_records):
         name = record.name()
         if start_record is None and name == "__start_profile":
             start_record = record
diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py
index a5fd1c9..b61a453 100644
--- a/torch/distributed/distributed_c10d.py
+++ b/torch/distributed/distributed_c10d.py
@@ -868,7 +868,7 @@
 
 def _ensure_all_tensors_same_dtype(*tensors) -> None:
     last_dtype = None
-    for tensor in itertools.chain(*map(_as_iterable, tensors)):
+    for tensor in itertools.chain.from_iterable(map(_as_iterable, tensors)):
         tensor_dtype = tensor.dtype
         # Mixing complex and its element type is allowed
         if tensor_dtype.is_complex:
diff --git a/torch/distributed/rpc/server_process_global_profiler.py b/torch/distributed/rpc/server_process_global_profiler.py
index ad26492..dc3f4c1 100644
--- a/torch/distributed/rpc/server_process_global_profiler.py
+++ b/torch/distributed/rpc/server_process_global_profiler.py
@@ -163,7 +163,7 @@
             process_global_function_events.append(thread_local_function_events)
 
         flattened_function_events = list(
-            itertools.chain(*process_global_function_events)
+            itertools.chain.from_iterable(process_global_function_events)
         )
         self.function_events = torch.autograd.profiler_util.EventList(
             flattened_function_events,
diff --git a/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py b/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py
index 153a840..d866117 100644
--- a/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py
+++ b/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py
@@ -1017,7 +1017,7 @@
     """
     dims, counter = gen_lists_of_dims(4, i, counter)
     [d1, d2, d3, d4] = dims
-    nat_dims_i = gen_nat_constraints(list(itertools.chain(*dims)))
+    nat_dims_i = gen_nat_constraints(list(itertools.chain.from_iterable(dims)))
 
     initialize_tensors_constraints = create_equality_constraints_for_broadcasting(e1, e2, e11, e12,
                                                                                   d1, d2, d3, d4)
diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py
index 1ede684..1b14fa3 100644
--- a/torch/nn/parallel/distributed.py
+++ b/torch/nn/parallel/distributed.py
@@ -150,12 +150,12 @@
     if isinstance(obj, torch.Tensor):
         return [obj]
     if isinstance(obj, (list, tuple)):
-        return itertools.chain(*map(_find_tensors, obj))
+        return itertools.chain.from_iterable(map(_find_tensors, obj))
     if isinstance(obj, dict):
-        return itertools.chain(*map(_find_tensors, obj.values()))
+        return itertools.chain.from_iterable(map(_find_tensors, obj.values()))
     if is_dataclass(obj):
-        return itertools.chain(
-            *map(_find_tensors, (getattr(obj, f.name) for f in fields(obj)))
+        return itertools.chain.from_iterable(
+            map(_find_tensors, (getattr(obj, f.name) for f in fields(obj)))
         )
 
     return []