[Inductor][CPP] Enable Local Buffer for Outer loop fusion (#126967)

**Summary**
Currently, the Inductor CPP backend [generated code](https://gist.github.com/leslie-fang-intel/98f91d43dabed581a1ffe23daf133a65#file-bf16-softmax-generated-code-wo-local-buffer-py) for `Softmax` with BF16 data type is significantly slower than the [ATen Implementation](https://github.com/pytorch/pytorch/blob/9a2beb862d9c30f037c9b2eac878ec3f9227a5e2/aten/src/ATen/native/cpu/SoftMaxKernel.cpp#L149). Upon comparing the generated code with ATen, the performance bottleneck appears to be related to the usage of [local buffer in ATen](https://github.com/pytorch/pytorch/blob/9a2beb862d9c30f037c9b2eac878ec3f9227a5e2/aten/src/ATen/native/cpu/SoftMaxKernel.cpp#L159-L160).

In the current implementation, the Inductor uses the output buffer of Kernel Group Args to store and load temporary result (such as `exp`), since this buffer is corresponding to a `SchedulerNode`. Each thread accesses a portion of this output buffer via indexing. However, since this buffer (take this `exp` as example) is only utilized internally within decomposed `softmax`, this buffer can be replaced with a thread-local buffer similar to ATen's approach.

In this PR, we have introduced the optimizations of `LocalBuffer`. Following this enhancement, the [new generated Inductor code with local buffer](https://gist.github.com/leslie-fang-intel/98f91d43dabed581a1ffe23daf133a65#file-bf16-softmax-generated-code-w-local-buffer-py) for BF16 `Softmax` demonstrates significantly improved performance. Running the benchmark [here](https://gist.github.com/leslie-fang-intel/37d81441237b5139c8295f5e6c4cd31a) to test this BF16 `Softmax` case on an 8480 Xeon server shows similar performance between the Inductor CPP Backend and the ATen implementation.

**TestPlan**
```
python -u -m pytest -s -v inductor/test_cpu_repro.py -k test_local_buffer_in_outer_loop_fusion
```

**Next Step**

- [ ] Support more than one Local Buffer/Global Buffer

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126967
Approved by: https://github.com/jgong5, https://github.com/peterbell10
diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py
index fcbac3a..2889e3d 100644
--- a/test/inductor/test_cpu_repro.py
+++ b/test/inductor/test_cpu_repro.py
@@ -2556,6 +2556,7 @@
                 self.common(fn, (x,))
                 assert metrics.generated_cpp_vec_kernel_count == 0
 
+    @config.patch(fx_graph_cache=False)
     def test_outer_loop_fusion(self):
         def fn(x):
             max = torch.amax(x, dim=-1, keepdim=True)
@@ -2567,8 +2568,47 @@
             torch._dynamo.reset()
             metrics.reset()
             self.common(fn, (x,))
-            assert len(metrics.cpp_outer_loop_fused_inner_counts) == 1
-            assert metrics.cpp_outer_loop_fused_inner_counts[0] == 2
+            self.assertEqual(
+                len(metrics.cpp_outer_loop_fused_inner_counts),
+                1,
+            )
+            self.assertEqual(
+                metrics.cpp_outer_loop_fused_inner_counts[0].inner_kernel_number,
+                2,
+            )
+
+    @config.patch(fx_graph_cache=False)
+    def test_local_buffer_in_outer_loop_fusion(self):
+        def fn(x):
+            max = torch.nn.functional.softmax(x, dim=-1)
+            return x - max
+
+        x = torch.randn(4, 12, 1023, 1022)
+
+        with config.patch({"cpp.simdlen": None}):
+            torch._dynamo.reset()
+            metrics.reset()
+            self.common(fn, (x,))
+            self.assertEqual(
+                len(metrics.cpp_outer_loop_fused_inner_counts),
+                1,
+            )
+            self.assertEqual(
+                metrics.cpp_outer_loop_fused_inner_counts[0].inner_kernel_number,
+                3,
+            )
+            self.assertEqual(
+                metrics.cpp_outer_loop_fused_inner_counts[0].local_buffer_number,
+                1,
+            )
+            # Check the number of global buffer allocation
+            torch._dynamo.reset()
+            metrics.reset()
+            _, code = run_and_get_cpp_code(
+                torch._dynamo.optimize("inductor")(fn),
+                x,
+            )
+            self.assertEqual(code.count("empty_strided_cpu("), 3)
 
     def test_argmin(self):
         def fn(x):
diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py
index 1c08853..caaaf03 100644
--- a/torch/_inductor/codegen/cpp.py
+++ b/torch/_inductor/codegen/cpp.py
@@ -7,6 +7,7 @@
 import math
 import re
 import sys
+from collections import namedtuple
 from copy import copy, deepcopy
 from enum import Enum
 from typing import Any, cast, Dict, List, Optional, Sequence, Set, Tuple, Union
@@ -69,6 +70,7 @@
     cexpr_index,
     DTYPE_TO_CPP,
     INDEX_TYPE,
+    LocalBufferContext,
     unify_mask_base_type,
     value_to_cpp,
 )
@@ -435,8 +437,6 @@
         loop_nest_list: List[LoopNestWithSplit] = [
             kernel.loop_nest for kernel in cpp_kernel_proxy_list
         ]
-        metrics.cpp_outer_loop_fused_inner_counts.append(len(loop_nest_list))
-
         kernel_group = cpp_kernel_proxy_list[0].kernel_group
 
         def _merge_outer_fusion_loop_levels(
@@ -1915,7 +1915,10 @@
         threads = parallel_num_threads()
         assert self.call_ranges is not None
         kernels = loop_nest.get_kernels()
-        if any(isinstance(kernel, OuterLoopFusedKernel) for kernel in kernels):
+        has_outer_loop_kernel = any(
+            isinstance(kernel, OuterLoopFusedKernel) for kernel in kernels
+        )
+        if has_outer_loop_kernel:
             assert len(kernels) == 1
             assert isinstance(kernels[0], OuterLoopFusedKernel)
             par_depth = kernels[0].decide_parallel_depth(
@@ -2045,6 +2048,31 @@
 
             stack.enter_context(code.indent())
             if loop_nest.root:
+                if (
+                    has_outer_loop_kernel
+                    and isinstance(V.local_buffer_context, LocalBufferContext)
+                    and V.local_buffer_context.local_buffers
+                ):
+                    # Allocate local buffer
+                    local_buffers = V.local_buffer_context.local_buffers
+                    assert len(local_buffers.items()) == 1
+                    local_buffer = next(iter(local_buffers.items()))[1]
+                    # For dynamic size, rename s to ks
+                    local_buf_size = sympy_product(
+                        [
+                            self.rename_indexing(size_val)
+                            for size_val in local_buffer.get_layout().size
+                        ]
+                    )
+                    local_buf_dtype = DTYPE_TO_CPP[local_buffer.get_layout().dtype]
+                    allocate = f"std::make_unique<{local_buf_dtype} []>({cexpr(local_buf_size)})"
+                    code.splice(
+                        f"std::unique_ptr<{local_buf_dtype} []> local_buffer = {allocate};"
+                    )
+                    local_buffer_name = local_buffer.get_name()
+                    code.splice(
+                        f"{local_buf_dtype}* {local_buffer_name} = local_buffer.get();"
+                    )
                 gen_loops(loop_nest.root)
             else:
                 gen_kernel(loop_nest.kernel)
@@ -3500,6 +3528,18 @@
                 return node.codegen(index_vars)
 
         fn_list = [functools.partial(fn, node) for node in nodes]
+
+        if (
+            isinstance(V.local_buffer_context, LocalBufferContext)
+            and V.local_buffer_context.local_buffers
+        ):
+            fn_list = [
+                V.local_buffer_context.localize_function(
+                    fn,
+                )
+                for fn in fn_list
+            ]
+
         var_sizes_list = [node.group[1] for node in nodes]
         self.codegen_functions(fn_list, var_sizes_list, vec_dtype)
 
@@ -3807,6 +3847,159 @@
             self._can_fuse_horizontal_impl(node1, node2) and not node1.is_reduction()
         ) or self.can_fuse_vertical_outer_loop(node1, node2)
 
+    def codegen_outer_loop_node(
+        self,
+        node: OuterLoopFusedSchedulerNode,
+    ):
+        """
+        Generate the code for the outer loop fused scheduler node.
+        1. Codegen with fused outer loop: depends on the analysis of
+            the outer loop fused scheduler node, with or without the local buffer.
+        2. If failed, fallback to standard codegen.
+        """
+        kernel_group = self.kernel_group
+        generated_cpp_vec_kernel_count = metrics.generated_cpp_vec_kernel_count
+        cpp_kernel_proxy_list: List[CppKernelProxy] = []
+        nodes_list: List[List[SchedulerNode]] = []
+        assert isinstance(node, OuterLoopFusedSchedulerNode)
+
+        def try_outer_loop_fusion_with_local_buf(node: OuterLoopFusedSchedulerNode):
+            """
+            Codegen code with fused outer loop and local Buffer.
+            """
+            assert isinstance(node, OuterLoopFusedSchedulerNode)
+            cpp_kernel_proxy_list.clear()
+            nodes_list.clear()
+
+            def get_call_ranges(node: BaseSchedulerNode):
+                assert isinstance(node, (SchedulerNode, FusedSchedulerNode))
+                nodes: List[SchedulerNode] = node.get_nodes()  # type: ignore[assignment]
+                _, (group, reduction_group) = max(
+                    nodes, key=lambda x: int(x.is_reduction())
+                ).group
+                call_ranges = tuple(group) + tuple(reduction_group)
+                return call_ranges
+
+            LocalBuffer = namedtuple("LocalBuffer", ["local_buf", "global_buf"])
+            local_buffers: List[LocalBuffer] = []
+            if all(
+                len(get_call_ranges(_node)) == node.outer_loop_fusion_depth + 1
+                for _node in node.get_outer_nodes()
+            ):
+                # Ref to the typical case of local buffer
+                # in https://github.com/pytorch/pytorch/blob/
+                # 1115a25c36340554442f28f9570abd42f0aface2/aten/src/ATen/native/cpu/SoftMaxKernel.cpp#L159
+                # where the buffer is with size of last dim and contiguous.
+                # Only support this typical case at first.
+                for scheduler_node in node.get_nodes():
+                    # all users inside same OuterLoopFusedSchedulerNode
+                    if not scheduler_node.is_reduction() and all(
+                        user.node in node.get_nodes() for user in scheduler_node.users
+                    ):
+                        global_buffer = scheduler_node.node
+                        assert isinstance(global_buffer, ir.ComputedBuffer)
+                        global_buffer_layout = global_buffer.get_layout()
+                        size_offset = node.outer_loop_fusion_depth - len(
+                            get_call_ranges(scheduler_node)
+                        )
+
+                        def is_all_write_read_contiguous(scheduler_node):
+                            contiguous_index_expr = 0
+                            stride = 1
+                            for var, range in reversed(
+                                scheduler_node._body.var_ranges.items()
+                            ):
+                                contiguous_index_expr += stride * var
+                                stride *= range
+                            write_index_expr = scheduler_node._body.writes_name2expr[
+                                scheduler_node.get_name()
+                            ]
+
+                            def is_contiguous_index(x):
+                                return x == contiguous_index_expr
+
+                            return is_contiguous_index(write_index_expr) and all(
+                                is_contiguous_index(
+                                    user.node._body.reads_name2expr[
+                                        scheduler_node.get_name()
+                                    ],
+                                )
+                                for user in scheduler_node.users
+                            )
+
+                        if not (
+                            global_buffer_layout.is_contiguous()
+                            and not scheduler_node.is_reduction()
+                            and is_all_write_read_contiguous(scheduler_node)
+                        ):
+                            continue
+                        # Local Buffer is a view of global buffer
+                        local_buffer_layout = ir.FixedLayout(
+                            global_buffer_layout.device,
+                            global_buffer_layout.dtype,
+                            global_buffer_layout.size[size_offset:],
+                            global_buffer_layout.stride[size_offset:],
+                        )
+                        local_buffers.append(
+                            LocalBuffer(
+                                local_buf=ir.Buffer(
+                                    "local_buffer_data", local_buffer_layout
+                                ),
+                                global_buf=global_buffer,
+                            )
+                        )
+                        # At most 1 node with local buf for each OuterLoopFusedSchedulerNode
+                        break
+            assert len(local_buffers) in [0, 1]
+
+            with LocalBufferContext(kernel_group.args) as scope:
+                if len(local_buffers) > 0:
+                    scope.add_local_buffer(
+                        local_buffers[0].local_buf, local_buffers[0].global_buf
+                    )
+                for _node in node.get_outer_nodes():
+                    assert isinstance(_node, (FusedSchedulerNode, SchedulerNode))
+                    cpp_kernel_proxy = CppKernelProxy(kernel_group)
+                    cpp_kernel_proxy.codegen_nodes(_node.get_nodes())  # type: ignore[arg-type]
+                    cpp_kernel_proxy_list.append(cpp_kernel_proxy)
+                    nodes_list.append(_node.get_nodes())  # type: ignore[arg-type]
+
+                if not node.check_outer_fusion_loop_level_attr(
+                    cpp_kernel_proxy_list, node.outer_loop_fusion_depth
+                ):
+                    return False
+                metrics.cpp_outer_loop_fused_inner_counts.append(
+                    metrics.CppOuterLoopFusedCount(
+                        len(cpp_kernel_proxy_list),
+                        local_buffer_number=len(local_buffers),
+                    )
+                )
+                outer_fusion_cpp_kernel_proxy = node.merge_outer_fusion_kernels(
+                    cpp_kernel_proxy_list,
+                )
+                kernel_group.finalize_kernel(
+                    outer_fusion_cpp_kernel_proxy,
+                    [_node for _nodes in nodes_list for _node in _nodes],
+                )
+
+            return True
+
+        if not try_outer_loop_fusion_with_local_buf(node):
+            # Reset generated_cpp_vec_kernel_count to codegen again
+            metrics.generated_cpp_vec_kernel_count = generated_cpp_vec_kernel_count
+            cpp_kernel_proxy_list.clear()
+            nodes_list.clear()
+            # Similar as comment in
+            # https://github.com/pytorch/pytorch/blob/469383755fe416eb1c41fa724762ad3eaecdff07/torch/_inductor/codegen/cpp.py#L3269-L3272
+            # Kernels share the same global contexts like V.graph.wrapper_code, V.kernel.args.
+            with torch._inductor.config.patch(inplace_buffers=False):
+                for _node in node.get_outer_nodes():
+                    assert isinstance(_node, (FusedSchedulerNode, SchedulerNode))
+                    _nodes: List[SchedulerNode] = _node.get_nodes()  # type: ignore[assignment]
+                    cpp_kernel_proxy = CppKernelProxy(kernel_group)
+                    cpp_kernel_proxy.codegen_nodes(_nodes)
+                    kernel_group.finalize_kernel(cpp_kernel_proxy, _nodes)
+
     def codegen_node(
         self,
         node: Union[OuterLoopFusedSchedulerNode, FusedSchedulerNode, SchedulerNode],
@@ -3817,38 +4010,7 @@
         kernel_group = self.kernel_group
 
         if isinstance(node, OuterLoopFusedSchedulerNode):
-            cpp_kernel_proxy_list: List[CppKernelProxy] = []
-            nodes_list: List[List[SchedulerNode]] = []
-
-            for _node in node.get_outer_nodes():
-                assert isinstance(_node, (FusedSchedulerNode, SchedulerNode))
-                _nodes: List[SchedulerNode] = _node.get_nodes()  # type: ignore[assignment]
-                cpp_kernel_proxy = CppKernelProxy(kernel_group)
-                cpp_kernel_proxy.codegen_nodes(_nodes)
-
-                cpp_kernel_proxy_list.append(cpp_kernel_proxy)
-                nodes_list.append(_nodes)
-
-            # Note that, in the future, when every kernel can be vectorized,
-            # the function select_tiling will be much easier, and we'll be able to lift
-            # check_outer_fusion_loop_level_attr to the fusion phase,
-            # avoiding grouping kernels at fusion time that "look like we'll be able to fuse them"
-            # but then we actually won't.
-            if node.check_outer_fusion_loop_level_attr(
-                cpp_kernel_proxy_list, node.outer_loop_fusion_depth
-            ):
-                # Merge the cpp_kernel_proxy_list into cpp_kernel_proxy
-                outer_fusion_cpp_kernel_proxy = node.merge_outer_fusion_kernels(
-                    cpp_kernel_proxy_list,
-                )
-                kernel_group.finalize_kernel(
-                    outer_fusion_cpp_kernel_proxy,
-                    [_node for _nodes in nodes_list for _node in _nodes],
-                )
-            else:
-                # Fall back to standard loop codegen
-                for _kernel_proxy, _nodes in zip(cpp_kernel_proxy_list, nodes_list):
-                    kernel_group.finalize_kernel(_kernel_proxy, _nodes)
+            self.codegen_outer_loop_node(node)
         else:
             nodes: List[SchedulerNode] = node.get_nodes()  # type: ignore[assignment]
             cpp_kernel_proxy = CppKernelProxy(kernel_group)
diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py
index 58adcda..a27ed7b 100644
--- a/torch/_inductor/codegen/cpp_template_kernel.py
+++ b/torch/_inductor/codegen/cpp_template_kernel.py
@@ -14,7 +14,7 @@
 from ..utils import sympy_index_symbol, sympy_index_symbol_with_prefix
 from ..virtualized import V
 from .cpp import CppKernel, CppKernelProxy, KernelGroup
-from .cpp_utils import cexpr_index, DTYPE_TO_CPP, LocalBufferScope
+from .cpp_utils import cexpr_index, DTYPE_TO_CPP, LocalBufferContext
 
 
 def parse_expr_with_index_symbols(expr):
@@ -270,13 +270,11 @@
         if offsets:
             offsets = parse_expr_with_index_symbols(offsets)
         if epilogue_nodes:
-            with LocalBufferScope(self) as scope:
+            with LocalBufferContext(self.args) as scope:
                 assert orig_src is not None
                 if orig_src.get_name() != src.get_name():
-                    scope.add_local_buffer(src)
-                    epilogue_nodes = scope.localize_buffer(
-                        orig_src, src, epilogue_nodes
-                    )
+                    scope.add_local_buffer(src, orig_src)
+                    epilogue_nodes = scope.localize_nodes(epilogue_nodes)
                 return self.store_pointwise_nodes(
                     dst, epilogue_nodes, offsets, reindexers  # type: ignore[arg-type]
                 )
@@ -284,7 +282,7 @@
             if dst.get_name() != src.get_name():
                 # src is local
                 copy = L.copy(dst, src).data.data
-                with LocalBufferScope(self) as scope:
+                with LocalBufferContext(self.args) as scope:
                     scope.add_local_buffer(src)
                     return self.store_pointwise_nodes(dst, [copy])
             else:
diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py
index 886a296..b27975c 100644
--- a/torch/_inductor/codegen/cpp_utils.py
+++ b/torch/_inductor/codegen/cpp_utils.py
@@ -4,7 +4,7 @@
 import math
 
 from collections import namedtuple
-from typing import Dict, List, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple
 from unittest.mock import patch
 
 import sympy
@@ -12,11 +12,10 @@
 import torch
 from torch.utils._sympy.symbol import symbol_is_type, SymT
 from .. import ir
-from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix
+from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix, sympy_subs
 from ..virtualized import V
 
-from .common import CSEVariable, ExprPrinter, Kernel
-
+from .common import CSEVariable, ExprPrinter, Kernel, KernelArgs
 
 DTYPE_TO_CPP = {
     torch.float32: "float",
@@ -304,7 +303,88 @@
         return f"static_cast<{cpp_type}>({repr(value)})"
 
 
-class LocalBufferScope:
+def rewrite_index_for_function(
+    localize_buffer_handler: "LocalizeBufferHandler",
+    index: sympy.Expr,
+):
+    # Local buffer at the inner dimensions
+    snode = V.graph.scheduler.name_to_node.get(
+        localize_buffer_handler.global_buf.get_name()
+    )
+    assert snode is not None
+    scheduler_nodes = snode.get_nodes()
+    _, (group, reduction_group) = max(
+        scheduler_nodes, key=lambda x: int(x.is_reduction())
+    ).group
+    call_ranges = tuple(group) + tuple(reduction_group)
+    indices_to_keep = [
+        f"x{len(call_ranges) - (idx + 1)}"
+        for idx in range(len(localize_buffer_handler.local_buf.get_layout().size))
+    ]
+    sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name)  # type: ignore[attr-defined]
+    replacements = {}
+    for x in sorted_symbols:
+        if x.name.startswith("x") and x.name not in indices_to_keep:  # type: ignore[attr-defined]
+            # Only keep index used by local buffer
+            replacements[x] = sympy.core.numbers.Zero()
+    index = sympy_subs(index, replacements)  # type: ignore[arg-type]
+    return index
+
+
+def rewrite_index_for_nodes(
+    localize_buffer_handler: "LocalizeBufferHandler",
+    index: sympy.Expr,
+):
+    used_vars = {s for s in index.free_symbols if symbol_is_type(s, SymT.INDEX)}
+    index_vars = []
+    for i in range(len(localize_buffer_handler.local_buf.get_size())):
+        var = sympy_index_symbol_with_prefix(SymT.INDEX, i)
+        index_vars.append(var if var in used_vars else 0)
+    index = localize_buffer_handler.local_buf.layout.make_indexer()(index_vars)
+    return index
+
+
+class LocalizeBufferHandler(V.WrapperHandler):  # type: ignore[name-defined]
+    def __init__(
+        self,
+        inner,
+        global_buf: ir.Buffer,
+        local_buf: ir.Buffer,
+        rewrite_index: Callable[["LocalizeBufferHandler", sympy.Expr], sympy.Expr],
+    ):
+        super().__init__(inner)
+        self.global_buf = global_buf
+        self.local_buf = local_buf
+        self.rewrite_index = rewrite_index
+
+    def localize(self, name: str, index: sympy.Expr):
+        if self.global_buf and name == self.global_buf.get_name():
+            assert self.rewrite_index is not None
+            name = self.local_buf.get_name()
+            index = self.rewrite_index(self, index)
+        return name, index
+
+    def load(self, name: str, index: sympy.Expr):
+        return self._inner.load(*self.localize(name, index))
+
+    def store(self, name, index, value, mode=None):
+        local_buffer_name, local_buffer_index = self.localize(name, index)
+        res = self._inner.store(local_buffer_name, local_buffer_index, value, mode)
+        if (
+            self.global_buf
+            and name == self.global_buf.get_name()
+            and isinstance(V.kernel, Kernel)
+        ):
+            # Remove name of local buffer from Kernel.store_buffer_names
+            # local_buffer_name is added to Kernel.store_buffer_names in Kernel.CSEProxy.store.
+            V.kernel.store_buffer_names.discard(local_buffer_name)
+        return res
+
+    def store_reduction(self, name, index, value):
+        return self._inner.store_reduction(*self.localize(name, index), value)
+
+
+class LocalBufferContext:
     """
     This class creates a context that helps to generate code involving Inductor IR with
     function local buffers. These buffers are constructed during the codegen process and
@@ -314,10 +394,13 @@
     these buffers without exposure to the outside world.
     """
 
-    def __init__(self, kernel: Kernel):
-        self.kernel = kernel
+    def __init__(self, kernel_args: KernelArgs):
+        self.kernel_args = kernel_args
         self.exit_stack = contextlib.ExitStack()
+        # Map Local Buffer name to Local Buffer
         self.local_buffers: Dict[str, ir.Buffer] = {}
+        # Map Local Buffer name to Global Buffer
+        self.local_to_global: Dict[str, ir.Buffer] = {}
 
     def __enter__(self):
         self.exit_stack.__enter__()
@@ -330,23 +413,26 @@
 
         self.exit_stack.enter_context(patch.object(V.graph, "get_dtype", get_dtype))
 
-        original_input = self.kernel.args.input
+        original_input = self.kernel_args.input
 
         def input(name):
             if name in self.local_buffers:
                 return name
             return original_input(name)
 
-        self.exit_stack.enter_context(patch.object(self.kernel.args, "input", input))
+        self.exit_stack.enter_context(patch.object(self.kernel_args, "input", input))
 
-        original_output = self.kernel.args.output
+        original_output = self.kernel_args.output
 
         def output(name):
             if name in self.local_buffers:
                 return name
             return original_output(name)
 
-        self.exit_stack.enter_context(patch.object(self.kernel.args, "output", output))
+        self.exit_stack.enter_context(patch.object(self.kernel_args, "output", output))
+
+        # Set current LocalBufferContext into V
+        self.exit_stack.enter_context(V.set_local_buffer_context(self))
 
         return self
 
@@ -354,53 +440,64 @@
         self.local_buffers.clear()
         self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
 
-    def add_local_buffer(self, buffer: ir.Buffer):
-        assert buffer.get_name() not in self.local_buffers
-        self.local_buffers[buffer.get_name()] = buffer
+    def add_local_buffer(
+        self, local_buffer: ir.Buffer, global_buffer: Optional[ir.Buffer] = None
+    ):
+        assert local_buffer.get_name() not in self.local_buffers
+        self.local_buffers[local_buffer.get_name()] = local_buffer
+        if global_buffer:
+            self.local_to_global[local_buffer.get_name()] = global_buffer
+            V.graph.removed_buffers.add(global_buffer.get_name())
 
-    def localize_buffer(
-        self, global_buf: ir.Buffer, local_buf: ir.Buffer, nodes: List[ir.IRNode]
+    def localize_function(
+        self,
+        fn: Callable[..., Any],
+        rewrite_index: Callable[
+            ["LocalizeBufferHandler", sympy.Expr], sympy.Expr
+        ] = rewrite_index_for_function,
+    ):
+        local_buffers = list(self.local_buffers.values())
+        global_buffers = list(self.local_to_global.values())
+        local_buf = local_buffers[0]
+        global_buf = global_buffers[0]
+
+        def inner(node, *index_vars):
+            with V.set_ops_handler(
+                LocalizeBufferHandler(
+                    V.get_ops_handler(),
+                    global_buf=global_buf,
+                    local_buf=local_buf,
+                    rewrite_index=rewrite_index,
+                )
+            ):
+                return fn(node, *index_vars)
+
+        return inner
+
+    def localize_nodes(
+        self,
+        nodes: List[ir.IRNode],
+        rewrite_index: Callable[
+            ["LocalizeBufferHandler", sympy.Expr], sympy.Expr
+        ] = rewrite_index_for_nodes,
     ) -> List[ir.IRNode]:
         """
-        Localizes the buffer `global_buf` to `local_buf` in the given `nodes` and returns
-        a new list of IR nodes that work on `local_buf` instead of `global_buf`, i.e., all
-        the loads and stores are redirected to `local_buf`. This helps the fused loops to
-        work on smaller-sized local buffers for better data locality.
+        Given `local_buf` and `global_buf` registered in current `LocalBufferContext`
+        though the method of `add_local_buffer`, localizes the `global_buf` to `local_buf`
+        for the given `nodes` and returns a new list of IR nodes that work on `local_buf`
+        instead of `global_buf`, i.e., all the loads and stores are redirected to
+        `local_buf`. This helps the fused loops to work on smaller-sized local buffers
+        for better data locality.
 
-        The `local_buf` should already be registered in the local scope and the data access
-        is assumed to be contiguous with the same order as the `global_buf`.
+        The the data access of `local_buf` is assumed to be contiguous with the
+        same order as the `global_buf`.
         """
-        assert local_buf.get_name() in self.local_buffers
-        assert len(global_buf.get_size()) == len(local_buf.get_size())
+        local_buffers = list(self.local_buffers.values())
+        global_buffers = list(self.local_to_global.values())
+        assert len(global_buffers[0].get_size()) == len(local_buffers[0].get_size())
         assert len(nodes) > 0
 
-        class LocalizeBufferHandler(V.WrapperHandler):  # type: ignore[name-defined]
-            def __init__(self, inner):
-                super().__init__(inner)
-
-            def localize(self, name: str, index: sympy.Expr):
-                if name == global_buf.get_name():
-                    name = local_buf.get_name()
-                    used_vars = {
-                        s for s in index.free_symbols if symbol_is_type(s, SymT.INDEX)
-                    }
-                    index_vars = []
-                    for i in range(len(local_buf.get_size())):
-                        var = sympy_index_symbol_with_prefix(SymT.INDEX, i)
-                        index_vars.append(var if var in used_vars else 0)
-                    index = local_buf.layout.make_indexer()(index_vars)
-                return name, index
-
-            def load(self, name: str, index: sympy.Expr):
-                return self._inner.load(*self.localize(name, index))
-
-            def store(self, name, index, value, mode=None):
-                return self._inner.store(*self.localize(name, index), value, mode)
-
-            def store_reduction(self, name, index, value):
-                return self._inner.store_reduction(*self.localize(name, index), value)
-
-        def wrap_inner_fn_for_node(node: ir.IRNode, inner_fn_wrapper):
+        def wrap_inner_fn_for_node(node: ir.IRNode):
             loops = node.data if isinstance(node, ir.ComputedBuffer) else node
             assert isinstance(loops, ir.Loops)
             new_loops = copy.copy(loops)
@@ -411,17 +508,13 @@
             else:
                 new_node = new_loops  # type: ignore[assignment]
 
-            new_loops.inner_fn = inner_fn_wrapper(new_loops.inner_fn)
+            new_loops.inner_fn = self.localize_function(
+                new_loops.inner_fn,
+                rewrite_index,
+            )
             return new_node
 
-        def inner_fn_wrapper(inner_fn):
-            def inner(index):
-                with V.set_ops_handler(LocalizeBufferHandler(V.get_ops_handler())):
-                    return inner_fn(index)
-
-            return inner
-
-        return [wrap_inner_fn_for_node(node, inner_fn_wrapper) for node in nodes]
+        return [wrap_inner_fn_for_node(node) for node in nodes]
 
 
 def unify_mask_base_type(
diff --git a/torch/_inductor/metrics.py b/torch/_inductor/metrics.py
index fc7d0e6..5c86667 100644
--- a/torch/_inductor/metrics.py
+++ b/torch/_inductor/metrics.py
@@ -41,9 +41,15 @@
 # counters for tracking to_dtype inserted
 cpp_to_dtype_count = 0
 
+
+@dataclasses.dataclass
+class CppOuterLoopFusedCount:
+    inner_kernel_number: int
+    local_buffer_number: int = 0
+
+
 # The length counts the number of outer loop fusions.
-# Each element counts the number of inner kernels in each outer loop fusion.
-cpp_outer_loop_fused_inner_counts: List[int] = []
+cpp_outer_loop_fused_inner_counts: List[CppOuterLoopFusedCount] = []
 
 num_comprehensive_padding = 0
 num_matches_for_scatter_upon_const_tensor = 0
diff --git a/torch/_inductor/virtualized.py b/torch/_inductor/virtualized.py
index ac8d3c6..51ff55a 100644
--- a/torch/_inductor/virtualized.py
+++ b/torch/_inductor/virtualized.py
@@ -72,6 +72,7 @@
 
 if TYPE_CHECKING:
     import torch
+    from torch._inductor.codegen.cpp_utils import LocalBufferContext
     from torch._inductor.debug import DebugContext
     from torch._inductor.graph import GraphLowering
     from torch._inductor.ir import InterpreterShim
@@ -162,6 +163,9 @@
 _interpreter: Virtualized[InterpreterShim] = Virtualized("interpreter", NullHandler)
 _aot_compilation: Virtualized[bool] = Virtualized("aot_compilation", NullHandler)
 _current_node: Virtualized[torch.fx.Node] = Virtualized("current_node", NullHandler)
+_local_buffer_context: Virtualized[LocalBufferContext] = Virtualized(
+    "local_buffer_context", NullHandler
+)
 
 
 class OpsValue:
@@ -306,6 +310,8 @@
     get_aot_compilation: Callable[[], Any] = _aot_compilation._get_handler
     set_current_node: Callable[[Any], Any] = _current_node._set_handler
     get_current_node: Callable[[], Any] = _current_node._get_handler
+    set_local_buffer_context: Callable[[Any], Any] = _local_buffer_context._set_handler
+    get_local_buffer_context: Callable[[], Any] = _local_buffer_context._get_handler
 
     @property
     def ops(self) -> OpsHandler[Any]:
@@ -348,5 +354,9 @@
     def current_node(self):
         return _current_node._get_handler()
 
+    @property
+    def local_buffer_context(self):
+        return _local_buffer_context._get_handler()
+
 
 V = _V()