[Indcutor][fx pass] Add sub and div pointwise ops to the post grad fusion (#115389)
Summary: Titled
Test Plan:
# unit test
```
buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/inductor:group_batch_fusion
```
Buck UI: https://www.internalfb.com/buck2/792c58db-c369-487d-9a42-b5da471657c0
Test UI: https://www.internalfb.com/intern/testinfra/testrun/2814749981661407
Network: Up: 74KiB Down: 29KiB (reSessionID-b47c266b-12d6-4e88-8dc3-4af1dd7ecbb4)
Jobs completed: 20. Time elapsed: 2:09.6s.
Cache hits: 0%. Commands: 2 (cached: 0, remote: 0, local: 2)
Tests finished: Pass 7. Fail 0. Fatal 0. Skip 0. Build failure 0
# local reproduce
OC: P899142918
MAI: P899175452
# e2e (oc)
Differential Revision: D51957242
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115389
Approved by: https://github.com/dshi7, https://github.com/jackiexu1992, https://github.com/xuzhao9
diff --git a/test/inductor/test_group_batch_fusion.py b/test/inductor/test_group_batch_fusion.py
index 9880a5c..db58761 100644
--- a/test/inductor/test_group_batch_fusion.py
+++ b/test/inductor/test_group_batch_fusion.py
@@ -205,15 +205,19 @@
def forward(self, x):
inputs = torch.split(x.to(self.device), 500, dim=1)
- x_split = torch.split(inputs[0].to(self.device), 100, dim=1)
- y_split = torch.split(inputs[1].to(self.device), 100, dim=1)
+ x_split = torch.split(inputs[0].to(self.device), 50, dim=1)
+ y_split = torch.split(inputs[1].to(self.device), 50, dim=1)
tanh_1 = [torch.tanh(x_split[i]) for i in range(len(x_split))]
tanh_2 = [torch.tanh(y_split[i]) for i in range(len(y_split))]
sigmoid_1 = [torch.sigmoid(tanh_1[i]) for i in range(len(tanh_1))]
sigmoid_2 = [torch.sigmoid(tanh_2[i]) for i in range(len(tanh_2))]
relu_1 = [torch.nn.functional.relu(sigmoid_1[i]) for i in range(len(sigmoid_1))]
relu_2 = [torch.nn.functional.relu(sigmoid_2[i]) for i in range(len(sigmoid_2))]
- return torch.cat(relu_1, dim=1) + torch.cat(relu_2, dim=1)
+ add = [torch.add(relu_1[i], relu_2[i]) for i in range(len(relu_1))]
+ mul = [torch.mul(add[i], add[i]) for i in range(len(add))]
+ sub = [torch.sub(mul[i], mul[i]) for i in range(len(mul))]
+ div = [torch.div(sub[i], sub[i]) for i in range(len(sub))]
+ return torch.cat(div, dim=1)
@requires_cuda()
@@ -226,7 +230,13 @@
"batch_relu": {},
"batch_sigmoid": {},
},
- post_grad_fusion_options={"group_linear": {"require_fbgemm": True}},
+ post_grad_fusion_options={
+ "batch_aten_add": {},
+ "batch_aten_mul": {},
+ "batch_aten_sub": {},
+ "batch_aten_div": {},
+ "group_linear": {"require_fbgemm": True},
+ },
)
class TestGroupBatchFusion(TestCase):
def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3):
@@ -285,7 +295,7 @@
)
self.assertEqual(
counters["inductor"]["batch_fusion"],
- 0,
+ 3,
)
counters.clear()
@@ -316,7 +326,7 @@
)
self.assertEqual(
counters["inductor"]["batch_fusion"],
- 0,
+ 1,
)
counters.clear()
@@ -399,7 +409,7 @@
self.compare_gradients(module, traced, rtol=1e-8, atol=1e-8)
counters.clear()
- def test_pointwise_op_pre_grad_fusion(self):
+ def test_pointwise_op_fusion(self):
counters.clear()
module = TestPoitwiseOps("cuda")
input = [torch.randn(50, 1000, requires_grad=True, device="cuda")]
@@ -407,7 +417,7 @@
ref = module(*input)
res = traced(*input)
self.compare_pred(module, traced, input)
- self.assertEqual(counters["inductor"]["batch_fusion"], 3)
+ self.assertEqual(counters["inductor"]["batch_fusion"], 7)
self.assertEqual(
counters["inductor"]["scmerge_split_removed"],
0,
diff --git a/torch/_inductor/fx_passes/group_batch_fusion.py b/torch/_inductor/fx_passes/group_batch_fusion.py
index 9a83d31..683e1ac 100644
--- a/torch/_inductor/fx_passes/group_batch_fusion.py
+++ b/torch/_inductor/fx_passes/group_batch_fusion.py
@@ -307,6 +307,7 @@
node
) and self._pointwise_node_can_be_fused(node):
alpha = node.kwargs.get("alpha", 1.0)
+ rounding_mode = node.kwargs.get("rounding_mode", None)
input, other = node.args
shape = list(input.meta["tensor_meta"].shape)
group_key = (
@@ -315,6 +316,7 @@
str(input.meta["tensor_meta"].dtype),
str(other.meta["tensor_meta"].dtype),
str(alpha),
+ str(rounding_mode),
)
else:
group_key = None
@@ -716,6 +718,18 @@
super().__init__(aten.add.Tensor, **kwargs)
+@register_fusion("batch_aten_sub", pre_grad=False)
+class BatchSubPostGradFusion(BatchPointwiseOpsPostGradFusion):
+ def __init__(self, **kwargs):
+ super().__init__(aten.sub.Tensor, **kwargs)
+
+
+@register_fusion("batch_aten_div", pre_grad=False)
+class BatchDivPostGradFusion(BatchPointwiseOpsPostGradFusion):
+ def __init__(self, **kwargs):
+ super().__init__(aten.div.Tensor, **kwargs)
+
+
@register_fusion("batch_aten_mul", pre_grad=False)
class BatchMulPostGradFusion(BatchPointwiseOpsPostGradFusion):
def __init__(self, **kwargs):