[Inductor CUTLASS backend] Step 4: CUDA (template) kernels (#107931)
This is the step 4 to add cutlass as an alternative inductor backend.
Full tests can be found from the last PR in the stack.
Feature request: https://github.com/pytorch/pytorch/issues/106991.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107931
Approved by: https://github.com/aakhundov, https://github.com/jansel, https://github.com/kadeng
ghstack dependencies: #107802, #107847, #107901
diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py
index d39e703..ce46477 100644
--- a/torch/_inductor/autotune_process.py
+++ b/torch/_inductor/autotune_process.py
@@ -17,6 +17,7 @@
from torch._inductor.codecache import CUDACodeCache, DLLWrapper, PyCodeCache
if TYPE_CHECKING:
+ from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller
from torch._inductor.select_algorithm import TritonTemplateCaller
from .utils import do_bench_using_profiling
@@ -379,7 +380,7 @@
def benchmark_in_sub_process(
- choice: "TritonTemplateCaller",
+ choice: "Union[TritonTemplateCaller, CUDATemplateCaller]",
) -> float:
"""
Do benchmarking in subprocess and return the perf number (latency).
diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py
index 312f711..59105c5 100644
--- a/torch/_inductor/codegen/common.py
+++ b/torch/_inductor/codegen/common.py
@@ -19,8 +19,8 @@
from .. import metrics
from ..utils import (
DeferredLineBase,
+ do_bench_using_profiling,
free_symbol_startswith,
- get_sympy_Expr_dtype,
IndentedBuffer,
sympy_dot,
sympy_subs,
@@ -560,17 +560,6 @@
def cpp_argdefs(self):
from .cpp import DTYPE_TO_CPP, INDEX_TYPE
- # TODO(jansel): replace this with data from scheduler
- buffer_types = {x.get_name(): x.get_dtype() for x in V.graph.buffers}
- for name, val in V.graph.graph_inputs.items():
- if isinstance(val, sympy.Expr):
- buffer_types[name] = get_sympy_Expr_dtype(val)
- else:
- buffer_types[name] = val.get_dtype()
- buffer_types.update(
- {name: val.dtype for name, val in V.graph.constants.items()}
- )
-
call_args = []
arg_defs = []
arg_types = []
@@ -579,7 +568,7 @@
continue
outer = inplaced.other_names[-1]
inner = inplaced.inner_name
- dtype = buffer_types[outer]
+ dtype = V.graph.get_dtype(outer)
cpp_dtype = DTYPE_TO_CPP[dtype]
arg_defs.append(f"{cpp_dtype}* {inner}")
call_args.append(self.wrap_ptr_arg(outer, dtype))
@@ -587,7 +576,7 @@
for outer, inner in self.input_buffers.items():
if outer in self.inplace_buffers:
continue
- dtype = buffer_types[outer]
+ dtype = V.graph.get_dtype(outer)
cpp_dtype = DTYPE_TO_CPP[dtype]
arg_defs.append(f"const {cpp_dtype}* {inner}")
call_args.append(self.wrap_ptr_arg(outer, dtype))
@@ -595,7 +584,7 @@
for outer, inner in self.output_buffers.items():
if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
continue
- dtype = buffer_types[outer]
+ dtype = V.graph.get_dtype(outer)
cpp_dtype = DTYPE_TO_CPP[dtype]
arg_defs.append(f"{cpp_dtype}* {inner}")
call_args.append(self.wrap_ptr_arg(outer, dtype))
@@ -1052,3 +1041,96 @@
# Load uint8 value as float32
is_load_uint8_as_float: bool = False
+
+
+@functools.lru_cache(None)
+def jinja2_env():
+ try:
+ import jinja2
+
+ return jinja2.Environment(
+ undefined=jinja2.StrictUndefined,
+ )
+ except ImportError:
+ return None
+
+
+class ChoiceCaller:
+ """
+ Represents a possible choice used in autotune_process.py.
+ During autotuning, self.benchmark() is first called to get benchmark result,
+ and if this choice is selected, self.output_node() is called to get the output_node.
+
+ Children classes: TritonTemplateCaller, CUDATemplateCaller.
+ """
+
+ def __init__(self, name, input_nodes, layout):
+ super().__init__()
+ self.name = name
+ self.layout = layout
+ self.input_nodes = input_nodes
+
+ def benchmark(self, *args, out) -> float:
+ algo = self.to_callable()
+ return do_bench_using_profiling(lambda: algo(*args, out=out))
+
+ def call_name(self) -> str:
+ raise NotImplementedError()
+
+ def to_callable(self):
+ raise NotImplementedError()
+
+ def hash_key(self) -> str:
+ raise NotImplementedError()
+
+ def output_node(self) -> "TensorBox": # type: ignore[name-defined]
+ raise NotImplementedError()
+
+
+class KernelTemplate:
+ """
+ Base class for defining kernel templates.
+
+ Children classes: TritonTemplate, CUDATemplate
+ """
+
+ @staticmethod
+ def _template_from_string(source):
+ env = jinja2_env()
+ if env is not None:
+ return env.from_string(source)
+ return None
+
+ @staticmethod
+ def _fake_get_dtype(fake_out):
+ _get_dtype_real = V.graph.get_dtype
+
+ def get_dtype(name):
+ if name == fake_out.get_name():
+ return fake_out.get_dtype()
+ return _get_dtype_real(name)
+
+ return get_dtype
+
+ def __init__(self, name: str):
+ self.name = name
+
+ def maybe_append_choice(self, choices, **kwargs):
+ """
+ Maybe generates a new ChoiceCaller and appends it into existing choices.
+
+ choices: A list of ChoiceCallers.
+ kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller.
+ """
+
+ try:
+ choices.append(self.generate(**kwargs))
+ except NotImplementedError:
+ pass
+
+ def generate(self, **kwargs) -> ChoiceCaller:
+ """
+ Generates a ChoiceCaller instance from the given arguments.
+ """
+
+ raise NotImplementedError()
diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py
index 8b1da1d..984c975 100644
--- a/torch/_inductor/codegen/cpp.py
+++ b/torch/_inductor/codegen/cpp.py
@@ -382,6 +382,7 @@
return f"std::max({il})"
+# A function to print, useful for printing sympy symbols.
cexpr = CppPrinter().doprint
diff --git a/torch/_inductor/codegen/cuda/cuda_kernel.py b/torch/_inductor/codegen/cuda/cuda_kernel.py
new file mode 100644
index 0000000..e20a1dd
--- /dev/null
+++ b/torch/_inductor/codegen/cuda/cuda_kernel.py
@@ -0,0 +1,309 @@
+from typing import Dict, List, Optional
+
+from ...autotune_process import CUDABenchmarkRequest
+from ...ir import Callable, CUDATemplateBuffer, IRNode, Layout, TensorBox
+from ...select_algorithm import ChoiceCaller
+from ...utils import sympy_product
+from ...virtualized import V
+
+from ..common import IndentedBuffer, Kernel, OpOverrides
+from ..cpp import CppPrinter, DTYPE_TO_CPP
+
+
+cexpr = CppPrinter().doprint
+
+
+def _normalize_idx(index: int, total_length: int) -> int:
+ return index if index >= 0 else index + total_length
+
+
+class CUDAKernel(Kernel):
+ """
+ Kernels defined by C++ CUDA.
+ """
+
+ overrides = OpOverrides # type: ignore[assignment]
+
+
+class CUDATemplateKernel(CUDAKernel):
+ """
+ Template kernels defined by C++ CUDA.
+ """
+
+ _EXTRA_CPP_ARGS = "size_t* workspace_size, uint8_t* workspace, cudaStream_t stream"
+
+ def __init__(
+ self,
+ kernel_name,
+ ):
+ super().__init__()
+ self.kernel_name = kernel_name
+ # Mapping from arg name to IRNode.
+ self.named_nodes: Dict[str, IRNode] = {}
+
+ def arg_name(self, node: IRNode) -> Optional[str]:
+ """
+ Returns arg name of a given input or output node.
+ """
+
+ if node is None:
+ return None
+ return {**self.args.input_buffers, **self.args.output_buffers}.get(
+ node.get_name(), None
+ )
+
+ def check_not_null(self, node: IRNode) -> str:
+ """
+ Generates code to check that a node is not null.
+ """
+
+ if node is None:
+ return ""
+
+ size_str = self.size(node, 0, -1)
+ name_str = self.arg_name(node)
+ if name_str is None:
+ return ""
+
+ res = IndentedBuffer(initial_indent=2)
+ res.tabwidth = 1
+ res.splice(
+ f"""
+ {{
+ if (!{name_str}) {{
+ int64_t {name_str}_size = {size_str};
+ if ({name_str}_size > 0) {{
+ throw std::runtime_error("input {name_str} is null but size is not 0!");
+ }}
+ }}
+ }}
+ """
+ )
+ return res.getvalue()
+
+ def def_kernel(
+ self,
+ inputs: List[IRNode],
+ outputs: List[IRNode],
+ names_str: str = "",
+ input_reorder: Optional[List[int]] = None,
+ ) -> str:
+ """
+ Hook called from template code to generate function def and
+ needed args.
+
+ inputs / outputs: List of input / output IRNodes. Note that IRNode can be None for optional arguments.
+ names_str: Comma separated list of input + output argument names.
+ input_reorder: The actual order of input nodes.
+ e.g. The template might have input argument defined as [X, W, Bias],
+ and the actual input passed into this template could be [Bias, X, W].
+ In this case, the `input_reorder` would be [2, 0, 1].
+ """
+
+ names = [x.strip() for x in names_str.strip().split(",")]
+ if len(inputs) + len(outputs) != len(names):
+ raise RuntimeError(
+ f"{len(inputs) + len(outputs)=} != {len(names)=}, {inputs=}, {outputs=}, {names=}"
+ )
+
+ if input_reorder is not None:
+ assert len(inputs) == len(input_reorder)
+ else:
+ input_reorder = list(range(len(inputs)))
+
+ for idx in input_reorder:
+ name = names[idx]
+ node = inputs[idx]
+ if node is not None:
+ self.named_nodes[name] = node
+ self.args.input_buffers[node.get_name()] = name
+
+ for name, node in zip(names[len(inputs) : len(inputs) + len(outputs)], outputs):
+ if node is not None:
+ self.named_nodes[name] = node
+ self.args.output_buffers[node.get_name()] = name
+
+ arg_defs, *_ = self.args.cpp_argdefs()
+ return f"PT_EXPORT int {self.kernel_name}({', '.join(arg_defs)}, {self._EXTRA_CPP_ARGS})"
+
+ def call_kernel(self, name: str, node: CUDATemplateBuffer) -> None:
+ """
+ Generates code to call the kernel through V.graph.wrapper_code.
+
+ name: Name of kernel function.
+ node: The IRNode which represents the kernel.
+ """
+
+ wrapper = V.graph.wrapper_code
+ _, call_args, _ = self.args.python_argdefs()
+ # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
+ for i in range(len(call_args)):
+ if V.graph.is_unspec_arg(call_args[i]):
+ call_args[i] = call_args[i] + ".item()"
+ else:
+ call_args[i] = f"c_void_p({call_args[i]}.data_ptr())"
+
+ # workspace_size ptr is NULL to mark this call is not intended for retrieving workspace_size.
+ # workspace_size should have already been retrieved prior to this call.
+ call_args.append("None")
+
+ if node.get_workspace_size() > 0:
+ call_args.append(f"c_void_p({node.get_name()}_workspace.data_ptr())")
+ else:
+ call_args.append("None")
+
+ wrapper.generate_kernel_call(
+ name,
+ call_args,
+ device_index=V.graph.scheduler.current_device.index,
+ cuda=True,
+ triton=False,
+ )
+
+ def dtype(self, node: IRNode) -> str:
+ """
+ Generates code which represents dtype of a given node.
+ """
+
+ if node is None:
+ return "void"
+ return DTYPE_TO_CPP.get(node.get_layout().dtype)
+
+ def offset(self, node: IRNode) -> str:
+ """
+ Generates code which represents offset of a given node.
+ """
+
+ if node is None:
+ return "0"
+ return str(node.get_layout().offset)
+
+ def ptr(self, node: IRNode) -> str:
+ """
+ Generates code which represents pointer of a given node.
+ """
+
+ if node is None:
+ return "nullptr"
+ arg_name = self.arg_name(node)
+ if arg_name is None:
+ return "nullptr"
+ offset = self.offset(node)
+ return arg_name if offset == "0" else f"{arg_name} + {offset}"
+
+ def size(
+ self,
+ node: IRNode,
+ start_index: int,
+ end_index: Optional[int] = None,
+ default_value: int = 0,
+ ) -> str:
+ """
+ Hook called from template code to get the size of an arg.
+ Generates code which represents size of a given node in [start_index, end_index).
+ If node is None, returns default_value.
+
+ TODO: Will add needed args to pass it in if it is dynamic.
+ """
+
+ if node is None:
+ return str(default_value)
+
+ start_index = _normalize_idx(start_index, len(node.get_size()))
+ if end_index is None:
+ end_index = start_index
+ end_index = _normalize_idx(end_index, len(node.get_size()))
+
+ sizes = node.get_size()[start_index : end_index + 1]
+ if len(sizes) == 0:
+ return str(default_value)
+
+ val = sympy_product(sizes)
+ return cexpr(self.rename_indexing(val))
+
+ def stride(self, node: IRNode, index: int, default_value: int = 0) -> str:
+ """
+ Hook called from template code to get the stride of an arg.
+ Generates code which represents stride of a given node at index.
+ If node is None, returns default_value.
+
+ TODO: Will add needed args to pass it in if it is dynamic.
+ """
+
+ if node is None:
+ return str(default_value)
+
+ index = _normalize_idx(index, len(node.get_size()))
+ if index < 0:
+ return str(default_value)
+
+ stride = node.get_stride()[index]
+ return cexpr(self.rename_indexing(stride))
+
+ def row_or_column_stride(self, node: IRNode, default_value: int = 0) -> str:
+ """
+ Hook called from template code to get the row or column stride of an arg.
+ This is required by some CUTLASS 2.X APIs.
+ If the node is in row_major, it returns stride[-2].
+ If the node is in column_major, it returns stride[-1].
+
+ TODO: Will add needed args to pass it in if it is dynamic.
+ """
+
+ if node is None or len(node.get_stride()) < 2:
+ return str(default_value)
+
+ stride0 = node.get_stride()[-1]
+ stride1 = node.get_stride()[-2]
+ if stride0 == 1:
+ return cexpr(self.rename_indexing(stride1))
+ elif stride1 == 1:
+ return cexpr(self.rename_indexing(stride0))
+ else:
+ raise RuntimeError(
+ f"At least 1 stride should be 1. Strides: {node.get_stride()=}"
+ )
+
+
+class CUDATemplateCaller(ChoiceCaller):
+ def __init__(
+ self,
+ name: str,
+ category: str,
+ input_nodes: List[IRNode],
+ layout: Layout,
+ make_kernel_render: Callable[[str], str],
+ bmreq: CUDABenchmarkRequest,
+ ):
+ super().__init__(name, input_nodes, layout)
+ self.category = category
+ self.make_kernel_render = make_kernel_render
+ self.bmreq = bmreq
+
+ def benchmark(self, *args, out) -> float:
+ assert self.bmreq is not None
+ return self.bmreq.benchmark(*args, output_tensor=out)
+
+ def __str__(self):
+ return f"CUDATemplateCaller(source_file={self.bmreq.source_file})"
+
+ def call_name(self) -> str:
+ return f"cuda_template_kernels.{self.name}"
+
+ def hash_key(self) -> str:
+ return "-".join(
+ [
+ self.category,
+ self.bmreq.hash_key,
+ ]
+ )
+
+ def output_node(self) -> TensorBox:
+ return TensorBox.create(
+ CUDATemplateBuffer(
+ layout=self.layout,
+ inputs=self.input_nodes,
+ make_kernel_render=self.make_kernel_render,
+ workspace_size=self.bmreq.workspace_size,
+ )
+ )
diff --git a/torch/_inductor/codegen/cuda/cuda_scheduling.py b/torch/_inductor/codegen/cuda/cuda_scheduling.py
new file mode 100644
index 0000000..a017f90
--- /dev/null
+++ b/torch/_inductor/codegen/cuda/cuda_scheduling.py
@@ -0,0 +1,43 @@
+from ... import config
+from ...codecache import code_hash, get_path
+from ...utils import get_fused_kernel_name, get_kernel_metadata
+from ...virtualized import V
+
+from ..common import IndentedBuffer
+from ..triton import TritonScheduling
+
+
+class CUDAScheduling(TritonScheduling):
+ """
+ Final codegen for CUDAKernels.
+ """
+
+ def define_kernel(self, src_code: str, node_schedule) -> str:
+ wrapper = V.graph.wrapper_code
+ if src_code in wrapper.src_to_kernel:
+ kernel_name = wrapper.src_to_kernel[src_code]
+ else:
+ fused_name = (
+ get_fused_kernel_name(node_schedule, config.triton.descriptive_names)
+ if config.triton.descriptive_names
+ else ""
+ )
+ kernel_name = "_".join(["cuda", fused_name, wrapper.next_kernel_suffix()])
+ # use the original src_code as the key
+ wrapper.src_to_kernel[src_code] = kernel_name
+ src_code = src_code.replace("KERNEL_NAME", kernel_name)
+
+ _, _, kernel_path = get_path(code_hash(src_code), "py")
+
+ compile_wrapper = IndentedBuffer()
+ compile_wrapper.writeline("async_compile.cuda(r'''")
+ compile_wrapper.splice(src_code, strip=True)
+ compile_wrapper.writeline("''', 'so')")
+
+ metadata_comment = f"# kernel path: {kernel_path}"
+ origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper)
+ metadata_comment += "\n" + origins + "\n" + detailed_origins
+ wrapper.define_kernel(
+ kernel_name, compile_wrapper.getvalue(), metadata_comment
+ )
+ return kernel_name
diff --git a/torch/_inductor/codegen/cuda/cuda_template.py b/torch/_inductor/codegen/cuda/cuda_template.py
new file mode 100644
index 0000000..765966f
--- /dev/null
+++ b/torch/_inductor/codegen/cuda/cuda_template.py
@@ -0,0 +1,197 @@
+import functools
+import itertools
+import logging
+
+from typing import List, Optional
+from unittest.mock import patch
+
+import sympy
+
+import torch
+
+from ...autotune_process import CUDABenchmarkRequest, TensorMeta
+from ...ir import Buffer, IRNode, Layout
+from ...utils import IndentedBuffer, unique
+from ...virtualized import V
+from ..common import KernelTemplate
+
+from .cuda_kernel import CUDATemplateCaller, CUDATemplateKernel
+
+log = logging.getLogger(__name__)
+
+
+class CUDATemplate(KernelTemplate):
+ index_counter = itertools.count()
+
+ def __init__(
+ self,
+ name: str,
+ input_nodes: List[IRNode],
+ layout: Layout,
+ input_reorder: Optional[List[int]] = None,
+ ):
+ super().__init__(name)
+ self.input_nodes = input_nodes
+ self.output_node = Buffer("buf_out", layout)
+ self.input_reorder = input_reorder
+
+ def generate(self, **kwargs) -> CUDATemplateCaller:
+ kernel_name = f"cuda_{self.name}"
+ with patch.object(
+ V.graph, "get_dtype", self._fake_get_dtype(self.output_node)
+ ), CUDATemplateKernel(
+ kernel_name=kernel_name,
+ ) as kernel:
+ code = self.render(kernel=kernel, **kwargs)
+ _, call_args, _ = kernel.args.python_argdefs()
+ log.debug("Generated Code:\n%s", code)
+ log.debug(
+ "Args: cpp_argdefs: %s, python_argdefs: %s",
+ kernel.args.cpp_argdefs(),
+ kernel.args.python_argdefs(),
+ )
+
+ input_reorder = (
+ self.input_reorder
+ if self.input_reorder is not None
+ else list(range(len(self.input_nodes)))
+ )
+ expected_args = list(
+ unique(self.input_nodes[idx].get_name() for idx in input_reorder)
+ )
+ expected_args.extend([self.output_node.get_name()])
+ assert list(call_args)[: len(expected_args)] == expected_args, (
+ call_args,
+ expected_args,
+ )
+ extra_args = V.graph.sizevars.size_hints(
+ map(sympy.expand, call_args[len(expected_args) :])
+ )
+
+ kernel_hash_name = f"cuda_{self.name}_{next(self.index_counter)}"
+
+ # create the BenchmarkRequest
+ bmreq = CUDABenchmarkRequest(
+ kernel_name=kernel_name,
+ input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes),
+ output_tensor_meta=TensorMeta.from_irnodes(self.output_node),
+ extra_args=extra_args,
+ source_code=code,
+ )
+
+ def make_kernel_render(output_node):
+ kernel = CUDATemplateKernel(
+ kernel_name="KERNEL_NAME",
+ )
+ render = functools.partial(
+ self.render,
+ kernel=kernel,
+ output_node=output_node,
+ **kwargs,
+ )
+ return kernel, render
+
+ return CUDATemplateCaller(
+ kernel_hash_name,
+ self.name,
+ self.input_nodes,
+ self.output_node.get_layout(),
+ make_kernel_render,
+ bmreq,
+ )
+
+ def header(self) -> IndentedBuffer:
+ res = IndentedBuffer()
+ res.splice(
+ """
+ #include <exception>
+ #include <iostream>
+ #include <memory>
+ #include <random>
+ #include <vector>
+ """
+ )
+ return res
+
+ def globals(self) -> IndentedBuffer:
+ res = IndentedBuffer()
+ res.splice(
+ """
+ // We compile all models with -fvisibility=hidden. Any symbols that need to be
+ // exposed in the final shared library must be declared with PT_EXPORT to make
+ // them visible.
+ #ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++)
+ #define PT_EXPORT __attribute__((__visibility__("default")))
+ #else
+ #ifdef _WIN32
+ #define PT_EXPORT __declspec(dllexport)
+ #else
+ #define PT_EXPORT
+ #endif
+ #endif
+ using bfloat16 = nv_bfloat16;
+ """
+ )
+ return res
+
+ def render(self, **kwargs) -> str:
+ raise NotImplementedError
+
+
+class CUTLASSTemplate(CUDATemplate):
+ def header(self) -> IndentedBuffer:
+ res = super().header()
+ res.splice(
+ """
+ #include "cutlass/cutlass.h"
+ #include "cutlass/numeric_types.h"
+ #include "cutlass/util/host_tensor.h"
+ #include "cutlass/util/reference/host/tensor_fill.h"
+ #include "cutlass/util/reference/device/tensor_fill.h"
+ #include "cutlass/util/device_memory.h"
+ """
+ )
+ return res
+
+ def globals(self) -> IndentedBuffer:
+ res = super().globals()
+ res.splice(
+ """
+ #define CUTLASS_CHECK(status) \\
+ { \\
+ cutlass::Status error = status; \\
+ if (error != cutlass::Status::kSuccess) { \\
+ auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \\
+ cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \\
+ throw std::runtime_error(msg); \\
+ } \\
+ }
+ """
+ )
+ return res
+
+ def cute_int(self, int_str: str, var_name: str) -> str:
+ res = ""
+ if int_str in {"1", "1L"}:
+ res = "cute::Int<1>{}"
+ else:
+ res = int_str
+
+ return f"{res} /* {var_name} */"
+
+ _DTYPE_TO_CUTLASS = {
+ torch.float32: "float",
+ torch.float64: "double",
+ torch.float16: "cutlass::half_t",
+ torch.int32: "int",
+ torch.int8: "int8_t",
+ torch.uint8: "uint8_t",
+ torch.bool: "bool",
+ torch.bfloat16: "cutlass::bfloat16_t",
+ }
+
+ def cutlass_type_cast(self, node: IRNode, ptr: str) -> str:
+ if node is None:
+ return ptr
+ else:
+ return f"({self._DTYPE_TO_CUTLASS.get(node.get_dtype())}*)({ptr})"
diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py
index d092609..24cd7fc 100644
--- a/torch/_inductor/codegen/triton.py
+++ b/torch/_inductor/codegen/triton.py
@@ -20,7 +20,7 @@
from .. import config, ir, scheduler
from ..codecache import code_hash, get_path
from ..dependencies import MemoryDep, StarDep
-from ..ir import ReductionHint
+from ..ir import IRNode, ReductionHint, TritonTemplateBuffer
from ..optimize_indexing import indexing_dtype_strength_reduction
from ..scheduler import BaseScheduling
from ..triton_heuristics import AutotuneHint
@@ -2061,7 +2061,7 @@
return f"[{', '.join(sizes)}]"
- def call_kernel(self, name: str):
+ def call_kernel(self, name: str, node: IRNode = None):
wrapper = V.graph.wrapper_code
_, call_args, _ = self.args.python_argdefs()
# dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
@@ -2086,6 +2086,8 @@
call_args,
grid,
V.graph.scheduler.current_device.index,
+ cuda=True,
+ triton=True,
)
def warn_mix_layout(self, kernel_name):
@@ -2186,7 +2188,9 @@
return False
if node1.is_template():
- return True # skip checks for compatible tiling
+ # Only allow fusion for TritonTemplates for now.
+ # Fusion for CUDATemplates are not supported.
+ return isinstance(node1.node, TritonTemplateBuffer)
# check for a bad combined tiling
tiling1 = self.select_tiling(node1.get_nodes(), numel1, rnumel1)
@@ -2603,11 +2607,16 @@
# finalize must be called after adding epilogue above
with V.set_kernel_handler(kernel):
- src_code = partial_code.finalize()
+ # TODO: Maybe unify CUDATemplateKernel to also use PartialRender for flexible epilogue fusion.
+ src_code = (
+ partial_code
+ if isinstance(partial_code, str)
+ else partial_code.finalize()
+ )
node_schedule = [template_node, *epilogue_nodes]
kernel_name = self.define_kernel(src_code, node_schedule)
self.codegen_comment(node_schedule)
- kernel.call_kernel(kernel_name)
+ kernel.call_kernel(kernel_name, template_node.node)
V.graph.removed_buffers |= kernel.removed_buffers
self.scheduler.free_buffers()
diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py
index 0cd5484..74eb33b 100644
--- a/torch/_inductor/codegen/wrapper.py
+++ b/torch/_inductor/codegen/wrapper.py
@@ -733,17 +733,36 @@
stack.enter_context(self.wrapper_call.indent())
def generate_kernel_call(
- self, name, call_args, grid=None, device_index=None, cuda=True
+ self,
+ name,
+ call_args,
+ grid=None,
+ device_index=None,
+ cuda=True,
+ triton=True,
):
+ """
+ Generates kernel call code.
+
+ cuda: Defines whether the backend is GPU. Otherwise the backend is CPU.
+
+ triton: Defines whether the GPU backend uses Triton for codegen.
+ Otherwise it uses the CUDA language for codegen.
+ Only valid when cuda == True.
+ """
if cuda:
call_args_str = ", ".join(pexpr(item) for item in call_args)
- grid_str = ", ".join(pexpr(item) for item in grid)
stream_name = self.write_get_cuda_stream(
V.graph.scheduler.current_device.index
)
- self.writeline(
- f"{name}.run({call_args_str}, grid=grid({grid_str}), stream={stream_name})"
- )
+ if triton:
+ grid_str = ", ".join(pexpr(item) for item in grid)
+ self.writeline(
+ f"{name}.run({call_args_str}, grid=grid({grid_str}), stream={stream_name})"
+ )
+ else:
+ stream_ptr = f"c_void_p({stream_name})"
+ self.writeline(f"{name}.{name}({call_args_str}, {stream_ptr})")
else:
self.writeline(self.wrap_kernel_call(name, call_args))
@@ -823,6 +842,10 @@
)
def codegen_allocation(self, buffer):
+ assert (
+ buffer.get_workspace_size() == 0
+ ), "Only support zero workspace size for now!"
+
name = buffer.get_name()
if name in V.graph.removed_buffers or name in self.allocated:
@@ -854,6 +877,10 @@
)
def codegen_free(self, buffer):
+ assert (
+ buffer.get_workspace_size() == 0
+ ), "Only support zero workspace size for now!"
+
name = buffer.get_name()
# can be freed but not reused
@@ -1682,11 +1709,11 @@
return ", ".join(new_args)
def generate_kernel_call(
- self, name, call_args, grid=None, device_index=None, cuda=True
+ self, name, call_args, grid=None, device_index=None, cuda=True, triton=True
):
if not cuda:
return super().generate_kernel_call(
- name, call_args, grid, device_index, cuda
+ name, call_args, grid, device_index, cuda, triton
)
params = CudaKernelParamCache.get(name)
diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py
index e6e5c8a..bd608dc 100644
--- a/torch/_inductor/ir.py
+++ b/torch/_inductor/ir.py
@@ -307,6 +307,12 @@
def get_read_names(self):
return {dep.name for dep in self.get_reads()}
+ def get_layout(self):
+ raise NotImplementedError(f"get_layout() is not implemented by {type(self)}!")
+
+ def get_size(self):
+ raise NotImplementedError(f"get_size() is not implemented by {type(self)}!")
+
def get_numel(self):
return sympy_product(self.get_size())
@@ -1594,6 +1600,9 @@
def get_dtype(self):
return self.data.get_dtype()
+ def get_layout(self):
+ return self.data.get_layout()
+
def get_device(self):
return self.data.get_device()
@@ -2574,6 +2583,13 @@
def realize(self):
pass
+ def get_workspace_size(self):
+ """
+ Gets extra global memory size needed by this buffer.
+ Some algorithms (e.g. group gemm) may require extra global memory in the generated code.
+ """
+ return 0
+
class InputBuffer(Buffer):
pass
@@ -2912,6 +2928,26 @@
)
+class TritonTemplateBuffer(TemplateBuffer):
+ pass
+
+
+class CUDATemplateBuffer(TemplateBuffer):
+ def __init__(
+ self,
+ layout,
+ inputs,
+ make_kernel_render,
+ workspace_size: int = 0,
+ ):
+ super().__init__(layout, inputs, make_kernel_render)
+ # Global memory (in bytes) needed for this template.
+ self.workspace_size = workspace_size
+
+ def get_workspace_size(self):
+ return self.workspace_size if self.workspace_size is not None else 0
+
+
@dataclasses.dataclass
class InputsKernel(Buffer):
inputs: List[Buffer]
@@ -5302,6 +5338,12 @@
def layout(self):
return self.data.layout
+ def get_layout(self):
+ return self.layout
+
+ def get_size(self):
+ return self.data.get_size()
+
def __str__(self):
if isinstance(self.data, MutableBox):
line0 = f"{type(self).__name__}({type(self.data).__name__}("
diff --git a/torch/_inductor/kernel/bmm.py b/torch/_inductor/kernel/bmm.py
index 9b414b2..f578f36 100644
--- a/torch/_inductor/kernel/bmm.py
+++ b/torch/_inductor/kernel/bmm.py
@@ -95,8 +95,8 @@
for config in mm_configs(m, n, k):
bmm_template.maybe_append_choice(
choices,
- (mat1, mat2),
- layout,
+ input_nodes=(mat1, mat2),
+ layout=layout,
**mm_options(config, k, layout),
)
@@ -118,8 +118,8 @@
for config in mm_configs(m, n, k):
bmm_template.maybe_append_choice(
choices,
- (inp, mat1, mat2),
- layout,
+ input_nodes=(inp, mat1, mat2),
+ layout=layout,
**mm_options(config, k, layout),
prefix_args=1,
epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
diff --git a/torch/_inductor/kernel/conv.py b/torch/_inductor/kernel/conv.py
index e0d4124..6b3c387 100644
--- a/torch/_inductor/kernel/conv.py
+++ b/torch/_inductor/kernel/conv.py
@@ -434,8 +434,8 @@
):
conv2d_template.maybe_append_choice(
choices,
- (x, weight),
- layout,
+ input_nodes=(x, weight),
+ layout=layout,
KERNEL_H=kernel_shape[0],
KERNEL_W=kernel_shape[1],
STRIDE_H=stride[0],
diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py
index 576c195..91f763b 100644
--- a/torch/_inductor/kernel/mm.py
+++ b/torch/_inductor/kernel/mm.py
@@ -9,7 +9,7 @@
ExternKernelChoice,
TritonTemplate,
)
-from ..utils import use_aten_gemm_kernels, use_triton_template
+from ..utils import use_aten_gemm_kernels, use_max_autotune, use_triton_template
from .mm_common import (
addmm_epilogue,
int8_mm_configs,
@@ -116,12 +116,13 @@
# options to tune from
choices = [aten_mm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else []
+
if m * n != 0 and use_triton_template(layout):
for config in mm_configs(m, n, k):
mm_template.maybe_append_choice(
choices,
- (mat1, mat2),
- layout,
+ input_nodes=(mat1, mat2),
+ layout=layout,
**mm_options(config, k, layout),
)
@@ -142,8 +143,8 @@
for config in int8_mm_configs(m, n, k):
mm_template.maybe_append_choice(
choices,
- (mat1, mat2),
- layout,
+ input_nodes=(mat1, mat2),
+ layout=layout,
**mm_options(config, k, layout),
)
return autotune_select_algorithm("int_mm", choices, [mat1, mat2], layout)
@@ -154,7 +155,7 @@
ordered_kwargs_for_cpp_kernel = ("beta", "alpha")
m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout)
- if m * n == 0 or not use_triton_template(layout):
+ if m * n == 0 or not use_max_autotune():
choices = (
[
aten_addmm.bind(
@@ -183,8 +184,10 @@
if use_aten_gemm_kernels()
else []
)
+
if (
- inp_expanded.get_stride()[0] == 0
+ use_aten_gemm_kernels()
+ and inp_expanded.get_stride()[0] == 0
and inp_expanded.get_device().type == "cuda"
and inductor_config.triton.autotune_cublasLt
):
@@ -196,15 +199,16 @@
),
)
- for config in mm_configs(m, n, k):
- mm_template.maybe_append_choice(
- choices,
- (inp_expanded, mat1, mat2),
- layout,
- **mm_options(config, k, layout),
- prefix_args=1,
- epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
- )
+ if use_triton_template(layout):
+ for config in mm_configs(m, n, k):
+ mm_template.maybe_append_choice(
+ choices,
+ input_nodes=(inp_expanded, mat1, mat2),
+ layout=layout,
+ **mm_options(config, k, layout),
+ prefix_args=1,
+ epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
+ )
return autotune_select_algorithm(
"addmm", choices, [inp_expanded, mat1, mat2], layout
@@ -231,8 +235,8 @@
for config in mm_configs(m, n, k, has_int8_tensor=has_int8_tensor):
mm_template.maybe_append_choice(
choices,
- (mat1, mat2),
- layout,
+ input_nodes=(mat1, mat2),
+ layout=layout,
**mm_options(config, k, layout, b_prologue_cast_type),
)
return autotune_select_algorithm("mixed_mm", choices, [mat1, mat2], layout)
diff --git a/torch/_inductor/kernel/mm_plus_mm.py b/torch/_inductor/kernel/mm_plus_mm.py
index 40564b3..0d1a780 100644
--- a/torch/_inductor/kernel/mm_plus_mm.py
+++ b/torch/_inductor/kernel/mm_plus_mm.py
@@ -229,8 +229,8 @@
if config.kwargs["BLOCK_K"] < k1:
mm_plus_mm_template.maybe_append_choice(
choices,
- (mat1, mat2, mat3, mat4),
- layout1,
+ input_nodes=(mat1, mat2, mat3, mat4),
+ layout=layout1,
**mm_options(config, k1, layout1),
)
diff --git a/torch/_inductor/kernel/unpack_mixed_mm.py b/torch/_inductor/kernel/unpack_mixed_mm.py
index a8da948..be57170 100644
--- a/torch/_inductor/kernel/unpack_mixed_mm.py
+++ b/torch/_inductor/kernel/unpack_mixed_mm.py
@@ -75,8 +75,8 @@
for config in mm_configs(m, n, k):
uint4x2_mixed_mm_template.maybe_append_choice(
choices,
- (mat1, mat2),
- layout,
+ input_nodes=(mat1, mat2),
+ layout=layout,
**mm_options(config, k, layout, b_prologue_cast_type),
)
return autotune_select_algorithm("uint4x2_mixed_mm", choices, [mat1, mat2], layout)
diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py
index c442333..07fc83b 100644
--- a/torch/_inductor/scheduler.py
+++ b/torch/_inductor/scheduler.py
@@ -1772,7 +1772,12 @@
if node.is_template():
node, *epilogue = node.get_nodes()
- self.get_backend(device).codegen_template(node, epilogue)
+ if isinstance(node.node, ir.CUDATemplateBuffer):
+ from .codegen.cuda.cuda_scheduling import CUDAScheduling
+
+ CUDAScheduling(self).codegen_template(node, epilogue)
+ else:
+ self.get_backend(device).codegen_template(node, epilogue)
elif node.is_extern():
self.codegen_extern_call(node)
elif node.is_foreach():
diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py
index c527377..4d8e7ff 100644
--- a/torch/_inductor/select_algorithm.py
+++ b/torch/_inductor/select_algorithm.py
@@ -20,13 +20,18 @@
from . import config, ir
from .autotune_process import TensorMeta, TritonBenchmarkRequest
from .codecache import code_hash, PersistentCache, PyCodeCache
-
-from .codegen.common import IndentedBuffer
+from .codegen.common import ChoiceCaller, IndentedBuffer, KernelTemplate
+from .codegen.cuda.cuda_kernel import CUDATemplateCaller
from .codegen.triton import texpr, TritonKernel, TritonPrinter, TritonScheduling
-
from .codegen.triton_utils import config_of, signature_to_meta
-
-from .utils import do_bench, Placeholder, sympy_dot, sympy_product, unique
+from .exc import CUDACompileError
+from .utils import (
+ do_bench_using_profiling,
+ Placeholder,
+ sympy_dot,
+ sympy_product,
+ unique,
+)
from .virtualized import V
log = logging.getLogger(__name__)
@@ -338,7 +343,7 @@
self.body.clear()
self.indexing_code.clear()
- def call_kernel(self, name: str):
+ def call_kernel(self, name: str, node: ir.TritonTemplateBuffer):
wrapper = V.graph.wrapper_code
_, call_args, _ = self.args.python_argdefs()
call_args = [str(a) for a in call_args]
@@ -385,54 +390,18 @@
return None
-class TritonTemplate:
+class TritonTemplate(KernelTemplate):
index_counter = itertools.count()
all_templates: Dict[str, "TritonTemplate"] = dict()
- @staticmethod
- def _template_from_string(source):
- env = _jinja2_env()
- if env is not None:
- return env.from_string(source)
- return None
-
def __init__(self, name: str, grid: Any, source: str, debug=False):
- super().__init__()
- self.name = name
+ super().__init__(name)
self.grid = grid
self.template = self._template_from_string(source)
assert name not in self.all_templates, "duplicate template name"
self.all_templates[name] = self
self.debug = debug
- def maybe_append_choice(
- self,
- choices,
- input_nodes,
- layout,
- num_stages,
- num_warps,
- prefix_args=0,
- suffix_args=0,
- epilogue_fn=identity,
- **kwargs,
- ):
- try:
- choices.append(
- self.generate(
- input_nodes=input_nodes,
- layout=layout,
- num_stages=num_stages,
- num_warps=num_warps,
- prefix_args=prefix_args,
- suffix_args=suffix_args,
- epilogue_fn=epilogue_fn,
- **kwargs,
- )
- )
- except NotImplementedError:
- pass
-
def generate(
self,
input_nodes,
@@ -474,7 +443,7 @@
index_dtype="tl.int32",
)
with patch.object(
- V.graph, "get_dtype", self.fake_get_dtype(fake_out)
+ V.graph, "get_dtype", self._fake_get_dtype(fake_out)
), TritonTemplateKernel(
kernel_name=kernel_name,
output_node=fake_out,
@@ -534,14 +503,14 @@
grid = self.grid(*V.graph.sizevars.size_hints(layout.size), kwargs)
bmreq = TritonBenchmarkRequest(
module_path=mod.__file__,
- input_tensor_meta=TensorMeta.from_irnodes(input_nodes),
- output_tensor_meta=TensorMeta.from_irnodes(layout),
module_cache_key=mod.key,
kernel_name=kernel_name,
grid=grid,
extra_args=extra_args,
num_stages=num_stages,
num_warps=num_warps,
+ input_tensor_meta=TensorMeta.from_irnodes(input_nodes),
+ output_tensor_meta=TensorMeta.from_irnodes(layout),
)
return TritonTemplateCaller(
@@ -553,17 +522,6 @@
bmreq,
)
- @staticmethod
- def fake_get_dtype(fake_out):
- _get_dtype_real = V.graph.get_dtype
-
- def get_dtype(name):
- if name == fake_out.get_name():
- return fake_out.get_dtype()
- return _get_dtype_real(name)
-
- return get_dtype
-
class ExternKernelChoice:
def __init__(
@@ -610,30 +568,6 @@
)
-class ChoiceCaller:
- def __init__(self, name, input_nodes, layout):
- super().__init__()
- self.name = name
- self.layout = layout
- self.input_nodes = input_nodes
-
- def benchmark(self, *args, out):
- algo = self.to_callable()
- return do_bench(lambda: algo(*args, out=out))
-
- def call_name(self):
- raise NotImplementedError()
-
- def to_callable(self):
- raise NotImplementedError()
-
- def hash_key(self):
- raise NotImplementedError()
-
- def output_node(self):
- raise NotImplementedError()
-
-
class TritonTemplateCaller(ChoiceCaller):
def __init__(
self, name, input_nodes, layout, make_kernel_render, debug_extra, bmreq
@@ -663,7 +597,7 @@
def output_node(self):
return ir.TensorBox.create(
- ir.TemplateBuffer(
+ ir.TritonTemplateBuffer(
layout=self.layout,
inputs=self.input_nodes,
make_kernel_render=self.make_kernel_render,
@@ -699,7 +633,7 @@
out_new, tuple(out.size()), tuple(out.stride())
)
out.copy_(out_new) # for correctness checking
- return do_bench(lambda: algo(*args))
+ return do_bench_using_profiling(lambda: algo(*args))
def to_callable(self):
fn = self.choice.to_callable()
@@ -766,9 +700,12 @@
"No choices to select, please consider adding ATEN into max_autotune_gemm_backends "
"config (defined in torch/_inductor/config.py) to allow at least one choice. "
)
+ log.info("Max autotune selects from %s choices.", str(len(choices)))
if len(choices) == 1:
- return choices[0].output_node()
+ if not isinstance(choices[0], CUDATemplateCaller):
+ # CUDATemplateCaller still needs to go through autotuning process to retrieve workspace size.
+ return choices[0].output_node()
@functools.lru_cache(None)
def make_benchmark_fn():
@@ -780,6 +717,11 @@
timing = benchmark_fn(
choice,
)
+ except CUDACompileError as e:
+ log.warning(
+ "CUDA compilation error: \n%s. \nIgnore this choice.", str(e)
+ )
+ return float("inf")
except RuntimeError as e:
msg = str(e)
if "invalid argument" in msg:
@@ -813,8 +755,14 @@
if make_benchmark_fn.cache_info().currsize:
counters["inductor"]["select_algorithm_autotune"] += 1
+ if (
+ make_benchmark_fn.cache_info().currsize
+ or log.getEffectiveLevel() == logging.DEBUG
+ ):
self.log_results(name, input_nodes, timings, autotune_elapse)
- return builtins.min(timings, key=timings.__getitem__).output_node()
+ selected_choice = builtins.min(timings, key=timings.__getitem__).output_node()
+ log.debug("selected choice: %s", str(selected_choice))
+ return selected_choice
@classmethod
def make_benchmark_fn(
@@ -922,7 +870,8 @@
for n in input_nodes
]
)
- top_k = sorted(timings, key=timings.__getitem__)[:10]
+ n = None if log.getEffectiveLevel() == logging.DEBUG else 10
+ top_k = sorted(timings, key=timings.__getitem__)[:n]
best = top_k[0]
best_time = timings[best]
sys.stderr.write(f"AUTOTUNE {name}({sizes})\n")
diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py
index 47371ee..6974566 100644
--- a/torch/_inductor/utils.py
+++ b/torch/_inductor/utils.py
@@ -734,25 +734,38 @@
return True
-def use_triton_template(layout, *, enable_int32=False):
- layout_dtypes = [torch.float16, torch.bfloat16, torch.float32]
- if enable_int32:
- layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32]
+def use_max_autotune() -> bool:
return (
- (
- config.max_autotune
- or config.max_autotune_gemm
- or config.search_autotune_cache
- )
- and "TRITON" in config.max_autotune_gemm_backends.upper().split(",")
+ config.max_autotune or config.max_autotune_gemm or config.search_autotune_cache
+ )
+
+
+def _use_template_for_cuda(layout, allowed_layout_dtypes: List[torch.dtype]) -> bool:
+ return (
+ use_max_autotune()
and layout.device.type == "cuda"
- and layout.dtype in layout_dtypes
+ and layout.dtype in allowed_layout_dtypes
and is_big_gpu(layout.device.index or 0)
)
+def _use_autotune_backend(backend: str) -> bool:
+ return backend.upper() in [
+ x.strip() for x in config.max_autotune_gemm_backends.upper().split(",")
+ ]
+
+
+def use_triton_template(layout, *, enable_int32=False):
+ layout_dtypes = [torch.float16, torch.bfloat16, torch.float32]
+ if enable_int32:
+ layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32]
+ return _use_template_for_cuda(layout, layout_dtypes) and _use_autotune_backend(
+ "TRITON"
+ )
+
+
def use_aten_gemm_kernels():
- return "ATEN" in config.max_autotune_gemm_backends.upper().split(",")
+ return not use_max_autotune() or _use_autotune_backend("ATEN")
class DebugDirManager: