C++ API parity: PoissonNLLLoss

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/28755

Test Plan: Imported from OSS

Differential Revision: D18202436

Pulled By: pbelevich

fbshipit-source-id: a7a27d5f3cdbcbbd9bbbffa02b576609d5fdc9b3
diff --git a/test/cpp/api/functional.cpp b/test/cpp/api/functional.cpp
index a3e19b2..ade6e0e 100644
--- a/test/cpp/api/functional.cpp
+++ b/test/cpp/api/functional.cpp
@@ -1654,3 +1654,20 @@
     }
   }
 }
+
+TEST_F(FunctionalTest, PoissonNLLLoss) {
+  const auto input = torch::tensor({0.5, 1.5, 2.5});
+  const auto target = torch::tensor({1., 2., 3.});
+  const auto component_wise_loss = torch::exp(input) - target * input;
+  ASSERT_TRUE(torch::allclose(torch::mean(component_wise_loss),
+    F::poisson_nll_loss(input, target)));
+  ASSERT_TRUE(torch::allclose(component_wise_loss,
+    F::poisson_nll_loss(input, target,
+    PoissonNLLLossOptions().reduction(torch::kNone))));
+  ASSERT_TRUE(torch::allclose(torch::sum(component_wise_loss),
+    F::poisson_nll_loss(input, target,
+    PoissonNLLLossOptions().reduction(torch::kSum))));
+  ASSERT_TRUE(torch::allclose(torch::mean(component_wise_loss),
+    F::poisson_nll_loss(input, target,
+    PoissonNLLLossOptions().reduction(torch::kMean))));
+}
diff --git a/test/cpp/api/modules.cpp b/test/cpp/api/modules.cpp
index b656b9d..885c91b 100644
--- a/test/cpp/api/modules.cpp
+++ b/test/cpp/api/modules.cpp
@@ -2032,6 +2032,33 @@
     -log_probs.sum(0).slice(1, 0, 1).view_as(output), output));
 }
 
+TEST_F(ModulesTest, PoissonNLLLoss) {
+  const auto input = torch::tensor({0.5, 1.5, 2.5});
+  const auto target = torch::tensor({1., 2., 3.});
+  const auto component_wise_loss = torch::exp(input) - target * input;
+  {
+    PoissonNLLLoss loss {PoissonNLLLossOptions().reduction(torch::kNone)};
+    ASSERT_TRUE(torch::allclose(
+      component_wise_loss,
+      loss->forward(input, target)
+    ));
+  }
+  {
+    PoissonNLLLoss loss {PoissonNLLLossOptions().reduction(torch::kSum)};
+    ASSERT_TRUE(torch::allclose(
+      torch::sum(component_wise_loss),
+      loss->forward(input, target)
+    ));
+  }
+  {
+    PoissonNLLLoss loss {PoissonNLLLossOptions().reduction(torch::kMean)};
+    ASSERT_TRUE(torch::allclose(
+      torch::mean(component_wise_loss),
+      loss->forward(input, target)
+    ));
+  }
+}
+
 TEST_F(ModulesTest, PrettyPrintIdentity) {
   ASSERT_EQ(c10::str(Identity()), "torch::nn::Identity()");
 }
@@ -2882,3 +2909,11 @@
     CTCLossOptions().blank(42).zero_infinity(false)
       .reduction(torch::kSum))), "torch::nn::CTCLoss()");
 }
+
+TEST_F(ModulesTest, PrettyPrintPoissonNLLLoss) {
+  ASSERT_EQ(c10::str(PoissonNLLLoss()), "torch::nn::PoissonNLLLoss()");
+  ASSERT_EQ(c10::str(PoissonNLLLoss(
+    PoissonNLLLossOptions().log_input(false).full(true).eps(0.42)
+    .reduction(torch::kSum))),
+    "torch::nn::PoissonNLLLoss()");
+}
diff --git a/test/cpp_api_parity/parity-tracker.md b/test/cpp_api_parity/parity-tracker.md
index 0636874..49c7b16 100644
--- a/test/cpp_api_parity/parity-tracker.md
+++ b/test/cpp_api_parity/parity-tracker.md
@@ -107,7 +107,7 @@
 torch.nn.CrossEntropyLoss|No|No
 torch.nn.CTCLoss|Yes|No
 torch.nn.NLLLoss|No|No
-torch.nn.PoissonNLLLoss|No|No
+torch.nn.PoissonNLLLoss|Yes|No
 torch.nn.KLDivLoss|Yes|No
 torch.nn.BCELoss|Yes|No
 torch.nn.BCEWithLogitsLoss|No|No
diff --git a/torch/csrc/api/include/torch/nn/functional/loss.h b/torch/csrc/api/include/torch/nn/functional/loss.h
index d39efde..c001026 100644
--- a/torch/csrc/api/include/torch/nn/functional/loss.h
+++ b/torch/csrc/api/include/torch/nn/functional/loss.h
@@ -245,6 +245,13 @@
     options.zero_infinity());
 }
 
+inline Tensor poisson_nll_loss(const Tensor& input, const Tensor& target,
+                               const PoissonNLLLossOptions& options = {}) {
+  return torch::poisson_nll_loss(input, target, options.log_input(),
+    options.full(), options.eps(),
+    enumtype::reduction_get_enum(options.reduction()));
+}
+
 } // namespace functional
 } // namespace nn
 } // namespace torch
diff --git a/torch/csrc/api/include/torch/nn/modules/loss.h b/torch/csrc/api/include/torch/nn/modules/loss.h
index ba8e9fe..dd7ef62 100644
--- a/torch/csrc/api/include/torch/nn/modules/loss.h
+++ b/torch/csrc/api/include/torch/nn/modules/loss.h
@@ -371,5 +371,27 @@
 /// PyTorch's module storage semantics.
 TORCH_MODULE(CTCLoss);
 
+// ============================================================================
+
+struct TORCH_API PoissonNLLLossImpl : public Cloneable<PoissonNLLLossImpl> {
+  explicit PoissonNLLLossImpl(const PoissonNLLLossOptions& options_ = {});
+
+  void reset() override;
+
+  /// Pretty prints the `PoissonNLLLoss` module into the given `stream`.
+  void pretty_print(std::ostream& stream) const override;
+
+  Tensor forward(const Tensor& log_input, const Tensor& targets);
+
+  /// The options with which this `Module` was constructed.
+  PoissonNLLLossOptions options;
+};
+
+/// A `ModuleHolder` subclass for `PoissonNLLLossImpl`.
+/// See the documentation for `PoissonNLLLoss` class to learn what
+/// methods it provides, or the documentation for `ModuleHolder` to learn about
+/// PyTorch's module storage semantics.
+TORCH_MODULE(PoissonNLLLoss);
+
 } // namespace nn
 } // namespace torch
diff --git a/torch/csrc/api/include/torch/nn/options/loss.h b/torch/csrc/api/include/torch/nn/options/loss.h
index 2b9a5a6..6989808 100644
--- a/torch/csrc/api/include/torch/nn/options/loss.h
+++ b/torch/csrc/api/include/torch/nn/options/loss.h
@@ -204,5 +204,24 @@
   TORCH_ARG(torch::Reduction::Reduction, reduction);
 };
 
+// ============================================================================
+
+/// Options for PoissonNLLLoss functional and module.
+struct TORCH_API PoissonNLLLossOptions {
+  typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum> reduction_t;
+
+  /// if true the loss is computed as `exp(input) - target * input`,
+  /// if false the loss is `input - target * log(input + eps)`.
+  TORCH_ARG(bool, log_input) = true;
+  /// whether to compute full loss, i.e. to add the Stirling approximation term
+  /// target * log(target) - target + 0.5 * log(2 * pi * target).
+  TORCH_ARG(bool, full) = false;
+  /// Small value to avoid evaluation of `log(0)` when `log_input = false`.
+  /// Default: 1e-8
+  TORCH_ARG(double, eps) = 1e-8;
+  /// Specifies the reduction to apply to the output. Default: Mean
+  TORCH_ARG(reduction_t, reduction) = torch::kMean;
+};
+
 } // namespace nn
 } // namespace torch
diff --git a/torch/csrc/api/src/nn/modules/loss.cpp b/torch/csrc/api/src/nn/modules/loss.cpp
index e459216..c98e4c5 100644
--- a/torch/csrc/api/src/nn/modules/loss.cpp
+++ b/torch/csrc/api/src/nn/modules/loss.cpp
@@ -231,5 +231,21 @@
   return F::ctc_loss(log_probs, targets, input_lengths, target_lengths, options);
 }
 
+// ============================================================================
+
+PoissonNLLLossImpl::PoissonNLLLossImpl(const PoissonNLLLossOptions& options_)
+  : options(options_) {}
+
+void PoissonNLLLossImpl::reset() {}
+
+void PoissonNLLLossImpl::pretty_print(std::ostream& stream) const {
+  stream << "torch::nn::PoissonNLLLoss()";
+}
+
+Tensor PoissonNLLLossImpl::forward(
+  const Tensor& log_input, const Tensor& target) {
+  return F::poisson_nll_loss(log_input, target, options);
+}
+
 } // namespace nn
 } // namespace torch