[Cutlass 3.2.2 submodule upgrade] Adapt Inductor cutlass backend to Cutlass 3.2.2 (#112762)

The inductor cutlass backend was written against Cutlass version 3.1.x,
there are some incompatible changes in Cutlass 3.2.2 which the
Inductor cutlass backend needs to adapt to.

Test plan:

If third_party/cutlass is upgraded to Cutlass tag v3.2.2,
several tests within test/inductor/test_max_autotune.py start to
fail. With this diff applied, they pass again.

Differential Revision: [D50986555](https://our.internmc.facebook.com/intern/diff/D50986555)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112762
Approved by: https://github.com/ipiszy, https://github.com/drisspg
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 3f27f8c..7af3f61 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -479,6 +479,17 @@
 endif()
 
 if(MSVC)
+  # MSVC by default does not apply the correct __cplusplus version as specified by the C++ standard
+  # because MSVC is not a completely compliant implementation. This option forces MSVC to use the
+  # appropriate value given the requested --std option. This fixes a compilation issue mismatch
+  # between GCC/Clang and MSVC.
+  #
+  # See:
+  # * https://learn.microsoft.com/en-us/cpp/build/reference/zc-cplusplus?view=msvc-170
+  # * https://en.cppreference.com/w/cpp/preprocessor/replace#Predefined_macros
+  set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Zc:__cplusplus")
+  set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler  /Zc:__cplusplus")
+
   set(CMAKE_NINJA_CMCLDEPS_RC OFF)
   foreach(flag_var
       CMAKE_C_FLAGS CMAKE_C_FLAGS_DEBUG CMAKE_C_FLAGS_RELEASE
diff --git a/third_party/cutlass b/third_party/cutlass
index 6f47420..44c704e 160000
--- a/third_party/cutlass
+++ b/third_party/cutlass
@@ -1 +1 @@
-Subproject commit 6f47420213f757831fae65c686aa471749fa8d60
+Subproject commit 44c704eae85da352d277d6f092f41412772f70e4
diff --git a/torch/_inductor/codegen/cuda/cutlass_utils.py b/torch/_inductor/codegen/cuda/cutlass_utils.py
index 953b3ed..dd1170d 100644
--- a/torch/_inductor/codegen/cuda/cutlass_utils.py
+++ b/torch/_inductor/codegen/cuda/cutlass_utils.py
@@ -17,66 +17,37 @@
 log = logging.getLogger(__name__)
 
 
-def _rename_cutlass_import(content: str, cutlass_modules: List[str]) -> str:
-    for cutlass_module in cutlass_modules:
-        content = content.replace(
-            f"from {cutlass_module} import ", f"from cutlass_{cutlass_module} import "
-        )
-    return content
-
-
-def _gen_cutlass_file(
-    file_name: str, cutlass_modules: List[str], src_dir: str, dst_dir: str
-) -> None:
-    orig_full_path = os.path.abspath(os.path.join(src_dir, file_name))
-    text = ""
-    with open(orig_full_path) as f:
-        text = f.read()
-    text = _rename_cutlass_import(text, cutlass_modules)
-    dst_full_path = os.path.abspath(
-        os.path.join(
-            dst_dir,
-            f"cutlass_{file_name}" if file_name != "__init__.py" else file_name,
-        )
-    )
-    with open(dst_full_path, "w") as f:
-        f.write(text)
-
-
 @functools.lru_cache(None)
 def try_import_cutlass() -> bool:
     # Copy CUTLASS python scripts to a temp dir and add the temp dir to Python search path.
     # This is a temporary hack to avoid CUTLASS module naming conflicts.
     # TODO(ipiszy): remove this hack when CUTLASS solves Python scripts packaging structure issues.
 
-    cutlass_py_full_path = os.path.join(
-        inductor_cuda_config.cutlass_dir, "tools/library/scripts"
+    cutlass_py_full_path = os.path.abspath(
+        os.path.join(inductor_cuda_config.cutlass_dir, "python/cutlass_library")
     )
     tmp_cutlass_py_full_path = os.path.abspath(
-        os.path.join(cache_dir(), "torch_cutlass_script")
+        os.path.join(cache_dir(), "torch_cutlass_library")
     )
+    dst_link = os.path.join(tmp_cutlass_py_full_path, "cutlass_library")
 
     if os.path.isdir(cutlass_py_full_path):
-        cutlass_file_names = [
-            file_name
-            for file_name in os.listdir(cutlass_py_full_path)
-            if file_name.endswith(".py")
-        ]
-        cutlass_module_names = [file_name[:-3] for file_name in cutlass_file_names]
-        if not os.path.isdir(tmp_cutlass_py_full_path):
-            os.mkdir(tmp_cutlass_py_full_path)
-        for file_name in cutlass_file_names:
-            _gen_cutlass_file(
-                file_name,
-                cutlass_module_names,
-                cutlass_py_full_path,
-                tmp_cutlass_py_full_path,
-            )
-        sys.path.append(tmp_cutlass_py_full_path)
+        if tmp_cutlass_py_full_path not in sys.path:
+            if os.path.exists(dst_link):
+                assert os.path.islink(
+                    dst_link
+                ), f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again."
+                assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath(
+                    cutlass_py_full_path
+                ), f"Symlink at {dst_link} does not point to {cutlass_py_full_path}"
+            else:
+                os.makedirs(tmp_cutlass_py_full_path, exist_ok=True)
+                os.symlink(cutlass_py_full_path, dst_link)
+            sys.path.append(tmp_cutlass_py_full_path)
         try:
-            import cutlass_generator  # type: ignore[import]  # noqa: F401
-            import cutlass_library  # type: ignore[import]  # noqa: F401
-            import cutlass_manifest  # type: ignore[import]  # noqa: F401
+            import cutlass_library.generator  # type: ignore[import]  # noqa: F401
+            import cutlass_library.library  # type: ignore[import]  # noqa: F401
+            import cutlass_library.manifest  # type: ignore[import]  # noqa: F401
 
             return True
 
@@ -136,18 +107,14 @@
 
 
 @functools.lru_cache(None)
-def gen_ops() -> List[Any]:
-    """
-    Generates all supported CUTLASS operations.
-    """
+def _gen_ops_cached(arch, version) -> List[Any]:
+    # Note: Cache needs to be specific for cuda architecture and version
 
     # Import cutlass python scripts.
     assert try_import_cutlass()
-    import cutlass_generator  # type: ignore[import]
-    import cutlass_manifest  # type: ignore[import]
+    import cutlass_library.generator as cutlass_generator  # type: ignore[import]
+    import cutlass_library.manifest as cutlass_manifest  # type: ignore[import]
 
-    arch = get_cuda_arch()
-    version = get_cuda_version()
     if arch is None or version is None:
         log.error(
             "Cannot detect cuda arch %s or cuda version %s. "
@@ -172,13 +139,21 @@
             raise NotImplementedError(
                 "Arch " + arch + " is not supported by current cutlass lib."
             ) from e
-
     return manifest.operations
 
 
+def gen_ops() -> List[Any]:
+    """
+    Generates all supported CUTLASS operations.
+    """
+    arch = get_cuda_arch()
+    version = get_cuda_version()
+    return _gen_ops_cached(arch, version)
+
+
 def dtype_match(
     torch_dtype: torch.dtype,
-    cutlass_dtype: "cutlass_library.DataType",  # type: ignore[name-defined]
+    cutlass_dtype: "cutlass_library.library.DataType",  # type: ignore[name-defined]
 ) -> bool:
     # Import cutlass python scripts.
     assert try_import_cutlass()
@@ -186,13 +161,13 @@
 
     if torch_dtype == torch.float:
         return (
-            cutlass_dtype == cutlass_library.DataType.f32
-            or cutlass_dtype == cutlass_library.DataType.tf32
+            cutlass_dtype == cutlass_library.library.DataType.f32
+            or cutlass_dtype == cutlass_library.library.DataType.tf32
         )
     elif torch_dtype == torch.half:
-        return cutlass_dtype == cutlass_library.DataType.f16
+        return cutlass_dtype == cutlass_library.library.DataType.f16
     elif torch_dtype == torch.bfloat16:
-        return cutlass_dtype == cutlass_library.DataType.bf16
+        return cutlass_dtype == cutlass_library.library.DataType.bf16
     else:
         return False
 
diff --git a/torch/_inductor/codegen/cuda/gemm_template.py b/torch/_inductor/codegen/cuda/gemm_template.py
index 367d368..8a51c73 100644
--- a/torch/_inductor/codegen/cuda/gemm_template.py
+++ b/torch/_inductor/codegen/cuda/gemm_template.py
@@ -131,7 +131,6 @@
   };
 """
 
-
 GEMM_ARGS_CUTLASS_3X_EPILOGUE = r"""
     {
       {ElementComputeEpilogue({{alpha}}), ElementComputeEpilogue({{beta}})},  // typename ThreadEpilogueOp::Params thread
@@ -183,7 +182,7 @@
     @staticmethod
     def cutlass_layout(torch_layout) -> "Optional[cutlass_lib.LayoutType]":  # type: ignore[name-defined]
         assert cutlass_utils.try_import_cutlass()
-        import cutlass_library as cutlass_lib  # type: ignore[import]
+        import cutlass_library.library as cutlass_lib  # type: ignore[import]
 
         if torch_layout.stride[-1] == 1:
             return cutlass_lib.LayoutType.RowMajor
@@ -197,7 +196,7 @@
         cutlass_layout: "cutlass_lib.LayoutType",  # type: ignore[name-defined]
     ) -> "cutlass_lib.LayoutType":  # type: ignore[name-defined]
         assert cutlass_utils.try_import_cutlass()
-        import cutlass_library as cutlass_lib  # type: ignore[import]
+        import cutlass_library.library as cutlass_lib  # type: ignore[import]
 
         if cutlass_layout == cutlass_lib.LayoutType.RowMajor:
             return cutlass_lib.LayoutType.ColumnMajor
@@ -220,7 +219,7 @@
     @staticmethod
     def has_tma_epilogue(op) -> bool:
         assert cutlass_utils.try_import_cutlass()
-        import cutlass_library as cutlass_lib  # type: ignore[import]
+        import cutlass_library.library as cutlass_lib  # type: ignore[import]
 
         result = False
         if op.gemm_kind == cutlass_lib.GemmKind.Universal3x:
@@ -233,8 +232,8 @@
         op: "cutlass_gemm_op.GemmOperation",  # type: ignore[name-defined]
     ) -> Tuple[str, str]:
         assert cutlass_utils.try_import_cutlass()
-        import cutlass_gemm_operation as cutlass_gemm_op  # type: ignore[import]
-        import cutlass_library as cutlass_lib  # type: ignore[import]
+        import cutlass_library.gemm_operation as cutlass_gemm_op  # type: ignore[import]
+        import cutlass_library.library as cutlass_lib  # type: ignore[import]
 
         if op.gemm_kind == cutlass_lib.GemmKind.Universal3x:
             emitter = cutlass_gemm_op.EmitGemmUniversal3xInstance()
@@ -291,7 +290,7 @@
         op: "cutlass_gemm_op.GemmOperation",  # type: ignore[name-defined]
     ) -> "cutlass_gemm_op.GemmOperation":  # type: ignore[name-defined]
         assert cutlass_utils.try_import_cutlass()
-        import cutlass_library as cutlass_lib  # type: ignore[import]
+        import cutlass_library.library as cutlass_lib  # type: ignore[import]
 
         # Skip simt kernels
         if (
@@ -306,7 +305,6 @@
             cutlass_lib.GemmKind.Universal3x,
         }:
             return None
-
         # Filter ops by dtypes.
         X = self.input_nodes[0]
         W = self.input_nodes[1]
@@ -372,21 +370,22 @@
 
     def gen_ops(self) -> "List[cutlass_gemm_op.GemmOperation]":  # type: ignore[name-defined]
         assert cutlass_utils.try_import_cutlass()
-        import cutlass_gemm_operation as cutlass_gemm_op  # type: ignore[import]
-        import cutlass_library as cutlass_lib  # type: ignore[import]
+        import cutlass_library.gemm_operation as cutlass_gemm_op  # type: ignore[import]
+        import cutlass_library.library as cutlass_lib  # type: ignore[import]
 
         ops = cutlass_utils.gen_ops()[cutlass_lib.OperationKind.Gemm]
         res: Dict[str, cutlass_gemm_op.GemmOperation] = dict()
         num_3x_ops = 0
         num_2x_ops = 0
-        for op_list in ops.values():
-            for op in op_list:
-                filter_res = self.filter_op(op)
-                if (
-                    filter_res is not None
-                    and res.get(filter_res.configuration_name(), None) is None
-                ):
-                    res[filter_res.configuration_name()] = filter_res
+        for op_dict in ops.values():
+            for op_list in op_dict.values():
+                for op in op_list:
+                    filter_res = self.filter_op(op)
+                    if (
+                        filter_res is not None
+                        and res.get(filter_res.configuration_name(), None) is None
+                    ):
+                        res[filter_res.configuration_name()] = filter_res
         for op in res.values():
             if op.gemm_kind == cutlass_lib.GemmKind.Universal3x:
                 num_3x_ops += 1
@@ -481,7 +480,7 @@
         output_node: Optional[Buffer] = None,
     ) -> str:
         assert cutlass_utils.try_import_cutlass()
-        import cutlass_library as cutlass_lib  # type: ignore[import]
+        import cutlass_library.library as cutlass_lib  # type: ignore[import]
 
         if output_node is not None:
             self.output_node = output_node
@@ -523,6 +522,5 @@
             instance_type=instance_type,
             input_reorder=self.input_reorder,
         )
-
         res = self._template_from_string(GEMM_TEMPLATE).render(**options)
         return res