[inductor] fix could not find as_strided with config.triton.mm=triton (#88946)

Summary: ReinterpretView doesn't seem to be handled properly with matrix multiply Triton kernels

Reviewed By: bertmaher

Differential Revision: D40836677

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88946
Approved by: https://github.com/jansel
diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py
index a949eff..932e8c9 100644
--- a/torch/_inductor/codegen/common.py
+++ b/torch/_inductor/codegen/common.py
@@ -283,6 +283,8 @@
         assert name not in V.graph.removed_buffers, name
         if name in self.output_buffers:
             return self.output_buffers[name]
+        if name in self.inplace_buffers:
+            return self.inplace_buffers[name].inner_name
         if name.startswith("seed"):
             return self._lookup("seed", self.input_buffers, name)
         return self._lookup("in_ptr", self.input_buffers, name)
@@ -290,6 +292,8 @@
     def output(self, name):
         name = V.graph.scheduler.mutation_real_name.get(name, name)
         assert name not in V.graph.removed_buffers, name
+        if name in self.inplace_buffers:
+            return self.inplace_buffers[name].inner_name
         return self._lookup("out_ptr", self.output_buffers, name)
 
     def make_inplace(self, input_name, output_name):
@@ -392,6 +396,14 @@
                 if other in self.output_buffers:
                     yield self.output_buffers[other], inplaced.inner_name
 
+    def is_removed(self, name):
+        def _is_removed(name, buffers):
+            return name not in buffers or buffers[name] == "REMOVED"
+
+        return _is_removed(name, self.output_buffers) and _is_removed(
+            name, self.inplace_buffers
+        )
+
 
 class CSE:
     """Common subexpression elimination"""
diff --git a/torch/_inductor/codegen/triton_template.py b/torch/_inductor/codegen/triton_template.py
index 0de771f..cd1c2be 100644
--- a/torch/_inductor/codegen/triton_template.py
+++ b/torch/_inductor/codegen/triton_template.py
@@ -330,7 +330,7 @@
     kernel_buf_replace_name = None
     if could_remove_kernel_buf:
         for node in epilogue:
-            if kernel.args.output_buffers[node.get_name()] != "REMOVED":
+            if not kernel.args.is_removed(node.get_name()):
                 kernel_buf_replace_name = node.get_name()
                 break
         assert kernel_buf_replace_name is not None
diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py
index f69a891..e0e41fd 100644
--- a/torch/_inductor/graph.py
+++ b/torch/_inductor/graph.py
@@ -1,6 +1,7 @@
 import logging
 import operator
 import os
+import re
 import time
 
 import sympy
@@ -90,6 +91,9 @@
             return self.name_to_buffer[buffer_name].get_dtype()
         if buffer_name in self.graph_inputs:
             return self.graph_inputs[buffer_name].get_dtype()
+        m = re.match(r"as_strided\(([a-zA-Z0-9_]+),", buffer_name)
+        if m:
+            return self.get_dtype(m.group(1))
         raise KeyError(f"could not find {buffer_name}")
 
     def random_seed_buffer(self, device: torch.device):