Revert "[inductor][cpp] support bf16/fp16 gemm template epilogue fusion (#126545)"
This reverts commit 43baabe9b94c86bd36ba4a00f501e52d833d7ec8.
Reverted https://github.com/pytorch/pytorch/pull/126545 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/124021#issuecomment-2133002071))
diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py
index 70cb4ea..0147ca8 100644
--- a/test/inductor/test_cpu_select_algorithm.py
+++ b/test/inductor/test_cpu_select_algorithm.py
@@ -3,7 +3,6 @@
import sys
import unittest
-from typing import Optional
from unittest.mock import patch
import torch
@@ -70,37 +69,6 @@
return wrapped
-def _get_epilogue(epilogue: str, other: Optional[torch.Tensor] = None):
- if epilogue == "none":
- return lambda x: x
- elif epilogue == "relu":
- return torch.nn.ReLU()
- elif epilogue == "gelu":
- return torch.nn.GELU()
- elif epilogue == "silu":
- return torch.nn.SiLU()
- elif epilogue == "sigmoid":
- return torch.nn.Sigmoid()
- elif epilogue == "tanh":
- return torch.nn.Tanh()
- elif epilogue == "hardswish":
- return torch.nn.Hardswish()
- elif epilogue == "hardsigmoid":
- return torch.nn.Hardsigmoid()
- elif epilogue == "leaky_relu":
- return torch.nn.LeakyReLU()
- elif epilogue == "hardtanh":
- return torch.nn.Hardtanh()
- elif epilogue == "add":
- return lambda x: x + other
- elif epilogue == "sub":
- return lambda x: x - other
- elif epilogue == "mul":
- return lambda x: x * other
- elif epilogue == "div":
- return lambda x: x / other
-
-
class TestSelectAlgorithm(TestCase):
common = check_model
@@ -196,7 +164,7 @@
"div",
),
)
- @dtypes(torch.float, torch.bfloat16, torch.half)
+ @dtypes(torch.float)
def test_linear_with_pointwise(self, bias, epilogue, dtype):
batch_size = 384
in_features = 196
@@ -206,7 +174,32 @@
def __init__(self, bias, epilogue, other):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias)
- self.epilogue = _get_epilogue(epilogue, other)
+ if epilogue == "relu":
+ self.epilogue = torch.nn.ReLU()
+ elif epilogue == "gelu":
+ self.epilogue = torch.nn.GELU()
+ elif epilogue == "silu":
+ self.epilogue = torch.nn.SiLU()
+ elif epilogue == "sigmoid":
+ self.epilogue = torch.nn.Sigmoid()
+ elif epilogue == "tanh":
+ self.epilogue = torch.nn.Tanh()
+ elif epilogue == "hardswish":
+ self.epilogue = torch.nn.Hardswish()
+ elif epilogue == "hardsigmoid":
+ self.epilogue = torch.nn.Hardsigmoid()
+ elif epilogue == "leaky_relu":
+ self.epilogue = torch.nn.LeakyReLU()
+ elif epilogue == "hardtanh":
+ self.epilogue = torch.nn.Hardtanh()
+ elif epilogue == "add":
+ self.epilogue = lambda x: x + other
+ elif epilogue == "sub":
+ self.epilogue = lambda x: x - other
+ elif epilogue == "mul":
+ self.epilogue = lambda x: x * other
+ elif epilogue == "div":
+ self.epilogue = lambda x: x / other
def forward(self, x):
return self.epilogue(self.linear(x))
@@ -215,69 +208,7 @@
v = torch.randn(batch_size, in_features).to(dtype=dtype)
u = torch.randn(batch_size, out_features).to(dtype=dtype)
mod = M(bias=bias, epilogue=epilogue, other=u).to(dtype=dtype).eval()
- atol, rtol = 1e-4, 1e-4
- if dtype == torch.half or dtype == torch.bfloat16:
- atol, rtol = 1e-2, 1e-2
- with patch.object(select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol)):
- self.common(mod, (v,), atol=atol, rtol=rtol)
- self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
- if (
- dtype == torch.bfloat16
- and epilogue != "mul"
- and epilogue != "div"
- or (dtype == torch.half and epilogue == "add" and not bias)
- ):
- # Several scenarios where epilogue fusion is not counted in:
- # 1. For bfloat16, the epilogue fusion is part of the template,
- # not fused via scheduler. This will also be true for float16 but
- # float16 oneDNN linear is not supported right now. The exception
- # is mul or div fusion which is not supported for oneDNN linear.
- # 2. For float16, since oneDNN linear is not applied, linear w/o bias
- # plus epilogue add is treated as linear w/ bias.
- self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 0)
- else:
- self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
-
- @inductor_config.patch({"freezing": True})
- @patches
- @torch.no_grad
- @unittest.skipIf(not TEST_MKL, "Test requires MKL")
- @parametrize("bias", (True, False))
- @parametrize(
- "epilogue",
- (
- "none",
- "relu",
- "add",
- "sub",
- "mul",
- ),
- )
- @dtypes(torch.float, torch.bfloat16, torch.half)
- def test_linear_with_transpose(self, bias, epilogue, dtype):
- batch_size = 384
- in_features = 196
- out_features = 128
-
- class M(torch.nn.Module):
- def __init__(self, bias, epilogue, other):
- super().__init__()
- self.epilogue = _get_epilogue(epilogue, other)
- self.linear = torch.nn.Linear(in_features, out_features, bias)
-
- def forward(self, x, y):
- return self.epilogue(self.linear(x)).transpose(0, 1) + y
-
- counters.clear()
- v = torch.randn(batch_size, in_features).to(dtype=dtype)
- u = torch.randn(out_features, batch_size).to(dtype=dtype)
- other = torch.randn(batch_size, out_features).to(dtype=dtype)
- mod = M(bias=bias, epilogue=epilogue, other=other).to(dtype=dtype).eval()
- atol, rtol = 1e-4, 1e-4
- if dtype == torch.half or dtype == torch.bfloat16:
- atol, rtol = 1e-2, 1e-2
- with patch.object(select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol)):
- self.common(mod, (v, u), atol=atol, rtol=rtol)
+ self.common(mod, (v,))
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
@@ -286,44 +217,25 @@
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("bias", (True, False))
- @parametrize(
- "unary",
- ("relu",),
- )
- @parametrize(
- "binary",
- (
- "add",
- "sub",
- "mul",
- "div",
- ),
- )
- @dtypes(torch.float, torch.bfloat16, torch.half)
- def test_linear_with_unary_binary(self, bias, unary, binary, dtype):
+ @dtypes(torch.float)
+ def test_linear_with_transpose(self, bias, dtype):
batch_size = 384
in_features = 196
- out_features = 384
+ out_features = 128
class M(torch.nn.Module):
- def __init__(self, bias, unary, binary, other):
+ def __init__(self, bias):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias)
- self.unary = _get_epilogue(unary)
- self.binary = _get_epilogue(binary, other)
- def forward(self, x):
- return self.binary(self.unary(self.linear(x)))
+ def forward(self, x, y):
+ return self.linear(x).transpose(0, 1) + y
counters.clear()
+ mod = M(bias=bias).to(dtype=dtype).eval()
v = torch.randn(batch_size, in_features).to(dtype=dtype)
- u = torch.randn(batch_size, out_features).to(dtype=dtype)
- mod = M(bias=bias, unary=unary, binary=binary, other=u).to(dtype=dtype).eval()
- atol, rtol = 1e-4, 1e-4
- if dtype == torch.half or dtype == torch.bfloat16:
- atol, rtol = 1e-2, 1e-2
- with patch.object(select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol)):
- self.common(mod, (v,), atol=atol, rtol=rtol)
+ u = torch.randn(out_features, batch_size).to(dtype=dtype)
+ self.common(mod, (v, u))
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
@@ -342,9 +254,6 @@
test_linear_with_transpose_dynamic_shapes = (
TestSelectAlgorithm.test_linear_with_transpose
)
- test_linear_with_unary_binary_dynamic_shapes = (
- TestSelectAlgorithm.test_linear_with_unary_binary
- )
instantiate_device_type_tests(TestSelectAlgorithm, globals(), only_for="cpu")
diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py
index 5046b6b..258d088 100644
--- a/torch/_inductor/codegen/cpp.py
+++ b/torch/_inductor/codegen/cpp.py
@@ -3029,7 +3029,7 @@
scheduler_node._lowp_fp_type = _lowp_fp_type # type: ignore[attr-defined]
return True
- def legalize_lowp_fp_dtype_loopbody(self, loop_body: ir.LoopBody):
+ def legalize_lowp_fp_dtype(self, nodes):
def add_to_dtype(sub_graph: torch.fx.Graph):
def is_lowp_fp_load(node: torch.fx.Node):
if node.target not in ["load"]:
@@ -3167,11 +3167,11 @@
eliminate_to_dtype(sub_graph)
- sub_blocks = [loop_body.root_block] + list(loop_body.subblocks.values())
- for sub_block in sub_blocks:
- add_to_dtype(sub_block.graph)
+ def _legalize_lowp_fp(loop_body: ir.LoopBody):
+ sub_blocks = [loop_body.root_block] + list(loop_body.subblocks.values())
+ for sub_block in sub_blocks:
+ add_to_dtype(sub_block.graph)
- def legalize_lowp_fp_dtype(self, nodes):
if all(
isinstance(_node, SchedulerNode) and self.is_lowp_fp_scheduler(_node)
for _node in nodes
@@ -3208,7 +3208,7 @@
should_legalize = not is_memory_copy_scheduler_node(node)
if should_legalize:
body: ir.LoopBody = node._body
- self.legalize_lowp_fp_dtype_loopbody(body)
+ _legalize_lowp_fp(body)
def codegen_functions(self, fn_list, var_sizes_list, vec_dtype=torch.float):
# TODO(jgong5): remove vec_dtype arg with alternative tiling factors for various dtypes
@@ -3373,8 +3373,8 @@
inner_tail_loop.set_kernel(vec_kernel)
def codegen_loop_bodies(self, loop_bodies, var_sizes_list):
+ # TODO(jgong5): support lowp legalization
for body in loop_bodies:
- self.legalize_lowp_fp_dtype_loopbody(body)
DataTypePropagation.propagate_loopbody(body)
self.codegen_functions(loop_bodies, var_sizes_list)
@@ -3779,6 +3779,7 @@
kernel, render = ctb.make_kernel_render(ctb, epilogue_nodes=epilogue_ir_nodes)
with kernel:
for node in [template_node, *epilogue_nodes]:
+ node.decide_inplace_update()
node.mark_run() # type: ignore[attr-defined]
src_code = render()
diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py
index c28d766..18d6301 100644
--- a/torch/_inductor/codegen/cpp_gemm_template.py
+++ b/torch/_inductor/codegen/cpp_gemm_template.py
@@ -1,4 +1,4 @@
-from typing import Callable, cast, List, Optional
+from typing import cast, List, Optional
import torch
import torch.utils
@@ -20,7 +20,7 @@
{{micro_gemm.codegen_define(kernel)}}
extern "C"
-{{kernel.def_kernel(inputs={"X": X, "W": W, "inp": inp}, outputs={"Y": Y}, aliases=buffer_aliases)}}
+{{kernel.def_kernel(inputs={"X": X, "W": W, "inp": inp}, outputs={"Y": Y})}}
{
{{kernel.maybe_codegen_profile()}}
constexpr int64_t num_threads = {{num_threads}};
@@ -91,8 +91,8 @@
const int64_t n_start = nc * N0;
const int64_t n_size = N0;
{%- if use_local_acc %}
- {{ kernel.define_buffer(acc_buf_name, ["m_end - m_start", "N0"]) }}
- {%- set acc = kernel.local_buffers[acc_buf_name] %}
+ {{ kernel.define_buffer("acc_local_buf", ["m_end - m_start", "N0"]) }}
+ {%- set acc = kernel.local_buffers["acc_local_buf"] %}
{%- else %}
{%- set acc = kernel.slice_nd(GemmOut, [("m_start", "m_end"), ("n_start", "n_start + N0")]) %}
{%- endif %}
@@ -127,7 +127,7 @@
{%- endif %}
{%- set tile_Y = kernel.slice_nd(Y_maybe_transposed, [("m_start", "m_end"), ("n_start", "n_start + N0")]) %}
{{ kernel.store_output(
- tile_Y, acc, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexer=reindexer
+ tile_Y, acc, epilogue_nodes, offsets=("m_start", "n_start"), reindexer=reindexer
)|indent(16, false)
}}
}
@@ -146,16 +146,11 @@
register_blocking: GemmBlocking,
beta=1,
alpha=1,
- has_bias=False,
- epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
):
assert layout.dtype in [torch.float, torch.bfloat16, torch.half]
- super().__init__(
- "packed_gemm", input_nodes, layout, epilogue_creator=epilogue_creator
- )
+ super().__init__("packed_gemm", input_nodes, layout)
self.beta = beta
self.alpha = alpha
- self.has_bias = has_bias
self.num_threads = num_threads
self.register_blocking = register_blocking
m, n = layout.size
@@ -223,30 +218,26 @@
input_nodes,
beta=1,
alpha=1,
- has_bias=False,
trans_w=False,
input_indices=None,
- epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
):
if input_indices is None:
input_indices = list(range(len(input_nodes)))
def reorder_and_filter(inputs, layout_or_out):
- if has_bias:
- assert len(input_indices) >= 3
+ if len(input_indices) == 2:
+ x_idx = input_indices[0]
+ w_idx = input_indices[1]
+ return [inputs[x_idx], inputs[w_idx]], layout_or_out
+ else:
+ assert (
+ len(input_indices) == 3
+ ), "Cpp Packed GEMM template requires 2 or 3 input nodes."
# assume the input order is [inp, x, w] and we reorder it to [x, w, inp]
inp_idx = input_indices[0]
x_idx = input_indices[1]
w_idx = input_indices[2]
- return [
- inputs[x_idx],
- inputs[w_idx],
- inputs[inp_idx],
- *[inputs[idx] for idx in input_indices[3:]],
- ], layout_or_out
- else:
- assert len(input_indices) >= 2
- return [inputs[idx] for idx in input_indices], layout_or_out
+ return [inputs[x_idx], inputs[w_idx], inputs[inp_idx]], layout_or_out
def maybe_to_dense(inputs, layout_or_out):
new_inputs = list(inputs)
@@ -382,8 +373,6 @@
register_blocking=micro_gemm.register_blocking,
beta=beta,
alpha=alpha,
- has_bias=has_bias,
- epilogue_creator=epilogue_creator,
)
template.maybe_append_choice(choices)
return template
@@ -398,7 +387,7 @@
assert len(self.input_nodes) >= 2
X, W = self.input_nodes[0], self.input_nodes[1]
- inp = self.input_nodes[2] if self.has_bias else None
+ inp = self.input_nodes[2] if len(self.input_nodes) > 2 else None
Y = self.output_node
if template_buffer_node is not None:
@@ -407,26 +396,11 @@
Y = template_buffer_node
template_buffer = Y
- gemm_output_buffer = template_buffer
-
- epilogues: List[ir.IRNode] = []
- if self.epilogue_creator is not None:
- gemm_output_name = "GemmOut"
- gemm_output_buffer = ir.Buffer(gemm_output_name, template_buffer.layout)
- epilogues.append(
- ir.ComputedBuffer(
- name=template_buffer.get_name(),
- layout=template_buffer.layout,
- data=self.epilogue_creator(gemm_output_buffer),
- )
- )
-
Y_is_transposed = False
use_local_acc = self.layout.dtype != torch.float
- acc_buf_name = "local_acc_buf"
if epilogue_nodes:
- epilogues.extend(epilogue_nodes)
Y = cast(ir.Buffer, epilogue_nodes[-1])
+ assert Y.get_name() in V.kernel.inplace_update_buffers
if Y.get_size() == list(
reversed(template_buffer.get_size())
) and Y.get_stride() == list(reversed(template_buffer.get_stride())):
@@ -450,10 +424,7 @@
W=W,
inp=inp,
Y=Y,
- GemmOut=gemm_output_buffer,
- buffer_aliases=[(gemm_output_buffer, Y)]
- if gemm_output_buffer is not Y
- else None,
+ GemmOut=template_buffer,
beta=self.beta,
alpha=self.alpha,
num_threads=self.num_threads,
@@ -461,9 +432,8 @@
is_dynamic_M=self.is_dynamic_M,
template=self,
kernel=kernel,
- epilogue_nodes=epilogues,
+ epilogue_nodes=epilogue_nodes,
reindexer=(lambda x: list(reversed(x))) if Y_is_transposed else None,
use_local_acc=use_local_acc,
- acc_buf_name=acc_buf_name,
)
return self._template_from_string(GEMM_TEMPLATE).render(**options)
diff --git a/torch/_inductor/codegen/cpp_template.py b/torch/_inductor/codegen/cpp_template.py
index 85f6a61..aeebd26 100644
--- a/torch/_inductor/codegen/cpp_template.py
+++ b/torch/_inductor/codegen/cpp_template.py
@@ -3,7 +3,7 @@
import logging
import sys
-from typing import Callable, List, Optional
+from typing import List, Optional
from unittest.mock import patch
import sympy
@@ -26,13 +26,11 @@
name: str,
input_nodes,
layout: ir.Layout,
- epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
):
super().__init__(name)
self.input_nodes = input_nodes
self.output_node: ir.Buffer = ir.Buffer("buf_out", layout)
self.layout = layout
- self.epilogue_creator = epilogue_creator
def generate(self, **kwargs):
kernel_name = f"cpp_{self.name}"
diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py
index 25d8a22..1c704cc 100644
--- a/torch/_inductor/codegen/cpp_template_kernel.py
+++ b/torch/_inductor/codegen/cpp_template_kernel.py
@@ -5,12 +5,11 @@
from sympy.parsing.sympy_parser import parse_expr
import torch
-from torch.utils._sympy.symbol import SymT
from .. import codecache, config, ir, lowering as L
from ..autotune_process import CppBenchmarkRequest
from ..select_algorithm import PartialRender
-from ..utils import sympy_index_symbol, sympy_index_symbol_with_prefix
+from ..utils import sympy_index_symbol
from ..virtualized import V
from .common import Kernel, OpOverrides
from .cpp import CppKernelProxy, KernelGroup
@@ -52,26 +51,13 @@
self,
inputs: Dict[str, ir.Buffer],
outputs: Dict[str, ir.Buffer],
- aliases: Optional[List[Tuple[ir.Buffer, ir.Buffer]]] = None,
) -> str:
for name, inp in inputs.items():
if inp is not None:
self.args.input_buffers[inp.get_name()] = name
for name, out in outputs.items():
- self.args.output_buffers[out.get_name()] = name
- if aliases is not None:
- for alias, orig in aliases:
- orig_name = orig.get_name()
- alias_name = alias.get_name()
- if orig_name in self.args.input_buffers:
- self.args.input_buffers[alias_name] = self.args.input_buffers[
- orig_name
- ]
- if orig_name in self.args.output_buffers:
- self.args.output_buffers[alias_name] = self.args.output_buffers[
- orig_name
- ]
-
+ if out.get_name() not in self.args.inplace_buffers:
+ self.args.output_buffers[out.get_name()] = name
unique_sizevars = {
s
for input in inputs.values()
@@ -92,14 +78,6 @@
self.args.sizevars[sizevar] = f"k{sizevar}"
def hook():
- # remove all aliases before generate function definition
- if aliases is not None:
- for alias, _ in aliases:
- alias_name = alias.get_name()
- if alias_name in self.args.input_buffers:
- self.args.input_buffers[alias_name] = "REMOVED"
- if alias_name in self.args.output_buffers:
- self.args.output_buffers[alias_name] = "REMOVED"
cpp_argdefs, _, _ = self.args.cpp_argdefs()
return f"void {self.kernel_name}({', '.join(cpp_argdefs)})"
@@ -205,10 +183,7 @@
reindexer: Optional[Callable[[List[Any]], List[Any]]] = None,
) -> str:
var_sizes = (tuple(dst.get_size()), ())
- var_ranges = {
- sympy_index_symbol_with_prefix(SymT.INDEX, i): sz
- for i, sz in enumerate(var_sizes[0])
- }
+ var_ranges = {sympy.Symbol(f"z{i}"): sz for i, sz in enumerate(var_sizes[0])}
if not offsets:
offsets = [sympy.Integer(0)] * len(var_sizes[0])
assert len(offsets) == len(var_sizes[0])
@@ -228,7 +203,7 @@
assert len(args[0]) == len(var_sizes[0])
assert len(args[1]) == 0
new_args = [arg + offset for arg, offset in zip(args[0], offsets)] # type: ignore[arg-type]
- if reindexer is not None and i == len(nodes) - 1:
+ if reindexer is not None:
new_args = reindexer(new_args)
V.ops.store(
output_name,
@@ -248,7 +223,6 @@
self,
dst: ir.Buffer,
src: ir.Buffer,
- orig_src: Optional[ir.Buffer] = None,
epilogue_nodes: Optional[List[ir.IRNode]] = None,
offsets: Optional[List[Any]] = None,
reindexer: Optional[Callable[[List[Any]], List[Any]]] = None,
@@ -270,23 +244,12 @@
the sizes of `src` and `dst`.
b) `dst` might be indexed in a different way as the `epilogue_nodes`, hence a `reindexer` is
needed on the indices to `epilogue_nodes` to match the indexing of `dst`.
- c) If `src` is local, we need to add a local buffer for it and localize the `orig_src` buffer
- in `epilogue_nodes` with `src`.
"""
assert dst.get_size() == src.get_size()
if offsets:
offsets = parse_expr_with_index_symbols(offsets)
if epilogue_nodes:
- with LocalBufferScope(self) as scope:
- assert orig_src is not None
- if orig_src.get_name() != src.get_name():
- scope.add_local_buffer(src)
- epilogue_nodes = scope.localize_buffer(
- orig_src, src, epilogue_nodes
- )
- return self.store_pointwise_nodes(
- dst, epilogue_nodes, offsets, reindexer # type: ignore[arg-type]
- )
+ return self.store_pointwise_nodes(dst, epilogue_nodes, offsets, reindexer)
else:
if dst.get_name() != src.get_name():
# src is local
diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py
index e681766..dbe3daf 100644
--- a/torch/_inductor/codegen/cpp_utils.py
+++ b/torch/_inductor/codegen/cpp_utils.py
@@ -1,15 +1,11 @@
import contextlib
-import copy
import math
from collections import namedtuple
-from typing import Dict, List
+from typing import Dict
from unittest.mock import patch
-import sympy
-
import torch
-from torch.utils._sympy.symbol import symbol_is_type, SymT
from .. import ir
from ..virtualized import V
@@ -305,69 +301,3 @@
def add_local_buffer(self, buffer: ir.Buffer):
assert buffer.get_name() not in self.local_buffers
self.local_buffers[buffer.get_name()] = buffer
-
- def localize_buffer(
- self, global_buf: ir.Buffer, local_buf: ir.Buffer, nodes: List[ir.IRNode]
- ) -> List[ir.IRNode]:
- """
- Localizes the buffer `global_buf` to `local_buf` in the given `nodes` and returns
- a new list of IR nodes that work on `local_buf` instead of `global_buf`, i.e., all
- the loads and stores are redirected to `local_buf`. This helps the fused loops to
- work on smaller-sized local buffers for better data locality.
-
- The `local_buf` should already be registered in the local scope and the data access
- is assumed to be contiguous with the same order as the `global_buf`.
- """
- assert local_buf.get_name() in self.local_buffers
- assert len(global_buf.get_size()) == len(local_buf.get_size())
- assert len(nodes) > 0
-
- class LocalizeBufferHandler(V.WrapperHandler): # type: ignore[name-defined]
- def __init__(self, inner):
- super().__init__(inner)
-
- def localize(self, name: str, index: sympy.Expr):
- if name == global_buf.get_name():
- name = local_buf.get_name()
- index_vars = sorted(
- [
- s
- for s in index.free_symbols
- if symbol_is_type(s, SymT.INDEX)
- ],
- key=str,
- )
- index = local_buf.layout.make_indexer()(index_vars)
- return name, index
-
- def load(self, name: str, index: sympy.Expr):
- return self._inner.load(*self.localize(name, index))
-
- def store(self, name, index, value, mode=None):
- return self._inner.store(*self.localize(name, index), value, mode)
-
- def store_reduction(self, name, index, value):
- return self._inner.store_reduction(*self.localize(name, index), value)
-
- def wrap_inner_fn_for_node(node: ir.IRNode, inner_fn_wrapper):
- loops = node.data if isinstance(node, ir.ComputedBuffer) else node
- assert isinstance(loops, ir.Loops)
- new_loops = copy.copy(loops)
- if isinstance(node, ir.ComputedBuffer):
- new_node = ir.ComputedBuffer(
- node.get_name(), node.get_layout(), new_loops
- )
- else:
- new_node = new_loops # type: ignore[assignment]
-
- new_loops.inner_fn = inner_fn_wrapper(new_loops.inner_fn)
- return new_node
-
- def inner_fn_wrapper(inner_fn):
- def inner(index):
- with V.set_ops_handler(LocalizeBufferHandler(V.get_ops_handler())):
- return inner_fn(index)
-
- return inner
-
- return [wrap_inner_fn_for_node(node, inner_fn_wrapper) for node in nodes]
diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py
index 919da9f..fa14b44 100644
--- a/torch/_inductor/kernel/mm.py
+++ b/torch/_inductor/kernel/mm.py
@@ -327,7 +327,6 @@
[inp_expanded, mat1, mat2],
alpha=alpha,
beta=beta,
- has_bias=True,
)
add_aten_fallback = False
diff --git a/torch/_inductor/mkldnn_lowerings.py b/torch/_inductor/mkldnn_lowerings.py
index 116941d..f9f2e66 100644
--- a/torch/_inductor/mkldnn_lowerings.py
+++ b/torch/_inductor/mkldnn_lowerings.py
@@ -21,149 +21,7 @@
ExternKernelChoice,
)
from .utils import use_aten_gemm_kernels, use_cpp_packed_gemm_template, use_max_autotune
-from .virtualized import ops, V
-
-
-def create_epilogue_with_attr(input_buffer, attr, **kwargs):
- input_loader = input_buffer.make_loader()
- dtype = input_buffer.get_dtype()
- if attr == "relu":
-
- def inner_fn(index):
- input = input_loader(index)
- zero = ops.constant(0, dtype)
- return ops.maximum(input, zero)
-
- elif attr == "gelu":
- assert "algorithm" in kwargs
- if kwargs["algorithm"] == "none":
-
- def inner_fn(index):
- input = input_loader(index)
- if dtype != torch.float:
- input = ops.to_dtype(input, torch.float)
- half = ops.constant(0.5, torch.float)
- one = ops.constant(1.0, torch.float)
- const = ops.constant(0.7071067811865476, torch.float)
- result = input * half * (ops.erf(input * const) + one)
- if dtype != torch.float:
- result = ops.to_dtype(result, dtype)
- return result
-
- else:
- assert kwargs["algorithm"] == "tanh"
-
- def inner_fn(index):
- input = input_loader(index)
- if dtype != torch.float:
- input = ops.to_dtype(input, torch.float)
- half = ops.constant(0.5, torch.float)
- one = ops.constant(1.0, torch.float)
- const1 = ops.constant(0.7978845608028654, torch.float)
- const2 = ops.constant(0.044715, torch.float)
- result = (
- half
- * input
- * (
- one
- + ops.tanh(const1 * (input + const2 * input * input * input))
- )
- )
- if dtype != torch.float:
- result = ops.to_dtype(result, dtype)
- return result
-
- elif attr == "swish":
-
- def inner_fn(index):
- input = input_loader(index)
- result = input * ops.sigmoid(input)
- return result
-
- elif attr == "sigmoid":
-
- def inner_fn(index):
- return ops.sigmoid(input_loader(index))
-
- elif attr == "tanh":
-
- def inner_fn(index):
- return ops.tanh(input_loader(index))
-
- elif attr == "hardswish" or attr == "hardsigmoid":
-
- def hardsigmoid_float(input):
- zero = ops.constant(0, torch.float)
- six = ops.constant(6, torch.float)
- three = ops.constant(3, torch.float)
- one_over_six = ops.constant(0.16666666666666666, torch.float)
- max = ops.maximum(input + three, zero)
- min = ops.minimum(max, six)
- return min * one_over_six
-
- def inner_fn(index):
- input = input_loader(index)
- if dtype != torch.float:
- input = ops.to_dtype(input, torch.float)
- result = hardsigmoid_float(input)
- if attr == "hardswish":
- result = input * result
- if dtype != torch.float:
- result = ops.to_dtype(result, dtype)
- return result
-
- elif attr == "leaky_relu":
- assert "scalars" in kwargs
- assert len(kwargs["scalars"]) == 1
- negative_slope = kwargs["scalars"][0]
-
- def inner_fn(index):
- input = input_loader(index)
- if dtype != torch.float:
- input = ops.to_dtype(input, torch.float)
- zero = ops.constant(0, torch.float)
- result = ops.where(
- input > zero, input, input * ops.constant(negative_slope, torch.float)
- )
- if dtype != torch.float:
- result = ops.to_dtype(result, dtype)
- return result
-
- elif attr == "hardtanh":
- assert "scalars" in kwargs
- assert len(kwargs["scalars"]) == 2
- min_value = kwargs["scalars"][0]
- max_value = kwargs["scalars"][1]
-
- def inner_fn(index):
- input = input_loader(index)
- if dtype != torch.float:
- input = ops.to_dtype(input, torch.float)
- result = ops.minimum(
- ops.maximum(input, ops.constant(min_value, torch.float)),
- ops.constant(max_value, torch.float),
- )
- if dtype != torch.float:
- result = ops.to_dtype(result, dtype)
- return result
-
- elif attr == "add" or attr == "sub":
- assert "other" in kwargs
- other = kwargs["other"]
- other_loader = other.make_loader()
-
- def inner_fn(index):
- op = getattr(ops, attr)
- return op(input_loader(index), other_loader(index))
-
- else:
- raise ValueError(f"Unsupported epilogue attribute: {attr}")
- return ir.Pointwise(
- device=input_buffer.get_device(),
- dtype=dtype,
- inner_fn=inner_fn,
- ranges=input_buffer.get_size(),
- )
+from .virtualized import V
def register_onednn_fusion_ops():
@@ -174,12 +32,6 @@
has_out_variant=False,
kernel_creator=ir.LinearUnary.create,
)
- aten_mkldnn_linear_binary = ExternKernelChoice(
- torch.ops.mkldnn._linear_pointwise.binary,
- "mkldnn::_linear_pointwise",
- has_out_variant=False,
- kernel_creator=ir.LinearBinary.create,
- )
cpu_needs_realized_inputs = [
torch.ops.mkldnn._convolution_pointwise,
torch.ops.mkldnn._convolution_pointwise_,
@@ -299,44 +151,51 @@
if len(x_size) > 2:
# GEMM template needs 2D input, normalize input shape here
x = view(x, [-1, x_size[-1]])
- if b is not None:
- b = ir.ExternKernel.realize_input(b)
- inputs = [x, w] if b is None else [x, w, b]
choices: List[ChoiceCaller] = []
if len(choices) == 0 or use_aten_gemm_kernels():
- kwargs = dict(attr=attr, scalars=scalars, algorithm=algorithm)
- if b is None:
- kwargs["B"] = None
choices.append(
aten_mkldnn_linear_unary.bind(
- inputs,
+ (x, w),
layout,
- **kwargs,
+ B=None,
+ attr=attr,
+ scalars=scalars,
+ algorithm=algorithm,
+ )
+ if b is None
+ else aten_mkldnn_linear_unary.bind(
+ (x, w, b),
+ layout,
+ attr=attr,
+ scalars=scalars,
+ algorithm=algorithm,
)
)
if use_max_autotune():
transposed_w = permute(w, [1, 0])
*_, layout, x, transposed_w = mm_args(x, transposed_w, layout=layout)
- if use_cpp_packed_gemm_template(layout, x, transposed_w):
-
- def epilogue_creator(buf):
- return create_epilogue_with_attr(
- buf, attr, scalars=scalars, algorithm=algorithm
+ if b is not None:
+ b = ir.ExternKernel.realize_input(b)
+ # TODO(jgong5): support epilogue fusion
+ if (
+ use_cpp_packed_gemm_template(layout, x, transposed_w)
+ and attr == "none"
+ ):
+ if b is None:
+ CppPackedGemmTemplate.add_choices(
+ choices,
+ layout,
+ [x, w],
+ trans_w=True,
)
-
- kwargs = dict(
- has_bias=b is not None,
- trans_w=True,
- epilogue_creator=None if attr == "none" else epilogue_creator,
- )
- if b is not None:
- kwargs["input_indices"] = [2, 0, 1]
- CppPackedGemmTemplate.add_choices(
- choices,
- layout,
- inputs,
- **kwargs,
- )
+ else:
+ CppPackedGemmTemplate.add_choices(
+ choices,
+ layout,
+ [x, w, b],
+ trans_w=True,
+ input_indices=[2, 0, 1],
+ )
assert w.get_name() in V.graph.constants
input_gen_fns = {
1: lambda x: V.graph.constants[x.get_name()],
@@ -344,7 +203,7 @@
result = autotune_select_algorithm(
"linear_unary",
choices,
- inputs,
+ [x, w] if b is None else [x, w, b],
layout,
input_gen_fns=input_gen_fns,
)
@@ -353,67 +212,8 @@
return result
@register_lowering(torch.ops.mkldnn._linear_pointwise.binary)
- def linear_binary(
- x: TensorBox, y: TensorBox, w: TensorBox, b: TensorBox, attr, layout=None
- ):
- x_size = x.get_size()
- if len(x_size) > 2:
- # GEMM template needs 2D input, normalize input shape here
- x = view(x, [-1, x_size[-1]])
- y_size = y.get_size()
- if len(y_size) > 2:
- y = view(y, [-1, y_size[-1]])
- if b is not None:
- b = ir.ExternKernel.realize_input(b)
- inputs = [x, y, w] if b is None else [x, y, w, b]
- choices: List[ChoiceCaller] = []
- if len(choices) == 0 or use_aten_gemm_kernels():
- kwargs = dict(attr=attr)
- if b is None:
- kwargs["B"] = None
- choices.append(
- aten_mkldnn_linear_binary.bind(
- inputs,
- layout,
- **kwargs,
- )
- )
- if use_max_autotune():
- transposed_w = permute(w, [1, 0])
- *_, layout, x, transposed_w, y = mm_args(
- x, transposed_w, y, layout=layout
- )
- if use_cpp_packed_gemm_template(layout, x, transposed_w):
-
- def epilogue_creator(buf):
- return create_epilogue_with_attr(buf, attr, other=y)
-
- kwargs = dict(
- has_bias=b is not None,
- trans_w=True,
- epilogue_creator=epilogue_creator,
- )
- kwargs["input_indices"] = [0, 2, 1] if b is None else [3, 0, 2, 1]
- CppPackedGemmTemplate.add_choices(
- choices,
- layout,
- inputs,
- **kwargs,
- )
- assert w.get_name() in V.graph.constants
- input_gen_fns = {
- 2: lambda x: V.graph.constants[x.get_name()],
- }
- result = autotune_select_algorithm(
- "linear_binary",
- choices,
- inputs,
- layout,
- input_gen_fns=input_gen_fns,
- )
- if len(x_size) > 2:
- result = view(result, (*x_size[:-1], result.get_size()[-1]))
- return result
+ def linear_binary(x: TensorBox, y: TensorBox, w: TensorBox, b: TensorBox, attr):
+ return TensorBox.create(ir.LinearBinary.create(x, y, w, b, attr))
@register_lowering(torch.ops.mkldnn._convolution_transpose_pointwise)
def convolution_transpose_unary(