[Inductor][CPP] Enable Local Buffer for Outer loop fusion (#126967)
**Summary**
Currently, the Inductor CPP backend [generated code](https://gist.github.com/leslie-fang-intel/98f91d43dabed581a1ffe23daf133a65#file-bf16-softmax-generated-code-wo-local-buffer-py) for `Softmax` with BF16 data type is significantly slower than the [ATen Implementation](https://github.com/pytorch/pytorch/blob/9a2beb862d9c30f037c9b2eac878ec3f9227a5e2/aten/src/ATen/native/cpu/SoftMaxKernel.cpp#L149). Upon comparing the generated code with ATen, the performance bottleneck appears to be related to the usage of [local buffer in ATen](https://github.com/pytorch/pytorch/blob/9a2beb862d9c30f037c9b2eac878ec3f9227a5e2/aten/src/ATen/native/cpu/SoftMaxKernel.cpp#L159-L160).
In the current implementation, the Inductor uses the output buffer of Kernel Group Args to store and load temporary result (such as `exp`), since this buffer is corresponding to a `SchedulerNode`. Each thread accesses a portion of this output buffer via indexing. However, since this buffer (take this `exp` as example) is only utilized internally within decomposed `softmax`, this buffer can be replaced with a thread-local buffer similar to ATen's approach.
In this PR, we have introduced the optimizations of `LocalBuffer`. Following this enhancement, the [new generated Inductor code with local buffer](https://gist.github.com/leslie-fang-intel/98f91d43dabed581a1ffe23daf133a65#file-bf16-softmax-generated-code-w-local-buffer-py) for BF16 `Softmax` demonstrates significantly improved performance. Running the benchmark [here](https://gist.github.com/leslie-fang-intel/37d81441237b5139c8295f5e6c4cd31a) to test this BF16 `Softmax` case on an 8480 Xeon server shows similar performance between the Inductor CPP Backend and the ATen implementation.
**TestPlan**
```
python -u -m pytest -s -v inductor/test_cpu_repro.py -k test_local_buffer_in_outer_loop_fusion
```
**Next Step**
- [ ] Support more than one Local Buffer/Global Buffer
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126967
Approved by: https://github.com/jgong5, https://github.com/peterbell10
diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py
index fcbac3a..2889e3d 100644
--- a/test/inductor/test_cpu_repro.py
+++ b/test/inductor/test_cpu_repro.py
@@ -2556,6 +2556,7 @@
self.common(fn, (x,))
assert metrics.generated_cpp_vec_kernel_count == 0
+ @config.patch(fx_graph_cache=False)
def test_outer_loop_fusion(self):
def fn(x):
max = torch.amax(x, dim=-1, keepdim=True)
@@ -2567,8 +2568,47 @@
torch._dynamo.reset()
metrics.reset()
self.common(fn, (x,))
- assert len(metrics.cpp_outer_loop_fused_inner_counts) == 1
- assert metrics.cpp_outer_loop_fused_inner_counts[0] == 2
+ self.assertEqual(
+ len(metrics.cpp_outer_loop_fused_inner_counts),
+ 1,
+ )
+ self.assertEqual(
+ metrics.cpp_outer_loop_fused_inner_counts[0].inner_kernel_number,
+ 2,
+ )
+
+ @config.patch(fx_graph_cache=False)
+ def test_local_buffer_in_outer_loop_fusion(self):
+ def fn(x):
+ max = torch.nn.functional.softmax(x, dim=-1)
+ return x - max
+
+ x = torch.randn(4, 12, 1023, 1022)
+
+ with config.patch({"cpp.simdlen": None}):
+ torch._dynamo.reset()
+ metrics.reset()
+ self.common(fn, (x,))
+ self.assertEqual(
+ len(metrics.cpp_outer_loop_fused_inner_counts),
+ 1,
+ )
+ self.assertEqual(
+ metrics.cpp_outer_loop_fused_inner_counts[0].inner_kernel_number,
+ 3,
+ )
+ self.assertEqual(
+ metrics.cpp_outer_loop_fused_inner_counts[0].local_buffer_number,
+ 1,
+ )
+ # Check the number of global buffer allocation
+ torch._dynamo.reset()
+ metrics.reset()
+ _, code = run_and_get_cpp_code(
+ torch._dynamo.optimize("inductor")(fn),
+ x,
+ )
+ self.assertEqual(code.count("empty_strided_cpu("), 3)
def test_argmin(self):
def fn(x):
diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py
index 1c08853..caaaf03 100644
--- a/torch/_inductor/codegen/cpp.py
+++ b/torch/_inductor/codegen/cpp.py
@@ -7,6 +7,7 @@
import math
import re
import sys
+from collections import namedtuple
from copy import copy, deepcopy
from enum import Enum
from typing import Any, cast, Dict, List, Optional, Sequence, Set, Tuple, Union
@@ -69,6 +70,7 @@
cexpr_index,
DTYPE_TO_CPP,
INDEX_TYPE,
+ LocalBufferContext,
unify_mask_base_type,
value_to_cpp,
)
@@ -435,8 +437,6 @@
loop_nest_list: List[LoopNestWithSplit] = [
kernel.loop_nest for kernel in cpp_kernel_proxy_list
]
- metrics.cpp_outer_loop_fused_inner_counts.append(len(loop_nest_list))
-
kernel_group = cpp_kernel_proxy_list[0].kernel_group
def _merge_outer_fusion_loop_levels(
@@ -1915,7 +1915,10 @@
threads = parallel_num_threads()
assert self.call_ranges is not None
kernels = loop_nest.get_kernels()
- if any(isinstance(kernel, OuterLoopFusedKernel) for kernel in kernels):
+ has_outer_loop_kernel = any(
+ isinstance(kernel, OuterLoopFusedKernel) for kernel in kernels
+ )
+ if has_outer_loop_kernel:
assert len(kernels) == 1
assert isinstance(kernels[0], OuterLoopFusedKernel)
par_depth = kernels[0].decide_parallel_depth(
@@ -2045,6 +2048,31 @@
stack.enter_context(code.indent())
if loop_nest.root:
+ if (
+ has_outer_loop_kernel
+ and isinstance(V.local_buffer_context, LocalBufferContext)
+ and V.local_buffer_context.local_buffers
+ ):
+ # Allocate local buffer
+ local_buffers = V.local_buffer_context.local_buffers
+ assert len(local_buffers.items()) == 1
+ local_buffer = next(iter(local_buffers.items()))[1]
+ # For dynamic size, rename s to ks
+ local_buf_size = sympy_product(
+ [
+ self.rename_indexing(size_val)
+ for size_val in local_buffer.get_layout().size
+ ]
+ )
+ local_buf_dtype = DTYPE_TO_CPP[local_buffer.get_layout().dtype]
+ allocate = f"std::make_unique<{local_buf_dtype} []>({cexpr(local_buf_size)})"
+ code.splice(
+ f"std::unique_ptr<{local_buf_dtype} []> local_buffer = {allocate};"
+ )
+ local_buffer_name = local_buffer.get_name()
+ code.splice(
+ f"{local_buf_dtype}* {local_buffer_name} = local_buffer.get();"
+ )
gen_loops(loop_nest.root)
else:
gen_kernel(loop_nest.kernel)
@@ -3500,6 +3528,18 @@
return node.codegen(index_vars)
fn_list = [functools.partial(fn, node) for node in nodes]
+
+ if (
+ isinstance(V.local_buffer_context, LocalBufferContext)
+ and V.local_buffer_context.local_buffers
+ ):
+ fn_list = [
+ V.local_buffer_context.localize_function(
+ fn,
+ )
+ for fn in fn_list
+ ]
+
var_sizes_list = [node.group[1] for node in nodes]
self.codegen_functions(fn_list, var_sizes_list, vec_dtype)
@@ -3807,6 +3847,159 @@
self._can_fuse_horizontal_impl(node1, node2) and not node1.is_reduction()
) or self.can_fuse_vertical_outer_loop(node1, node2)
+ def codegen_outer_loop_node(
+ self,
+ node: OuterLoopFusedSchedulerNode,
+ ):
+ """
+ Generate the code for the outer loop fused scheduler node.
+ 1. Codegen with fused outer loop: depends on the analysis of
+ the outer loop fused scheduler node, with or without the local buffer.
+ 2. If failed, fallback to standard codegen.
+ """
+ kernel_group = self.kernel_group
+ generated_cpp_vec_kernel_count = metrics.generated_cpp_vec_kernel_count
+ cpp_kernel_proxy_list: List[CppKernelProxy] = []
+ nodes_list: List[List[SchedulerNode]] = []
+ assert isinstance(node, OuterLoopFusedSchedulerNode)
+
+ def try_outer_loop_fusion_with_local_buf(node: OuterLoopFusedSchedulerNode):
+ """
+ Codegen code with fused outer loop and local Buffer.
+ """
+ assert isinstance(node, OuterLoopFusedSchedulerNode)
+ cpp_kernel_proxy_list.clear()
+ nodes_list.clear()
+
+ def get_call_ranges(node: BaseSchedulerNode):
+ assert isinstance(node, (SchedulerNode, FusedSchedulerNode))
+ nodes: List[SchedulerNode] = node.get_nodes() # type: ignore[assignment]
+ _, (group, reduction_group) = max(
+ nodes, key=lambda x: int(x.is_reduction())
+ ).group
+ call_ranges = tuple(group) + tuple(reduction_group)
+ return call_ranges
+
+ LocalBuffer = namedtuple("LocalBuffer", ["local_buf", "global_buf"])
+ local_buffers: List[LocalBuffer] = []
+ if all(
+ len(get_call_ranges(_node)) == node.outer_loop_fusion_depth + 1
+ for _node in node.get_outer_nodes()
+ ):
+ # Ref to the typical case of local buffer
+ # in https://github.com/pytorch/pytorch/blob/
+ # 1115a25c36340554442f28f9570abd42f0aface2/aten/src/ATen/native/cpu/SoftMaxKernel.cpp#L159
+ # where the buffer is with size of last dim and contiguous.
+ # Only support this typical case at first.
+ for scheduler_node in node.get_nodes():
+ # all users inside same OuterLoopFusedSchedulerNode
+ if not scheduler_node.is_reduction() and all(
+ user.node in node.get_nodes() for user in scheduler_node.users
+ ):
+ global_buffer = scheduler_node.node
+ assert isinstance(global_buffer, ir.ComputedBuffer)
+ global_buffer_layout = global_buffer.get_layout()
+ size_offset = node.outer_loop_fusion_depth - len(
+ get_call_ranges(scheduler_node)
+ )
+
+ def is_all_write_read_contiguous(scheduler_node):
+ contiguous_index_expr = 0
+ stride = 1
+ for var, range in reversed(
+ scheduler_node._body.var_ranges.items()
+ ):
+ contiguous_index_expr += stride * var
+ stride *= range
+ write_index_expr = scheduler_node._body.writes_name2expr[
+ scheduler_node.get_name()
+ ]
+
+ def is_contiguous_index(x):
+ return x == contiguous_index_expr
+
+ return is_contiguous_index(write_index_expr) and all(
+ is_contiguous_index(
+ user.node._body.reads_name2expr[
+ scheduler_node.get_name()
+ ],
+ )
+ for user in scheduler_node.users
+ )
+
+ if not (
+ global_buffer_layout.is_contiguous()
+ and not scheduler_node.is_reduction()
+ and is_all_write_read_contiguous(scheduler_node)
+ ):
+ continue
+ # Local Buffer is a view of global buffer
+ local_buffer_layout = ir.FixedLayout(
+ global_buffer_layout.device,
+ global_buffer_layout.dtype,
+ global_buffer_layout.size[size_offset:],
+ global_buffer_layout.stride[size_offset:],
+ )
+ local_buffers.append(
+ LocalBuffer(
+ local_buf=ir.Buffer(
+ "local_buffer_data", local_buffer_layout
+ ),
+ global_buf=global_buffer,
+ )
+ )
+ # At most 1 node with local buf for each OuterLoopFusedSchedulerNode
+ break
+ assert len(local_buffers) in [0, 1]
+
+ with LocalBufferContext(kernel_group.args) as scope:
+ if len(local_buffers) > 0:
+ scope.add_local_buffer(
+ local_buffers[0].local_buf, local_buffers[0].global_buf
+ )
+ for _node in node.get_outer_nodes():
+ assert isinstance(_node, (FusedSchedulerNode, SchedulerNode))
+ cpp_kernel_proxy = CppKernelProxy(kernel_group)
+ cpp_kernel_proxy.codegen_nodes(_node.get_nodes()) # type: ignore[arg-type]
+ cpp_kernel_proxy_list.append(cpp_kernel_proxy)
+ nodes_list.append(_node.get_nodes()) # type: ignore[arg-type]
+
+ if not node.check_outer_fusion_loop_level_attr(
+ cpp_kernel_proxy_list, node.outer_loop_fusion_depth
+ ):
+ return False
+ metrics.cpp_outer_loop_fused_inner_counts.append(
+ metrics.CppOuterLoopFusedCount(
+ len(cpp_kernel_proxy_list),
+ local_buffer_number=len(local_buffers),
+ )
+ )
+ outer_fusion_cpp_kernel_proxy = node.merge_outer_fusion_kernels(
+ cpp_kernel_proxy_list,
+ )
+ kernel_group.finalize_kernel(
+ outer_fusion_cpp_kernel_proxy,
+ [_node for _nodes in nodes_list for _node in _nodes],
+ )
+
+ return True
+
+ if not try_outer_loop_fusion_with_local_buf(node):
+ # Reset generated_cpp_vec_kernel_count to codegen again
+ metrics.generated_cpp_vec_kernel_count = generated_cpp_vec_kernel_count
+ cpp_kernel_proxy_list.clear()
+ nodes_list.clear()
+ # Similar as comment in
+ # https://github.com/pytorch/pytorch/blob/469383755fe416eb1c41fa724762ad3eaecdff07/torch/_inductor/codegen/cpp.py#L3269-L3272
+ # Kernels share the same global contexts like V.graph.wrapper_code, V.kernel.args.
+ with torch._inductor.config.patch(inplace_buffers=False):
+ for _node in node.get_outer_nodes():
+ assert isinstance(_node, (FusedSchedulerNode, SchedulerNode))
+ _nodes: List[SchedulerNode] = _node.get_nodes() # type: ignore[assignment]
+ cpp_kernel_proxy = CppKernelProxy(kernel_group)
+ cpp_kernel_proxy.codegen_nodes(_nodes)
+ kernel_group.finalize_kernel(cpp_kernel_proxy, _nodes)
+
def codegen_node(
self,
node: Union[OuterLoopFusedSchedulerNode, FusedSchedulerNode, SchedulerNode],
@@ -3817,38 +4010,7 @@
kernel_group = self.kernel_group
if isinstance(node, OuterLoopFusedSchedulerNode):
- cpp_kernel_proxy_list: List[CppKernelProxy] = []
- nodes_list: List[List[SchedulerNode]] = []
-
- for _node in node.get_outer_nodes():
- assert isinstance(_node, (FusedSchedulerNode, SchedulerNode))
- _nodes: List[SchedulerNode] = _node.get_nodes() # type: ignore[assignment]
- cpp_kernel_proxy = CppKernelProxy(kernel_group)
- cpp_kernel_proxy.codegen_nodes(_nodes)
-
- cpp_kernel_proxy_list.append(cpp_kernel_proxy)
- nodes_list.append(_nodes)
-
- # Note that, in the future, when every kernel can be vectorized,
- # the function select_tiling will be much easier, and we'll be able to lift
- # check_outer_fusion_loop_level_attr to the fusion phase,
- # avoiding grouping kernels at fusion time that "look like we'll be able to fuse them"
- # but then we actually won't.
- if node.check_outer_fusion_loop_level_attr(
- cpp_kernel_proxy_list, node.outer_loop_fusion_depth
- ):
- # Merge the cpp_kernel_proxy_list into cpp_kernel_proxy
- outer_fusion_cpp_kernel_proxy = node.merge_outer_fusion_kernels(
- cpp_kernel_proxy_list,
- )
- kernel_group.finalize_kernel(
- outer_fusion_cpp_kernel_proxy,
- [_node for _nodes in nodes_list for _node in _nodes],
- )
- else:
- # Fall back to standard loop codegen
- for _kernel_proxy, _nodes in zip(cpp_kernel_proxy_list, nodes_list):
- kernel_group.finalize_kernel(_kernel_proxy, _nodes)
+ self.codegen_outer_loop_node(node)
else:
nodes: List[SchedulerNode] = node.get_nodes() # type: ignore[assignment]
cpp_kernel_proxy = CppKernelProxy(kernel_group)
diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py
index 58adcda..a27ed7b 100644
--- a/torch/_inductor/codegen/cpp_template_kernel.py
+++ b/torch/_inductor/codegen/cpp_template_kernel.py
@@ -14,7 +14,7 @@
from ..utils import sympy_index_symbol, sympy_index_symbol_with_prefix
from ..virtualized import V
from .cpp import CppKernel, CppKernelProxy, KernelGroup
-from .cpp_utils import cexpr_index, DTYPE_TO_CPP, LocalBufferScope
+from .cpp_utils import cexpr_index, DTYPE_TO_CPP, LocalBufferContext
def parse_expr_with_index_symbols(expr):
@@ -270,13 +270,11 @@
if offsets:
offsets = parse_expr_with_index_symbols(offsets)
if epilogue_nodes:
- with LocalBufferScope(self) as scope:
+ with LocalBufferContext(self.args) as scope:
assert orig_src is not None
if orig_src.get_name() != src.get_name():
- scope.add_local_buffer(src)
- epilogue_nodes = scope.localize_buffer(
- orig_src, src, epilogue_nodes
- )
+ scope.add_local_buffer(src, orig_src)
+ epilogue_nodes = scope.localize_nodes(epilogue_nodes)
return self.store_pointwise_nodes(
dst, epilogue_nodes, offsets, reindexers # type: ignore[arg-type]
)
@@ -284,7 +282,7 @@
if dst.get_name() != src.get_name():
# src is local
copy = L.copy(dst, src).data.data
- with LocalBufferScope(self) as scope:
+ with LocalBufferContext(self.args) as scope:
scope.add_local_buffer(src)
return self.store_pointwise_nodes(dst, [copy])
else:
diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py
index 886a296..b27975c 100644
--- a/torch/_inductor/codegen/cpp_utils.py
+++ b/torch/_inductor/codegen/cpp_utils.py
@@ -4,7 +4,7 @@
import math
from collections import namedtuple
-from typing import Dict, List, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple
from unittest.mock import patch
import sympy
@@ -12,11 +12,10 @@
import torch
from torch.utils._sympy.symbol import symbol_is_type, SymT
from .. import ir
-from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix
+from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix, sympy_subs
from ..virtualized import V
-from .common import CSEVariable, ExprPrinter, Kernel
-
+from .common import CSEVariable, ExprPrinter, Kernel, KernelArgs
DTYPE_TO_CPP = {
torch.float32: "float",
@@ -304,7 +303,88 @@
return f"static_cast<{cpp_type}>({repr(value)})"
-class LocalBufferScope:
+def rewrite_index_for_function(
+ localize_buffer_handler: "LocalizeBufferHandler",
+ index: sympy.Expr,
+):
+ # Local buffer at the inner dimensions
+ snode = V.graph.scheduler.name_to_node.get(
+ localize_buffer_handler.global_buf.get_name()
+ )
+ assert snode is not None
+ scheduler_nodes = snode.get_nodes()
+ _, (group, reduction_group) = max(
+ scheduler_nodes, key=lambda x: int(x.is_reduction())
+ ).group
+ call_ranges = tuple(group) + tuple(reduction_group)
+ indices_to_keep = [
+ f"x{len(call_ranges) - (idx + 1)}"
+ for idx in range(len(localize_buffer_handler.local_buf.get_layout().size))
+ ]
+ sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) # type: ignore[attr-defined]
+ replacements = {}
+ for x in sorted_symbols:
+ if x.name.startswith("x") and x.name not in indices_to_keep: # type: ignore[attr-defined]
+ # Only keep index used by local buffer
+ replacements[x] = sympy.core.numbers.Zero()
+ index = sympy_subs(index, replacements) # type: ignore[arg-type]
+ return index
+
+
+def rewrite_index_for_nodes(
+ localize_buffer_handler: "LocalizeBufferHandler",
+ index: sympy.Expr,
+):
+ used_vars = {s for s in index.free_symbols if symbol_is_type(s, SymT.INDEX)}
+ index_vars = []
+ for i in range(len(localize_buffer_handler.local_buf.get_size())):
+ var = sympy_index_symbol_with_prefix(SymT.INDEX, i)
+ index_vars.append(var if var in used_vars else 0)
+ index = localize_buffer_handler.local_buf.layout.make_indexer()(index_vars)
+ return index
+
+
+class LocalizeBufferHandler(V.WrapperHandler): # type: ignore[name-defined]
+ def __init__(
+ self,
+ inner,
+ global_buf: ir.Buffer,
+ local_buf: ir.Buffer,
+ rewrite_index: Callable[["LocalizeBufferHandler", sympy.Expr], sympy.Expr],
+ ):
+ super().__init__(inner)
+ self.global_buf = global_buf
+ self.local_buf = local_buf
+ self.rewrite_index = rewrite_index
+
+ def localize(self, name: str, index: sympy.Expr):
+ if self.global_buf and name == self.global_buf.get_name():
+ assert self.rewrite_index is not None
+ name = self.local_buf.get_name()
+ index = self.rewrite_index(self, index)
+ return name, index
+
+ def load(self, name: str, index: sympy.Expr):
+ return self._inner.load(*self.localize(name, index))
+
+ def store(self, name, index, value, mode=None):
+ local_buffer_name, local_buffer_index = self.localize(name, index)
+ res = self._inner.store(local_buffer_name, local_buffer_index, value, mode)
+ if (
+ self.global_buf
+ and name == self.global_buf.get_name()
+ and isinstance(V.kernel, Kernel)
+ ):
+ # Remove name of local buffer from Kernel.store_buffer_names
+ # local_buffer_name is added to Kernel.store_buffer_names in Kernel.CSEProxy.store.
+ V.kernel.store_buffer_names.discard(local_buffer_name)
+ return res
+
+ def store_reduction(self, name, index, value):
+ return self._inner.store_reduction(*self.localize(name, index), value)
+
+
+class LocalBufferContext:
"""
This class creates a context that helps to generate code involving Inductor IR with
function local buffers. These buffers are constructed during the codegen process and
@@ -314,10 +394,13 @@
these buffers without exposure to the outside world.
"""
- def __init__(self, kernel: Kernel):
- self.kernel = kernel
+ def __init__(self, kernel_args: KernelArgs):
+ self.kernel_args = kernel_args
self.exit_stack = contextlib.ExitStack()
+ # Map Local Buffer name to Local Buffer
self.local_buffers: Dict[str, ir.Buffer] = {}
+ # Map Local Buffer name to Global Buffer
+ self.local_to_global: Dict[str, ir.Buffer] = {}
def __enter__(self):
self.exit_stack.__enter__()
@@ -330,23 +413,26 @@
self.exit_stack.enter_context(patch.object(V.graph, "get_dtype", get_dtype))
- original_input = self.kernel.args.input
+ original_input = self.kernel_args.input
def input(name):
if name in self.local_buffers:
return name
return original_input(name)
- self.exit_stack.enter_context(patch.object(self.kernel.args, "input", input))
+ self.exit_stack.enter_context(patch.object(self.kernel_args, "input", input))
- original_output = self.kernel.args.output
+ original_output = self.kernel_args.output
def output(name):
if name in self.local_buffers:
return name
return original_output(name)
- self.exit_stack.enter_context(patch.object(self.kernel.args, "output", output))
+ self.exit_stack.enter_context(patch.object(self.kernel_args, "output", output))
+
+ # Set current LocalBufferContext into V
+ self.exit_stack.enter_context(V.set_local_buffer_context(self))
return self
@@ -354,53 +440,64 @@
self.local_buffers.clear()
self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
- def add_local_buffer(self, buffer: ir.Buffer):
- assert buffer.get_name() not in self.local_buffers
- self.local_buffers[buffer.get_name()] = buffer
+ def add_local_buffer(
+ self, local_buffer: ir.Buffer, global_buffer: Optional[ir.Buffer] = None
+ ):
+ assert local_buffer.get_name() not in self.local_buffers
+ self.local_buffers[local_buffer.get_name()] = local_buffer
+ if global_buffer:
+ self.local_to_global[local_buffer.get_name()] = global_buffer
+ V.graph.removed_buffers.add(global_buffer.get_name())
- def localize_buffer(
- self, global_buf: ir.Buffer, local_buf: ir.Buffer, nodes: List[ir.IRNode]
+ def localize_function(
+ self,
+ fn: Callable[..., Any],
+ rewrite_index: Callable[
+ ["LocalizeBufferHandler", sympy.Expr], sympy.Expr
+ ] = rewrite_index_for_function,
+ ):
+ local_buffers = list(self.local_buffers.values())
+ global_buffers = list(self.local_to_global.values())
+ local_buf = local_buffers[0]
+ global_buf = global_buffers[0]
+
+ def inner(node, *index_vars):
+ with V.set_ops_handler(
+ LocalizeBufferHandler(
+ V.get_ops_handler(),
+ global_buf=global_buf,
+ local_buf=local_buf,
+ rewrite_index=rewrite_index,
+ )
+ ):
+ return fn(node, *index_vars)
+
+ return inner
+
+ def localize_nodes(
+ self,
+ nodes: List[ir.IRNode],
+ rewrite_index: Callable[
+ ["LocalizeBufferHandler", sympy.Expr], sympy.Expr
+ ] = rewrite_index_for_nodes,
) -> List[ir.IRNode]:
"""
- Localizes the buffer `global_buf` to `local_buf` in the given `nodes` and returns
- a new list of IR nodes that work on `local_buf` instead of `global_buf`, i.e., all
- the loads and stores are redirected to `local_buf`. This helps the fused loops to
- work on smaller-sized local buffers for better data locality.
+ Given `local_buf` and `global_buf` registered in current `LocalBufferContext`
+ though the method of `add_local_buffer`, localizes the `global_buf` to `local_buf`
+ for the given `nodes` and returns a new list of IR nodes that work on `local_buf`
+ instead of `global_buf`, i.e., all the loads and stores are redirected to
+ `local_buf`. This helps the fused loops to work on smaller-sized local buffers
+ for better data locality.
- The `local_buf` should already be registered in the local scope and the data access
- is assumed to be contiguous with the same order as the `global_buf`.
+ The the data access of `local_buf` is assumed to be contiguous with the
+ same order as the `global_buf`.
"""
- assert local_buf.get_name() in self.local_buffers
- assert len(global_buf.get_size()) == len(local_buf.get_size())
+ local_buffers = list(self.local_buffers.values())
+ global_buffers = list(self.local_to_global.values())
+ assert len(global_buffers[0].get_size()) == len(local_buffers[0].get_size())
assert len(nodes) > 0
- class LocalizeBufferHandler(V.WrapperHandler): # type: ignore[name-defined]
- def __init__(self, inner):
- super().__init__(inner)
-
- def localize(self, name: str, index: sympy.Expr):
- if name == global_buf.get_name():
- name = local_buf.get_name()
- used_vars = {
- s for s in index.free_symbols if symbol_is_type(s, SymT.INDEX)
- }
- index_vars = []
- for i in range(len(local_buf.get_size())):
- var = sympy_index_symbol_with_prefix(SymT.INDEX, i)
- index_vars.append(var if var in used_vars else 0)
- index = local_buf.layout.make_indexer()(index_vars)
- return name, index
-
- def load(self, name: str, index: sympy.Expr):
- return self._inner.load(*self.localize(name, index))
-
- def store(self, name, index, value, mode=None):
- return self._inner.store(*self.localize(name, index), value, mode)
-
- def store_reduction(self, name, index, value):
- return self._inner.store_reduction(*self.localize(name, index), value)
-
- def wrap_inner_fn_for_node(node: ir.IRNode, inner_fn_wrapper):
+ def wrap_inner_fn_for_node(node: ir.IRNode):
loops = node.data if isinstance(node, ir.ComputedBuffer) else node
assert isinstance(loops, ir.Loops)
new_loops = copy.copy(loops)
@@ -411,17 +508,13 @@
else:
new_node = new_loops # type: ignore[assignment]
- new_loops.inner_fn = inner_fn_wrapper(new_loops.inner_fn)
+ new_loops.inner_fn = self.localize_function(
+ new_loops.inner_fn,
+ rewrite_index,
+ )
return new_node
- def inner_fn_wrapper(inner_fn):
- def inner(index):
- with V.set_ops_handler(LocalizeBufferHandler(V.get_ops_handler())):
- return inner_fn(index)
-
- return inner
-
- return [wrap_inner_fn_for_node(node, inner_fn_wrapper) for node in nodes]
+ return [wrap_inner_fn_for_node(node) for node in nodes]
def unify_mask_base_type(
diff --git a/torch/_inductor/metrics.py b/torch/_inductor/metrics.py
index fc7d0e6..5c86667 100644
--- a/torch/_inductor/metrics.py
+++ b/torch/_inductor/metrics.py
@@ -41,9 +41,15 @@
# counters for tracking to_dtype inserted
cpp_to_dtype_count = 0
+
+@dataclasses.dataclass
+class CppOuterLoopFusedCount:
+ inner_kernel_number: int
+ local_buffer_number: int = 0
+
+
# The length counts the number of outer loop fusions.
-# Each element counts the number of inner kernels in each outer loop fusion.
-cpp_outer_loop_fused_inner_counts: List[int] = []
+cpp_outer_loop_fused_inner_counts: List[CppOuterLoopFusedCount] = []
num_comprehensive_padding = 0
num_matches_for_scatter_upon_const_tensor = 0
diff --git a/torch/_inductor/virtualized.py b/torch/_inductor/virtualized.py
index ac8d3c6..51ff55a 100644
--- a/torch/_inductor/virtualized.py
+++ b/torch/_inductor/virtualized.py
@@ -72,6 +72,7 @@
if TYPE_CHECKING:
import torch
+ from torch._inductor.codegen.cpp_utils import LocalBufferContext
from torch._inductor.debug import DebugContext
from torch._inductor.graph import GraphLowering
from torch._inductor.ir import InterpreterShim
@@ -162,6 +163,9 @@
_interpreter: Virtualized[InterpreterShim] = Virtualized("interpreter", NullHandler)
_aot_compilation: Virtualized[bool] = Virtualized("aot_compilation", NullHandler)
_current_node: Virtualized[torch.fx.Node] = Virtualized("current_node", NullHandler)
+_local_buffer_context: Virtualized[LocalBufferContext] = Virtualized(
+ "local_buffer_context", NullHandler
+)
class OpsValue:
@@ -306,6 +310,8 @@
get_aot_compilation: Callable[[], Any] = _aot_compilation._get_handler
set_current_node: Callable[[Any], Any] = _current_node._set_handler
get_current_node: Callable[[], Any] = _current_node._get_handler
+ set_local_buffer_context: Callable[[Any], Any] = _local_buffer_context._set_handler
+ get_local_buffer_context: Callable[[], Any] = _local_buffer_context._get_handler
@property
def ops(self) -> OpsHandler[Any]:
@@ -348,5 +354,9 @@
def current_node(self):
return _current_node._get_handler()
+ @property
+ def local_buffer_context(self):
+ return _local_buffer_context._get_handler()
+
V = _V()