Add nan_to_num plugin (#72144)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72144
As title
Test Plan: buck test mode/opt //caffe2/torch/fb/fx2trt:test_nan2num
Reviewed By: yinghai
Differential Revision: D33851283
fbshipit-source-id: 582029c57d2189d727e55d1cbd099dc02b787b75
(cherry picked from commit ee940681761d7f42f21e1e35914b1e0825eee8e9)
diff --git a/test/fx_acc/test_acc_tracer.py b/test/fx_acc/test_acc_tracer.py
index 27a4ed1..b9c5d01 100644
--- a/test/fx_acc/test_acc_tracer.py
+++ b/test/fx_acc/test_acc_tracer.py
@@ -2095,5 +2095,6 @@
acc_ops.chunk,
acc_ops.rescale_quantize_per_tensor,
acc_ops.rescale_quantize_per_channel,
+ acc_ops.nan_to_num,
},
)
diff --git a/torch/fx/experimental/fx_acc/acc_ops.py b/torch/fx/experimental/fx_acc/acc_ops.py
index 858877d..2dbf3b5 100644
--- a/torch/fx/experimental/fx_acc/acc_ops.py
+++ b/torch/fx/experimental/fx_acc/acc_ops.py
@@ -1511,6 +1511,13 @@
return input[idx]
+@register_acc_op_mapping(op_and_target=("call_function", torch.nan_to_num))
+@register_acc_op_mapping(op_and_target=("call_method", "nan_to_num"))
+@register_acc_op
+def nan_to_num(*, input, nan=0.0, posinf=None, neginf=None):
+ return torch.nan_to_num(input, nan=nan, posinf=posinf, neginf=neginf)
+
+
@register_acc_op_properties(AccOpProperty.unary)
@register_acc_op
def slice_tensor(*, input, dim, start, stop, step):