C++ API parity: PoissonNLLLoss
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/28755
Test Plan: Imported from OSS
Differential Revision: D18202436
Pulled By: pbelevich
fbshipit-source-id: a7a27d5f3cdbcbbd9bbbffa02b576609d5fdc9b3
diff --git a/test/cpp/api/functional.cpp b/test/cpp/api/functional.cpp
index a3e19b2..ade6e0e 100644
--- a/test/cpp/api/functional.cpp
+++ b/test/cpp/api/functional.cpp
@@ -1654,3 +1654,20 @@
}
}
}
+
+TEST_F(FunctionalTest, PoissonNLLLoss) {
+ const auto input = torch::tensor({0.5, 1.5, 2.5});
+ const auto target = torch::tensor({1., 2., 3.});
+ const auto component_wise_loss = torch::exp(input) - target * input;
+ ASSERT_TRUE(torch::allclose(torch::mean(component_wise_loss),
+ F::poisson_nll_loss(input, target)));
+ ASSERT_TRUE(torch::allclose(component_wise_loss,
+ F::poisson_nll_loss(input, target,
+ PoissonNLLLossOptions().reduction(torch::kNone))));
+ ASSERT_TRUE(torch::allclose(torch::sum(component_wise_loss),
+ F::poisson_nll_loss(input, target,
+ PoissonNLLLossOptions().reduction(torch::kSum))));
+ ASSERT_TRUE(torch::allclose(torch::mean(component_wise_loss),
+ F::poisson_nll_loss(input, target,
+ PoissonNLLLossOptions().reduction(torch::kMean))));
+}
diff --git a/test/cpp/api/modules.cpp b/test/cpp/api/modules.cpp
index b656b9d..885c91b 100644
--- a/test/cpp/api/modules.cpp
+++ b/test/cpp/api/modules.cpp
@@ -2032,6 +2032,33 @@
-log_probs.sum(0).slice(1, 0, 1).view_as(output), output));
}
+TEST_F(ModulesTest, PoissonNLLLoss) {
+ const auto input = torch::tensor({0.5, 1.5, 2.5});
+ const auto target = torch::tensor({1., 2., 3.});
+ const auto component_wise_loss = torch::exp(input) - target * input;
+ {
+ PoissonNLLLoss loss {PoissonNLLLossOptions().reduction(torch::kNone)};
+ ASSERT_TRUE(torch::allclose(
+ component_wise_loss,
+ loss->forward(input, target)
+ ));
+ }
+ {
+ PoissonNLLLoss loss {PoissonNLLLossOptions().reduction(torch::kSum)};
+ ASSERT_TRUE(torch::allclose(
+ torch::sum(component_wise_loss),
+ loss->forward(input, target)
+ ));
+ }
+ {
+ PoissonNLLLoss loss {PoissonNLLLossOptions().reduction(torch::kMean)};
+ ASSERT_TRUE(torch::allclose(
+ torch::mean(component_wise_loss),
+ loss->forward(input, target)
+ ));
+ }
+}
+
TEST_F(ModulesTest, PrettyPrintIdentity) {
ASSERT_EQ(c10::str(Identity()), "torch::nn::Identity()");
}
@@ -2882,3 +2909,11 @@
CTCLossOptions().blank(42).zero_infinity(false)
.reduction(torch::kSum))), "torch::nn::CTCLoss()");
}
+
+TEST_F(ModulesTest, PrettyPrintPoissonNLLLoss) {
+ ASSERT_EQ(c10::str(PoissonNLLLoss()), "torch::nn::PoissonNLLLoss()");
+ ASSERT_EQ(c10::str(PoissonNLLLoss(
+ PoissonNLLLossOptions().log_input(false).full(true).eps(0.42)
+ .reduction(torch::kSum))),
+ "torch::nn::PoissonNLLLoss()");
+}
diff --git a/test/cpp_api_parity/parity-tracker.md b/test/cpp_api_parity/parity-tracker.md
index 0636874..49c7b16 100644
--- a/test/cpp_api_parity/parity-tracker.md
+++ b/test/cpp_api_parity/parity-tracker.md
@@ -107,7 +107,7 @@
torch.nn.CrossEntropyLoss|No|No
torch.nn.CTCLoss|Yes|No
torch.nn.NLLLoss|No|No
-torch.nn.PoissonNLLLoss|No|No
+torch.nn.PoissonNLLLoss|Yes|No
torch.nn.KLDivLoss|Yes|No
torch.nn.BCELoss|Yes|No
torch.nn.BCEWithLogitsLoss|No|No
diff --git a/torch/csrc/api/include/torch/nn/functional/loss.h b/torch/csrc/api/include/torch/nn/functional/loss.h
index d39efde..c001026 100644
--- a/torch/csrc/api/include/torch/nn/functional/loss.h
+++ b/torch/csrc/api/include/torch/nn/functional/loss.h
@@ -245,6 +245,13 @@
options.zero_infinity());
}
+inline Tensor poisson_nll_loss(const Tensor& input, const Tensor& target,
+ const PoissonNLLLossOptions& options = {}) {
+ return torch::poisson_nll_loss(input, target, options.log_input(),
+ options.full(), options.eps(),
+ enumtype::reduction_get_enum(options.reduction()));
+}
+
} // namespace functional
} // namespace nn
} // namespace torch
diff --git a/torch/csrc/api/include/torch/nn/modules/loss.h b/torch/csrc/api/include/torch/nn/modules/loss.h
index ba8e9fe..dd7ef62 100644
--- a/torch/csrc/api/include/torch/nn/modules/loss.h
+++ b/torch/csrc/api/include/torch/nn/modules/loss.h
@@ -371,5 +371,27 @@
/// PyTorch's module storage semantics.
TORCH_MODULE(CTCLoss);
+// ============================================================================
+
+struct TORCH_API PoissonNLLLossImpl : public Cloneable<PoissonNLLLossImpl> {
+ explicit PoissonNLLLossImpl(const PoissonNLLLossOptions& options_ = {});
+
+ void reset() override;
+
+ /// Pretty prints the `PoissonNLLLoss` module into the given `stream`.
+ void pretty_print(std::ostream& stream) const override;
+
+ Tensor forward(const Tensor& log_input, const Tensor& targets);
+
+ /// The options with which this `Module` was constructed.
+ PoissonNLLLossOptions options;
+};
+
+/// A `ModuleHolder` subclass for `PoissonNLLLossImpl`.
+/// See the documentation for `PoissonNLLLoss` class to learn what
+/// methods it provides, or the documentation for `ModuleHolder` to learn about
+/// PyTorch's module storage semantics.
+TORCH_MODULE(PoissonNLLLoss);
+
} // namespace nn
} // namespace torch
diff --git a/torch/csrc/api/include/torch/nn/options/loss.h b/torch/csrc/api/include/torch/nn/options/loss.h
index 2b9a5a6..6989808 100644
--- a/torch/csrc/api/include/torch/nn/options/loss.h
+++ b/torch/csrc/api/include/torch/nn/options/loss.h
@@ -204,5 +204,24 @@
TORCH_ARG(torch::Reduction::Reduction, reduction);
};
+// ============================================================================
+
+/// Options for PoissonNLLLoss functional and module.
+struct TORCH_API PoissonNLLLossOptions {
+ typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum> reduction_t;
+
+ /// if true the loss is computed as `exp(input) - target * input`,
+ /// if false the loss is `input - target * log(input + eps)`.
+ TORCH_ARG(bool, log_input) = true;
+ /// whether to compute full loss, i.e. to add the Stirling approximation term
+ /// target * log(target) - target + 0.5 * log(2 * pi * target).
+ TORCH_ARG(bool, full) = false;
+ /// Small value to avoid evaluation of `log(0)` when `log_input = false`.
+ /// Default: 1e-8
+ TORCH_ARG(double, eps) = 1e-8;
+ /// Specifies the reduction to apply to the output. Default: Mean
+ TORCH_ARG(reduction_t, reduction) = torch::kMean;
+};
+
} // namespace nn
} // namespace torch
diff --git a/torch/csrc/api/src/nn/modules/loss.cpp b/torch/csrc/api/src/nn/modules/loss.cpp
index e459216..c98e4c5 100644
--- a/torch/csrc/api/src/nn/modules/loss.cpp
+++ b/torch/csrc/api/src/nn/modules/loss.cpp
@@ -231,5 +231,21 @@
return F::ctc_loss(log_probs, targets, input_lengths, target_lengths, options);
}
+// ============================================================================
+
+PoissonNLLLossImpl::PoissonNLLLossImpl(const PoissonNLLLossOptions& options_)
+ : options(options_) {}
+
+void PoissonNLLLossImpl::reset() {}
+
+void PoissonNLLLossImpl::pretty_print(std::ostream& stream) const {
+ stream << "torch::nn::PoissonNLLLoss()";
+}
+
+Tensor PoissonNLLLossImpl::forward(
+ const Tensor& log_input, const Tensor& target) {
+ return F::poisson_nll_loss(log_input, target, options);
+}
+
} // namespace nn
} // namespace torch