Revert "[inductor][cpp] support bf16/fp16 gemm template epilogue fusion (#126545)"

This reverts commit 43baabe9b94c86bd36ba4a00f501e52d833d7ec8.

Reverted https://github.com/pytorch/pytorch/pull/126545 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/124021#issuecomment-2133002071))
diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py
index 70cb4ea..0147ca8 100644
--- a/test/inductor/test_cpu_select_algorithm.py
+++ b/test/inductor/test_cpu_select_algorithm.py
@@ -3,7 +3,6 @@
 
 import sys
 import unittest
-from typing import Optional
 from unittest.mock import patch
 
 import torch
@@ -70,37 +69,6 @@
     return wrapped
 
 
-def _get_epilogue(epilogue: str, other: Optional[torch.Tensor] = None):
-    if epilogue == "none":
-        return lambda x: x
-    elif epilogue == "relu":
-        return torch.nn.ReLU()
-    elif epilogue == "gelu":
-        return torch.nn.GELU()
-    elif epilogue == "silu":
-        return torch.nn.SiLU()
-    elif epilogue == "sigmoid":
-        return torch.nn.Sigmoid()
-    elif epilogue == "tanh":
-        return torch.nn.Tanh()
-    elif epilogue == "hardswish":
-        return torch.nn.Hardswish()
-    elif epilogue == "hardsigmoid":
-        return torch.nn.Hardsigmoid()
-    elif epilogue == "leaky_relu":
-        return torch.nn.LeakyReLU()
-    elif epilogue == "hardtanh":
-        return torch.nn.Hardtanh()
-    elif epilogue == "add":
-        return lambda x: x + other
-    elif epilogue == "sub":
-        return lambda x: x - other
-    elif epilogue == "mul":
-        return lambda x: x * other
-    elif epilogue == "div":
-        return lambda x: x / other
-
-
 class TestSelectAlgorithm(TestCase):
     common = check_model
 
@@ -196,7 +164,7 @@
             "div",
         ),
     )
-    @dtypes(torch.float, torch.bfloat16, torch.half)
+    @dtypes(torch.float)
     def test_linear_with_pointwise(self, bias, epilogue, dtype):
         batch_size = 384
         in_features = 196
@@ -206,7 +174,32 @@
             def __init__(self, bias, epilogue, other):
                 super().__init__()
                 self.linear = torch.nn.Linear(in_features, out_features, bias)
-                self.epilogue = _get_epilogue(epilogue, other)
+                if epilogue == "relu":
+                    self.epilogue = torch.nn.ReLU()
+                elif epilogue == "gelu":
+                    self.epilogue = torch.nn.GELU()
+                elif epilogue == "silu":
+                    self.epilogue = torch.nn.SiLU()
+                elif epilogue == "sigmoid":
+                    self.epilogue = torch.nn.Sigmoid()
+                elif epilogue == "tanh":
+                    self.epilogue = torch.nn.Tanh()
+                elif epilogue == "hardswish":
+                    self.epilogue = torch.nn.Hardswish()
+                elif epilogue == "hardsigmoid":
+                    self.epilogue = torch.nn.Hardsigmoid()
+                elif epilogue == "leaky_relu":
+                    self.epilogue = torch.nn.LeakyReLU()
+                elif epilogue == "hardtanh":
+                    self.epilogue = torch.nn.Hardtanh()
+                elif epilogue == "add":
+                    self.epilogue = lambda x: x + other
+                elif epilogue == "sub":
+                    self.epilogue = lambda x: x - other
+                elif epilogue == "mul":
+                    self.epilogue = lambda x: x * other
+                elif epilogue == "div":
+                    self.epilogue = lambda x: x / other
 
             def forward(self, x):
                 return self.epilogue(self.linear(x))
@@ -215,69 +208,7 @@
         v = torch.randn(batch_size, in_features).to(dtype=dtype)
         u = torch.randn(batch_size, out_features).to(dtype=dtype)
         mod = M(bias=bias, epilogue=epilogue, other=u).to(dtype=dtype).eval()
-        atol, rtol = 1e-4, 1e-4
-        if dtype == torch.half or dtype == torch.bfloat16:
-            atol, rtol = 1e-2, 1e-2
-        with patch.object(select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol)):
-            self.common(mod, (v,), atol=atol, rtol=rtol)
-        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
-        if (
-            dtype == torch.bfloat16
-            and epilogue != "mul"
-            and epilogue != "div"
-            or (dtype == torch.half and epilogue == "add" and not bias)
-        ):
-            # Several scenarios where epilogue fusion is not counted in:
-            # 1. For bfloat16, the epilogue fusion is part of the template,
-            #    not fused via scheduler. This will also be true for float16 but
-            #    float16 oneDNN linear is not supported right now. The exception
-            #    is mul or div fusion which is not supported for oneDNN linear.
-            # 2. For float16, since oneDNN linear is not applied, linear w/o bias
-            #    plus epilogue add is treated as linear w/ bias.
-            self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 0)
-        else:
-            self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
-
-    @inductor_config.patch({"freezing": True})
-    @patches
-    @torch.no_grad
-    @unittest.skipIf(not TEST_MKL, "Test requires MKL")
-    @parametrize("bias", (True, False))
-    @parametrize(
-        "epilogue",
-        (
-            "none",
-            "relu",
-            "add",
-            "sub",
-            "mul",
-        ),
-    )
-    @dtypes(torch.float, torch.bfloat16, torch.half)
-    def test_linear_with_transpose(self, bias, epilogue, dtype):
-        batch_size = 384
-        in_features = 196
-        out_features = 128
-
-        class M(torch.nn.Module):
-            def __init__(self, bias, epilogue, other):
-                super().__init__()
-                self.epilogue = _get_epilogue(epilogue, other)
-                self.linear = torch.nn.Linear(in_features, out_features, bias)
-
-            def forward(self, x, y):
-                return self.epilogue(self.linear(x)).transpose(0, 1) + y
-
-        counters.clear()
-        v = torch.randn(batch_size, in_features).to(dtype=dtype)
-        u = torch.randn(out_features, batch_size).to(dtype=dtype)
-        other = torch.randn(batch_size, out_features).to(dtype=dtype)
-        mod = M(bias=bias, epilogue=epilogue, other=other).to(dtype=dtype).eval()
-        atol, rtol = 1e-4, 1e-4
-        if dtype == torch.half or dtype == torch.bfloat16:
-            atol, rtol = 1e-2, 1e-2
-        with patch.object(select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol)):
-            self.common(mod, (v, u), atol=atol, rtol=rtol)
+        self.common(mod, (v,))
         self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
         self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
 
@@ -286,44 +217,25 @@
     @torch.no_grad
     @unittest.skipIf(not TEST_MKL, "Test requires MKL")
     @parametrize("bias", (True, False))
-    @parametrize(
-        "unary",
-        ("relu",),
-    )
-    @parametrize(
-        "binary",
-        (
-            "add",
-            "sub",
-            "mul",
-            "div",
-        ),
-    )
-    @dtypes(torch.float, torch.bfloat16, torch.half)
-    def test_linear_with_unary_binary(self, bias, unary, binary, dtype):
+    @dtypes(torch.float)
+    def test_linear_with_transpose(self, bias, dtype):
         batch_size = 384
         in_features = 196
-        out_features = 384
+        out_features = 128
 
         class M(torch.nn.Module):
-            def __init__(self, bias, unary, binary, other):
+            def __init__(self, bias):
                 super().__init__()
                 self.linear = torch.nn.Linear(in_features, out_features, bias)
-                self.unary = _get_epilogue(unary)
-                self.binary = _get_epilogue(binary, other)
 
-            def forward(self, x):
-                return self.binary(self.unary(self.linear(x)))
+            def forward(self, x, y):
+                return self.linear(x).transpose(0, 1) + y
 
         counters.clear()
+        mod = M(bias=bias).to(dtype=dtype).eval()
         v = torch.randn(batch_size, in_features).to(dtype=dtype)
-        u = torch.randn(batch_size, out_features).to(dtype=dtype)
-        mod = M(bias=bias, unary=unary, binary=binary, other=u).to(dtype=dtype).eval()
-        atol, rtol = 1e-4, 1e-4
-        if dtype == torch.half or dtype == torch.bfloat16:
-            atol, rtol = 1e-2, 1e-2
-        with patch.object(select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol)):
-            self.common(mod, (v,), atol=atol, rtol=rtol)
+        u = torch.randn(out_features, batch_size).to(dtype=dtype)
+        self.common(mod, (v, u))
         self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
         self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
 
@@ -342,9 +254,6 @@
     test_linear_with_transpose_dynamic_shapes = (
         TestSelectAlgorithm.test_linear_with_transpose
     )
-    test_linear_with_unary_binary_dynamic_shapes = (
-        TestSelectAlgorithm.test_linear_with_unary_binary
-    )
 
 
 instantiate_device_type_tests(TestSelectAlgorithm, globals(), only_for="cpu")
diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py
index 5046b6b..258d088 100644
--- a/torch/_inductor/codegen/cpp.py
+++ b/torch/_inductor/codegen/cpp.py
@@ -3029,7 +3029,7 @@
         scheduler_node._lowp_fp_type = _lowp_fp_type  # type: ignore[attr-defined]
         return True
 
-    def legalize_lowp_fp_dtype_loopbody(self, loop_body: ir.LoopBody):
+    def legalize_lowp_fp_dtype(self, nodes):
         def add_to_dtype(sub_graph: torch.fx.Graph):
             def is_lowp_fp_load(node: torch.fx.Node):
                 if node.target not in ["load"]:
@@ -3167,11 +3167,11 @@
 
             eliminate_to_dtype(sub_graph)
 
-        sub_blocks = [loop_body.root_block] + list(loop_body.subblocks.values())
-        for sub_block in sub_blocks:
-            add_to_dtype(sub_block.graph)
+        def _legalize_lowp_fp(loop_body: ir.LoopBody):
+            sub_blocks = [loop_body.root_block] + list(loop_body.subblocks.values())
+            for sub_block in sub_blocks:
+                add_to_dtype(sub_block.graph)
 
-    def legalize_lowp_fp_dtype(self, nodes):
         if all(
             isinstance(_node, SchedulerNode) and self.is_lowp_fp_scheduler(_node)
             for _node in nodes
@@ -3208,7 +3208,7 @@
             should_legalize = not is_memory_copy_scheduler_node(node)
             if should_legalize:
                 body: ir.LoopBody = node._body
-                self.legalize_lowp_fp_dtype_loopbody(body)
+                _legalize_lowp_fp(body)
 
     def codegen_functions(self, fn_list, var_sizes_list, vec_dtype=torch.float):
         # TODO(jgong5): remove vec_dtype arg with alternative tiling factors for various dtypes
@@ -3373,8 +3373,8 @@
                 inner_tail_loop.set_kernel(vec_kernel)
 
     def codegen_loop_bodies(self, loop_bodies, var_sizes_list):
+        # TODO(jgong5): support lowp legalization
         for body in loop_bodies:
-            self.legalize_lowp_fp_dtype_loopbody(body)
             DataTypePropagation.propagate_loopbody(body)
         self.codegen_functions(loop_bodies, var_sizes_list)
 
@@ -3779,6 +3779,7 @@
         kernel, render = ctb.make_kernel_render(ctb, epilogue_nodes=epilogue_ir_nodes)
         with kernel:
             for node in [template_node, *epilogue_nodes]:
+                node.decide_inplace_update()
                 node.mark_run()  # type: ignore[attr-defined]
             src_code = render()
 
diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py
index c28d766..18d6301 100644
--- a/torch/_inductor/codegen/cpp_gemm_template.py
+++ b/torch/_inductor/codegen/cpp_gemm_template.py
@@ -1,4 +1,4 @@
-from typing import Callable, cast, List, Optional
+from typing import cast, List, Optional
 
 import torch
 import torch.utils
@@ -20,7 +20,7 @@
 {{micro_gemm.codegen_define(kernel)}}
 
 extern "C"
-{{kernel.def_kernel(inputs={"X": X, "W": W, "inp": inp}, outputs={"Y": Y}, aliases=buffer_aliases)}}
+{{kernel.def_kernel(inputs={"X": X, "W": W, "inp": inp}, outputs={"Y": Y})}}
 {
     {{kernel.maybe_codegen_profile()}}
     constexpr int64_t num_threads = {{num_threads}};
@@ -91,8 +91,8 @@
                 const int64_t n_start = nc * N0;
                 const int64_t n_size = N0;
                 {%- if use_local_acc %}
-                {{ kernel.define_buffer(acc_buf_name, ["m_end - m_start", "N0"]) }}
-                {%- set acc = kernel.local_buffers[acc_buf_name] %}
+                {{ kernel.define_buffer("acc_local_buf", ["m_end - m_start", "N0"]) }}
+                {%- set acc = kernel.local_buffers["acc_local_buf"] %}
                 {%- else %}
                 {%- set acc = kernel.slice_nd(GemmOut, [("m_start", "m_end"), ("n_start", "n_start + N0")]) %}
                 {%- endif %}
@@ -127,7 +127,7 @@
                 {%- endif %}
                 {%- set tile_Y = kernel.slice_nd(Y_maybe_transposed, [("m_start", "m_end"), ("n_start", "n_start + N0")]) %}
                 {{ kernel.store_output(
-                      tile_Y, acc, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexer=reindexer
+                      tile_Y, acc, epilogue_nodes, offsets=("m_start", "n_start"), reindexer=reindexer
                    )|indent(16, false)
                 }}
             }
@@ -146,16 +146,11 @@
         register_blocking: GemmBlocking,
         beta=1,
         alpha=1,
-        has_bias=False,
-        epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
     ):
         assert layout.dtype in [torch.float, torch.bfloat16, torch.half]
-        super().__init__(
-            "packed_gemm", input_nodes, layout, epilogue_creator=epilogue_creator
-        )
+        super().__init__("packed_gemm", input_nodes, layout)
         self.beta = beta
         self.alpha = alpha
-        self.has_bias = has_bias
         self.num_threads = num_threads
         self.register_blocking = register_blocking
         m, n = layout.size
@@ -223,30 +218,26 @@
         input_nodes,
         beta=1,
         alpha=1,
-        has_bias=False,
         trans_w=False,
         input_indices=None,
-        epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
     ):
         if input_indices is None:
             input_indices = list(range(len(input_nodes)))
 
         def reorder_and_filter(inputs, layout_or_out):
-            if has_bias:
-                assert len(input_indices) >= 3
+            if len(input_indices) == 2:
+                x_idx = input_indices[0]
+                w_idx = input_indices[1]
+                return [inputs[x_idx], inputs[w_idx]], layout_or_out
+            else:
+                assert (
+                    len(input_indices) == 3
+                ), "Cpp Packed GEMM template requires 2 or 3 input nodes."
                 # assume the input order is [inp, x, w] and we reorder it to [x, w, inp]
                 inp_idx = input_indices[0]
                 x_idx = input_indices[1]
                 w_idx = input_indices[2]
-                return [
-                    inputs[x_idx],
-                    inputs[w_idx],
-                    inputs[inp_idx],
-                    *[inputs[idx] for idx in input_indices[3:]],
-                ], layout_or_out
-            else:
-                assert len(input_indices) >= 2
-                return [inputs[idx] for idx in input_indices], layout_or_out
+                return [inputs[x_idx], inputs[w_idx], inputs[inp_idx]], layout_or_out
 
         def maybe_to_dense(inputs, layout_or_out):
             new_inputs = list(inputs)
@@ -382,8 +373,6 @@
             register_blocking=micro_gemm.register_blocking,
             beta=beta,
             alpha=alpha,
-            has_bias=has_bias,
-            epilogue_creator=epilogue_creator,
         )
         template.maybe_append_choice(choices)
         return template
@@ -398,7 +387,7 @@
         assert len(self.input_nodes) >= 2
 
         X, W = self.input_nodes[0], self.input_nodes[1]
-        inp = self.input_nodes[2] if self.has_bias else None
+        inp = self.input_nodes[2] if len(self.input_nodes) > 2 else None
         Y = self.output_node
 
         if template_buffer_node is not None:
@@ -407,26 +396,11 @@
             Y = template_buffer_node
 
         template_buffer = Y
-        gemm_output_buffer = template_buffer
-
-        epilogues: List[ir.IRNode] = []
-        if self.epilogue_creator is not None:
-            gemm_output_name = "GemmOut"
-            gemm_output_buffer = ir.Buffer(gemm_output_name, template_buffer.layout)
-            epilogues.append(
-                ir.ComputedBuffer(
-                    name=template_buffer.get_name(),
-                    layout=template_buffer.layout,
-                    data=self.epilogue_creator(gemm_output_buffer),
-                )
-            )
-
         Y_is_transposed = False
         use_local_acc = self.layout.dtype != torch.float
-        acc_buf_name = "local_acc_buf"
         if epilogue_nodes:
-            epilogues.extend(epilogue_nodes)
             Y = cast(ir.Buffer, epilogue_nodes[-1])
+            assert Y.get_name() in V.kernel.inplace_update_buffers
             if Y.get_size() == list(
                 reversed(template_buffer.get_size())
             ) and Y.get_stride() == list(reversed(template_buffer.get_stride())):
@@ -450,10 +424,7 @@
             W=W,
             inp=inp,
             Y=Y,
-            GemmOut=gemm_output_buffer,
-            buffer_aliases=[(gemm_output_buffer, Y)]
-            if gemm_output_buffer is not Y
-            else None,
+            GemmOut=template_buffer,
             beta=self.beta,
             alpha=self.alpha,
             num_threads=self.num_threads,
@@ -461,9 +432,8 @@
             is_dynamic_M=self.is_dynamic_M,
             template=self,
             kernel=kernel,
-            epilogue_nodes=epilogues,
+            epilogue_nodes=epilogue_nodes,
             reindexer=(lambda x: list(reversed(x))) if Y_is_transposed else None,
             use_local_acc=use_local_acc,
-            acc_buf_name=acc_buf_name,
         )
         return self._template_from_string(GEMM_TEMPLATE).render(**options)
diff --git a/torch/_inductor/codegen/cpp_template.py b/torch/_inductor/codegen/cpp_template.py
index 85f6a61..aeebd26 100644
--- a/torch/_inductor/codegen/cpp_template.py
+++ b/torch/_inductor/codegen/cpp_template.py
@@ -3,7 +3,7 @@
 import logging
 
 import sys
-from typing import Callable, List, Optional
+from typing import List, Optional
 from unittest.mock import patch
 
 import sympy
@@ -26,13 +26,11 @@
         name: str,
         input_nodes,
         layout: ir.Layout,
-        epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
     ):
         super().__init__(name)
         self.input_nodes = input_nodes
         self.output_node: ir.Buffer = ir.Buffer("buf_out", layout)
         self.layout = layout
-        self.epilogue_creator = epilogue_creator
 
     def generate(self, **kwargs):
         kernel_name = f"cpp_{self.name}"
diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py
index 25d8a22..1c704cc 100644
--- a/torch/_inductor/codegen/cpp_template_kernel.py
+++ b/torch/_inductor/codegen/cpp_template_kernel.py
@@ -5,12 +5,11 @@
 from sympy.parsing.sympy_parser import parse_expr
 
 import torch
-from torch.utils._sympy.symbol import SymT
 from .. import codecache, config, ir, lowering as L
 
 from ..autotune_process import CppBenchmarkRequest
 from ..select_algorithm import PartialRender
-from ..utils import sympy_index_symbol, sympy_index_symbol_with_prefix
+from ..utils import sympy_index_symbol
 from ..virtualized import V
 from .common import Kernel, OpOverrides
 from .cpp import CppKernelProxy, KernelGroup
@@ -52,26 +51,13 @@
         self,
         inputs: Dict[str, ir.Buffer],
         outputs: Dict[str, ir.Buffer],
-        aliases: Optional[List[Tuple[ir.Buffer, ir.Buffer]]] = None,
     ) -> str:
         for name, inp in inputs.items():
             if inp is not None:
                 self.args.input_buffers[inp.get_name()] = name
         for name, out in outputs.items():
-            self.args.output_buffers[out.get_name()] = name
-        if aliases is not None:
-            for alias, orig in aliases:
-                orig_name = orig.get_name()
-                alias_name = alias.get_name()
-                if orig_name in self.args.input_buffers:
-                    self.args.input_buffers[alias_name] = self.args.input_buffers[
-                        orig_name
-                    ]
-                if orig_name in self.args.output_buffers:
-                    self.args.output_buffers[alias_name] = self.args.output_buffers[
-                        orig_name
-                    ]
-
+            if out.get_name() not in self.args.inplace_buffers:
+                self.args.output_buffers[out.get_name()] = name
         unique_sizevars = {
             s
             for input in inputs.values()
@@ -92,14 +78,6 @@
             self.args.sizevars[sizevar] = f"k{sizevar}"
 
         def hook():
-            # remove all aliases before generate function definition
-            if aliases is not None:
-                for alias, _ in aliases:
-                    alias_name = alias.get_name()
-                    if alias_name in self.args.input_buffers:
-                        self.args.input_buffers[alias_name] = "REMOVED"
-                    if alias_name in self.args.output_buffers:
-                        self.args.output_buffers[alias_name] = "REMOVED"
             cpp_argdefs, _, _ = self.args.cpp_argdefs()
             return f"void {self.kernel_name}({', '.join(cpp_argdefs)})"
 
@@ -205,10 +183,7 @@
         reindexer: Optional[Callable[[List[Any]], List[Any]]] = None,
     ) -> str:
         var_sizes = (tuple(dst.get_size()), ())
-        var_ranges = {
-            sympy_index_symbol_with_prefix(SymT.INDEX, i): sz
-            for i, sz in enumerate(var_sizes[0])
-        }
+        var_ranges = {sympy.Symbol(f"z{i}"): sz for i, sz in enumerate(var_sizes[0])}
         if not offsets:
             offsets = [sympy.Integer(0)] * len(var_sizes[0])
         assert len(offsets) == len(var_sizes[0])
@@ -228,7 +203,7 @@
                 assert len(args[0]) == len(var_sizes[0])
                 assert len(args[1]) == 0
                 new_args = [arg + offset for arg, offset in zip(args[0], offsets)]  # type: ignore[arg-type]
-                if reindexer is not None and i == len(nodes) - 1:
+                if reindexer is not None:
                     new_args = reindexer(new_args)
                 V.ops.store(
                     output_name,
@@ -248,7 +223,6 @@
         self,
         dst: ir.Buffer,
         src: ir.Buffer,
-        orig_src: Optional[ir.Buffer] = None,
         epilogue_nodes: Optional[List[ir.IRNode]] = None,
         offsets: Optional[List[Any]] = None,
         reindexer: Optional[Callable[[List[Any]], List[Any]]] = None,
@@ -270,23 +244,12 @@
               the sizes of `src` and `dst`.
            b) `dst` might be indexed in a different way as the `epilogue_nodes`, hence a `reindexer` is
               needed on the indices to `epilogue_nodes` to match the indexing of `dst`.
-           c) If `src` is local, we need to add a local buffer for it and localize the `orig_src` buffer
-              in `epilogue_nodes` with `src`.
         """
         assert dst.get_size() == src.get_size()
         if offsets:
             offsets = parse_expr_with_index_symbols(offsets)
         if epilogue_nodes:
-            with LocalBufferScope(self) as scope:
-                assert orig_src is not None
-                if orig_src.get_name() != src.get_name():
-                    scope.add_local_buffer(src)
-                    epilogue_nodes = scope.localize_buffer(
-                        orig_src, src, epilogue_nodes
-                    )
-                return self.store_pointwise_nodes(
-                    dst, epilogue_nodes, offsets, reindexer  # type: ignore[arg-type]
-                )
+            return self.store_pointwise_nodes(dst, epilogue_nodes, offsets, reindexer)
         else:
             if dst.get_name() != src.get_name():
                 # src is local
diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py
index e681766..dbe3daf 100644
--- a/torch/_inductor/codegen/cpp_utils.py
+++ b/torch/_inductor/codegen/cpp_utils.py
@@ -1,15 +1,11 @@
 import contextlib
-import copy
 import math
 
 from collections import namedtuple
-from typing import Dict, List
+from typing import Dict
 from unittest.mock import patch
 
-import sympy
-
 import torch
-from torch.utils._sympy.symbol import symbol_is_type, SymT
 from .. import ir
 from ..virtualized import V
 
@@ -305,69 +301,3 @@
     def add_local_buffer(self, buffer: ir.Buffer):
         assert buffer.get_name() not in self.local_buffers
         self.local_buffers[buffer.get_name()] = buffer
-
-    def localize_buffer(
-        self, global_buf: ir.Buffer, local_buf: ir.Buffer, nodes: List[ir.IRNode]
-    ) -> List[ir.IRNode]:
-        """
-        Localizes the buffer `global_buf` to `local_buf` in the given `nodes` and returns
-        a new list of IR nodes that work on `local_buf` instead of `global_buf`, i.e., all
-        the loads and stores are redirected to `local_buf`. This helps the fused loops to
-        work on smaller-sized local buffers for better data locality.
-
-        The `local_buf` should already be registered in the local scope and the data access
-        is assumed to be contiguous with the same order as the `global_buf`.
-        """
-        assert local_buf.get_name() in self.local_buffers
-        assert len(global_buf.get_size()) == len(local_buf.get_size())
-        assert len(nodes) > 0
-
-        class LocalizeBufferHandler(V.WrapperHandler):  # type: ignore[name-defined]
-            def __init__(self, inner):
-                super().__init__(inner)
-
-            def localize(self, name: str, index: sympy.Expr):
-                if name == global_buf.get_name():
-                    name = local_buf.get_name()
-                    index_vars = sorted(
-                        [
-                            s
-                            for s in index.free_symbols
-                            if symbol_is_type(s, SymT.INDEX)
-                        ],
-                        key=str,
-                    )
-                    index = local_buf.layout.make_indexer()(index_vars)
-                return name, index
-
-            def load(self, name: str, index: sympy.Expr):
-                return self._inner.load(*self.localize(name, index))
-
-            def store(self, name, index, value, mode=None):
-                return self._inner.store(*self.localize(name, index), value, mode)
-
-            def store_reduction(self, name, index, value):
-                return self._inner.store_reduction(*self.localize(name, index), value)
-
-        def wrap_inner_fn_for_node(node: ir.IRNode, inner_fn_wrapper):
-            loops = node.data if isinstance(node, ir.ComputedBuffer) else node
-            assert isinstance(loops, ir.Loops)
-            new_loops = copy.copy(loops)
-            if isinstance(node, ir.ComputedBuffer):
-                new_node = ir.ComputedBuffer(
-                    node.get_name(), node.get_layout(), new_loops
-                )
-            else:
-                new_node = new_loops  # type: ignore[assignment]
-
-            new_loops.inner_fn = inner_fn_wrapper(new_loops.inner_fn)
-            return new_node
-
-        def inner_fn_wrapper(inner_fn):
-            def inner(index):
-                with V.set_ops_handler(LocalizeBufferHandler(V.get_ops_handler())):
-                    return inner_fn(index)
-
-            return inner
-
-        return [wrap_inner_fn_for_node(node, inner_fn_wrapper) for node in nodes]
diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py
index 919da9f..fa14b44 100644
--- a/torch/_inductor/kernel/mm.py
+++ b/torch/_inductor/kernel/mm.py
@@ -327,7 +327,6 @@
             [inp_expanded, mat1, mat2],
             alpha=alpha,
             beta=beta,
-            has_bias=True,
         )
 
     add_aten_fallback = False
diff --git a/torch/_inductor/mkldnn_lowerings.py b/torch/_inductor/mkldnn_lowerings.py
index 116941d..f9f2e66 100644
--- a/torch/_inductor/mkldnn_lowerings.py
+++ b/torch/_inductor/mkldnn_lowerings.py
@@ -21,149 +21,7 @@
     ExternKernelChoice,
 )
 from .utils import use_aten_gemm_kernels, use_cpp_packed_gemm_template, use_max_autotune
-from .virtualized import ops, V
-
-
-def create_epilogue_with_attr(input_buffer, attr, **kwargs):
-    input_loader = input_buffer.make_loader()
-    dtype = input_buffer.get_dtype()
-    if attr == "relu":
-
-        def inner_fn(index):
-            input = input_loader(index)
-            zero = ops.constant(0, dtype)
-            return ops.maximum(input, zero)
-
-    elif attr == "gelu":
-        assert "algorithm" in kwargs
-        if kwargs["algorithm"] == "none":
-
-            def inner_fn(index):
-                input = input_loader(index)
-                if dtype != torch.float:
-                    input = ops.to_dtype(input, torch.float)
-                half = ops.constant(0.5, torch.float)
-                one = ops.constant(1.0, torch.float)
-                const = ops.constant(0.7071067811865476, torch.float)
-                result = input * half * (ops.erf(input * const) + one)
-                if dtype != torch.float:
-                    result = ops.to_dtype(result, dtype)
-                return result
-
-        else:
-            assert kwargs["algorithm"] == "tanh"
-
-            def inner_fn(index):
-                input = input_loader(index)
-                if dtype != torch.float:
-                    input = ops.to_dtype(input, torch.float)
-                half = ops.constant(0.5, torch.float)
-                one = ops.constant(1.0, torch.float)
-                const1 = ops.constant(0.7978845608028654, torch.float)
-                const2 = ops.constant(0.044715, torch.float)
-                result = (
-                    half
-                    * input
-                    * (
-                        one
-                        + ops.tanh(const1 * (input + const2 * input * input * input))
-                    )
-                )
-                if dtype != torch.float:
-                    result = ops.to_dtype(result, dtype)
-                return result
-
-    elif attr == "swish":
-
-        def inner_fn(index):
-            input = input_loader(index)
-            result = input * ops.sigmoid(input)
-            return result
-
-    elif attr == "sigmoid":
-
-        def inner_fn(index):
-            return ops.sigmoid(input_loader(index))
-
-    elif attr == "tanh":
-
-        def inner_fn(index):
-            return ops.tanh(input_loader(index))
-
-    elif attr == "hardswish" or attr == "hardsigmoid":
-
-        def hardsigmoid_float(input):
-            zero = ops.constant(0, torch.float)
-            six = ops.constant(6, torch.float)
-            three = ops.constant(3, torch.float)
-            one_over_six = ops.constant(0.16666666666666666, torch.float)
-            max = ops.maximum(input + three, zero)
-            min = ops.minimum(max, six)
-            return min * one_over_six
-
-        def inner_fn(index):
-            input = input_loader(index)
-            if dtype != torch.float:
-                input = ops.to_dtype(input, torch.float)
-            result = hardsigmoid_float(input)
-            if attr == "hardswish":
-                result = input * result
-            if dtype != torch.float:
-                result = ops.to_dtype(result, dtype)
-            return result
-
-    elif attr == "leaky_relu":
-        assert "scalars" in kwargs
-        assert len(kwargs["scalars"]) == 1
-        negative_slope = kwargs["scalars"][0]
-
-        def inner_fn(index):
-            input = input_loader(index)
-            if dtype != torch.float:
-                input = ops.to_dtype(input, torch.float)
-            zero = ops.constant(0, torch.float)
-            result = ops.where(
-                input > zero, input, input * ops.constant(negative_slope, torch.float)
-            )
-            if dtype != torch.float:
-                result = ops.to_dtype(result, dtype)
-            return result
-
-    elif attr == "hardtanh":
-        assert "scalars" in kwargs
-        assert len(kwargs["scalars"]) == 2
-        min_value = kwargs["scalars"][0]
-        max_value = kwargs["scalars"][1]
-
-        def inner_fn(index):
-            input = input_loader(index)
-            if dtype != torch.float:
-                input = ops.to_dtype(input, torch.float)
-            result = ops.minimum(
-                ops.maximum(input, ops.constant(min_value, torch.float)),
-                ops.constant(max_value, torch.float),
-            )
-            if dtype != torch.float:
-                result = ops.to_dtype(result, dtype)
-            return result
-
-    elif attr == "add" or attr == "sub":
-        assert "other" in kwargs
-        other = kwargs["other"]
-        other_loader = other.make_loader()
-
-        def inner_fn(index):
-            op = getattr(ops, attr)
-            return op(input_loader(index), other_loader(index))
-
-    else:
-        raise ValueError(f"Unsupported epilogue attribute: {attr}")
-    return ir.Pointwise(
-        device=input_buffer.get_device(),
-        dtype=dtype,
-        inner_fn=inner_fn,
-        ranges=input_buffer.get_size(),
-    )
+from .virtualized import V
 
 
 def register_onednn_fusion_ops():
@@ -174,12 +32,6 @@
             has_out_variant=False,
             kernel_creator=ir.LinearUnary.create,
         )
-        aten_mkldnn_linear_binary = ExternKernelChoice(
-            torch.ops.mkldnn._linear_pointwise.binary,
-            "mkldnn::_linear_pointwise",
-            has_out_variant=False,
-            kernel_creator=ir.LinearBinary.create,
-        )
         cpu_needs_realized_inputs = [
             torch.ops.mkldnn._convolution_pointwise,
             torch.ops.mkldnn._convolution_pointwise_,
@@ -299,44 +151,51 @@
             if len(x_size) > 2:
                 # GEMM template needs 2D input, normalize input shape here
                 x = view(x, [-1, x_size[-1]])
-            if b is not None:
-                b = ir.ExternKernel.realize_input(b)
-            inputs = [x, w] if b is None else [x, w, b]
             choices: List[ChoiceCaller] = []
             if len(choices) == 0 or use_aten_gemm_kernels():
-                kwargs = dict(attr=attr, scalars=scalars, algorithm=algorithm)
-                if b is None:
-                    kwargs["B"] = None
                 choices.append(
                     aten_mkldnn_linear_unary.bind(
-                        inputs,
+                        (x, w),
                         layout,
-                        **kwargs,
+                        B=None,
+                        attr=attr,
+                        scalars=scalars,
+                        algorithm=algorithm,
+                    )
+                    if b is None
+                    else aten_mkldnn_linear_unary.bind(
+                        (x, w, b),
+                        layout,
+                        attr=attr,
+                        scalars=scalars,
+                        algorithm=algorithm,
                     )
                 )
             if use_max_autotune():
                 transposed_w = permute(w, [1, 0])
                 *_, layout, x, transposed_w = mm_args(x, transposed_w, layout=layout)
-                if use_cpp_packed_gemm_template(layout, x, transposed_w):
-
-                    def epilogue_creator(buf):
-                        return create_epilogue_with_attr(
-                            buf, attr, scalars=scalars, algorithm=algorithm
+                if b is not None:
+                    b = ir.ExternKernel.realize_input(b)
+                # TODO(jgong5): support epilogue fusion
+                if (
+                    use_cpp_packed_gemm_template(layout, x, transposed_w)
+                    and attr == "none"
+                ):
+                    if b is None:
+                        CppPackedGemmTemplate.add_choices(
+                            choices,
+                            layout,
+                            [x, w],
+                            trans_w=True,
                         )
-
-                    kwargs = dict(
-                        has_bias=b is not None,
-                        trans_w=True,
-                        epilogue_creator=None if attr == "none" else epilogue_creator,
-                    )
-                    if b is not None:
-                        kwargs["input_indices"] = [2, 0, 1]
-                    CppPackedGemmTemplate.add_choices(
-                        choices,
-                        layout,
-                        inputs,
-                        **kwargs,
-                    )
+                    else:
+                        CppPackedGemmTemplate.add_choices(
+                            choices,
+                            layout,
+                            [x, w, b],
+                            trans_w=True,
+                            input_indices=[2, 0, 1],
+                        )
             assert w.get_name() in V.graph.constants
             input_gen_fns = {
                 1: lambda x: V.graph.constants[x.get_name()],
@@ -344,7 +203,7 @@
             result = autotune_select_algorithm(
                 "linear_unary",
                 choices,
-                inputs,
+                [x, w] if b is None else [x, w, b],
                 layout,
                 input_gen_fns=input_gen_fns,
             )
@@ -353,67 +212,8 @@
             return result
 
         @register_lowering(torch.ops.mkldnn._linear_pointwise.binary)
-        def linear_binary(
-            x: TensorBox, y: TensorBox, w: TensorBox, b: TensorBox, attr, layout=None
-        ):
-            x_size = x.get_size()
-            if len(x_size) > 2:
-                # GEMM template needs 2D input, normalize input shape here
-                x = view(x, [-1, x_size[-1]])
-            y_size = y.get_size()
-            if len(y_size) > 2:
-                y = view(y, [-1, y_size[-1]])
-            if b is not None:
-                b = ir.ExternKernel.realize_input(b)
-            inputs = [x, y, w] if b is None else [x, y, w, b]
-            choices: List[ChoiceCaller] = []
-            if len(choices) == 0 or use_aten_gemm_kernels():
-                kwargs = dict(attr=attr)
-                if b is None:
-                    kwargs["B"] = None
-                choices.append(
-                    aten_mkldnn_linear_binary.bind(
-                        inputs,
-                        layout,
-                        **kwargs,
-                    )
-                )
-            if use_max_autotune():
-                transposed_w = permute(w, [1, 0])
-                *_, layout, x, transposed_w, y = mm_args(
-                    x, transposed_w, y, layout=layout
-                )
-                if use_cpp_packed_gemm_template(layout, x, transposed_w):
-
-                    def epilogue_creator(buf):
-                        return create_epilogue_with_attr(buf, attr, other=y)
-
-                    kwargs = dict(
-                        has_bias=b is not None,
-                        trans_w=True,
-                        epilogue_creator=epilogue_creator,
-                    )
-                    kwargs["input_indices"] = [0, 2, 1] if b is None else [3, 0, 2, 1]
-                    CppPackedGemmTemplate.add_choices(
-                        choices,
-                        layout,
-                        inputs,
-                        **kwargs,
-                    )
-            assert w.get_name() in V.graph.constants
-            input_gen_fns = {
-                2: lambda x: V.graph.constants[x.get_name()],
-            }
-            result = autotune_select_algorithm(
-                "linear_binary",
-                choices,
-                inputs,
-                layout,
-                input_gen_fns=input_gen_fns,
-            )
-            if len(x_size) > 2:
-                result = view(result, (*x_size[:-1], result.get_size()[-1]))
-            return result
+        def linear_binary(x: TensorBox, y: TensorBox, w: TensorBox, b: TensorBox, attr):
+            return TensorBox.create(ir.LinearBinary.create(x, y, w, b, attr))
 
         @register_lowering(torch.ops.mkldnn._convolution_transpose_pointwise)
         def convolution_transpose_unary(