[C++ API] AdaptiveLogSoftmaxWithLoss (#29076)

Summary:
Implemented AdaptiveLogSoftmaxWithLoss and some tests for modules. Reference https://github.com/pytorch/pytorch/issues/25883
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29076

Differential Revision: D20404588

Pulled By: yf225

fbshipit-source-id: edbadf432b8173cbcc6caf83c9c03dd92dc31a37
diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt
index 76ab3eb..2798755 100644
--- a/caffe2/CMakeLists.txt
+++ b/caffe2/CMakeLists.txt
@@ -595,6 +595,7 @@
       ${TORCH_SRC_DIR}/csrc/api/src/nn/module.cpp
       ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/_functions.cpp
       ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/activation.cpp
+      ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/adaptive.cpp
       ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/batchnorm.cpp
       ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/normalization.cpp
       ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/instancenorm.cpp
@@ -613,6 +614,7 @@
       ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/container/functional.cpp
       ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/container/named_any.cpp
       ${TORCH_SRC_DIR}/csrc/api/src/nn/options/activation.cpp
+      ${TORCH_SRC_DIR}/csrc/api/src/nn/options/adaptive.cpp
       ${TORCH_SRC_DIR}/csrc/api/src/nn/options/batchnorm.cpp
       ${TORCH_SRC_DIR}/csrc/api/src/nn/options/embedding.cpp
       ${TORCH_SRC_DIR}/csrc/api/src/nn/options/instancenorm.cpp
diff --git a/test/cpp/api/modules.cpp b/test/cpp/api/modules.cpp
index 6fcf841..f5d0999 100644
--- a/test/cpp/api/modules.cpp
+++ b/test/cpp/api/modules.cpp
@@ -2407,6 +2407,50 @@
   }
 }
 
+TEST_F(ModulesTest, AdaptiveLogSoftmaxWithLoss) {
+  {
+    // log_probs actually returns log_proba
+    AdaptiveLogSoftmaxWithLoss asfm(AdaptiveLogSoftmaxWithLossOptions(8, 4, {2}).div_value(2.));
+    auto x = torch::randn({4, 8});
+    auto logprob_out = asfm->log_prob(x);
+    ASSERT_TRUE(torch::allclose(torch::exp(logprob_out).data().sum(1), torch::ones(4)));
+  }
+  {
+    // test predict
+    AdaptiveLogSoftmaxWithLoss asfm(AdaptiveLogSoftmaxWithLossOptions(8, 10, {4, 8}).div_value(2.).head_bias(true));
+    auto x = torch::randn({64, 8});
+    auto logprob_out = asfm->log_prob(x);
+    auto predict_out = asfm->predict(x);
+    ASSERT_TRUE(torch::allclose(predict_out, logprob_out.argmax(1)));
+  }
+  {
+    // cluster sizes
+    AdaptiveLogSoftmaxWithLoss asfm(AdaptiveLogSoftmaxWithLossOptions(16, 20, {4, 10, 15}).div_value(2.));
+    auto x = torch::arange(100, 132, torch::kFloat).reshape({2, 16});
+    auto y = torch::tensor({0, 17}, torch::kLong);
+    auto asm_out = asfm(x, y);
+    ASSERT_EQ(asm_out.output.sizes(), std::vector<int64_t>({2}));
+  }
+  {
+    // forward returns the same thing as log_probs
+    AdaptiveLogSoftmaxWithLoss asfm(AdaptiveLogSoftmaxWithLossOptions(8, 4, {2}).div_value(2.));
+    auto x = torch::randn({4, 8});
+    auto logprob_out = asfm->log_prob(x);
+    NLLLoss nll_loss;
+
+    for (int64_t v = 0; v < 4; ++v) {
+      auto y = torch::full({4}, v, torch::kLong);
+      auto asm_out = asfm(x, y);
+      auto out = asm_out.output;
+      auto loss = torch::tensor(asm_out.loss, torch::kFloat);
+      auto expected = nll_loss->forward(logprob_out, y);
+
+      ASSERT_TRUE(torch::allclose(loss, expected));
+      ASSERT_TRUE(torch::allclose(out, logprob_out.gather(1, y.unsqueeze(1)).squeeze()));
+    }
+  }
+}
+
 TEST_F(ModulesTest, Softmax2d) {
   Softmax2d m;
   auto input = torch::arange(24, torch::kFloat).reshape({1, 2, 3, 4});
@@ -4545,3 +4589,38 @@
   ASSERT_EQ(c10::str(MultiheadAttention(MultiheadAttentionOptions(20, 10).bias(false))),
     "torch::nn::MultiheadAttention(\n  (out_proj): torch::nn::Linear(in_features=20, out_features=20, bias=false)\n)");
 }
+
+TEST_F(ModulesTest, PrettyPrintAdaptiveLogSoftmaxWithLoss) {
+  {
+    AdaptiveLogSoftmaxWithLoss asfm(AdaptiveLogSoftmaxWithLossOptions(8, 4, {2}).div_value(2.));
+    ASSERT_EQ(
+      c10::str(asfm),
+      "torch::nn::AdaptiveLogSoftmaxWithLoss(\n"
+      "  (head): torch::nn::Linear(in_features=8, out_features=3, bias=false)\n"
+      "  (tail): torch::nn::ModuleList(\n"
+      "    (0): torch::nn::Sequential(\n"
+      "      (0): torch::nn::Linear(in_features=8, out_features=4, bias=false)\n"
+      "      (1): torch::nn::Linear(in_features=4, out_features=2, bias=false)\n"
+      "    )\n"
+      "  )\n"
+      ")");
+  }
+  {
+    AdaptiveLogSoftmaxWithLoss asfm(AdaptiveLogSoftmaxWithLossOptions(8, 10, {4, 8}).div_value(2.).head_bias(true));
+    ASSERT_EQ(
+      c10::str(asfm),
+      "torch::nn::AdaptiveLogSoftmaxWithLoss(\n"
+      "  (head): torch::nn::Linear(in_features=8, out_features=6, bias=true)\n"
+      "  (tail): torch::nn::ModuleList(\n"
+      "    (0): torch::nn::Sequential(\n"
+      "      (0): torch::nn::Linear(in_features=8, out_features=4, bias=false)\n"
+      "      (1): torch::nn::Linear(in_features=4, out_features=4, bias=false)\n"
+      "    )\n"
+      "    (1): torch::nn::Sequential(\n"
+      "      (0): torch::nn::Linear(in_features=8, out_features=2, bias=false)\n"
+      "      (1): torch::nn::Linear(in_features=2, out_features=2, bias=false)\n"
+      "    )\n"
+      "  )\n"
+      ")");
+  }
+}
diff --git a/test/cpp_api_parity/parity-tracker.md b/test/cpp_api_parity/parity-tracker.md
index 57fb344..f49a5bd 100644
--- a/test/cpp_api_parity/parity-tracker.md
+++ b/test/cpp_api_parity/parity-tracker.md
@@ -69,7 +69,7 @@
 torch::nn::Softmax|Yes|No
 torch::nn::Softmax2d|Yes|No
 torch::nn::LogSoftmax|Yes|No
-torch::nn::AdaptiveLogSoftmaxWithLoss|No|No
+torch::nn::AdaptiveLogSoftmaxWithLoss|Yes|No
 torch::nn::BatchNorm1d|Yes|No
 torch::nn::BatchNorm2d|Yes|No
 torch::nn::BatchNorm3d|Yes|No
diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl
index b65a21d..4000a7e 100644
--- a/tools/build_variables.bzl
+++ b/tools/build_variables.bzl
@@ -237,6 +237,7 @@
     "torch/csrc/api/src/nn/module.cpp",
     "torch/csrc/api/src/nn/modules/_functions.cpp",
     "torch/csrc/api/src/nn/modules/activation.cpp",
+    "torch/csrc/api/src/nn/modules/adaptive.cpp",
     "torch/csrc/api/src/nn/modules/batchnorm.cpp",
     "torch/csrc/api/src/nn/modules/normalization.cpp",
     "torch/csrc/api/src/nn/modules/instancenorm.cpp",
@@ -255,6 +256,7 @@
     "torch/csrc/api/src/nn/modules/container/functional.cpp",
     "torch/csrc/api/src/nn/modules/container/named_any.cpp",
     "torch/csrc/api/src/nn/options/activation.cpp",
+    "torch/csrc/api/src/nn/options/adaptive.cpp",
     "torch/csrc/api/src/nn/options/batchnorm.cpp",
     "torch/csrc/api/src/nn/options/conv.cpp",
     "torch/csrc/api/src/nn/options/dropout.cpp",
diff --git a/torch/csrc/api/include/torch/nn/modules.h b/torch/csrc/api/include/torch/nn/modules.h
index 9c60486..63c0659 100644
--- a/torch/csrc/api/include/torch/nn/modules.h
+++ b/torch/csrc/api/include/torch/nn/modules.h
@@ -11,6 +11,7 @@
 #include <torch/nn/modules/container/sequential.h>
 
 // Layers
+#include <torch/nn/modules/adaptive.h>
 #include <torch/nn/modules/batchnorm.h>
 #include <torch/nn/modules/instancenorm.h>
 #include <torch/nn/modules/conv.h>
diff --git a/torch/csrc/api/include/torch/nn/modules/adaptive.h b/torch/csrc/api/include/torch/nn/modules/adaptive.h
new file mode 100644
index 0000000..9722fbe
--- /dev/null
+++ b/torch/csrc/api/include/torch/nn/modules/adaptive.h
@@ -0,0 +1,94 @@
+#pragma once
+
+#include <torch/nn/cloneable.h>
+#include <torch/nn/module.h>
+#include <torch/nn/modules/linear.h>
+#include <torch/nn/modules/container/modulelist.h>
+#include <torch/nn/modules/container/sequential.h>
+#include <torch/nn/functional/activation.h> 
+#include <torch/nn/options/adaptive.h>
+
+namespace torch {
+namespace nn {
+
+/// The output of a single invocation of an AdaptiveLogSoftmaxWithLoss
+/// module's `forward()` method.
+struct TORCH_API ASMoutput {
+  ASMoutput(Tensor output_, double loss_);
+
+  /// Tensor containing computed target log probabilities for each example
+  Tensor output;
+
+  /// Scalar representing the computed negative log likelihood loss
+  double loss;
+};
+
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AdaptiveLogSoftmaxWithLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+/// Efficient softmax approximation as described in
+/// `Efficient softmax approximation for GPUs`_ by Edouard Grave, Armand Joulin,
+/// Moustapha Cissé, David Grangier, and Hervé Jégou.
+/// See https://pytorch.org/docs/master/nn.html#torch.nn.AdaptiveLogSoftmaxWithLoss to learn
+/// about the exact behavior of this module.
+///
+/// See the documentation for `torch::nn::AdaptiveLogSoftmaxWithLossOptions` class to learn what
+/// constructor arguments are supported for this module.
+///
+/// Example:
+/// ```
+/// AdaptiveLogSoftmaxWithLoss model(AdaptiveLogSoftmaxWithLossOptions(8, 10, {4, 8}).div_value(2.).head_bias(true));
+/// ```
+class TORCH_API AdaptiveLogSoftmaxWithLossImpl : public Cloneable<AdaptiveLogSoftmaxWithLossImpl> {
+ public:
+   AdaptiveLogSoftmaxWithLossImpl(int64_t in_features, int64_t n_classes, std::vector<int64_t> cutoffs)
+      : AdaptiveLogSoftmaxWithLossImpl(AdaptiveLogSoftmaxWithLossOptions(in_features, n_classes, cutoffs)) {}
+     
+  explicit AdaptiveLogSoftmaxWithLossImpl(AdaptiveLogSoftmaxWithLossOptions options_);
+
+  ASMoutput forward(const Tensor& input, const Tensor& target);
+
+  void reset() override;
+
+  void reset_parameters();
+
+  /// Pretty prints the `AdaptiveLogSoftmaxWithLoss` module into the given `stream`.
+  void pretty_print(std::ostream& stream) const override;
+
+  /// Given input tensor, and output of `head`, computes the log of the full distribution
+  Tensor _get_full_log_prob(const Tensor &input, const Tensor& head_output);
+
+  /// Computes log probabilities for all n_classes
+  Tensor log_prob(const Tensor& input);
+
+  /// This is equivalent to `log_pob(input).argmax(1)` but is more efficient in some cases
+  Tensor predict(const Tensor& input);
+
+  /// The options with which this `Module` was constructed
+  AdaptiveLogSoftmaxWithLossOptions options;
+
+  /// Cutoffs used to assign targets to their buckets. It should be an ordered Sequence 
+  /// of integers sorted in the increasing order
+  std::vector<int64_t> cutoffs;
+
+  int64_t shortlist_size;
+  
+  /// Number of clusters
+  int64_t n_clusters;
+
+  /// Output size of head classifier
+  int64_t head_size;
+
+  Linear head = nullptr;
+
+  ModuleList tail;
+};
+
+/// A `ModuleHolder` subclass for `AdaptiveLogSoftmaxWithLossImpl`.
+/// See the documentation for `AdaptiveLogSoftmaxWithLossImpl` class to learn what methods it
+/// provides, and examples of how to use `AdaptiveLogSoftmaxWithLoss` with `torch::nn::AdaptiveLogSoftmaxWithLossOptions`.
+/// See the documentation for `ModuleHolder` to learn about PyTorch's
+/// module storage semantics.
+TORCH_MODULE(AdaptiveLogSoftmaxWithLoss);
+
+} // namespace nn
+} // namespace torc
diff --git a/torch/csrc/api/include/torch/nn/modules/container/sequential.h b/torch/csrc/api/include/torch/nn/modules/container/sequential.h
index 4252ea1..b5c014d 100644
--- a/torch/csrc/api/include/torch/nn/modules/container/sequential.h
+++ b/torch/csrc/api/include/torch/nn/modules/container/sequential.h
@@ -4,6 +4,7 @@
 #include <torch/nn/cloneable.h>
 #include <torch/nn/module.h>
 #include <torch/nn/modules/container/any.h>
+#include <torch/nn/modules/container/named_any.h>
 #include <torch/nn/pimpl.h>
 #include <torch/types.h>
 
diff --git a/torch/csrc/api/include/torch/nn/options/adaptive.h b/torch/csrc/api/include/torch/nn/options/adaptive.h
new file mode 100644
index 0000000..7f73dbb
--- /dev/null
+++ b/torch/csrc/api/include/torch/nn/options/adaptive.h
@@ -0,0 +1,37 @@
+#pragma once
+
+#include <torch/arg.h>
+#include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/types.h>
+
+namespace torch {
+namespace nn {
+
+/// Options for the `AdaptiveLogSoftmaxWithLoss` module.
+///
+/// Example:
+/// ```
+/// AdaptiveLogSoftmaxWithLoss model(AdaptiveLogSoftmaxWithLossOptions(8, 10, {4, 8}).div_value(2.).head_bias(true));
+/// ```
+struct TORCH_API AdaptiveLogSoftmaxWithLossOptions {
+  /* implicit */ AdaptiveLogSoftmaxWithLossOptions(int64_t in_features, int64_t n_classes, std::vector<int64_t> cutoffs);
+
+  /// Number of features in the input tensor
+  TORCH_ARG(int64_t, in_features);
+
+  /// Number of classes in the dataset
+  TORCH_ARG(int64_t, n_classes);
+
+  /// Cutoffs used to assign targets to their buckets
+  TORCH_ARG(std::vector<int64_t>, cutoffs);
+
+  /// value used as an exponent to compute sizes of the clusters. Default: 4.0
+  TORCH_ARG(double, div_value) = 4.;
+
+  /// If ``true``, adds a bias term to the 'head' of
+  /// the adaptive softmax. Default: false
+  TORCH_ARG(bool, head_bias) = false;
+};
+
+} // namespace nn
+} // namespace torch
diff --git a/torch/csrc/api/src/nn/modules/adaptive.cpp b/torch/csrc/api/src/nn/modules/adaptive.cpp
new file mode 100644
index 0000000..62e5d00
--- /dev/null
+++ b/torch/csrc/api/src/nn/modules/adaptive.cpp
@@ -0,0 +1,166 @@
+#include <torch/nn/modules/adaptive.h>
+#include <torch/nn/options/activation.h>
+#include <torch/nn/options/linear.h>
+
+namespace F = torch::nn::functional;
+
+using namespace torch::indexing;
+
+namespace torch {
+namespace nn {
+
+ASMoutput::ASMoutput(Tensor output_, double loss_): output(std::move(output_)), loss(loss_) {}
+
+AdaptiveLogSoftmaxWithLossImpl::AdaptiveLogSoftmaxWithLossImpl(AdaptiveLogSoftmaxWithLossOptions options_)
+    : options(std::move(options_)),
+      shortlist_size(0),
+      n_clusters(0),
+      head_size(0) {
+  reset();
+}
+
+void AdaptiveLogSoftmaxWithLossImpl::reset() {
+  TORCH_CHECK( std::is_sorted(options.cutoffs().begin(), options.cutoffs().end()) &&
+          *std::min_element(options.cutoffs().begin(), options.cutoffs().end()) > 0 &&
+          *std::max_element(options.cutoffs().begin(), options.cutoffs().end()) <= (options.n_classes() - 1) &&
+          std::set<int64_t>(options.cutoffs().begin(), options.cutoffs().end()).size() == options.cutoffs().size(),
+          "cutoffs should be a sequence of unique, positive integers sorted in an increasing order, ",
+          "where each value is between 1 and n_classes-1");
+
+  cutoffs = options.cutoffs();
+  cutoffs.push_back(options.n_classes());
+
+  shortlist_size = cutoffs[0];
+  n_clusters = cutoffs.size() - 1;
+  head_size = shortlist_size + n_clusters;
+
+  head = this->register_module("head", Linear(LinearOptions(options.in_features(), head_size).bias(options.head_bias())));
+  tail = this->register_module("tail", ModuleList());
+
+  for (int64_t i = 0; i < n_clusters; i++) {
+    int64_t hsz = options.in_features() / static_cast<int64_t>(std::pow(options.div_value(), (i + 1)));
+    int64_t osz = cutoffs[i + 1] - cutoffs[i];
+
+    Sequential projection(
+        Linear(LinearOptions(options.in_features(), hsz).bias(false)),
+        Linear(LinearOptions(hsz, osz).bias(false)));
+    tail->push_back(projection);
+  }
+}
+
+void AdaptiveLogSoftmaxWithLossImpl::reset_parameters() {
+  head->reset_parameters();
+  for (size_t i = 0; i < tail->size(); ++i) {
+    auto i2h = tail[i]->children()[0]->as<Linear>();
+    auto h2o = tail[i]->children()[1]->as<Linear>();
+    i2h->reset_parameters();
+    h2o->reset_parameters();
+  }
+}
+
+ASMoutput AdaptiveLogSoftmaxWithLossImpl::forward(const Tensor& input, const Tensor& target) {
+  TORCH_CHECK(input.size(0) == target.size(0),
+      "Input and target should have the same size in the batch dimension.");
+
+  int64_t used_rows = 0;
+  const int64_t batch_size = target.size(0);
+
+  Tensor output = input.new_zeros(batch_size);
+  Tensor gather_inds = target.new_empty(batch_size);
+
+  auto cutoff_values = cutoffs;
+  cutoff_values.insert(cutoff_values.begin(), 0);
+
+  for (size_t i = 0; i < cutoff_values.size() - 1; ++i) {
+    int64_t low_idx = cutoff_values[i];
+    int64_t high_idx = cutoff_values[i + 1];
+
+    const Tensor target_mask = (target >= low_idx) * (target < high_idx);
+    const Tensor row_indices = target_mask.nonzero().squeeze();
+
+    if (row_indices.numel() == 0) {
+      continue;
+    }
+
+    if (i == 0) {
+      gather_inds.index_copy_(0, row_indices, target.index({target_mask}));
+    } else {
+      Tensor relative_target = target.index({target_mask}) - low_idx;
+      Tensor input_subset = input.index_select(0, row_indices);
+
+      const Tensor cluster_output = tail[i - 1]->as<Sequential>()->forward(input_subset);
+      int64_t cluster_index = shortlist_size + i - 1;
+
+      gather_inds.index_fill_(0, row_indices, cluster_index);
+
+      const Tensor cluster_logprob = F::log_softmax(cluster_output, 1);
+      const Tensor local_logprob = cluster_logprob.gather(1, relative_target.unsqueeze(1));
+      output.index_copy_(0, row_indices, local_logprob.squeeze(1));
+    }
+
+    used_rows += row_indices.numel();
+  }
+
+  TORCH_CHECK(
+    used_rows == batch_size,
+    "Target values should be in [0, ", options.n_classes() - 1, "], "
+    "but values in range [", target.min().item().toDouble(), ", ", target.max().item().toDouble(), "] "
+    "were found. ");
+
+  const Tensor head_output = head(input);
+  const Tensor head_logprob = F::log_softmax(head_output, 1);
+  output += head_logprob.gather(1, gather_inds.unsqueeze(1)).squeeze();
+  const double loss = (-output).mean().item().toDouble();
+
+  return ASMoutput(output, loss);
+}
+
+Tensor AdaptiveLogSoftmaxWithLossImpl::_get_full_log_prob(const Tensor& input, const Tensor& head_output) {
+  Tensor out = input.new_empty({head_output.size(0), options.n_classes()});
+  const Tensor head_logprob = F::log_softmax(head_output, 1);
+
+  out.index_put_({Slice(), Slice(None, shortlist_size)}, head_logprob.index({Slice(), Slice(None, shortlist_size)}));
+
+  for (size_t i = 0; i < cutoffs.size() - 1; ++i) {
+    int64_t start_idx = cutoffs[i];
+    int64_t stop_idx = cutoffs[i+1];
+    const Tensor cluster_output = tail[i]->as<Sequential>()->forward(input);
+    const Tensor cluster_logprob = F::log_softmax(cluster_output, 1);
+    auto output_logprob = cluster_logprob + head_logprob.index({Slice(), static_cast<int64_t>(shortlist_size + i)}).unsqueeze(1);
+
+    out.index_put_({Slice(), Slice(start_idx, stop_idx)}, output_logprob);
+  }
+  return out;
+}
+
+Tensor AdaptiveLogSoftmaxWithLossImpl::AdaptiveLogSoftmaxWithLossImpl::log_prob(const Tensor& input) {
+  const Tensor head_output = head(input);
+  return _get_full_log_prob(input, head_output);
+}
+
+Tensor AdaptiveLogSoftmaxWithLossImpl::predict(const Tensor& input) {
+  const Tensor head_output = head(input);
+  Tensor output = torch::argmax(head_output, 1);
+  auto not_in_shortlist = (output >= shortlist_size);
+  auto all_in_shortlist = bitwise_not(not_in_shortlist.any());
+
+  if (all_in_shortlist.item().toBool()) {
+    return output;
+  } else if (not_in_shortlist.all().item().toBool()) {
+    const Tensor log_prob = _get_full_log_prob(input, head_output);
+    return torch::argmax(log_prob, 1);
+  } else {
+    const Tensor log_prob = _get_full_log_prob(
+      input.index({not_in_shortlist}),
+      head_output.index({not_in_shortlist}));
+    output.index_put_({not_in_shortlist}, torch::argmax(log_prob, 1));
+    return output;
+  }
+}
+
+void AdaptiveLogSoftmaxWithLossImpl::pretty_print(std::ostream& stream) const {
+  stream << "torch::nn::AdaptiveLogSoftmaxWithLoss"; 
+}
+
+} // namespace nn
+} // namespace torch
diff --git a/torch/csrc/api/src/nn/options/adaptive.cpp b/torch/csrc/api/src/nn/options/adaptive.cpp
new file mode 100644
index 0000000..bff2e42
--- /dev/null
+++ b/torch/csrc/api/src/nn/options/adaptive.cpp
@@ -0,0 +1,11 @@
+#include <torch/nn/options/adaptive.h>
+
+namespace torch {
+namespace nn {
+
+AdaptiveLogSoftmaxWithLossOptions::AdaptiveLogSoftmaxWithLossOptions(
+    int64_t in_features, int64_t n_classes, std::vector<int64_t> cutoffs)
+  : in_features_(in_features), n_classes_(n_classes), cutoffs_(std::move(cutoffs)) {}
+
+} // namespace nn
+} // namespace torch