[Cutlass 3.2.2 submodule upgrade] Adapt Inductor cutlass backend to Cutlass 3.2.2 (#112762)
The inductor cutlass backend was written against Cutlass version 3.1.x,
there are some incompatible changes in Cutlass 3.2.2 which the
Inductor cutlass backend needs to adapt to.
Test plan:
If third_party/cutlass is upgraded to Cutlass tag v3.2.2,
several tests within test/inductor/test_max_autotune.py start to
fail. With this diff applied, they pass again.
Differential Revision: [D50986555](https://our.internmc.facebook.com/intern/diff/D50986555)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112762
Approved by: https://github.com/ipiszy, https://github.com/drisspg
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 3f27f8c..7af3f61 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -479,6 +479,17 @@
endif()
if(MSVC)
+ # MSVC by default does not apply the correct __cplusplus version as specified by the C++ standard
+ # because MSVC is not a completely compliant implementation. This option forces MSVC to use the
+ # appropriate value given the requested --std option. This fixes a compilation issue mismatch
+ # between GCC/Clang and MSVC.
+ #
+ # See:
+ # * https://learn.microsoft.com/en-us/cpp/build/reference/zc-cplusplus?view=msvc-170
+ # * https://en.cppreference.com/w/cpp/preprocessor/replace#Predefined_macros
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Zc:__cplusplus")
+ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler /Zc:__cplusplus")
+
set(CMAKE_NINJA_CMCLDEPS_RC OFF)
foreach(flag_var
CMAKE_C_FLAGS CMAKE_C_FLAGS_DEBUG CMAKE_C_FLAGS_RELEASE
diff --git a/third_party/cutlass b/third_party/cutlass
index 6f47420..44c704e 160000
--- a/third_party/cutlass
+++ b/third_party/cutlass
@@ -1 +1 @@
-Subproject commit 6f47420213f757831fae65c686aa471749fa8d60
+Subproject commit 44c704eae85da352d277d6f092f41412772f70e4
diff --git a/torch/_inductor/codegen/cuda/cutlass_utils.py b/torch/_inductor/codegen/cuda/cutlass_utils.py
index 953b3ed..dd1170d 100644
--- a/torch/_inductor/codegen/cuda/cutlass_utils.py
+++ b/torch/_inductor/codegen/cuda/cutlass_utils.py
@@ -17,66 +17,37 @@
log = logging.getLogger(__name__)
-def _rename_cutlass_import(content: str, cutlass_modules: List[str]) -> str:
- for cutlass_module in cutlass_modules:
- content = content.replace(
- f"from {cutlass_module} import ", f"from cutlass_{cutlass_module} import "
- )
- return content
-
-
-def _gen_cutlass_file(
- file_name: str, cutlass_modules: List[str], src_dir: str, dst_dir: str
-) -> None:
- orig_full_path = os.path.abspath(os.path.join(src_dir, file_name))
- text = ""
- with open(orig_full_path) as f:
- text = f.read()
- text = _rename_cutlass_import(text, cutlass_modules)
- dst_full_path = os.path.abspath(
- os.path.join(
- dst_dir,
- f"cutlass_{file_name}" if file_name != "__init__.py" else file_name,
- )
- )
- with open(dst_full_path, "w") as f:
- f.write(text)
-
-
@functools.lru_cache(None)
def try_import_cutlass() -> bool:
# Copy CUTLASS python scripts to a temp dir and add the temp dir to Python search path.
# This is a temporary hack to avoid CUTLASS module naming conflicts.
# TODO(ipiszy): remove this hack when CUTLASS solves Python scripts packaging structure issues.
- cutlass_py_full_path = os.path.join(
- inductor_cuda_config.cutlass_dir, "tools/library/scripts"
+ cutlass_py_full_path = os.path.abspath(
+ os.path.join(inductor_cuda_config.cutlass_dir, "python/cutlass_library")
)
tmp_cutlass_py_full_path = os.path.abspath(
- os.path.join(cache_dir(), "torch_cutlass_script")
+ os.path.join(cache_dir(), "torch_cutlass_library")
)
+ dst_link = os.path.join(tmp_cutlass_py_full_path, "cutlass_library")
if os.path.isdir(cutlass_py_full_path):
- cutlass_file_names = [
- file_name
- for file_name in os.listdir(cutlass_py_full_path)
- if file_name.endswith(".py")
- ]
- cutlass_module_names = [file_name[:-3] for file_name in cutlass_file_names]
- if not os.path.isdir(tmp_cutlass_py_full_path):
- os.mkdir(tmp_cutlass_py_full_path)
- for file_name in cutlass_file_names:
- _gen_cutlass_file(
- file_name,
- cutlass_module_names,
- cutlass_py_full_path,
- tmp_cutlass_py_full_path,
- )
- sys.path.append(tmp_cutlass_py_full_path)
+ if tmp_cutlass_py_full_path not in sys.path:
+ if os.path.exists(dst_link):
+ assert os.path.islink(
+ dst_link
+ ), f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again."
+ assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath(
+ cutlass_py_full_path
+ ), f"Symlink at {dst_link} does not point to {cutlass_py_full_path}"
+ else:
+ os.makedirs(tmp_cutlass_py_full_path, exist_ok=True)
+ os.symlink(cutlass_py_full_path, dst_link)
+ sys.path.append(tmp_cutlass_py_full_path)
try:
- import cutlass_generator # type: ignore[import] # noqa: F401
- import cutlass_library # type: ignore[import] # noqa: F401
- import cutlass_manifest # type: ignore[import] # noqa: F401
+ import cutlass_library.generator # type: ignore[import] # noqa: F401
+ import cutlass_library.library # type: ignore[import] # noqa: F401
+ import cutlass_library.manifest # type: ignore[import] # noqa: F401
return True
@@ -136,18 +107,14 @@
@functools.lru_cache(None)
-def gen_ops() -> List[Any]:
- """
- Generates all supported CUTLASS operations.
- """
+def _gen_ops_cached(arch, version) -> List[Any]:
+ # Note: Cache needs to be specific for cuda architecture and version
# Import cutlass python scripts.
assert try_import_cutlass()
- import cutlass_generator # type: ignore[import]
- import cutlass_manifest # type: ignore[import]
+ import cutlass_library.generator as cutlass_generator # type: ignore[import]
+ import cutlass_library.manifest as cutlass_manifest # type: ignore[import]
- arch = get_cuda_arch()
- version = get_cuda_version()
if arch is None or version is None:
log.error(
"Cannot detect cuda arch %s or cuda version %s. "
@@ -172,13 +139,21 @@
raise NotImplementedError(
"Arch " + arch + " is not supported by current cutlass lib."
) from e
-
return manifest.operations
+def gen_ops() -> List[Any]:
+ """
+ Generates all supported CUTLASS operations.
+ """
+ arch = get_cuda_arch()
+ version = get_cuda_version()
+ return _gen_ops_cached(arch, version)
+
+
def dtype_match(
torch_dtype: torch.dtype,
- cutlass_dtype: "cutlass_library.DataType", # type: ignore[name-defined]
+ cutlass_dtype: "cutlass_library.library.DataType", # type: ignore[name-defined]
) -> bool:
# Import cutlass python scripts.
assert try_import_cutlass()
@@ -186,13 +161,13 @@
if torch_dtype == torch.float:
return (
- cutlass_dtype == cutlass_library.DataType.f32
- or cutlass_dtype == cutlass_library.DataType.tf32
+ cutlass_dtype == cutlass_library.library.DataType.f32
+ or cutlass_dtype == cutlass_library.library.DataType.tf32
)
elif torch_dtype == torch.half:
- return cutlass_dtype == cutlass_library.DataType.f16
+ return cutlass_dtype == cutlass_library.library.DataType.f16
elif torch_dtype == torch.bfloat16:
- return cutlass_dtype == cutlass_library.DataType.bf16
+ return cutlass_dtype == cutlass_library.library.DataType.bf16
else:
return False
diff --git a/torch/_inductor/codegen/cuda/gemm_template.py b/torch/_inductor/codegen/cuda/gemm_template.py
index 367d368..8a51c73 100644
--- a/torch/_inductor/codegen/cuda/gemm_template.py
+++ b/torch/_inductor/codegen/cuda/gemm_template.py
@@ -131,7 +131,6 @@
};
"""
-
GEMM_ARGS_CUTLASS_3X_EPILOGUE = r"""
{
{ElementComputeEpilogue({{alpha}}), ElementComputeEpilogue({{beta}})}, // typename ThreadEpilogueOp::Params thread
@@ -183,7 +182,7 @@
@staticmethod
def cutlass_layout(torch_layout) -> "Optional[cutlass_lib.LayoutType]": # type: ignore[name-defined]
assert cutlass_utils.try_import_cutlass()
- import cutlass_library as cutlass_lib # type: ignore[import]
+ import cutlass_library.library as cutlass_lib # type: ignore[import]
if torch_layout.stride[-1] == 1:
return cutlass_lib.LayoutType.RowMajor
@@ -197,7 +196,7 @@
cutlass_layout: "cutlass_lib.LayoutType", # type: ignore[name-defined]
) -> "cutlass_lib.LayoutType": # type: ignore[name-defined]
assert cutlass_utils.try_import_cutlass()
- import cutlass_library as cutlass_lib # type: ignore[import]
+ import cutlass_library.library as cutlass_lib # type: ignore[import]
if cutlass_layout == cutlass_lib.LayoutType.RowMajor:
return cutlass_lib.LayoutType.ColumnMajor
@@ -220,7 +219,7 @@
@staticmethod
def has_tma_epilogue(op) -> bool:
assert cutlass_utils.try_import_cutlass()
- import cutlass_library as cutlass_lib # type: ignore[import]
+ import cutlass_library.library as cutlass_lib # type: ignore[import]
result = False
if op.gemm_kind == cutlass_lib.GemmKind.Universal3x:
@@ -233,8 +232,8 @@
op: "cutlass_gemm_op.GemmOperation", # type: ignore[name-defined]
) -> Tuple[str, str]:
assert cutlass_utils.try_import_cutlass()
- import cutlass_gemm_operation as cutlass_gemm_op # type: ignore[import]
- import cutlass_library as cutlass_lib # type: ignore[import]
+ import cutlass_library.gemm_operation as cutlass_gemm_op # type: ignore[import]
+ import cutlass_library.library as cutlass_lib # type: ignore[import]
if op.gemm_kind == cutlass_lib.GemmKind.Universal3x:
emitter = cutlass_gemm_op.EmitGemmUniversal3xInstance()
@@ -291,7 +290,7 @@
op: "cutlass_gemm_op.GemmOperation", # type: ignore[name-defined]
) -> "cutlass_gemm_op.GemmOperation": # type: ignore[name-defined]
assert cutlass_utils.try_import_cutlass()
- import cutlass_library as cutlass_lib # type: ignore[import]
+ import cutlass_library.library as cutlass_lib # type: ignore[import]
# Skip simt kernels
if (
@@ -306,7 +305,6 @@
cutlass_lib.GemmKind.Universal3x,
}:
return None
-
# Filter ops by dtypes.
X = self.input_nodes[0]
W = self.input_nodes[1]
@@ -372,21 +370,22 @@
def gen_ops(self) -> "List[cutlass_gemm_op.GemmOperation]": # type: ignore[name-defined]
assert cutlass_utils.try_import_cutlass()
- import cutlass_gemm_operation as cutlass_gemm_op # type: ignore[import]
- import cutlass_library as cutlass_lib # type: ignore[import]
+ import cutlass_library.gemm_operation as cutlass_gemm_op # type: ignore[import]
+ import cutlass_library.library as cutlass_lib # type: ignore[import]
ops = cutlass_utils.gen_ops()[cutlass_lib.OperationKind.Gemm]
res: Dict[str, cutlass_gemm_op.GemmOperation] = dict()
num_3x_ops = 0
num_2x_ops = 0
- for op_list in ops.values():
- for op in op_list:
- filter_res = self.filter_op(op)
- if (
- filter_res is not None
- and res.get(filter_res.configuration_name(), None) is None
- ):
- res[filter_res.configuration_name()] = filter_res
+ for op_dict in ops.values():
+ for op_list in op_dict.values():
+ for op in op_list:
+ filter_res = self.filter_op(op)
+ if (
+ filter_res is not None
+ and res.get(filter_res.configuration_name(), None) is None
+ ):
+ res[filter_res.configuration_name()] = filter_res
for op in res.values():
if op.gemm_kind == cutlass_lib.GemmKind.Universal3x:
num_3x_ops += 1
@@ -481,7 +480,7 @@
output_node: Optional[Buffer] = None,
) -> str:
assert cutlass_utils.try_import_cutlass()
- import cutlass_library as cutlass_lib # type: ignore[import]
+ import cutlass_library.library as cutlass_lib # type: ignore[import]
if output_node is not None:
self.output_node = output_node
@@ -523,6 +522,5 @@
instance_type=instance_type,
input_reorder=self.input_reorder,
)
-
res = self._template_from_string(GEMM_TEMPLATE).render(**options)
return res