Add operator div (#68528)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68528

Add operator converter div, torch.floor_div is announce to be deprecated by pytorch, consider remove after full deprecation done by pytorch.

Reviewed By: 842974287

Differential Revision: D32497573

fbshipit-source-id: d06c864077f745c295c33fb25639b7116f85ca20
diff --git a/test/fx2trt/converters/acc_op/test_binary_ops.py b/test/fx2trt/converters/acc_op/test_binary_ops.py
index bc1a958..aeeeaf9 100644
--- a/test/fx2trt/converters/acc_op/test_binary_ops.py
+++ b/test/fx2trt/converters/acc_op/test_binary_ops.py
@@ -13,6 +13,11 @@
     ((lambda x, y: x - y), acc_ops.sub),
     # Avoid dividing by 0.
     ((lambda x, y: x / (y + 1.0)), acc_ops.div),
+    ((lambda x, y: x // (y + 1.0)), acc_ops.div),
+    ((lambda x, y: torch.div(x, y + 1.0, rounding_mode="trunc")), acc_ops.div),
+    ((lambda x, y: torch.div(x, y + 1.0, rounding_mode="floor")), acc_ops.div),
+    ((lambda x, y: torch.div(x, y + 1.0)), acc_ops.div),
+    ((lambda x, y: torch.floor_divide(x, y + 1.0)), acc_ops.div),
     ((lambda x, y: x * y), acc_ops.mul),
     (torch.pow, acc_ops.pow),
 ]
@@ -43,7 +48,7 @@
 
             def forward(self, x):
                 x = self.orig_op(x, self.constant)
-                return self.orig_op(x, 1)
+                return self.orig_op(x, -2)
 
         m = TestModule(orig_op)
         inputs = [torch.randn(2, 2)]
diff --git a/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py b/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py
index da84d75..a060c56 100644
--- a/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py
+++ b/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py
@@ -28,6 +28,7 @@
     add_activation_layer,
     extend_attr_to_tuple,
     get_positive_dim,
+    trunc_div,
 )
 
 
@@ -873,9 +874,21 @@
 
 @tensorrt_converter(acc_ops.div)
 def acc_ops_div(network, target, args, kwargs, name):
-    return add_binary_elementwise_layer(
-        network, kwargs["input"], kwargs["other"], trt.ElementWiseOperation.DIV, target, name
-    )
+    if kwargs["rounding_mode"] == "trunc":
+        inputs = kwargs["input"]
+        other = kwargs["other"]
+        return trunc_div(inputs, other, network, target, name)
+    elif kwargs["rounding_mode"] == "floor":
+        return add_binary_elementwise_layer(
+            network, kwargs["input"], kwargs["other"], trt.ElementWiseOperation.FLOOR_DIV, target, name
+        )
+    elif kwargs["rounding_mode"] is None:
+        return add_binary_elementwise_layer(
+            network, kwargs["input"], kwargs["other"], trt.ElementWiseOperation.DIV, target, name
+        )
+    else :
+        mode = kwargs["rounding_mode"]
+        raise RuntimeError(f"Div received mode {mode} that is not supported!")
 
 
 @tensorrt_converter(acc_ops.mul)
diff --git a/torch/fx/experimental/fx2trt/converters/converter_utils.py b/torch/fx/experimental/fx2trt/converters/converter_utils.py
index acdb9e4..0ad417b 100644
--- a/torch/fx/experimental/fx2trt/converters/converter_utils.py
+++ b/torch/fx/experimental/fx2trt/converters/converter_utils.py
@@ -487,3 +487,57 @@
         else:
             inputs.append(kwargs[key])
     return inputs
+
+
+def trunc_div(
+    input: trt.tensorrt.ITensor,
+    other: trt.tensorrt.ITensor,
+    network: trt.INetworkDefinition,
+    target: Target,
+    name: str
+) -> trt.tensorrt.ITensor:
+    """
+    Perform trunc divide on Tensor, result of divide will be round toward zero.
+    This means for positive number, it will be floor round; for negative number,
+    it will be ceil round. Example: [2.1, 0.8, -3.2] -> [2, 0, -3].
+    Args:
+        input: divisor.
+        other: dividend.
+        network: INetworkDefinition.
+        target: node target.
+        name: namespace for the op
+    Returns:
+        A TensorRT tensor represent the result of trunc divide.
+    """
+    prod_output = add_binary_elementwise_layer(network, input, other, trt.ElementWiseOperation.PROD, target, f"{name}_prod")
+
+    # get sign:
+    # x = input * other
+    # sign = (exp(x) // exp(abs(x))) * 2 - 1
+    # For positive number and 0, (exp(x) // exp(abs(x))) yield 1; for negative number, (exp(x) // exp(abs(x))) yield 0.
+    # With multiply 2, the value become 2(for pos and 0) and 0(for neg).
+    # Finally minus 1, the value become 1(for pos and 0) and -1(for neg).
+    prod_exp_output = add_unary_layer(network, prod_output, trt.UnaryOperation.EXP, target, f"{name}_prod_exp")
+    prod_abs_output = add_unary_layer(network, prod_output, trt.UnaryOperation.ABS, target, f"{name}_prod_abs")
+    prod_abs_exp_output = add_unary_layer(network, prod_abs_output, trt.UnaryOperation.EXP, target, f"{name}_prod_abs_exp")
+    floor_div_output = add_binary_elementwise_layer(network, prod_abs_output, prod_abs_exp_output,
+                                                    trt.ElementWiseOperation.FLOOR_DIV, target, f"{name}_exp_floor_div")
+    double_floor_div_output = add_binary_elementwise_layer(network, floor_div_output, 2,
+                                                           trt.ElementWiseOperation.PROD, target, f"{name}_double_floor_div")
+    sign_output = add_binary_elementwise_layer(network, double_floor_div_output, 1,
+                                               trt.ElementWiseOperation.SUB, target, f"{name}_binary_sign")
+
+    # Convert constant input into ITensor for UnaryOperation
+    if not isinstance(input, trt.tensorrt.ITensor):
+        input = get_trt_tensor(network, input, f"{name}_input")
+    if not isinstance(other, trt.tensorrt.ITensor):
+        other = get_trt_tensor(network, other, f"{name}_other", dtype=torch_dtype_from_trt(input.dtype))
+
+    abs_a_output = add_unary_layer(network, input, trt.UnaryOperation.ABS, target, f"{name}_abs_a")
+    abs_b_output = add_unary_layer(network, other, trt.UnaryOperation.ABS, target, f"{name}_abs_b")
+    abs_floor_output = add_binary_elementwise_layer(network, abs_a_output, abs_b_output,
+                                                    trt.ElementWiseOperation.FLOOR_DIV, target, f"{name}_floor_div")
+    output = add_binary_elementwise_layer(network, abs_floor_output, sign_output,
+                                          trt.ElementWiseOperation.PROD, target, f"{name}_output")
+
+    return output
diff --git a/torch/fx/experimental/fx_acc/acc_ops.py b/torch/fx/experimental/fx_acc/acc_ops.py
index 151c39d..1466272 100644
--- a/torch/fx/experimental/fx_acc/acc_ops.py
+++ b/torch/fx/experimental/fx_acc/acc_ops.py
@@ -617,11 +617,58 @@
     return input * other
 
 
-@register_acc_op_properties(AccOpProperty.pointwise)
-@register_acc_op_mapping(op_and_target=("call_function", operator.truediv))
+# Torch.floor_divide is announced to be deprecated, consider using torch.div() with 'trunc' or 'floor'
+# mode instead.
+# This implementation matches torch.floor_div's behavior, which for negative number the divide result
+# is round toward zero, rather than -Inf.
+@register_custom_acc_mapper_fn(
+    op_and_target=("call_function", torch.floor_divide),
+    arg_replacement_tuples=[
+        ("input", "input"),
+        ("other", "other"),
+    ],
+)
+@register_custom_acc_mapper_fn(
+    op_and_target=("call_function", operator.floordiv),
+    arg_replacement_tuples=[
+        ("input", "input"),
+        ("other", "other"),
+    ],
+)
+@register_custom_acc_mapper_fn(
+    op_and_target=("call_function", torch.div),
+    arg_replacement_tuples=[
+        ("input", "input"),
+        ("other", "other"),
+        ("rounding_mode", "rounding_mode", this_arg_is_optional),
+    ],
+)
+@register_custom_acc_mapper_fn(
+    op_and_target=("call_function", operator.truediv),
+    arg_replacement_tuples=[
+        ("input", "input"),
+        ("other", "other"),
+    ],
+)
+def div_mapper(node: torch.fx.Node, mod: torch.fx.GraphModule) -> torch.fx.Node:
+    with node.graph.inserting_before(node):
+        div_kwargs = dict(node.kwargs)
+        if "rounding_mode" not in div_kwargs and node.op == "call_function":
+            div_kwargs["rounding_mode"] = None
+            if node.target is torch.floor_divide:
+                div_kwargs["rounding_mode"] = "trunc"
+            elif node.target is operator.floordiv:
+                div_kwargs["rounding_mode"] = "floor"
+            elif node.target is operator.truediv:
+                div_kwargs["rounding_mode"] = None
+        div_node = node.graph.call_function(div, kwargs=div_kwargs)
+        div_node.meta = node.meta.copy()
+        return div_node
+
+
 @register_acc_op
-def div(*, input, other):
-    return input / other
+def div(input, other, *, rounding_mode=None):
+    return torch.div(input, other, rounding_mode=rounding_mode)
 
 
 @register_acc_op_properties(AccOpProperty.pointwise)