Add operator div (#68528)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68528
Add operator converter div, torch.floor_div is announce to be deprecated by pytorch, consider remove after full deprecation done by pytorch.
Reviewed By: 842974287
Differential Revision: D32497573
fbshipit-source-id: d06c864077f745c295c33fb25639b7116f85ca20
diff --git a/test/fx2trt/converters/acc_op/test_binary_ops.py b/test/fx2trt/converters/acc_op/test_binary_ops.py
index bc1a958..aeeeaf9 100644
--- a/test/fx2trt/converters/acc_op/test_binary_ops.py
+++ b/test/fx2trt/converters/acc_op/test_binary_ops.py
@@ -13,6 +13,11 @@
((lambda x, y: x - y), acc_ops.sub),
# Avoid dividing by 0.
((lambda x, y: x / (y + 1.0)), acc_ops.div),
+ ((lambda x, y: x // (y + 1.0)), acc_ops.div),
+ ((lambda x, y: torch.div(x, y + 1.0, rounding_mode="trunc")), acc_ops.div),
+ ((lambda x, y: torch.div(x, y + 1.0, rounding_mode="floor")), acc_ops.div),
+ ((lambda x, y: torch.div(x, y + 1.0)), acc_ops.div),
+ ((lambda x, y: torch.floor_divide(x, y + 1.0)), acc_ops.div),
((lambda x, y: x * y), acc_ops.mul),
(torch.pow, acc_ops.pow),
]
@@ -43,7 +48,7 @@
def forward(self, x):
x = self.orig_op(x, self.constant)
- return self.orig_op(x, 1)
+ return self.orig_op(x, -2)
m = TestModule(orig_op)
inputs = [torch.randn(2, 2)]
diff --git a/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py b/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py
index da84d75..a060c56 100644
--- a/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py
+++ b/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py
@@ -28,6 +28,7 @@
add_activation_layer,
extend_attr_to_tuple,
get_positive_dim,
+ trunc_div,
)
@@ -873,9 +874,21 @@
@tensorrt_converter(acc_ops.div)
def acc_ops_div(network, target, args, kwargs, name):
- return add_binary_elementwise_layer(
- network, kwargs["input"], kwargs["other"], trt.ElementWiseOperation.DIV, target, name
- )
+ if kwargs["rounding_mode"] == "trunc":
+ inputs = kwargs["input"]
+ other = kwargs["other"]
+ return trunc_div(inputs, other, network, target, name)
+ elif kwargs["rounding_mode"] == "floor":
+ return add_binary_elementwise_layer(
+ network, kwargs["input"], kwargs["other"], trt.ElementWiseOperation.FLOOR_DIV, target, name
+ )
+ elif kwargs["rounding_mode"] is None:
+ return add_binary_elementwise_layer(
+ network, kwargs["input"], kwargs["other"], trt.ElementWiseOperation.DIV, target, name
+ )
+ else :
+ mode = kwargs["rounding_mode"]
+ raise RuntimeError(f"Div received mode {mode} that is not supported!")
@tensorrt_converter(acc_ops.mul)
diff --git a/torch/fx/experimental/fx2trt/converters/converter_utils.py b/torch/fx/experimental/fx2trt/converters/converter_utils.py
index acdb9e4..0ad417b 100644
--- a/torch/fx/experimental/fx2trt/converters/converter_utils.py
+++ b/torch/fx/experimental/fx2trt/converters/converter_utils.py
@@ -487,3 +487,57 @@
else:
inputs.append(kwargs[key])
return inputs
+
+
+def trunc_div(
+ input: trt.tensorrt.ITensor,
+ other: trt.tensorrt.ITensor,
+ network: trt.INetworkDefinition,
+ target: Target,
+ name: str
+) -> trt.tensorrt.ITensor:
+ """
+ Perform trunc divide on Tensor, result of divide will be round toward zero.
+ This means for positive number, it will be floor round; for negative number,
+ it will be ceil round. Example: [2.1, 0.8, -3.2] -> [2, 0, -3].
+ Args:
+ input: divisor.
+ other: dividend.
+ network: INetworkDefinition.
+ target: node target.
+ name: namespace for the op
+ Returns:
+ A TensorRT tensor represent the result of trunc divide.
+ """
+ prod_output = add_binary_elementwise_layer(network, input, other, trt.ElementWiseOperation.PROD, target, f"{name}_prod")
+
+ # get sign:
+ # x = input * other
+ # sign = (exp(x) // exp(abs(x))) * 2 - 1
+ # For positive number and 0, (exp(x) // exp(abs(x))) yield 1; for negative number, (exp(x) // exp(abs(x))) yield 0.
+ # With multiply 2, the value become 2(for pos and 0) and 0(for neg).
+ # Finally minus 1, the value become 1(for pos and 0) and -1(for neg).
+ prod_exp_output = add_unary_layer(network, prod_output, trt.UnaryOperation.EXP, target, f"{name}_prod_exp")
+ prod_abs_output = add_unary_layer(network, prod_output, trt.UnaryOperation.ABS, target, f"{name}_prod_abs")
+ prod_abs_exp_output = add_unary_layer(network, prod_abs_output, trt.UnaryOperation.EXP, target, f"{name}_prod_abs_exp")
+ floor_div_output = add_binary_elementwise_layer(network, prod_abs_output, prod_abs_exp_output,
+ trt.ElementWiseOperation.FLOOR_DIV, target, f"{name}_exp_floor_div")
+ double_floor_div_output = add_binary_elementwise_layer(network, floor_div_output, 2,
+ trt.ElementWiseOperation.PROD, target, f"{name}_double_floor_div")
+ sign_output = add_binary_elementwise_layer(network, double_floor_div_output, 1,
+ trt.ElementWiseOperation.SUB, target, f"{name}_binary_sign")
+
+ # Convert constant input into ITensor for UnaryOperation
+ if not isinstance(input, trt.tensorrt.ITensor):
+ input = get_trt_tensor(network, input, f"{name}_input")
+ if not isinstance(other, trt.tensorrt.ITensor):
+ other = get_trt_tensor(network, other, f"{name}_other", dtype=torch_dtype_from_trt(input.dtype))
+
+ abs_a_output = add_unary_layer(network, input, trt.UnaryOperation.ABS, target, f"{name}_abs_a")
+ abs_b_output = add_unary_layer(network, other, trt.UnaryOperation.ABS, target, f"{name}_abs_b")
+ abs_floor_output = add_binary_elementwise_layer(network, abs_a_output, abs_b_output,
+ trt.ElementWiseOperation.FLOOR_DIV, target, f"{name}_floor_div")
+ output = add_binary_elementwise_layer(network, abs_floor_output, sign_output,
+ trt.ElementWiseOperation.PROD, target, f"{name}_output")
+
+ return output
diff --git a/torch/fx/experimental/fx_acc/acc_ops.py b/torch/fx/experimental/fx_acc/acc_ops.py
index 151c39d..1466272 100644
--- a/torch/fx/experimental/fx_acc/acc_ops.py
+++ b/torch/fx/experimental/fx_acc/acc_ops.py
@@ -617,11 +617,58 @@
return input * other
-@register_acc_op_properties(AccOpProperty.pointwise)
-@register_acc_op_mapping(op_and_target=("call_function", operator.truediv))
+# Torch.floor_divide is announced to be deprecated, consider using torch.div() with 'trunc' or 'floor'
+# mode instead.
+# This implementation matches torch.floor_div's behavior, which for negative number the divide result
+# is round toward zero, rather than -Inf.
+@register_custom_acc_mapper_fn(
+ op_and_target=("call_function", torch.floor_divide),
+ arg_replacement_tuples=[
+ ("input", "input"),
+ ("other", "other"),
+ ],
+)
+@register_custom_acc_mapper_fn(
+ op_and_target=("call_function", operator.floordiv),
+ arg_replacement_tuples=[
+ ("input", "input"),
+ ("other", "other"),
+ ],
+)
+@register_custom_acc_mapper_fn(
+ op_and_target=("call_function", torch.div),
+ arg_replacement_tuples=[
+ ("input", "input"),
+ ("other", "other"),
+ ("rounding_mode", "rounding_mode", this_arg_is_optional),
+ ],
+)
+@register_custom_acc_mapper_fn(
+ op_and_target=("call_function", operator.truediv),
+ arg_replacement_tuples=[
+ ("input", "input"),
+ ("other", "other"),
+ ],
+)
+def div_mapper(node: torch.fx.Node, mod: torch.fx.GraphModule) -> torch.fx.Node:
+ with node.graph.inserting_before(node):
+ div_kwargs = dict(node.kwargs)
+ if "rounding_mode" not in div_kwargs and node.op == "call_function":
+ div_kwargs["rounding_mode"] = None
+ if node.target is torch.floor_divide:
+ div_kwargs["rounding_mode"] = "trunc"
+ elif node.target is operator.floordiv:
+ div_kwargs["rounding_mode"] = "floor"
+ elif node.target is operator.truediv:
+ div_kwargs["rounding_mode"] = None
+ div_node = node.graph.call_function(div, kwargs=div_kwargs)
+ div_node.meta = node.meta.copy()
+ return div_node
+
+
@register_acc_op
-def div(*, input, other):
- return input / other
+def div(input, other, *, rounding_mode=None):
+ return torch.div(input, other, rounding_mode=rounding_mode)
@register_acc_op_properties(AccOpProperty.pointwise)