[Inductor CUTLASS backend] Step 4: CUDA (template) kernels (#107931)

This is the step 4 to add cutlass as an alternative inductor backend.
Full tests can be found from the last PR in the stack.

Feature request: https://github.com/pytorch/pytorch/issues/106991.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107931
Approved by: https://github.com/aakhundov, https://github.com/jansel, https://github.com/kadeng
ghstack dependencies: #107802, #107847, #107901
diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py
index d39e703..ce46477 100644
--- a/torch/_inductor/autotune_process.py
+++ b/torch/_inductor/autotune_process.py
@@ -17,6 +17,7 @@
 from torch._inductor.codecache import CUDACodeCache, DLLWrapper, PyCodeCache
 
 if TYPE_CHECKING:
+    from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller
     from torch._inductor.select_algorithm import TritonTemplateCaller
 
 from .utils import do_bench_using_profiling
@@ -379,7 +380,7 @@
 
 
 def benchmark_in_sub_process(
-    choice: "TritonTemplateCaller",
+    choice: "Union[TritonTemplateCaller, CUDATemplateCaller]",
 ) -> float:
     """
     Do benchmarking in subprocess and return the perf number (latency).
diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py
index 312f711..59105c5 100644
--- a/torch/_inductor/codegen/common.py
+++ b/torch/_inductor/codegen/common.py
@@ -19,8 +19,8 @@
 from .. import metrics
 from ..utils import (
     DeferredLineBase,
+    do_bench_using_profiling,
     free_symbol_startswith,
-    get_sympy_Expr_dtype,
     IndentedBuffer,
     sympy_dot,
     sympy_subs,
@@ -560,17 +560,6 @@
     def cpp_argdefs(self):
         from .cpp import DTYPE_TO_CPP, INDEX_TYPE
 
-        # TODO(jansel): replace this with data from scheduler
-        buffer_types = {x.get_name(): x.get_dtype() for x in V.graph.buffers}
-        for name, val in V.graph.graph_inputs.items():
-            if isinstance(val, sympy.Expr):
-                buffer_types[name] = get_sympy_Expr_dtype(val)
-            else:
-                buffer_types[name] = val.get_dtype()
-        buffer_types.update(
-            {name: val.dtype for name, val in V.graph.constants.items()}
-        )
-
         call_args = []
         arg_defs = []
         arg_types = []
@@ -579,7 +568,7 @@
                 continue
             outer = inplaced.other_names[-1]
             inner = inplaced.inner_name
-            dtype = buffer_types[outer]
+            dtype = V.graph.get_dtype(outer)
             cpp_dtype = DTYPE_TO_CPP[dtype]
             arg_defs.append(f"{cpp_dtype}* {inner}")
             call_args.append(self.wrap_ptr_arg(outer, dtype))
@@ -587,7 +576,7 @@
         for outer, inner in self.input_buffers.items():
             if outer in self.inplace_buffers:
                 continue
-            dtype = buffer_types[outer]
+            dtype = V.graph.get_dtype(outer)
             cpp_dtype = DTYPE_TO_CPP[dtype]
             arg_defs.append(f"const {cpp_dtype}* {inner}")
             call_args.append(self.wrap_ptr_arg(outer, dtype))
@@ -595,7 +584,7 @@
         for outer, inner in self.output_buffers.items():
             if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
                 continue
-            dtype = buffer_types[outer]
+            dtype = V.graph.get_dtype(outer)
             cpp_dtype = DTYPE_TO_CPP[dtype]
             arg_defs.append(f"{cpp_dtype}* {inner}")
             call_args.append(self.wrap_ptr_arg(outer, dtype))
@@ -1052,3 +1041,96 @@
 
     # Load uint8 value as float32
     is_load_uint8_as_float: bool = False
+
+
+@functools.lru_cache(None)
+def jinja2_env():
+    try:
+        import jinja2
+
+        return jinja2.Environment(
+            undefined=jinja2.StrictUndefined,
+        )
+    except ImportError:
+        return None
+
+
+class ChoiceCaller:
+    """
+    Represents a possible choice used in autotune_process.py.
+    During autotuning, self.benchmark() is first called to get benchmark result,
+    and if this choice is selected, self.output_node() is called to get the output_node.
+
+    Children classes: TritonTemplateCaller, CUDATemplateCaller.
+    """
+
+    def __init__(self, name, input_nodes, layout):
+        super().__init__()
+        self.name = name
+        self.layout = layout
+        self.input_nodes = input_nodes
+
+    def benchmark(self, *args, out) -> float:
+        algo = self.to_callable()
+        return do_bench_using_profiling(lambda: algo(*args, out=out))
+
+    def call_name(self) -> str:
+        raise NotImplementedError()
+
+    def to_callable(self):
+        raise NotImplementedError()
+
+    def hash_key(self) -> str:
+        raise NotImplementedError()
+
+    def output_node(self) -> "TensorBox":  # type: ignore[name-defined]
+        raise NotImplementedError()
+
+
+class KernelTemplate:
+    """
+    Base class for defining kernel templates.
+
+    Children classes: TritonTemplate, CUDATemplate
+    """
+
+    @staticmethod
+    def _template_from_string(source):
+        env = jinja2_env()
+        if env is not None:
+            return env.from_string(source)
+        return None
+
+    @staticmethod
+    def _fake_get_dtype(fake_out):
+        _get_dtype_real = V.graph.get_dtype
+
+        def get_dtype(name):
+            if name == fake_out.get_name():
+                return fake_out.get_dtype()
+            return _get_dtype_real(name)
+
+        return get_dtype
+
+    def __init__(self, name: str):
+        self.name = name
+
+    def maybe_append_choice(self, choices, **kwargs):
+        """
+        Maybe generates a new ChoiceCaller and appends it into existing choices.
+
+        choices: A list of ChoiceCallers.
+        kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller.
+        """
+
+        try:
+            choices.append(self.generate(**kwargs))
+        except NotImplementedError:
+            pass
+
+    def generate(self, **kwargs) -> ChoiceCaller:
+        """
+        Generates a ChoiceCaller instance from the given arguments.
+        """
+
+        raise NotImplementedError()
diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py
index 8b1da1d..984c975 100644
--- a/torch/_inductor/codegen/cpp.py
+++ b/torch/_inductor/codegen/cpp.py
@@ -382,6 +382,7 @@
             return f"std::max({il})"
 
 
+# A function to print, useful for printing sympy symbols.
 cexpr = CppPrinter().doprint
 
 
diff --git a/torch/_inductor/codegen/cuda/cuda_kernel.py b/torch/_inductor/codegen/cuda/cuda_kernel.py
new file mode 100644
index 0000000..e20a1dd
--- /dev/null
+++ b/torch/_inductor/codegen/cuda/cuda_kernel.py
@@ -0,0 +1,309 @@
+from typing import Dict, List, Optional
+
+from ...autotune_process import CUDABenchmarkRequest
+from ...ir import Callable, CUDATemplateBuffer, IRNode, Layout, TensorBox
+from ...select_algorithm import ChoiceCaller
+from ...utils import sympy_product
+from ...virtualized import V
+
+from ..common import IndentedBuffer, Kernel, OpOverrides
+from ..cpp import CppPrinter, DTYPE_TO_CPP
+
+
+cexpr = CppPrinter().doprint
+
+
+def _normalize_idx(index: int, total_length: int) -> int:
+    return index if index >= 0 else index + total_length
+
+
+class CUDAKernel(Kernel):
+    """
+    Kernels defined by C++ CUDA.
+    """
+
+    overrides = OpOverrides  # type: ignore[assignment]
+
+
+class CUDATemplateKernel(CUDAKernel):
+    """
+    Template kernels defined by C++ CUDA.
+    """
+
+    _EXTRA_CPP_ARGS = "size_t* workspace_size, uint8_t* workspace, cudaStream_t stream"
+
+    def __init__(
+        self,
+        kernel_name,
+    ):
+        super().__init__()
+        self.kernel_name = kernel_name
+        # Mapping from arg name to IRNode.
+        self.named_nodes: Dict[str, IRNode] = {}
+
+    def arg_name(self, node: IRNode) -> Optional[str]:
+        """
+        Returns arg name of a given input or output node.
+        """
+
+        if node is None:
+            return None
+        return {**self.args.input_buffers, **self.args.output_buffers}.get(
+            node.get_name(), None
+        )
+
+    def check_not_null(self, node: IRNode) -> str:
+        """
+        Generates code to check that a node is not null.
+        """
+
+        if node is None:
+            return ""
+
+        size_str = self.size(node, 0, -1)
+        name_str = self.arg_name(node)
+        if name_str is None:
+            return ""
+
+        res = IndentedBuffer(initial_indent=2)
+        res.tabwidth = 1
+        res.splice(
+            f"""
+            {{
+              if (!{name_str}) {{
+                int64_t {name_str}_size = {size_str};
+                if ({name_str}_size > 0) {{
+                  throw std::runtime_error("input {name_str} is null but size is not 0!");
+                }}
+              }}
+            }}
+            """
+        )
+        return res.getvalue()
+
+    def def_kernel(
+        self,
+        inputs: List[IRNode],
+        outputs: List[IRNode],
+        names_str: str = "",
+        input_reorder: Optional[List[int]] = None,
+    ) -> str:
+        """
+        Hook called from template code to generate function def and
+        needed args.
+
+        inputs / outputs: List of input / output IRNodes. Note that IRNode can be None for optional arguments.
+        names_str: Comma separated list of input + output argument names.
+        input_reorder: The actual order of input nodes.
+                       e.g. The template might have input argument defined as [X, W, Bias],
+                       and the actual input passed into this template could be [Bias, X, W].
+                       In this case, the `input_reorder` would be [2, 0, 1].
+        """
+
+        names = [x.strip() for x in names_str.strip().split(",")]
+        if len(inputs) + len(outputs) != len(names):
+            raise RuntimeError(
+                f"{len(inputs) + len(outputs)=} != {len(names)=}, {inputs=}, {outputs=}, {names=}"
+            )
+
+        if input_reorder is not None:
+            assert len(inputs) == len(input_reorder)
+        else:
+            input_reorder = list(range(len(inputs)))
+
+        for idx in input_reorder:
+            name = names[idx]
+            node = inputs[idx]
+            if node is not None:
+                self.named_nodes[name] = node
+                self.args.input_buffers[node.get_name()] = name
+
+        for name, node in zip(names[len(inputs) : len(inputs) + len(outputs)], outputs):
+            if node is not None:
+                self.named_nodes[name] = node
+                self.args.output_buffers[node.get_name()] = name
+
+        arg_defs, *_ = self.args.cpp_argdefs()
+        return f"PT_EXPORT int {self.kernel_name}({', '.join(arg_defs)}, {self._EXTRA_CPP_ARGS})"
+
+    def call_kernel(self, name: str, node: CUDATemplateBuffer) -> None:
+        """
+        Generates code to call the kernel through V.graph.wrapper_code.
+
+        name: Name of kernel function.
+        node: The IRNode which represents the kernel.
+        """
+
+        wrapper = V.graph.wrapper_code
+        _, call_args, _ = self.args.python_argdefs()
+        # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
+        for i in range(len(call_args)):
+            if V.graph.is_unspec_arg(call_args[i]):
+                call_args[i] = call_args[i] + ".item()"
+            else:
+                call_args[i] = f"c_void_p({call_args[i]}.data_ptr())"
+
+        # workspace_size ptr is NULL to mark this call is not intended for retrieving workspace_size.
+        # workspace_size should have already been retrieved prior to this call.
+        call_args.append("None")
+
+        if node.get_workspace_size() > 0:
+            call_args.append(f"c_void_p({node.get_name()}_workspace.data_ptr())")
+        else:
+            call_args.append("None")
+
+        wrapper.generate_kernel_call(
+            name,
+            call_args,
+            device_index=V.graph.scheduler.current_device.index,
+            cuda=True,
+            triton=False,
+        )
+
+    def dtype(self, node: IRNode) -> str:
+        """
+        Generates code which represents dtype of a given node.
+        """
+
+        if node is None:
+            return "void"
+        return DTYPE_TO_CPP.get(node.get_layout().dtype)
+
+    def offset(self, node: IRNode) -> str:
+        """
+        Generates code which represents offset of a given node.
+        """
+
+        if node is None:
+            return "0"
+        return str(node.get_layout().offset)
+
+    def ptr(self, node: IRNode) -> str:
+        """
+        Generates code which represents pointer of a given node.
+        """
+
+        if node is None:
+            return "nullptr"
+        arg_name = self.arg_name(node)
+        if arg_name is None:
+            return "nullptr"
+        offset = self.offset(node)
+        return arg_name if offset == "0" else f"{arg_name} + {offset}"
+
+    def size(
+        self,
+        node: IRNode,
+        start_index: int,
+        end_index: Optional[int] = None,
+        default_value: int = 0,
+    ) -> str:
+        """
+        Hook called from template code to get the size of an arg.
+        Generates code which represents size of a given node in [start_index, end_index).
+        If node is None, returns default_value.
+
+        TODO: Will add needed args to pass it in if it is dynamic.
+        """
+
+        if node is None:
+            return str(default_value)
+
+        start_index = _normalize_idx(start_index, len(node.get_size()))
+        if end_index is None:
+            end_index = start_index
+        end_index = _normalize_idx(end_index, len(node.get_size()))
+
+        sizes = node.get_size()[start_index : end_index + 1]
+        if len(sizes) == 0:
+            return str(default_value)
+
+        val = sympy_product(sizes)
+        return cexpr(self.rename_indexing(val))
+
+    def stride(self, node: IRNode, index: int, default_value: int = 0) -> str:
+        """
+        Hook called from template code to get the stride of an arg.
+        Generates code which represents stride of a given node at index.
+        If node is None, returns default_value.
+
+        TODO: Will add needed args to pass it in if it is dynamic.
+        """
+
+        if node is None:
+            return str(default_value)
+
+        index = _normalize_idx(index, len(node.get_size()))
+        if index < 0:
+            return str(default_value)
+
+        stride = node.get_stride()[index]
+        return cexpr(self.rename_indexing(stride))
+
+    def row_or_column_stride(self, node: IRNode, default_value: int = 0) -> str:
+        """
+        Hook called from template code to get the row or column stride of an arg.
+        This is required by some CUTLASS 2.X APIs.
+        If the node is in row_major, it returns stride[-2].
+        If the node is in column_major, it returns stride[-1].
+
+        TODO: Will add needed args to pass it in if it is dynamic.
+        """
+
+        if node is None or len(node.get_stride()) < 2:
+            return str(default_value)
+
+        stride0 = node.get_stride()[-1]
+        stride1 = node.get_stride()[-2]
+        if stride0 == 1:
+            return cexpr(self.rename_indexing(stride1))
+        elif stride1 == 1:
+            return cexpr(self.rename_indexing(stride0))
+        else:
+            raise RuntimeError(
+                f"At least 1 stride should be 1. Strides: {node.get_stride()=}"
+            )
+
+
+class CUDATemplateCaller(ChoiceCaller):
+    def __init__(
+        self,
+        name: str,
+        category: str,
+        input_nodes: List[IRNode],
+        layout: Layout,
+        make_kernel_render: Callable[[str], str],
+        bmreq: CUDABenchmarkRequest,
+    ):
+        super().__init__(name, input_nodes, layout)
+        self.category = category
+        self.make_kernel_render = make_kernel_render
+        self.bmreq = bmreq
+
+    def benchmark(self, *args, out) -> float:
+        assert self.bmreq is not None
+        return self.bmreq.benchmark(*args, output_tensor=out)
+
+    def __str__(self):
+        return f"CUDATemplateCaller(source_file={self.bmreq.source_file})"
+
+    def call_name(self) -> str:
+        return f"cuda_template_kernels.{self.name}"
+
+    def hash_key(self) -> str:
+        return "-".join(
+            [
+                self.category,
+                self.bmreq.hash_key,
+            ]
+        )
+
+    def output_node(self) -> TensorBox:
+        return TensorBox.create(
+            CUDATemplateBuffer(
+                layout=self.layout,
+                inputs=self.input_nodes,
+                make_kernel_render=self.make_kernel_render,
+                workspace_size=self.bmreq.workspace_size,
+            )
+        )
diff --git a/torch/_inductor/codegen/cuda/cuda_scheduling.py b/torch/_inductor/codegen/cuda/cuda_scheduling.py
new file mode 100644
index 0000000..a017f90
--- /dev/null
+++ b/torch/_inductor/codegen/cuda/cuda_scheduling.py
@@ -0,0 +1,43 @@
+from ... import config
+from ...codecache import code_hash, get_path
+from ...utils import get_fused_kernel_name, get_kernel_metadata
+from ...virtualized import V
+
+from ..common import IndentedBuffer
+from ..triton import TritonScheduling
+
+
+class CUDAScheduling(TritonScheduling):
+    """
+    Final codegen for CUDAKernels.
+    """
+
+    def define_kernel(self, src_code: str, node_schedule) -> str:
+        wrapper = V.graph.wrapper_code
+        if src_code in wrapper.src_to_kernel:
+            kernel_name = wrapper.src_to_kernel[src_code]
+        else:
+            fused_name = (
+                get_fused_kernel_name(node_schedule, config.triton.descriptive_names)
+                if config.triton.descriptive_names
+                else ""
+            )
+            kernel_name = "_".join(["cuda", fused_name, wrapper.next_kernel_suffix()])
+            # use the original src_code as the key
+            wrapper.src_to_kernel[src_code] = kernel_name
+            src_code = src_code.replace("KERNEL_NAME", kernel_name)
+
+            _, _, kernel_path = get_path(code_hash(src_code), "py")
+
+            compile_wrapper = IndentedBuffer()
+            compile_wrapper.writeline("async_compile.cuda(r'''")
+            compile_wrapper.splice(src_code, strip=True)
+            compile_wrapper.writeline("''', 'so')")
+
+            metadata_comment = f"# kernel path: {kernel_path}"
+            origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper)
+            metadata_comment += "\n" + origins + "\n" + detailed_origins
+            wrapper.define_kernel(
+                kernel_name, compile_wrapper.getvalue(), metadata_comment
+            )
+        return kernel_name
diff --git a/torch/_inductor/codegen/cuda/cuda_template.py b/torch/_inductor/codegen/cuda/cuda_template.py
new file mode 100644
index 0000000..765966f
--- /dev/null
+++ b/torch/_inductor/codegen/cuda/cuda_template.py
@@ -0,0 +1,197 @@
+import functools
+import itertools
+import logging
+
+from typing import List, Optional
+from unittest.mock import patch
+
+import sympy
+
+import torch
+
+from ...autotune_process import CUDABenchmarkRequest, TensorMeta
+from ...ir import Buffer, IRNode, Layout
+from ...utils import IndentedBuffer, unique
+from ...virtualized import V
+from ..common import KernelTemplate
+
+from .cuda_kernel import CUDATemplateCaller, CUDATemplateKernel
+
+log = logging.getLogger(__name__)
+
+
+class CUDATemplate(KernelTemplate):
+    index_counter = itertools.count()
+
+    def __init__(
+        self,
+        name: str,
+        input_nodes: List[IRNode],
+        layout: Layout,
+        input_reorder: Optional[List[int]] = None,
+    ):
+        super().__init__(name)
+        self.input_nodes = input_nodes
+        self.output_node = Buffer("buf_out", layout)
+        self.input_reorder = input_reorder
+
+    def generate(self, **kwargs) -> CUDATemplateCaller:
+        kernel_name = f"cuda_{self.name}"
+        with patch.object(
+            V.graph, "get_dtype", self._fake_get_dtype(self.output_node)
+        ), CUDATemplateKernel(
+            kernel_name=kernel_name,
+        ) as kernel:
+            code = self.render(kernel=kernel, **kwargs)
+            _, call_args, _ = kernel.args.python_argdefs()
+            log.debug("Generated Code:\n%s", code)
+            log.debug(
+                "Args: cpp_argdefs: %s, python_argdefs: %s",
+                kernel.args.cpp_argdefs(),
+                kernel.args.python_argdefs(),
+            )
+
+        input_reorder = (
+            self.input_reorder
+            if self.input_reorder is not None
+            else list(range(len(self.input_nodes)))
+        )
+        expected_args = list(
+            unique(self.input_nodes[idx].get_name() for idx in input_reorder)
+        )
+        expected_args.extend([self.output_node.get_name()])
+        assert list(call_args)[: len(expected_args)] == expected_args, (
+            call_args,
+            expected_args,
+        )
+        extra_args = V.graph.sizevars.size_hints(
+            map(sympy.expand, call_args[len(expected_args) :])
+        )
+
+        kernel_hash_name = f"cuda_{self.name}_{next(self.index_counter)}"
+
+        # create the BenchmarkRequest
+        bmreq = CUDABenchmarkRequest(
+            kernel_name=kernel_name,
+            input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes),
+            output_tensor_meta=TensorMeta.from_irnodes(self.output_node),
+            extra_args=extra_args,
+            source_code=code,
+        )
+
+        def make_kernel_render(output_node):
+            kernel = CUDATemplateKernel(
+                kernel_name="KERNEL_NAME",
+            )
+            render = functools.partial(
+                self.render,
+                kernel=kernel,
+                output_node=output_node,
+                **kwargs,
+            )
+            return kernel, render
+
+        return CUDATemplateCaller(
+            kernel_hash_name,
+            self.name,
+            self.input_nodes,
+            self.output_node.get_layout(),
+            make_kernel_render,
+            bmreq,
+        )
+
+    def header(self) -> IndentedBuffer:
+        res = IndentedBuffer()
+        res.splice(
+            """
+                #include <exception>
+                #include <iostream>
+                #include <memory>
+                #include <random>
+                #include <vector>
+            """
+        )
+        return res
+
+    def globals(self) -> IndentedBuffer:
+        res = IndentedBuffer()
+        res.splice(
+            """
+                // We compile all models with -fvisibility=hidden. Any symbols that need to be
+                // exposed in the final shared library must be declared with PT_EXPORT to make
+                // them visible.
+                #ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++)
+                #define PT_EXPORT __attribute__((__visibility__("default")))
+                #else
+                #ifdef _WIN32
+                #define PT_EXPORT __declspec(dllexport)
+                #else
+                #define PT_EXPORT
+                #endif
+                #endif
+                using bfloat16 = nv_bfloat16;
+            """
+        )
+        return res
+
+    def render(self, **kwargs) -> str:
+        raise NotImplementedError
+
+
+class CUTLASSTemplate(CUDATemplate):
+    def header(self) -> IndentedBuffer:
+        res = super().header()
+        res.splice(
+            """
+                #include "cutlass/cutlass.h"
+                #include "cutlass/numeric_types.h"
+                #include "cutlass/util/host_tensor.h"
+                #include "cutlass/util/reference/host/tensor_fill.h"
+                #include "cutlass/util/reference/device/tensor_fill.h"
+                #include "cutlass/util/device_memory.h"
+            """
+        )
+        return res
+
+    def globals(self) -> IndentedBuffer:
+        res = super().globals()
+        res.splice(
+            """
+                #define CUTLASS_CHECK(status)                                                      \\
+                {                                                                                  \\
+                  cutlass::Status error = status;                                                  \\
+                  if (error != cutlass::Status::kSuccess) {                                        \\
+                    auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " +             \\
+                        cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__);        \\
+                    throw std::runtime_error(msg);                                                 \\
+                  }                                                                                \\
+                }
+            """
+        )
+        return res
+
+    def cute_int(self, int_str: str, var_name: str) -> str:
+        res = ""
+        if int_str in {"1", "1L"}:
+            res = "cute::Int<1>{}"
+        else:
+            res = int_str
+
+        return f"{res} /* {var_name} */"
+
+    _DTYPE_TO_CUTLASS = {
+        torch.float32: "float",
+        torch.float64: "double",
+        torch.float16: "cutlass::half_t",
+        torch.int32: "int",
+        torch.int8: "int8_t",
+        torch.uint8: "uint8_t",
+        torch.bool: "bool",
+        torch.bfloat16: "cutlass::bfloat16_t",
+    }
+
+    def cutlass_type_cast(self, node: IRNode, ptr: str) -> str:
+        if node is None:
+            return ptr
+        else:
+            return f"({self._DTYPE_TO_CUTLASS.get(node.get_dtype())}*)({ptr})"
diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py
index d092609..24cd7fc 100644
--- a/torch/_inductor/codegen/triton.py
+++ b/torch/_inductor/codegen/triton.py
@@ -20,7 +20,7 @@
 from .. import config, ir, scheduler
 from ..codecache import code_hash, get_path
 from ..dependencies import MemoryDep, StarDep
-from ..ir import ReductionHint
+from ..ir import IRNode, ReductionHint, TritonTemplateBuffer
 from ..optimize_indexing import indexing_dtype_strength_reduction
 from ..scheduler import BaseScheduling
 from ..triton_heuristics import AutotuneHint
@@ -2061,7 +2061,7 @@
 
         return f"[{', '.join(sizes)}]"
 
-    def call_kernel(self, name: str):
+    def call_kernel(self, name: str, node: IRNode = None):
         wrapper = V.graph.wrapper_code
         _, call_args, _ = self.args.python_argdefs()
         # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
@@ -2086,6 +2086,8 @@
             call_args,
             grid,
             V.graph.scheduler.current_device.index,
+            cuda=True,
+            triton=True,
         )
 
     def warn_mix_layout(self, kernel_name):
@@ -2186,7 +2188,9 @@
                 return False
 
             if node1.is_template():
-                return True  # skip checks for compatible tiling
+                # Only allow fusion for TritonTemplates for now.
+                # Fusion for CUDATemplates are not supported.
+                return isinstance(node1.node, TritonTemplateBuffer)
 
             # check for a bad combined tiling
             tiling1 = self.select_tiling(node1.get_nodes(), numel1, rnumel1)
@@ -2603,11 +2607,16 @@
 
         # finalize must be called after adding epilogue above
         with V.set_kernel_handler(kernel):
-            src_code = partial_code.finalize()
+            # TODO: Maybe unify CUDATemplateKernel to also use PartialRender for flexible epilogue fusion.
+            src_code = (
+                partial_code
+                if isinstance(partial_code, str)
+                else partial_code.finalize()
+            )
             node_schedule = [template_node, *epilogue_nodes]
             kernel_name = self.define_kernel(src_code, node_schedule)
         self.codegen_comment(node_schedule)
-        kernel.call_kernel(kernel_name)
+        kernel.call_kernel(kernel_name, template_node.node)
         V.graph.removed_buffers |= kernel.removed_buffers
         self.scheduler.free_buffers()
 
diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py
index 0cd5484..74eb33b 100644
--- a/torch/_inductor/codegen/wrapper.py
+++ b/torch/_inductor/codegen/wrapper.py
@@ -733,17 +733,36 @@
         stack.enter_context(self.wrapper_call.indent())
 
     def generate_kernel_call(
-        self, name, call_args, grid=None, device_index=None, cuda=True
+        self,
+        name,
+        call_args,
+        grid=None,
+        device_index=None,
+        cuda=True,
+        triton=True,
     ):
+        """
+        Generates kernel call code.
+
+        cuda: Defines whether the backend is GPU. Otherwise the backend is CPU.
+
+        triton: Defines whether the GPU backend uses Triton for codegen.
+                Otherwise it uses the CUDA language for codegen.
+                Only valid when cuda == True.
+        """
         if cuda:
             call_args_str = ", ".join(pexpr(item) for item in call_args)
-            grid_str = ", ".join(pexpr(item) for item in grid)
             stream_name = self.write_get_cuda_stream(
                 V.graph.scheduler.current_device.index
             )
-            self.writeline(
-                f"{name}.run({call_args_str}, grid=grid({grid_str}), stream={stream_name})"
-            )
+            if triton:
+                grid_str = ", ".join(pexpr(item) for item in grid)
+                self.writeline(
+                    f"{name}.run({call_args_str}, grid=grid({grid_str}), stream={stream_name})"
+                )
+            else:
+                stream_ptr = f"c_void_p({stream_name})"
+                self.writeline(f"{name}.{name}({call_args_str}, {stream_ptr})")
         else:
             self.writeline(self.wrap_kernel_call(name, call_args))
 
@@ -823,6 +842,10 @@
         )
 
     def codegen_allocation(self, buffer):
+        assert (
+            buffer.get_workspace_size() == 0
+        ), "Only support zero workspace size for now!"
+
         name = buffer.get_name()
 
         if name in V.graph.removed_buffers or name in self.allocated:
@@ -854,6 +877,10 @@
         )
 
     def codegen_free(self, buffer):
+        assert (
+            buffer.get_workspace_size() == 0
+        ), "Only support zero workspace size for now!"
+
         name = buffer.get_name()
 
         # can be freed but not reused
@@ -1682,11 +1709,11 @@
         return ", ".join(new_args)
 
     def generate_kernel_call(
-        self, name, call_args, grid=None, device_index=None, cuda=True
+        self, name, call_args, grid=None, device_index=None, cuda=True, triton=True
     ):
         if not cuda:
             return super().generate_kernel_call(
-                name, call_args, grid, device_index, cuda
+                name, call_args, grid, device_index, cuda, triton
             )
 
         params = CudaKernelParamCache.get(name)
diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py
index e6e5c8a..bd608dc 100644
--- a/torch/_inductor/ir.py
+++ b/torch/_inductor/ir.py
@@ -307,6 +307,12 @@
     def get_read_names(self):
         return {dep.name for dep in self.get_reads()}
 
+    def get_layout(self):
+        raise NotImplementedError(f"get_layout() is not implemented by {type(self)}!")
+
+    def get_size(self):
+        raise NotImplementedError(f"get_size() is not implemented by {type(self)}!")
+
     def get_numel(self):
         return sympy_product(self.get_size())
 
@@ -1594,6 +1600,9 @@
     def get_dtype(self):
         return self.data.get_dtype()
 
+    def get_layout(self):
+        return self.data.get_layout()
+
     def get_device(self):
         return self.data.get_device()
 
@@ -2574,6 +2583,13 @@
     def realize(self):
         pass
 
+    def get_workspace_size(self):
+        """
+        Gets extra global memory size needed by this buffer.
+        Some algorithms (e.g. group gemm) may require extra global memory in the generated code.
+        """
+        return 0
+
 
 class InputBuffer(Buffer):
     pass
@@ -2912,6 +2928,26 @@
         )
 
 
+class TritonTemplateBuffer(TemplateBuffer):
+    pass
+
+
+class CUDATemplateBuffer(TemplateBuffer):
+    def __init__(
+        self,
+        layout,
+        inputs,
+        make_kernel_render,
+        workspace_size: int = 0,
+    ):
+        super().__init__(layout, inputs, make_kernel_render)
+        # Global memory (in bytes) needed for this template.
+        self.workspace_size = workspace_size
+
+    def get_workspace_size(self):
+        return self.workspace_size if self.workspace_size is not None else 0
+
+
 @dataclasses.dataclass
 class InputsKernel(Buffer):
     inputs: List[Buffer]
@@ -5302,6 +5338,12 @@
     def layout(self):
         return self.data.layout
 
+    def get_layout(self):
+        return self.layout
+
+    def get_size(self):
+        return self.data.get_size()
+
     def __str__(self):
         if isinstance(self.data, MutableBox):
             line0 = f"{type(self).__name__}({type(self.data).__name__}("
diff --git a/torch/_inductor/kernel/bmm.py b/torch/_inductor/kernel/bmm.py
index 9b414b2..f578f36 100644
--- a/torch/_inductor/kernel/bmm.py
+++ b/torch/_inductor/kernel/bmm.py
@@ -95,8 +95,8 @@
         for config in mm_configs(m, n, k):
             bmm_template.maybe_append_choice(
                 choices,
-                (mat1, mat2),
-                layout,
+                input_nodes=(mat1, mat2),
+                layout=layout,
                 **mm_options(config, k, layout),
             )
 
@@ -118,8 +118,8 @@
         for config in mm_configs(m, n, k):
             bmm_template.maybe_append_choice(
                 choices,
-                (inp, mat1, mat2),
-                layout,
+                input_nodes=(inp, mat1, mat2),
+                layout=layout,
                 **mm_options(config, k, layout),
                 prefix_args=1,
                 epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
diff --git a/torch/_inductor/kernel/conv.py b/torch/_inductor/kernel/conv.py
index e0d4124..6b3c387 100644
--- a/torch/_inductor/kernel/conv.py
+++ b/torch/_inductor/kernel/conv.py
@@ -434,8 +434,8 @@
         ):
             conv2d_template.maybe_append_choice(
                 choices,
-                (x, weight),
-                layout,
+                input_nodes=(x, weight),
+                layout=layout,
                 KERNEL_H=kernel_shape[0],
                 KERNEL_W=kernel_shape[1],
                 STRIDE_H=stride[0],
diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py
index 576c195..91f763b 100644
--- a/torch/_inductor/kernel/mm.py
+++ b/torch/_inductor/kernel/mm.py
@@ -9,7 +9,7 @@
     ExternKernelChoice,
     TritonTemplate,
 )
-from ..utils import use_aten_gemm_kernels, use_triton_template
+from ..utils import use_aten_gemm_kernels, use_max_autotune, use_triton_template
 from .mm_common import (
     addmm_epilogue,
     int8_mm_configs,
@@ -116,12 +116,13 @@
 
     # options to tune from
     choices = [aten_mm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else []
+
     if m * n != 0 and use_triton_template(layout):
         for config in mm_configs(m, n, k):
             mm_template.maybe_append_choice(
                 choices,
-                (mat1, mat2),
-                layout,
+                input_nodes=(mat1, mat2),
+                layout=layout,
                 **mm_options(config, k, layout),
             )
 
@@ -142,8 +143,8 @@
         for config in int8_mm_configs(m, n, k):
             mm_template.maybe_append_choice(
                 choices,
-                (mat1, mat2),
-                layout,
+                input_nodes=(mat1, mat2),
+                layout=layout,
                 **mm_options(config, k, layout),
             )
     return autotune_select_algorithm("int_mm", choices, [mat1, mat2], layout)
@@ -154,7 +155,7 @@
     ordered_kwargs_for_cpp_kernel = ("beta", "alpha")
 
     m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout)
-    if m * n == 0 or not use_triton_template(layout):
+    if m * n == 0 or not use_max_autotune():
         choices = (
             [
                 aten_addmm.bind(
@@ -183,8 +184,10 @@
         if use_aten_gemm_kernels()
         else []
     )
+
     if (
-        inp_expanded.get_stride()[0] == 0
+        use_aten_gemm_kernels()
+        and inp_expanded.get_stride()[0] == 0
         and inp_expanded.get_device().type == "cuda"
         and inductor_config.triton.autotune_cublasLt
     ):
@@ -196,15 +199,16 @@
             ),
         )
 
-    for config in mm_configs(m, n, k):
-        mm_template.maybe_append_choice(
-            choices,
-            (inp_expanded, mat1, mat2),
-            layout,
-            **mm_options(config, k, layout),
-            prefix_args=1,
-            epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
-        )
+    if use_triton_template(layout):
+        for config in mm_configs(m, n, k):
+            mm_template.maybe_append_choice(
+                choices,
+                input_nodes=(inp_expanded, mat1, mat2),
+                layout=layout,
+                **mm_options(config, k, layout),
+                prefix_args=1,
+                epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
+            )
 
     return autotune_select_algorithm(
         "addmm", choices, [inp_expanded, mat1, mat2], layout
@@ -231,8 +235,8 @@
     for config in mm_configs(m, n, k, has_int8_tensor=has_int8_tensor):
         mm_template.maybe_append_choice(
             choices,
-            (mat1, mat2),
-            layout,
+            input_nodes=(mat1, mat2),
+            layout=layout,
             **mm_options(config, k, layout, b_prologue_cast_type),
         )
     return autotune_select_algorithm("mixed_mm", choices, [mat1, mat2], layout)
diff --git a/torch/_inductor/kernel/mm_plus_mm.py b/torch/_inductor/kernel/mm_plus_mm.py
index 40564b3..0d1a780 100644
--- a/torch/_inductor/kernel/mm_plus_mm.py
+++ b/torch/_inductor/kernel/mm_plus_mm.py
@@ -229,8 +229,8 @@
             if config.kwargs["BLOCK_K"] < k1:
                 mm_plus_mm_template.maybe_append_choice(
                     choices,
-                    (mat1, mat2, mat3, mat4),
-                    layout1,
+                    input_nodes=(mat1, mat2, mat3, mat4),
+                    layout=layout1,
                     **mm_options(config, k1, layout1),
                 )
 
diff --git a/torch/_inductor/kernel/unpack_mixed_mm.py b/torch/_inductor/kernel/unpack_mixed_mm.py
index a8da948..be57170 100644
--- a/torch/_inductor/kernel/unpack_mixed_mm.py
+++ b/torch/_inductor/kernel/unpack_mixed_mm.py
@@ -75,8 +75,8 @@
     for config in mm_configs(m, n, k):
         uint4x2_mixed_mm_template.maybe_append_choice(
             choices,
-            (mat1, mat2),
-            layout,
+            input_nodes=(mat1, mat2),
+            layout=layout,
             **mm_options(config, k, layout, b_prologue_cast_type),
         )
     return autotune_select_algorithm("uint4x2_mixed_mm", choices, [mat1, mat2], layout)
diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py
index c442333..07fc83b 100644
--- a/torch/_inductor/scheduler.py
+++ b/torch/_inductor/scheduler.py
@@ -1772,7 +1772,12 @@
 
             if node.is_template():
                 node, *epilogue = node.get_nodes()
-                self.get_backend(device).codegen_template(node, epilogue)
+                if isinstance(node.node, ir.CUDATemplateBuffer):
+                    from .codegen.cuda.cuda_scheduling import CUDAScheduling
+
+                    CUDAScheduling(self).codegen_template(node, epilogue)
+                else:
+                    self.get_backend(device).codegen_template(node, epilogue)
             elif node.is_extern():
                 self.codegen_extern_call(node)
             elif node.is_foreach():
diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py
index c527377..4d8e7ff 100644
--- a/torch/_inductor/select_algorithm.py
+++ b/torch/_inductor/select_algorithm.py
@@ -20,13 +20,18 @@
 from . import config, ir
 from .autotune_process import TensorMeta, TritonBenchmarkRequest
 from .codecache import code_hash, PersistentCache, PyCodeCache
-
-from .codegen.common import IndentedBuffer
+from .codegen.common import ChoiceCaller, IndentedBuffer, KernelTemplate
+from .codegen.cuda.cuda_kernel import CUDATemplateCaller
 from .codegen.triton import texpr, TritonKernel, TritonPrinter, TritonScheduling
-
 from .codegen.triton_utils import config_of, signature_to_meta
-
-from .utils import do_bench, Placeholder, sympy_dot, sympy_product, unique
+from .exc import CUDACompileError
+from .utils import (
+    do_bench_using_profiling,
+    Placeholder,
+    sympy_dot,
+    sympy_product,
+    unique,
+)
 from .virtualized import V
 
 log = logging.getLogger(__name__)
@@ -338,7 +343,7 @@
         self.body.clear()
         self.indexing_code.clear()
 
-    def call_kernel(self, name: str):
+    def call_kernel(self, name: str, node: ir.TritonTemplateBuffer):
         wrapper = V.graph.wrapper_code
         _, call_args, _ = self.args.python_argdefs()
         call_args = [str(a) for a in call_args]
@@ -385,54 +390,18 @@
         return None
 
 
-class TritonTemplate:
+class TritonTemplate(KernelTemplate):
     index_counter = itertools.count()
     all_templates: Dict[str, "TritonTemplate"] = dict()
 
-    @staticmethod
-    def _template_from_string(source):
-        env = _jinja2_env()
-        if env is not None:
-            return env.from_string(source)
-        return None
-
     def __init__(self, name: str, grid: Any, source: str, debug=False):
-        super().__init__()
-        self.name = name
+        super().__init__(name)
         self.grid = grid
         self.template = self._template_from_string(source)
         assert name not in self.all_templates, "duplicate template name"
         self.all_templates[name] = self
         self.debug = debug
 
-    def maybe_append_choice(
-        self,
-        choices,
-        input_nodes,
-        layout,
-        num_stages,
-        num_warps,
-        prefix_args=0,
-        suffix_args=0,
-        epilogue_fn=identity,
-        **kwargs,
-    ):
-        try:
-            choices.append(
-                self.generate(
-                    input_nodes=input_nodes,
-                    layout=layout,
-                    num_stages=num_stages,
-                    num_warps=num_warps,
-                    prefix_args=prefix_args,
-                    suffix_args=suffix_args,
-                    epilogue_fn=epilogue_fn,
-                    **kwargs,
-                )
-            )
-        except NotImplementedError:
-            pass
-
     def generate(
         self,
         input_nodes,
@@ -474,7 +443,7 @@
             index_dtype="tl.int32",
         )
         with patch.object(
-            V.graph, "get_dtype", self.fake_get_dtype(fake_out)
+            V.graph, "get_dtype", self._fake_get_dtype(fake_out)
         ), TritonTemplateKernel(
             kernel_name=kernel_name,
             output_node=fake_out,
@@ -534,14 +503,14 @@
         grid = self.grid(*V.graph.sizevars.size_hints(layout.size), kwargs)
         bmreq = TritonBenchmarkRequest(
             module_path=mod.__file__,
-            input_tensor_meta=TensorMeta.from_irnodes(input_nodes),
-            output_tensor_meta=TensorMeta.from_irnodes(layout),
             module_cache_key=mod.key,
             kernel_name=kernel_name,
             grid=grid,
             extra_args=extra_args,
             num_stages=num_stages,
             num_warps=num_warps,
+            input_tensor_meta=TensorMeta.from_irnodes(input_nodes),
+            output_tensor_meta=TensorMeta.from_irnodes(layout),
         )
 
         return TritonTemplateCaller(
@@ -553,17 +522,6 @@
             bmreq,
         )
 
-    @staticmethod
-    def fake_get_dtype(fake_out):
-        _get_dtype_real = V.graph.get_dtype
-
-        def get_dtype(name):
-            if name == fake_out.get_name():
-                return fake_out.get_dtype()
-            return _get_dtype_real(name)
-
-        return get_dtype
-
 
 class ExternKernelChoice:
     def __init__(
@@ -610,30 +568,6 @@
         )
 
 
-class ChoiceCaller:
-    def __init__(self, name, input_nodes, layout):
-        super().__init__()
-        self.name = name
-        self.layout = layout
-        self.input_nodes = input_nodes
-
-    def benchmark(self, *args, out):
-        algo = self.to_callable()
-        return do_bench(lambda: algo(*args, out=out))
-
-    def call_name(self):
-        raise NotImplementedError()
-
-    def to_callable(self):
-        raise NotImplementedError()
-
-    def hash_key(self):
-        raise NotImplementedError()
-
-    def output_node(self):
-        raise NotImplementedError()
-
-
 class TritonTemplateCaller(ChoiceCaller):
     def __init__(
         self, name, input_nodes, layout, make_kernel_render, debug_extra, bmreq
@@ -663,7 +597,7 @@
 
     def output_node(self):
         return ir.TensorBox.create(
-            ir.TemplateBuffer(
+            ir.TritonTemplateBuffer(
                 layout=self.layout,
                 inputs=self.input_nodes,
                 make_kernel_render=self.make_kernel_render,
@@ -699,7 +633,7 @@
                 out_new, tuple(out.size()), tuple(out.stride())
             )
             out.copy_(out_new)  # for correctness checking
-            return do_bench(lambda: algo(*args))
+            return do_bench_using_profiling(lambda: algo(*args))
 
     def to_callable(self):
         fn = self.choice.to_callable()
@@ -766,9 +700,12 @@
                 "No choices to select, please consider adding ATEN into max_autotune_gemm_backends "
                 "config (defined in torch/_inductor/config.py) to allow at least one choice. "
             )
+        log.info("Max autotune selects from %s choices.", str(len(choices)))
 
         if len(choices) == 1:
-            return choices[0].output_node()
+            if not isinstance(choices[0], CUDATemplateCaller):
+                # CUDATemplateCaller still needs to go through autotuning process to retrieve workspace size.
+                return choices[0].output_node()
 
         @functools.lru_cache(None)
         def make_benchmark_fn():
@@ -780,6 +717,11 @@
                 timing = benchmark_fn(
                     choice,
                 )
+            except CUDACompileError as e:
+                log.warning(
+                    "CUDA compilation error: \n%s. \nIgnore this choice.", str(e)
+                )
+                return float("inf")
             except RuntimeError as e:
                 msg = str(e)
                 if "invalid argument" in msg:
@@ -813,8 +755,14 @@
 
         if make_benchmark_fn.cache_info().currsize:
             counters["inductor"]["select_algorithm_autotune"] += 1
+        if (
+            make_benchmark_fn.cache_info().currsize
+            or log.getEffectiveLevel() == logging.DEBUG
+        ):
             self.log_results(name, input_nodes, timings, autotune_elapse)
-        return builtins.min(timings, key=timings.__getitem__).output_node()
+        selected_choice = builtins.min(timings, key=timings.__getitem__).output_node()
+        log.debug("selected choice: %s", str(selected_choice))
+        return selected_choice
 
     @classmethod
     def make_benchmark_fn(
@@ -922,7 +870,8 @@
                 for n in input_nodes
             ]
         )
-        top_k = sorted(timings, key=timings.__getitem__)[:10]
+        n = None if log.getEffectiveLevel() == logging.DEBUG else 10
+        top_k = sorted(timings, key=timings.__getitem__)[:n]
         best = top_k[0]
         best_time = timings[best]
         sys.stderr.write(f"AUTOTUNE {name}({sizes})\n")
diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py
index 47371ee..6974566 100644
--- a/torch/_inductor/utils.py
+++ b/torch/_inductor/utils.py
@@ -734,25 +734,38 @@
     return True
 
 
-def use_triton_template(layout, *, enable_int32=False):
-    layout_dtypes = [torch.float16, torch.bfloat16, torch.float32]
-    if enable_int32:
-        layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32]
+def use_max_autotune() -> bool:
     return (
-        (
-            config.max_autotune
-            or config.max_autotune_gemm
-            or config.search_autotune_cache
-        )
-        and "TRITON" in config.max_autotune_gemm_backends.upper().split(",")
+        config.max_autotune or config.max_autotune_gemm or config.search_autotune_cache
+    )
+
+
+def _use_template_for_cuda(layout, allowed_layout_dtypes: List[torch.dtype]) -> bool:
+    return (
+        use_max_autotune()
         and layout.device.type == "cuda"
-        and layout.dtype in layout_dtypes
+        and layout.dtype in allowed_layout_dtypes
         and is_big_gpu(layout.device.index or 0)
     )
 
 
+def _use_autotune_backend(backend: str) -> bool:
+    return backend.upper() in [
+        x.strip() for x in config.max_autotune_gemm_backends.upper().split(",")
+    ]
+
+
+def use_triton_template(layout, *, enable_int32=False):
+    layout_dtypes = [torch.float16, torch.bfloat16, torch.float32]
+    if enable_int32:
+        layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32]
+    return _use_template_for_cuda(layout, layout_dtypes) and _use_autotune_backend(
+        "TRITON"
+    )
+
+
 def use_aten_gemm_kernels():
-    return "ATEN" in config.max_autotune_gemm_backends.upper().split(",")
+    return not use_max_autotune() or _use_autotune_backend("ATEN")
 
 
 class DebugDirManager: