[Indcutor][fx pass] Add sub and div pointwise ops to the post grad fusion (#115389)

Summary: Titled

Test Plan:
# unit test
```
buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/inductor:group_batch_fusion
```
Buck UI: https://www.internalfb.com/buck2/792c58db-c369-487d-9a42-b5da471657c0
Test UI: https://www.internalfb.com/intern/testinfra/testrun/2814749981661407
Network: Up: 74KiB  Down: 29KiB  (reSessionID-b47c266b-12d6-4e88-8dc3-4af1dd7ecbb4)
Jobs completed: 20. Time elapsed: 2:09.6s.
Cache hits: 0%. Commands: 2 (cached: 0, remote: 0, local: 2)
Tests finished: Pass 7. Fail 0. Fatal 0. Skip 0. Build failure 0

# local reproduce
OC: P899142918
MAI: P899175452
# e2e (oc)

Differential Revision: D51957242

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115389
Approved by: https://github.com/dshi7, https://github.com/jackiexu1992, https://github.com/xuzhao9
diff --git a/test/inductor/test_group_batch_fusion.py b/test/inductor/test_group_batch_fusion.py
index 9880a5c..db58761 100644
--- a/test/inductor/test_group_batch_fusion.py
+++ b/test/inductor/test_group_batch_fusion.py
@@ -205,15 +205,19 @@
 
     def forward(self, x):
         inputs = torch.split(x.to(self.device), 500, dim=1)
-        x_split = torch.split(inputs[0].to(self.device), 100, dim=1)
-        y_split = torch.split(inputs[1].to(self.device), 100, dim=1)
+        x_split = torch.split(inputs[0].to(self.device), 50, dim=1)
+        y_split = torch.split(inputs[1].to(self.device), 50, dim=1)
         tanh_1 = [torch.tanh(x_split[i]) for i in range(len(x_split))]
         tanh_2 = [torch.tanh(y_split[i]) for i in range(len(y_split))]
         sigmoid_1 = [torch.sigmoid(tanh_1[i]) for i in range(len(tanh_1))]
         sigmoid_2 = [torch.sigmoid(tanh_2[i]) for i in range(len(tanh_2))]
         relu_1 = [torch.nn.functional.relu(sigmoid_1[i]) for i in range(len(sigmoid_1))]
         relu_2 = [torch.nn.functional.relu(sigmoid_2[i]) for i in range(len(sigmoid_2))]
-        return torch.cat(relu_1, dim=1) + torch.cat(relu_2, dim=1)
+        add = [torch.add(relu_1[i], relu_2[i]) for i in range(len(relu_1))]
+        mul = [torch.mul(add[i], add[i]) for i in range(len(add))]
+        sub = [torch.sub(mul[i], mul[i]) for i in range(len(mul))]
+        div = [torch.div(sub[i], sub[i]) for i in range(len(sub))]
+        return torch.cat(div, dim=1)
 
 
 @requires_cuda()
@@ -226,7 +230,13 @@
         "batch_relu": {},
         "batch_sigmoid": {},
     },
-    post_grad_fusion_options={"group_linear": {"require_fbgemm": True}},
+    post_grad_fusion_options={
+        "batch_aten_add": {},
+        "batch_aten_mul": {},
+        "batch_aten_sub": {},
+        "batch_aten_div": {},
+        "group_linear": {"require_fbgemm": True},
+    },
 )
 class TestGroupBatchFusion(TestCase):
     def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3):
@@ -285,7 +295,7 @@
             )
             self.assertEqual(
                 counters["inductor"]["batch_fusion"],
-                0,
+                3,
             )
             counters.clear()
 
@@ -316,7 +326,7 @@
         )
         self.assertEqual(
             counters["inductor"]["batch_fusion"],
-            0,
+            1,
         )
         counters.clear()
 
@@ -399,7 +409,7 @@
             self.compare_gradients(module, traced, rtol=1e-8, atol=1e-8)
             counters.clear()
 
-    def test_pointwise_op_pre_grad_fusion(self):
+    def test_pointwise_op_fusion(self):
         counters.clear()
         module = TestPoitwiseOps("cuda")
         input = [torch.randn(50, 1000, requires_grad=True, device="cuda")]
@@ -407,7 +417,7 @@
         ref = module(*input)
         res = traced(*input)
         self.compare_pred(module, traced, input)
-        self.assertEqual(counters["inductor"]["batch_fusion"], 3)
+        self.assertEqual(counters["inductor"]["batch_fusion"], 7)
         self.assertEqual(
             counters["inductor"]["scmerge_split_removed"],
             0,
diff --git a/torch/_inductor/fx_passes/group_batch_fusion.py b/torch/_inductor/fx_passes/group_batch_fusion.py
index 9a83d31..683e1ac 100644
--- a/torch/_inductor/fx_passes/group_batch_fusion.py
+++ b/torch/_inductor/fx_passes/group_batch_fusion.py
@@ -307,6 +307,7 @@
             node
         ) and self._pointwise_node_can_be_fused(node):
             alpha = node.kwargs.get("alpha", 1.0)
+            rounding_mode = node.kwargs.get("rounding_mode", None)
             input, other = node.args
             shape = list(input.meta["tensor_meta"].shape)
             group_key = (
@@ -315,6 +316,7 @@
                 str(input.meta["tensor_meta"].dtype),
                 str(other.meta["tensor_meta"].dtype),
                 str(alpha),
+                str(rounding_mode),
             )
         else:
             group_key = None
@@ -716,6 +718,18 @@
         super().__init__(aten.add.Tensor, **kwargs)
 
 
+@register_fusion("batch_aten_sub", pre_grad=False)
+class BatchSubPostGradFusion(BatchPointwiseOpsPostGradFusion):
+    def __init__(self, **kwargs):
+        super().__init__(aten.sub.Tensor, **kwargs)
+
+
+@register_fusion("batch_aten_div", pre_grad=False)
+class BatchDivPostGradFusion(BatchPointwiseOpsPostGradFusion):
+    def __init__(self, **kwargs):
+        super().__init__(aten.div.Tensor, **kwargs)
+
+
 @register_fusion("batch_aten_mul", pre_grad=False)
 class BatchMulPostGradFusion(BatchPointwiseOpsPostGradFusion):
     def __init__(self, **kwargs):