[inductor] fix could not find as_strided with config.triton.mm=triton (#88946)
Summary: ReinterpretView doesn't seem to be handled properly with matrix multiply Triton kernels
Reviewed By: bertmaher
Differential Revision: D40836677
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88946
Approved by: https://github.com/jansel
diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py
index a949eff..932e8c9 100644
--- a/torch/_inductor/codegen/common.py
+++ b/torch/_inductor/codegen/common.py
@@ -283,6 +283,8 @@
assert name not in V.graph.removed_buffers, name
if name in self.output_buffers:
return self.output_buffers[name]
+ if name in self.inplace_buffers:
+ return self.inplace_buffers[name].inner_name
if name.startswith("seed"):
return self._lookup("seed", self.input_buffers, name)
return self._lookup("in_ptr", self.input_buffers, name)
@@ -290,6 +292,8 @@
def output(self, name):
name = V.graph.scheduler.mutation_real_name.get(name, name)
assert name not in V.graph.removed_buffers, name
+ if name in self.inplace_buffers:
+ return self.inplace_buffers[name].inner_name
return self._lookup("out_ptr", self.output_buffers, name)
def make_inplace(self, input_name, output_name):
@@ -392,6 +396,14 @@
if other in self.output_buffers:
yield self.output_buffers[other], inplaced.inner_name
+ def is_removed(self, name):
+ def _is_removed(name, buffers):
+ return name not in buffers or buffers[name] == "REMOVED"
+
+ return _is_removed(name, self.output_buffers) and _is_removed(
+ name, self.inplace_buffers
+ )
+
class CSE:
"""Common subexpression elimination"""
diff --git a/torch/_inductor/codegen/triton_template.py b/torch/_inductor/codegen/triton_template.py
index 0de771f..cd1c2be 100644
--- a/torch/_inductor/codegen/triton_template.py
+++ b/torch/_inductor/codegen/triton_template.py
@@ -330,7 +330,7 @@
kernel_buf_replace_name = None
if could_remove_kernel_buf:
for node in epilogue:
- if kernel.args.output_buffers[node.get_name()] != "REMOVED":
+ if not kernel.args.is_removed(node.get_name()):
kernel_buf_replace_name = node.get_name()
break
assert kernel_buf_replace_name is not None
diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py
index f69a891..e0e41fd 100644
--- a/torch/_inductor/graph.py
+++ b/torch/_inductor/graph.py
@@ -1,6 +1,7 @@
import logging
import operator
import os
+import re
import time
import sympy
@@ -90,6 +91,9 @@
return self.name_to_buffer[buffer_name].get_dtype()
if buffer_name in self.graph_inputs:
return self.graph_inputs[buffer_name].get_dtype()
+ m = re.match(r"as_strided\(([a-zA-Z0-9_]+),", buffer_name)
+ if m:
+ return self.get_dtype(m.group(1))
raise KeyError(f"could not find {buffer_name}")
def random_seed_buffer(self, device: torch.device):