Fixed minor issues for bmm/mm decompositon (#109836)
Summary:
* Fixed minor issues for bmm/mm decompositon
* enabled addmm for inductor
Test Plan: ci
Reviewed By: mikekgfb
Differential Revision: D49522332
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109836
Approved by: https://github.com/jansel, https://github.com/mikekgfb
diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py
index 4d36e5d..73bbf75 100644
--- a/torch/_inductor/decomposition.py
+++ b/torch/_inductor/decomposition.py
@@ -179,8 +179,9 @@
@register_decomposition([aten.bmm])
+@pw_cast_for_opmath
def bmm(self, batch2):
- if self.device == "cpu":
+ if self.device.type == "cpu":
if self.size(1) == 1 and batch2.size(-1) == 1:
return torch.sum(
self.squeeze(1) * batch2.squeeze(-1), dim=1, keepdim=True
@@ -201,15 +202,17 @@
@register_decomposition([aten.mm])
+@pw_cast_for_opmath
def mm(self, input2):
# Our matrix vector multiplies only achieve peak bandwidth with coordinate descent tuning.
# todo: Look into why and fix it (hopefully)
if config.coordinate_descent_tuning:
if self.shape[0] == 1 or input2.shape[1] == 1:
return (self.unsqueeze(2) * input2.unsqueeze(0)).sum(dim=1)
- if self.device == "cpu":
+ if self.device.type == "cpu":
if (
self.size(-1) == 1
+ and self.size(0) > 0
and input2.size(0) == 1
and (self.dtype == input2.dtype)
and ((torch.numel(self) + torch.numel(input2)) <= 32)