Update dper3 to use torch.nan_to_num and nan_to_num_ (#46873)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46873
OSS:
Add op benchmark for torch.nan_to_num and torch.nan_to_num_
Test Plan:
OSS:
`buck run mode/opt caffe2/benchmarks/operator_benchmark/pt:nan_to_num_test`
Reviewed By: qizzzh, houseroad
Differential Revision: D24521835
fbshipit-source-id: 1fd50a99e5329ffec2d470525ce6976d39424958
diff --git a/benchmarks/operator_benchmark/pt/nan_to_num_test.py b/benchmarks/operator_benchmark/pt/nan_to_num_test.py
new file mode 100644
index 0000000..72f5daf
--- /dev/null
+++ b/benchmarks/operator_benchmark/pt/nan_to_num_test.py
@@ -0,0 +1,59 @@
+import operator_benchmark as op_bench
+import torch
+import math
+
+
+"""Microbenchmarks for torch.nan_to_num / nan_to_num_ operators"""
+
+# Configs for PT torch.nan_to_num / nan_to_num_ operators
+nan_to_num_long_configs = op_bench.cross_product_configs(
+ M=[32, 64, 128],
+ N=range(32, 128, 32),
+ dtype=[torch.float, torch.double],
+ op=["nan_to_num", "nan_to_num_"],
+ replace_inf=[True, False],
+ tags=["long"],
+)
+
+
+nan_to_num_short_configs = op_bench.cross_product_configs(
+ M=[16, 64],
+ N=[64, 64],
+ dtype=[torch.float, torch.double],
+ op=["nan_to_num", "nan_to_num_"],
+ replace_inf=[True, False],
+ tags=["short"],
+)
+
+
+class ReplaceNaNBenchmark(op_bench.TorchBenchmarkBase):
+ def init(self, M, N, dtype, op, replace_inf):
+ self.input = torch.randn(M, N, dtype=dtype)
+ self.input[0][0] = float("nan")
+ self.op = op
+ self.replace_inf = replace_inf
+ self.set_module_name("nan_to_num")
+
+ def forward(self):
+ # compare inplace
+ if self.op == "nan_to_num":
+ if self.replace_inf:
+ output = torch.nan_to_num(self.input, nan=1.0)
+ else:
+ output = torch.nan_to_num(self.input, nan=1.0, posinf=math.inf, neginf=-math.inf)
+ else:
+ if self.replace_inf:
+ output = torch.nan_to_num_(self.input, nan=1.0)
+ else:
+ output = torch.nan_to_num_(self.input, nan=1.0, posinf=math.inf, neginf=-math.inf)
+ return output
+
+
+op_bench.generate_pt_test(
+ nan_to_num_long_configs + nan_to_num_short_configs,
+ ReplaceNaNBenchmark,
+)
+
+
+if __name__ == "__main__":
+ op_bench.benchmark_runner.main()