[TorchArrow][efficiency][3/n] variadic versions of op fused /unfused inference_wrapper_run_flat (#81133)
Summary:
Added `variadic` version (just an optimization) of the registered fused and unfused ops.
Reviewed By: tenpercent
Differential Revision: D37456033
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81133
Approved by: https://github.com/tenpercent, https://github.com/qxy11
diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp
index 602142f..38ce4de 100644
--- a/torch/csrc/jit/runtime/static/impl.cpp
+++ b/torch/csrc/jit/runtime/static/impl.cpp
@@ -167,6 +167,10 @@
graph,
fromQualString("fb::sigrid_transforms_torch_bind"),
fromQualString("fb::variadic_sigrid_transforms_torch_bind"));
+ UseVariadicOp(
+ graph,
+ fromQualString("torcharrow::inference_wrapper_run_flat"),
+ fromQualString("torcharrow::variadic_inference_wrapper_run_flat"));
// These fused ops only have out variants - we can't do the fusion when
// out variants are disabled.
FuseSignLog1P(graph);
diff --git a/torch/csrc/jit/runtime/static/passes.cpp b/torch/csrc/jit/runtime/static/passes.cpp
index 1d818b2..05e3760 100644
--- a/torch/csrc/jit/runtime/static/passes.cpp
+++ b/torch/csrc/jit/runtime/static/passes.cpp
@@ -872,6 +872,9 @@
OP_PAIR(
"torcharrow::inference_wrapper_run_flat",
"static_runtime::fused_inference_wrapper_run_flat"),
+ OP_PAIR(
+ "torcharrow::variadic_inference_wrapper_run_flat",
+ "static_runtime::fused_variadic_inference_wrapper_run_flat"),
OP_PAIR("fb::equally_split", "static_runtime::fused_equally_split"),
OP_PAIR(
"fb::sigrid_transforms", "static_runtime::fused_sigrid_transforms"),