[SR] Fuse clamp/nan_to_num
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77094
Fuse `clamp` and `nan_to_num` in an NNC kernel. This leads to a big speed up on many models. We can avoid comparisons since clamp potentially gets rid of all of the `inf`s in the input tensor.
Differential Revision: [D36220967](https://our.internmc.facebook.com/intern/diff/D36220967/)
**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D36220967/)!
Approved by: https://github.com/navahgar
diff --git a/benchmarks/static_runtime/test_static_module.cc b/benchmarks/static_runtime/test_static_module.cc
index 1871a52..41758ec 100644
--- a/benchmarks/static_runtime/test_static_module.cc
+++ b/benchmarks/static_runtime/test_static_module.cc
@@ -1698,3 +1698,69 @@
EliminateNoOpSlice(graph);
EXPECT_FALSE(hasNodeWithKind(graph, "aten::slice"));
}
+
+#ifdef FBCODE_CAFFE2
+// FuseClampNaNToNum pass is disabled externally to avoid MSVC errors in CI
+TEST(FuseClampNaNToNum, FusionHappens) {
+ const auto src = R"JIT(
+ def forward(self, x):
+ y = torch.clamp(x, min=0.0, max=1.0)
+ z = y.nan_to_num()
+ return z.clone()
+ )JIT";
+ torch::jit::Module mod("m");
+ mod.define(src);
+ auto graph = mod.get_method("forward").graph();
+ FuseClampNaNToNum(graph);
+ EXPECT_FALSE(hasNodeWithKind(graph, "aten::clamp"));
+ EXPECT_FALSE(hasNodeWithKind(graph, "aten::nan_to_num"));
+ EXPECT_TRUE(hasNodeWithKind(graph, "static_runtime::clamp_nan_to_num"));
+ // Correctness of the op is exercised in StaticRuntime.clamp_nan_to_num
+}
+
+TEST(FuseClampNaNToNum, NoFusion) {
+ const auto src1 = R"JIT(
+ def forward(self, x, a: float, b: float):
+ y = torch.clamp(x, a, b)
+ z = y.nan_to_num()
+ return z.clone()
+ )JIT";
+
+ const auto src2 = R"JIT(
+ def forward(self, x):
+ y = torch.clamp(x, min=0.0)
+ z = y.nan_to_num()
+ return z.clone()
+ )JIT";
+
+ const auto src3 = R"JIT(
+ def forward(self, x):
+ y = torch.clamp(x, max=0.0)
+ z = y.nan_to_num()
+ return z.clone()
+ )JIT";
+
+ const auto src4 = R"JIT(
+ def forward(self, x):
+ y = torch.clamp(x)
+ z = y.nan_to_num()
+ return z.clone()
+ )JIT";
+
+
+ auto checkScript = [](const char* src) {
+ torch::jit::Module mod("m");
+ mod.define(src);
+ auto graph = mod.get_method("forward").graph();
+ FuseClampNaNToNum(graph);
+ EXPECT_TRUE(hasNodeWithKind(graph, "aten::clamp"));
+ EXPECT_TRUE(hasNodeWithKind(graph, "aten::nan_to_num"));
+ EXPECT_FALSE(hasNodeWithKind(graph, "static_runtime::clamp_nan_to_num"));
+ };
+
+ checkScript(src1);
+ checkScript(src2);
+ checkScript(src3);
+ checkScript(src4);
+}
+#endif
diff --git a/benchmarks/static_runtime/test_static_runtime.cc b/benchmarks/static_runtime/test_static_runtime.cc
index 9a2d801..f6d4b0e 100644
--- a/benchmarks/static_runtime/test_static_runtime.cc
+++ b/benchmarks/static_runtime/test_static_runtime.cc
@@ -3336,3 +3336,44 @@
EXPECT_FALSE(hasNodeWithKind(graph, "quantized::linear_dynamic_fp16"));
EXPECT_TRUE(hasNodeWithKind(graph, "quantized::linear_relu_dynamic_fp16"));
}
+
+TEST(StaticRuntime, ClampNaNToNum) {
+ const auto src1 = R"JIT(
+ def forward(self, a):
+ return torch.clamp(a, min=1.0, max=2.0).nan_to_num().clone()
+ )JIT";
+
+ const auto src2 = R"JIT(
+ def forward(self, a, nan: float):
+ return torch.clamp(a, min=-1.0, max=2.0).nan_to_num(nan=nan).clone()
+ )JIT";
+
+ const auto src3 = R"JIT(
+ def forward(self, a):
+ return torch.clamp(a, min=1.0, max=-1.0).nan_to_num().clone()
+ )JIT";
+
+ auto a = at::tensor({
+ std::numeric_limits<float>::quiet_NaN(),
+ std::numeric_limits<float>::infinity(),
+ -std::numeric_limits<float>::infinity(),
+ 0.0f,
+ 3.0f
+ });
+ auto b = a.repeat({10, 5});
+
+ // Have to use_allclose even though all NaNs will be replaced - testStaticRuntime
+ // also checks inputs at the end to make sure they're not changed
+ testStaticRuntime(src1, {a}, {}, /*use_allclose=*/true, /*use_equalnan=*/true);
+ testStaticRuntime(src1, {a}, {b}, /*use_allclose=*/true, /*use_equalnan=*/true);
+
+ testStaticRuntime(src2, {a, 42.0}, {}, /*use_allclose=*/true, /*use_equalnan=*/true);
+ testStaticRuntime(src2, {a, 2.0}, {b, 1.0}, /*use_allclose=*/true, /*use_equalnan=*/true);
+
+ testStaticRuntime(src3, {a}, {}, /*use_allclose=*/true, /*use_equalnan=*/true);
+ testStaticRuntime(src3, {a}, {b}, /*use_allclose=*/true, /*use_equalnan=*/true);
+
+ // Non-NNC path
+ testStaticRuntime(src1, {a.to(at::kDouble)}, {}, /*use_allclose=*/true, /*use_equalnan=*/true);
+ testStaticRuntime(src1, {a.to(at::kDouble)}, {b.to(at::kDouble)}, /*use_allclose=*/true, /*use_equalnan=*/true);
+}
diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp
index 3083639..b33082b 100644
--- a/torch/csrc/jit/runtime/static/impl.cpp
+++ b/torch/csrc/jit/runtime/static/impl.cpp
@@ -153,10 +153,11 @@
graph,
fromQualString("fb::sigrid_transforms_torch_bind"),
fromQualString("fb::variadic_sigrid_transforms_torch_bind"));
+ // These fused ops only have out variants - we can't do the fusion when
+ // out variants are disabled.
FuseSignLog1P(graph);
+ FuseClampNaNToNum(graph);
- // TODO: we can avoid this guard by moving operations
- // to exposed folders.
#ifdef FBCODE_CAFFE2
if (opts.use_copy_variants && !opts.enable_tensorexpr_fusion) {
ReplaceWithCopy(graph);
diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp
index ce5545b..2430e9f 100644
--- a/torch/csrc/jit/runtime/static/ops.cpp
+++ b/torch/csrc/jit/runtime/static/ops.cpp
@@ -604,6 +604,65 @@
};
});
+#ifdef FBCODE_CAFFE2
+// Disable externally to avoid MSVC errors in open-source CI
+
+REGISTER_OPERATOR_FUNCTOR(
+ static_runtime::clamp_nan_to_num,
+ static_runtime_clamp_nan_to_num,
+ [](Node* n) -> SROperator {
+ auto clamp_min_ival_opt = toIValue(n->input(1));
+ auto clamp_max_ival_opt = toIValue(n->input(2));
+ TORCH_CHECK(
+ clamp_min_ival_opt.has_value() && clamp_max_ival_opt.has_value());
+
+ auto clamp_min_opt = clamp_min_ival_opt->toOptional<at::Scalar>();
+ auto clamp_max_opt = clamp_max_ival_opt->toOptional<at::Scalar>();
+ TORCH_CHECK(clamp_min_opt.has_value() && clamp_max_opt.has_value());
+
+ return [te = createClampNanToNum(),
+ clamp_min = clamp_min_opt->to<float>(),
+ clamp_max =
+ clamp_max_opt->to<float>()](ProcessedNode* p_node) mutable {
+ const auto& in0_t = p_node->Input(0).toTensor();
+ if (p_node->Output(0).isNone()) {
+ p_node->Output(0) = create_empty_from(in0_t);
+ }
+ auto& out_t = p_node->Output(0).toTensor();
+ fastResizeToZero(out_t);
+ auto in3_s = p_node->Input(3).toOptional<double>();
+
+ if (!te || !te->checkInput<float>(in0_t)) {
+ at::cpu::nan_to_num_out(
+ out_t,
+ at::cpu::clamp(in0_t, clamp_min, clamp_max),
+ in3_s,
+ c10::nullopt,
+ c10::nullopt);
+ return;
+ }
+ at::native::resize_(out_t, in0_t.sizes(), c10::nullopt);
+
+ auto output_size = in0_t.numel();
+
+ // This might be UB if in3_s is absurdly large, but most implementations
+ // just turn it into `inf` in that case. The PyTorch core nan_to_num
+ // kernel just static_cast()s the limits to the destination type, so
+ // we'll ignore overflow issues here as well.
+ auto nan = in3_s.has_value() ? static_cast<float>(*in3_s) : 0.f;
+
+ te->call(
+ {out_t.data_ptr(),
+ in0_t.data_ptr(),
+ &clamp_min,
+ &clamp_max,
+ &nan,
+ &output_size});
+ };
+ });
+
+#endif
+
REGISTER_OPERATOR_FUNCTOR(aten::clamp, aten_clamp, [](Node* n) -> SROperator {
if (n->matches(torch::schema(
"aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor"))) {
@@ -623,9 +682,9 @@
at::native::resize_(out_t, in0_t.sizes(), c10::nullopt);
auto output_size = in0_t.numel();
auto min = in1_s.has_value() ? in1_s->toFloat()
- : std::numeric_limits<float>::lowest();
+ : -std::numeric_limits<float>::infinity();
auto max = in2_s.has_value() ? in2_s->toFloat()
- : std::numeric_limits<float>::max();
+ : std::numeric_limits<float>::infinity();
te->call({out_t.data_ptr(), in0_t.data_ptr(), &min, &max, &output_size});
};
}
diff --git a/torch/csrc/jit/runtime/static/passes.cpp b/torch/csrc/jit/runtime/static/passes.cpp
index 38529e5..ca2449b 100644
--- a/torch/csrc/jit/runtime/static/passes.cpp
+++ b/torch/csrc/jit/runtime/static/passes.cpp
@@ -455,6 +455,9 @@
m.def(torch::schema(
"static_runtime::embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor)",
c10::AliasAnalysisKind::PURE_FUNCTION));
+ m.def(torch::schema(
+ "static_runtime::clamp_nan_to_num(Tensor input, Scalar? min, Scalar? max, float? nan, float? posinf, float? posinf) -> Tensor",
+ c10::AliasAnalysisKind::PURE_FUNCTION));
}
void FuseSignLog1P(std::shared_ptr<torch::jit::Graph>& graph) {
@@ -1325,5 +1328,45 @@
fuse.runOnGraph(graph);
}
+void FuseClampNaNToNum(std::shared_ptr<Graph>& graph) {
+#ifdef FBCODE_CAFFE2
+ std::string pattern = R"IR(
+ graph(%input, %clamp_min: Scalar?, %clamp_max: Scalar?, %nan, %posinf, %neginf):
+ %x : Tensor = aten::clamp(%input, %clamp_min, %clamp_max)
+ %y : Tensor = aten::nan_to_num(%x, %nan, %posinf, %neginf)
+ return (%y))IR";
+
+ std::string fused_pattern = R"IR(
+ graph(%input, %clamp_min: Scalar?, %clamp_max: Scalar?, %nan, %posinf, %neginf):
+ %x : Tensor = static_runtime::clamp_nan_to_num(%input, %clamp_min, %clamp_max, %nan, %posinf, %neginf)
+ return (%x))IR";
+
+ auto isConstantAndNotNone = [](Value* value) {
+ auto ival_opt = toIValue(value);
+ if (!ival_opt.has_value()) {
+ return false;
+ }
+ auto scalar_opt = ival_opt->toOptional<at::Scalar>();
+ return scalar_opt.has_value();
+ };
+
+ auto clampValuesAreConstant =
+ [&isConstantAndNotNone](
+ const Match& match,
+ const std::unordered_map<std::string, Value*>& vmap) {
+ // Get the nodes in the real graph from the nodes in the template
+ // pattern graph
+ const auto& node_map = match.nodes_map;
+ auto* clamp_node = node_map.at(vmap.at("x")->node());
+ return isConstantAndNotNone(clamp_node->input(1)) &&
+ isConstantAndNotNone(clamp_node->input(2));
+ };
+
+ SubgraphRewriter fuse;
+ fuse.RegisterRewritePattern(pattern, fused_pattern);
+ fuse.runOnGraph(graph, clampValuesAreConstant);
+#endif
+}
+
} // namespace jit
} // namespace torch
diff --git a/torch/csrc/jit/runtime/static/passes.h b/torch/csrc/jit/runtime/static/passes.h
index 7b79478..270c907 100644
--- a/torch/csrc/jit/runtime/static/passes.h
+++ b/torch/csrc/jit/runtime/static/passes.h
@@ -78,5 +78,7 @@
TORCH_API void QuantizedLinearReluFusion(std::shared_ptr<Graph>& graph);
+TORCH_API void FuseClampNaNToNum(std::shared_ptr<Graph>& graph);
+
} // namespace jit
} // namespace torch
diff --git a/torch/csrc/jit/runtime/static/te_wrapper.cpp b/torch/csrc/jit/runtime/static/te_wrapper.cpp
index 96acf99..d982e65 100644
--- a/torch/csrc/jit/runtime/static/te_wrapper.cpp
+++ b/torch/csrc/jit/runtime/static/te_wrapper.cpp
@@ -209,6 +209,34 @@
return wrap;
}
+std::shared_ptr<TEWrapper> createClampNanToNum() {
+ static auto symbol =
+ c10::Symbol::fromQualString("static_runtime::clamp_nan_to_num");
+ auto wrap = lookupNNCCache(symbol);
+ if (wrap) {
+ return wrap;
+ }
+ wrap = std::make_shared<TEWrapper>();
+ auto N = VarHandle("N", kInt);
+ auto min_handle = VarHandle("min", kFloat);
+ auto max_handle = VarHandle("max", kFloat);
+ auto nan_replace_val = VarHandle("nan_replace_val", kFloat);
+
+ BufHandle A("A", {N}, kFloat);
+ Tensor result = Compute("aten_clamp", {N}, [&](const VarHandle& i) {
+ auto a = A.load(i);
+ auto clamp = tensorexpr::clamp(min_handle, max_handle, a);
+ auto is_nan = tensorexpr::isnan(clamp);
+ auto nans_replaced =
+ tensorexpr::CompareSelect::make(is_nan, 1, nan_replace_val, clamp, kEQ);
+ return nans_replaced;
+ });
+ wrap = wrapTECompute(
+ wrap, result, {A, min_handle, max_handle, nan_replace_val, N});
+ updateNNCCache(symbol, wrap);
+ return wrap;
+}
+
std::shared_ptr<TEWrapper> createSignedLog1p() {
static auto signed_log1p_symbol =
c10::Symbol::fromQualString("static_runtime::signed_log1p");
diff --git a/torch/csrc/jit/runtime/static/te_wrapper.h b/torch/csrc/jit/runtime/static/te_wrapper.h
index fcb642f..a9f2a55 100644
--- a/torch/csrc/jit/runtime/static/te_wrapper.h
+++ b/torch/csrc/jit/runtime/static/te_wrapper.h
@@ -39,6 +39,7 @@
std::shared_ptr<TEWrapper> createSigmoid();
std::shared_ptr<TEWrapper> createSignedLog1p();
std::shared_ptr<TEWrapper> createClamp();
+std::shared_ptr<TEWrapper> createClampNanToNum();
} // namespace jit
} // namespace torch