inductor: add fallback test case for hardtanh and leakyrelu fusion pattern (#99859)
Fix https://github.com/pytorch/pytorch/issues/99841.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99859
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py
index 0313549..e6f2d4f 100644
--- a/test/inductor/test_mkldnn_pattern_matcher.py
+++ b/test/inductor/test_mkldnn_pattern_matcher.py
@@ -287,6 +287,65 @@
)
counters.clear()
+ # https://github.com/pytorch/pytorch/issues/99841.
+ def test_hardtanh_pattern_fallback(self):
+ class Model(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.conv_transpose = torch.nn.ConvTranspose2d(
+ in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1
+ )
+
+ def forward(self, x, min_value, max_value):
+ conv_transpose_output = self.conv_transpose(x)
+ clamp_min_output = torch.clamp_min(conv_transpose_output, min_value)
+ clamp_max_output = torch.clamp_max(clamp_min_output, max_value)
+ return clamp_max_output
+
+ # check works for min_value > max_value.
+ min_values = [3, torch.randn(1, 32, 28, 28)]
+ max_values = [0, torch.randn(1, 32, 28, 28)]
+ with torch.no_grad():
+ mod = Model().eval()
+ v = torch.randn(1, 3, 28, 28)
+ for min_value, max_value in zip(min_values, max_values):
+ expected = mod(v, min_value, max_value)
+ actual = torch.compile(mod)(v, min_value, max_value)
+ torch.testing.assert_close(actual, expected)
+ self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
+ self.assertEqual(
+ counters["inductor"]["pattern_matcher_nodes"],
+ 3,
+ )
+ counters.clear()
+
+ def test_leaky_relu_pattern_fallback(self):
+ class Model(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = torch.nn.Conv2d(
+ in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1
+ )
+
+ def forward(self, x, negative_slope):
+ conv_out = self.conv(x)
+ return torch.where(conv_out > 0, conv_out, conv_out * negative_slope)
+
+ negative_slopes = [0.1, torch.randn(1, 32, 28, 28)]
+ with torch.no_grad():
+ mod = Model().eval()
+ v = torch.randn(1, 3, 28, 28)
+ for negative_slope in negative_slopes:
+ expected = mod(v, negative_slope)
+ actual = torch.compile(mod)(v, negative_slope)
+ torch.testing.assert_close(actual, expected)
+ self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
+ self.assertEqual(
+ counters["inductor"]["pattern_matcher_nodes"],
+ 4,
+ )
+ counters.clear()
+
if __name__ == "__main__":
if IS_LINUX and HAS_CPU and torch._C.has_mkldnn:
diff --git a/torch/_inductor/fx_passes/mkldnn_fusion.py b/torch/_inductor/fx_passes/mkldnn_fusion.py
index 496c895..55d6d80 100644
--- a/torch/_inductor/fx_passes/mkldnn_fusion.py
+++ b/torch/_inductor/fx_passes/mkldnn_fusion.py
@@ -238,7 +238,7 @@
):
matched = False
else: # inp is a Number
- matched = True
+ matched = min_value <= max_value
computation_args = list(args)
if matched:
computation_args = computation_args[:-3] + [