[SR] Fuse clamp/nan_to_num

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77094

Fuse `clamp` and `nan_to_num` in an NNC kernel. This leads to a big speed up on many models. We can avoid comparisons since clamp potentially gets rid of all of the `inf`s in the input tensor.

Differential Revision: [D36220967](https://our.internmc.facebook.com/intern/diff/D36220967/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D36220967/)!

Approved by: https://github.com/navahgar
diff --git a/benchmarks/static_runtime/test_static_module.cc b/benchmarks/static_runtime/test_static_module.cc
index 1871a52..41758ec 100644
--- a/benchmarks/static_runtime/test_static_module.cc
+++ b/benchmarks/static_runtime/test_static_module.cc
@@ -1698,3 +1698,69 @@
   EliminateNoOpSlice(graph);
   EXPECT_FALSE(hasNodeWithKind(graph, "aten::slice"));
 }
+
+#ifdef FBCODE_CAFFE2
+// FuseClampNaNToNum pass is disabled externally to avoid MSVC errors in CI
+TEST(FuseClampNaNToNum, FusionHappens) {
+  const auto src = R"JIT(
+    def forward(self, x):
+        y = torch.clamp(x, min=0.0, max=1.0)
+        z = y.nan_to_num()
+        return z.clone()
+  )JIT";
+  torch::jit::Module mod("m");
+  mod.define(src);
+  auto graph = mod.get_method("forward").graph();
+  FuseClampNaNToNum(graph);
+  EXPECT_FALSE(hasNodeWithKind(graph, "aten::clamp"));
+  EXPECT_FALSE(hasNodeWithKind(graph, "aten::nan_to_num"));
+  EXPECT_TRUE(hasNodeWithKind(graph, "static_runtime::clamp_nan_to_num"));
+  // Correctness of the op is exercised in StaticRuntime.clamp_nan_to_num
+}
+
+TEST(FuseClampNaNToNum, NoFusion) {
+  const auto src1 = R"JIT(
+    def forward(self, x, a: float, b: float):
+        y = torch.clamp(x, a, b)
+        z = y.nan_to_num()
+        return z.clone()
+  )JIT";
+
+  const auto src2 = R"JIT(
+    def forward(self, x):
+        y = torch.clamp(x, min=0.0)
+        z = y.nan_to_num()
+        return z.clone()
+  )JIT";
+
+  const auto src3 = R"JIT(
+    def forward(self, x):
+        y = torch.clamp(x, max=0.0)
+        z = y.nan_to_num()
+        return z.clone()
+  )JIT";
+
+  const auto src4 = R"JIT(
+    def forward(self, x):
+        y = torch.clamp(x)
+        z = y.nan_to_num()
+        return z.clone()
+  )JIT";
+
+
+  auto checkScript = [](const char* src) {
+    torch::jit::Module mod("m");
+    mod.define(src);
+    auto graph = mod.get_method("forward").graph();
+    FuseClampNaNToNum(graph);
+    EXPECT_TRUE(hasNodeWithKind(graph, "aten::clamp"));
+    EXPECT_TRUE(hasNodeWithKind(graph, "aten::nan_to_num"));
+    EXPECT_FALSE(hasNodeWithKind(graph, "static_runtime::clamp_nan_to_num"));
+  };
+
+  checkScript(src1);
+  checkScript(src2);
+  checkScript(src3);
+  checkScript(src4);
+}
+#endif
diff --git a/benchmarks/static_runtime/test_static_runtime.cc b/benchmarks/static_runtime/test_static_runtime.cc
index 9a2d801..f6d4b0e 100644
--- a/benchmarks/static_runtime/test_static_runtime.cc
+++ b/benchmarks/static_runtime/test_static_runtime.cc
@@ -3336,3 +3336,44 @@
   EXPECT_FALSE(hasNodeWithKind(graph, "quantized::linear_dynamic_fp16"));
   EXPECT_TRUE(hasNodeWithKind(graph, "quantized::linear_relu_dynamic_fp16"));
 }
+
+TEST(StaticRuntime, ClampNaNToNum) {
+  const auto src1 = R"JIT(
+    def forward(self, a):
+        return torch.clamp(a, min=1.0, max=2.0).nan_to_num().clone()
+  )JIT";
+
+  const auto src2 = R"JIT(
+    def forward(self, a, nan: float):
+        return torch.clamp(a, min=-1.0, max=2.0).nan_to_num(nan=nan).clone()
+  )JIT";
+
+  const auto src3 = R"JIT(
+    def forward(self, a):
+        return torch.clamp(a, min=1.0, max=-1.0).nan_to_num().clone()
+  )JIT";
+
+  auto a = at::tensor({
+      std::numeric_limits<float>::quiet_NaN(),
+      std::numeric_limits<float>::infinity(),
+      -std::numeric_limits<float>::infinity(),
+      0.0f,
+      3.0f
+    });
+  auto b = a.repeat({10, 5});
+
+  // Have to use_allclose even though all NaNs will be replaced - testStaticRuntime
+  // also checks inputs at the end to make sure they're not changed
+  testStaticRuntime(src1, {a}, {}, /*use_allclose=*/true, /*use_equalnan=*/true);
+  testStaticRuntime(src1, {a}, {b}, /*use_allclose=*/true, /*use_equalnan=*/true);
+
+  testStaticRuntime(src2, {a, 42.0}, {}, /*use_allclose=*/true, /*use_equalnan=*/true);
+  testStaticRuntime(src2, {a, 2.0}, {b, 1.0}, /*use_allclose=*/true, /*use_equalnan=*/true);
+
+  testStaticRuntime(src3, {a}, {}, /*use_allclose=*/true, /*use_equalnan=*/true);
+  testStaticRuntime(src3, {a}, {b}, /*use_allclose=*/true, /*use_equalnan=*/true);
+
+  // Non-NNC path
+  testStaticRuntime(src1, {a.to(at::kDouble)}, {}, /*use_allclose=*/true, /*use_equalnan=*/true);
+  testStaticRuntime(src1, {a.to(at::kDouble)}, {b.to(at::kDouble)}, /*use_allclose=*/true, /*use_equalnan=*/true);
+}
diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp
index 3083639..b33082b 100644
--- a/torch/csrc/jit/runtime/static/impl.cpp
+++ b/torch/csrc/jit/runtime/static/impl.cpp
@@ -153,10 +153,11 @@
         graph,
         fromQualString("fb::sigrid_transforms_torch_bind"),
         fromQualString("fb::variadic_sigrid_transforms_torch_bind"));
+    // These fused ops only have out variants - we can't do the fusion when
+    // out variants are disabled.
     FuseSignLog1P(graph);
+    FuseClampNaNToNum(graph);
 
-    // TODO: we can avoid this guard by moving operations
-    // to exposed folders.
 #ifdef FBCODE_CAFFE2
     if (opts.use_copy_variants && !opts.enable_tensorexpr_fusion) {
       ReplaceWithCopy(graph);
diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp
index ce5545b..2430e9f 100644
--- a/torch/csrc/jit/runtime/static/ops.cpp
+++ b/torch/csrc/jit/runtime/static/ops.cpp
@@ -604,6 +604,65 @@
   };
 });
 
+#ifdef FBCODE_CAFFE2
+// Disable externally to avoid MSVC errors in open-source CI
+
+REGISTER_OPERATOR_FUNCTOR(
+    static_runtime::clamp_nan_to_num,
+    static_runtime_clamp_nan_to_num,
+    [](Node* n) -> SROperator {
+      auto clamp_min_ival_opt = toIValue(n->input(1));
+      auto clamp_max_ival_opt = toIValue(n->input(2));
+      TORCH_CHECK(
+          clamp_min_ival_opt.has_value() && clamp_max_ival_opt.has_value());
+
+      auto clamp_min_opt = clamp_min_ival_opt->toOptional<at::Scalar>();
+      auto clamp_max_opt = clamp_max_ival_opt->toOptional<at::Scalar>();
+      TORCH_CHECK(clamp_min_opt.has_value() && clamp_max_opt.has_value());
+
+      return [te = createClampNanToNum(),
+              clamp_min = clamp_min_opt->to<float>(),
+              clamp_max =
+                  clamp_max_opt->to<float>()](ProcessedNode* p_node) mutable {
+        const auto& in0_t = p_node->Input(0).toTensor();
+        if (p_node->Output(0).isNone()) {
+          p_node->Output(0) = create_empty_from(in0_t);
+        }
+        auto& out_t = p_node->Output(0).toTensor();
+        fastResizeToZero(out_t);
+        auto in3_s = p_node->Input(3).toOptional<double>();
+
+        if (!te || !te->checkInput<float>(in0_t)) {
+          at::cpu::nan_to_num_out(
+              out_t,
+              at::cpu::clamp(in0_t, clamp_min, clamp_max),
+              in3_s,
+              c10::nullopt,
+              c10::nullopt);
+          return;
+        }
+        at::native::resize_(out_t, in0_t.sizes(), c10::nullopt);
+
+        auto output_size = in0_t.numel();
+
+        // This might be UB if in3_s is absurdly large, but most implementations
+        // just turn it into `inf` in that case. The PyTorch core nan_to_num
+        // kernel just static_cast()s the limits to the destination type, so
+        // we'll ignore overflow issues here as well.
+        auto nan = in3_s.has_value() ? static_cast<float>(*in3_s) : 0.f;
+
+        te->call(
+            {out_t.data_ptr(),
+             in0_t.data_ptr(),
+             &clamp_min,
+             &clamp_max,
+             &nan,
+             &output_size});
+      };
+    });
+
+#endif
+
 REGISTER_OPERATOR_FUNCTOR(aten::clamp, aten_clamp, [](Node* n) -> SROperator {
   if (n->matches(torch::schema(
           "aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor"))) {
@@ -623,9 +682,9 @@
       at::native::resize_(out_t, in0_t.sizes(), c10::nullopt);
       auto output_size = in0_t.numel();
       auto min = in1_s.has_value() ? in1_s->toFloat()
-                                   : std::numeric_limits<float>::lowest();
+                                   : -std::numeric_limits<float>::infinity();
       auto max = in2_s.has_value() ? in2_s->toFloat()
-                                   : std::numeric_limits<float>::max();
+                                   : std::numeric_limits<float>::infinity();
       te->call({out_t.data_ptr(), in0_t.data_ptr(), &min, &max, &output_size});
     };
   }
diff --git a/torch/csrc/jit/runtime/static/passes.cpp b/torch/csrc/jit/runtime/static/passes.cpp
index 38529e5..ca2449b 100644
--- a/torch/csrc/jit/runtime/static/passes.cpp
+++ b/torch/csrc/jit/runtime/static/passes.cpp
@@ -455,6 +455,9 @@
   m.def(torch::schema(
       "static_runtime::embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor)",
       c10::AliasAnalysisKind::PURE_FUNCTION));
+  m.def(torch::schema(
+      "static_runtime::clamp_nan_to_num(Tensor input, Scalar? min, Scalar? max, float? nan, float? posinf, float? posinf) -> Tensor",
+      c10::AliasAnalysisKind::PURE_FUNCTION));
 }
 
 void FuseSignLog1P(std::shared_ptr<torch::jit::Graph>& graph) {
@@ -1325,5 +1328,45 @@
   fuse.runOnGraph(graph);
 }
 
+void FuseClampNaNToNum(std::shared_ptr<Graph>& graph) {
+#ifdef FBCODE_CAFFE2
+  std::string pattern = R"IR(
+    graph(%input, %clamp_min: Scalar?, %clamp_max: Scalar?, %nan, %posinf, %neginf):
+        %x : Tensor = aten::clamp(%input, %clamp_min, %clamp_max)
+        %y : Tensor = aten::nan_to_num(%x, %nan, %posinf, %neginf)
+        return (%y))IR";
+
+  std::string fused_pattern = R"IR(
+    graph(%input, %clamp_min: Scalar?, %clamp_max: Scalar?, %nan, %posinf, %neginf):
+        %x : Tensor = static_runtime::clamp_nan_to_num(%input, %clamp_min, %clamp_max, %nan, %posinf, %neginf)
+        return (%x))IR";
+
+  auto isConstantAndNotNone = [](Value* value) {
+    auto ival_opt = toIValue(value);
+    if (!ival_opt.has_value()) {
+      return false;
+    }
+    auto scalar_opt = ival_opt->toOptional<at::Scalar>();
+    return scalar_opt.has_value();
+  };
+
+  auto clampValuesAreConstant =
+      [&isConstantAndNotNone](
+          const Match& match,
+          const std::unordered_map<std::string, Value*>& vmap) {
+        // Get the nodes in the real graph from the nodes in the template
+        // pattern graph
+        const auto& node_map = match.nodes_map;
+        auto* clamp_node = node_map.at(vmap.at("x")->node());
+        return isConstantAndNotNone(clamp_node->input(1)) &&
+            isConstantAndNotNone(clamp_node->input(2));
+      };
+
+  SubgraphRewriter fuse;
+  fuse.RegisterRewritePattern(pattern, fused_pattern);
+  fuse.runOnGraph(graph, clampValuesAreConstant);
+#endif
+}
+
 } // namespace jit
 } // namespace torch
diff --git a/torch/csrc/jit/runtime/static/passes.h b/torch/csrc/jit/runtime/static/passes.h
index 7b79478..270c907 100644
--- a/torch/csrc/jit/runtime/static/passes.h
+++ b/torch/csrc/jit/runtime/static/passes.h
@@ -78,5 +78,7 @@
 
 TORCH_API void QuantizedLinearReluFusion(std::shared_ptr<Graph>& graph);
 
+TORCH_API void FuseClampNaNToNum(std::shared_ptr<Graph>& graph);
+
 } // namespace jit
 } // namespace torch
diff --git a/torch/csrc/jit/runtime/static/te_wrapper.cpp b/torch/csrc/jit/runtime/static/te_wrapper.cpp
index 96acf99..d982e65 100644
--- a/torch/csrc/jit/runtime/static/te_wrapper.cpp
+++ b/torch/csrc/jit/runtime/static/te_wrapper.cpp
@@ -209,6 +209,34 @@
   return wrap;
 }
 
+std::shared_ptr<TEWrapper> createClampNanToNum() {
+  static auto symbol =
+      c10::Symbol::fromQualString("static_runtime::clamp_nan_to_num");
+  auto wrap = lookupNNCCache(symbol);
+  if (wrap) {
+    return wrap;
+  }
+  wrap = std::make_shared<TEWrapper>();
+  auto N = VarHandle("N", kInt);
+  auto min_handle = VarHandle("min", kFloat);
+  auto max_handle = VarHandle("max", kFloat);
+  auto nan_replace_val = VarHandle("nan_replace_val", kFloat);
+
+  BufHandle A("A", {N}, kFloat);
+  Tensor result = Compute("aten_clamp", {N}, [&](const VarHandle& i) {
+    auto a = A.load(i);
+    auto clamp = tensorexpr::clamp(min_handle, max_handle, a);
+    auto is_nan = tensorexpr::isnan(clamp);
+    auto nans_replaced =
+        tensorexpr::CompareSelect::make(is_nan, 1, nan_replace_val, clamp, kEQ);
+    return nans_replaced;
+  });
+  wrap = wrapTECompute(
+      wrap, result, {A, min_handle, max_handle, nan_replace_val, N});
+  updateNNCCache(symbol, wrap);
+  return wrap;
+}
+
 std::shared_ptr<TEWrapper> createSignedLog1p() {
   static auto signed_log1p_symbol =
       c10::Symbol::fromQualString("static_runtime::signed_log1p");
diff --git a/torch/csrc/jit/runtime/static/te_wrapper.h b/torch/csrc/jit/runtime/static/te_wrapper.h
index fcb642f..a9f2a55 100644
--- a/torch/csrc/jit/runtime/static/te_wrapper.h
+++ b/torch/csrc/jit/runtime/static/te_wrapper.h
@@ -39,6 +39,7 @@
 std::shared_ptr<TEWrapper> createSigmoid();
 std::shared_ptr<TEWrapper> createSignedLog1p();
 std::shared_ptr<TEWrapper> createClamp();
+std::shared_ptr<TEWrapper> createClampNanToNum();
 
 } // namespace jit
 } // namespace torch