[C++ API] AdaptiveLogSoftmaxWithLoss (#29076)
Summary:
Implemented AdaptiveLogSoftmaxWithLoss and some tests for modules. Reference https://github.com/pytorch/pytorch/issues/25883
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29076
Differential Revision: D20404588
Pulled By: yf225
fbshipit-source-id: edbadf432b8173cbcc6caf83c9c03dd92dc31a37
diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt
index 76ab3eb..2798755 100644
--- a/caffe2/CMakeLists.txt
+++ b/caffe2/CMakeLists.txt
@@ -595,6 +595,7 @@
${TORCH_SRC_DIR}/csrc/api/src/nn/module.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/_functions.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/activation.cpp
+ ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/adaptive.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/batchnorm.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/normalization.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/instancenorm.cpp
@@ -613,6 +614,7 @@
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/container/functional.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/container/named_any.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/options/activation.cpp
+ ${TORCH_SRC_DIR}/csrc/api/src/nn/options/adaptive.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/options/batchnorm.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/options/embedding.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/options/instancenorm.cpp
diff --git a/test/cpp/api/modules.cpp b/test/cpp/api/modules.cpp
index 6fcf841..f5d0999 100644
--- a/test/cpp/api/modules.cpp
+++ b/test/cpp/api/modules.cpp
@@ -2407,6 +2407,50 @@
}
}
+TEST_F(ModulesTest, AdaptiveLogSoftmaxWithLoss) {
+ {
+ // log_probs actually returns log_proba
+ AdaptiveLogSoftmaxWithLoss asfm(AdaptiveLogSoftmaxWithLossOptions(8, 4, {2}).div_value(2.));
+ auto x = torch::randn({4, 8});
+ auto logprob_out = asfm->log_prob(x);
+ ASSERT_TRUE(torch::allclose(torch::exp(logprob_out).data().sum(1), torch::ones(4)));
+ }
+ {
+ // test predict
+ AdaptiveLogSoftmaxWithLoss asfm(AdaptiveLogSoftmaxWithLossOptions(8, 10, {4, 8}).div_value(2.).head_bias(true));
+ auto x = torch::randn({64, 8});
+ auto logprob_out = asfm->log_prob(x);
+ auto predict_out = asfm->predict(x);
+ ASSERT_TRUE(torch::allclose(predict_out, logprob_out.argmax(1)));
+ }
+ {
+ // cluster sizes
+ AdaptiveLogSoftmaxWithLoss asfm(AdaptiveLogSoftmaxWithLossOptions(16, 20, {4, 10, 15}).div_value(2.));
+ auto x = torch::arange(100, 132, torch::kFloat).reshape({2, 16});
+ auto y = torch::tensor({0, 17}, torch::kLong);
+ auto asm_out = asfm(x, y);
+ ASSERT_EQ(asm_out.output.sizes(), std::vector<int64_t>({2}));
+ }
+ {
+ // forward returns the same thing as log_probs
+ AdaptiveLogSoftmaxWithLoss asfm(AdaptiveLogSoftmaxWithLossOptions(8, 4, {2}).div_value(2.));
+ auto x = torch::randn({4, 8});
+ auto logprob_out = asfm->log_prob(x);
+ NLLLoss nll_loss;
+
+ for (int64_t v = 0; v < 4; ++v) {
+ auto y = torch::full({4}, v, torch::kLong);
+ auto asm_out = asfm(x, y);
+ auto out = asm_out.output;
+ auto loss = torch::tensor(asm_out.loss, torch::kFloat);
+ auto expected = nll_loss->forward(logprob_out, y);
+
+ ASSERT_TRUE(torch::allclose(loss, expected));
+ ASSERT_TRUE(torch::allclose(out, logprob_out.gather(1, y.unsqueeze(1)).squeeze()));
+ }
+ }
+}
+
TEST_F(ModulesTest, Softmax2d) {
Softmax2d m;
auto input = torch::arange(24, torch::kFloat).reshape({1, 2, 3, 4});
@@ -4545,3 +4589,38 @@
ASSERT_EQ(c10::str(MultiheadAttention(MultiheadAttentionOptions(20, 10).bias(false))),
"torch::nn::MultiheadAttention(\n (out_proj): torch::nn::Linear(in_features=20, out_features=20, bias=false)\n)");
}
+
+TEST_F(ModulesTest, PrettyPrintAdaptiveLogSoftmaxWithLoss) {
+ {
+ AdaptiveLogSoftmaxWithLoss asfm(AdaptiveLogSoftmaxWithLossOptions(8, 4, {2}).div_value(2.));
+ ASSERT_EQ(
+ c10::str(asfm),
+ "torch::nn::AdaptiveLogSoftmaxWithLoss(\n"
+ " (head): torch::nn::Linear(in_features=8, out_features=3, bias=false)\n"
+ " (tail): torch::nn::ModuleList(\n"
+ " (0): torch::nn::Sequential(\n"
+ " (0): torch::nn::Linear(in_features=8, out_features=4, bias=false)\n"
+ " (1): torch::nn::Linear(in_features=4, out_features=2, bias=false)\n"
+ " )\n"
+ " )\n"
+ ")");
+ }
+ {
+ AdaptiveLogSoftmaxWithLoss asfm(AdaptiveLogSoftmaxWithLossOptions(8, 10, {4, 8}).div_value(2.).head_bias(true));
+ ASSERT_EQ(
+ c10::str(asfm),
+ "torch::nn::AdaptiveLogSoftmaxWithLoss(\n"
+ " (head): torch::nn::Linear(in_features=8, out_features=6, bias=true)\n"
+ " (tail): torch::nn::ModuleList(\n"
+ " (0): torch::nn::Sequential(\n"
+ " (0): torch::nn::Linear(in_features=8, out_features=4, bias=false)\n"
+ " (1): torch::nn::Linear(in_features=4, out_features=4, bias=false)\n"
+ " )\n"
+ " (1): torch::nn::Sequential(\n"
+ " (0): torch::nn::Linear(in_features=8, out_features=2, bias=false)\n"
+ " (1): torch::nn::Linear(in_features=2, out_features=2, bias=false)\n"
+ " )\n"
+ " )\n"
+ ")");
+ }
+}
diff --git a/test/cpp_api_parity/parity-tracker.md b/test/cpp_api_parity/parity-tracker.md
index 57fb344..f49a5bd 100644
--- a/test/cpp_api_parity/parity-tracker.md
+++ b/test/cpp_api_parity/parity-tracker.md
@@ -69,7 +69,7 @@
torch::nn::Softmax|Yes|No
torch::nn::Softmax2d|Yes|No
torch::nn::LogSoftmax|Yes|No
-torch::nn::AdaptiveLogSoftmaxWithLoss|No|No
+torch::nn::AdaptiveLogSoftmaxWithLoss|Yes|No
torch::nn::BatchNorm1d|Yes|No
torch::nn::BatchNorm2d|Yes|No
torch::nn::BatchNorm3d|Yes|No
diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl
index b65a21d..4000a7e 100644
--- a/tools/build_variables.bzl
+++ b/tools/build_variables.bzl
@@ -237,6 +237,7 @@
"torch/csrc/api/src/nn/module.cpp",
"torch/csrc/api/src/nn/modules/_functions.cpp",
"torch/csrc/api/src/nn/modules/activation.cpp",
+ "torch/csrc/api/src/nn/modules/adaptive.cpp",
"torch/csrc/api/src/nn/modules/batchnorm.cpp",
"torch/csrc/api/src/nn/modules/normalization.cpp",
"torch/csrc/api/src/nn/modules/instancenorm.cpp",
@@ -255,6 +256,7 @@
"torch/csrc/api/src/nn/modules/container/functional.cpp",
"torch/csrc/api/src/nn/modules/container/named_any.cpp",
"torch/csrc/api/src/nn/options/activation.cpp",
+ "torch/csrc/api/src/nn/options/adaptive.cpp",
"torch/csrc/api/src/nn/options/batchnorm.cpp",
"torch/csrc/api/src/nn/options/conv.cpp",
"torch/csrc/api/src/nn/options/dropout.cpp",
diff --git a/torch/csrc/api/include/torch/nn/modules.h b/torch/csrc/api/include/torch/nn/modules.h
index 9c60486..63c0659 100644
--- a/torch/csrc/api/include/torch/nn/modules.h
+++ b/torch/csrc/api/include/torch/nn/modules.h
@@ -11,6 +11,7 @@
#include <torch/nn/modules/container/sequential.h>
// Layers
+#include <torch/nn/modules/adaptive.h>
#include <torch/nn/modules/batchnorm.h>
#include <torch/nn/modules/instancenorm.h>
#include <torch/nn/modules/conv.h>
diff --git a/torch/csrc/api/include/torch/nn/modules/adaptive.h b/torch/csrc/api/include/torch/nn/modules/adaptive.h
new file mode 100644
index 0000000..9722fbe
--- /dev/null
+++ b/torch/csrc/api/include/torch/nn/modules/adaptive.h
@@ -0,0 +1,94 @@
+#pragma once
+
+#include <torch/nn/cloneable.h>
+#include <torch/nn/module.h>
+#include <torch/nn/modules/linear.h>
+#include <torch/nn/modules/container/modulelist.h>
+#include <torch/nn/modules/container/sequential.h>
+#include <torch/nn/functional/activation.h>
+#include <torch/nn/options/adaptive.h>
+
+namespace torch {
+namespace nn {
+
+/// The output of a single invocation of an AdaptiveLogSoftmaxWithLoss
+/// module's `forward()` method.
+struct TORCH_API ASMoutput {
+ ASMoutput(Tensor output_, double loss_);
+
+ /// Tensor containing computed target log probabilities for each example
+ Tensor output;
+
+ /// Scalar representing the computed negative log likelihood loss
+ double loss;
+};
+
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AdaptiveLogSoftmaxWithLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+/// Efficient softmax approximation as described in
+/// `Efficient softmax approximation for GPUs`_ by Edouard Grave, Armand Joulin,
+/// Moustapha Cissé, David Grangier, and Hervé Jégou.
+/// See https://pytorch.org/docs/master/nn.html#torch.nn.AdaptiveLogSoftmaxWithLoss to learn
+/// about the exact behavior of this module.
+///
+/// See the documentation for `torch::nn::AdaptiveLogSoftmaxWithLossOptions` class to learn what
+/// constructor arguments are supported for this module.
+///
+/// Example:
+/// ```
+/// AdaptiveLogSoftmaxWithLoss model(AdaptiveLogSoftmaxWithLossOptions(8, 10, {4, 8}).div_value(2.).head_bias(true));
+/// ```
+class TORCH_API AdaptiveLogSoftmaxWithLossImpl : public Cloneable<AdaptiveLogSoftmaxWithLossImpl> {
+ public:
+ AdaptiveLogSoftmaxWithLossImpl(int64_t in_features, int64_t n_classes, std::vector<int64_t> cutoffs)
+ : AdaptiveLogSoftmaxWithLossImpl(AdaptiveLogSoftmaxWithLossOptions(in_features, n_classes, cutoffs)) {}
+
+ explicit AdaptiveLogSoftmaxWithLossImpl(AdaptiveLogSoftmaxWithLossOptions options_);
+
+ ASMoutput forward(const Tensor& input, const Tensor& target);
+
+ void reset() override;
+
+ void reset_parameters();
+
+ /// Pretty prints the `AdaptiveLogSoftmaxWithLoss` module into the given `stream`.
+ void pretty_print(std::ostream& stream) const override;
+
+ /// Given input tensor, and output of `head`, computes the log of the full distribution
+ Tensor _get_full_log_prob(const Tensor &input, const Tensor& head_output);
+
+ /// Computes log probabilities for all n_classes
+ Tensor log_prob(const Tensor& input);
+
+ /// This is equivalent to `log_pob(input).argmax(1)` but is more efficient in some cases
+ Tensor predict(const Tensor& input);
+
+ /// The options with which this `Module` was constructed
+ AdaptiveLogSoftmaxWithLossOptions options;
+
+ /// Cutoffs used to assign targets to their buckets. It should be an ordered Sequence
+ /// of integers sorted in the increasing order
+ std::vector<int64_t> cutoffs;
+
+ int64_t shortlist_size;
+
+ /// Number of clusters
+ int64_t n_clusters;
+
+ /// Output size of head classifier
+ int64_t head_size;
+
+ Linear head = nullptr;
+
+ ModuleList tail;
+};
+
+/// A `ModuleHolder` subclass for `AdaptiveLogSoftmaxWithLossImpl`.
+/// See the documentation for `AdaptiveLogSoftmaxWithLossImpl` class to learn what methods it
+/// provides, and examples of how to use `AdaptiveLogSoftmaxWithLoss` with `torch::nn::AdaptiveLogSoftmaxWithLossOptions`.
+/// See the documentation for `ModuleHolder` to learn about PyTorch's
+/// module storage semantics.
+TORCH_MODULE(AdaptiveLogSoftmaxWithLoss);
+
+} // namespace nn
+} // namespace torc
diff --git a/torch/csrc/api/include/torch/nn/modules/container/sequential.h b/torch/csrc/api/include/torch/nn/modules/container/sequential.h
index 4252ea1..b5c014d 100644
--- a/torch/csrc/api/include/torch/nn/modules/container/sequential.h
+++ b/torch/csrc/api/include/torch/nn/modules/container/sequential.h
@@ -4,6 +4,7 @@
#include <torch/nn/cloneable.h>
#include <torch/nn/module.h>
#include <torch/nn/modules/container/any.h>
+#include <torch/nn/modules/container/named_any.h>
#include <torch/nn/pimpl.h>
#include <torch/types.h>
diff --git a/torch/csrc/api/include/torch/nn/options/adaptive.h b/torch/csrc/api/include/torch/nn/options/adaptive.h
new file mode 100644
index 0000000..7f73dbb
--- /dev/null
+++ b/torch/csrc/api/include/torch/nn/options/adaptive.h
@@ -0,0 +1,37 @@
+#pragma once
+
+#include <torch/arg.h>
+#include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/types.h>
+
+namespace torch {
+namespace nn {
+
+/// Options for the `AdaptiveLogSoftmaxWithLoss` module.
+///
+/// Example:
+/// ```
+/// AdaptiveLogSoftmaxWithLoss model(AdaptiveLogSoftmaxWithLossOptions(8, 10, {4, 8}).div_value(2.).head_bias(true));
+/// ```
+struct TORCH_API AdaptiveLogSoftmaxWithLossOptions {
+ /* implicit */ AdaptiveLogSoftmaxWithLossOptions(int64_t in_features, int64_t n_classes, std::vector<int64_t> cutoffs);
+
+ /// Number of features in the input tensor
+ TORCH_ARG(int64_t, in_features);
+
+ /// Number of classes in the dataset
+ TORCH_ARG(int64_t, n_classes);
+
+ /// Cutoffs used to assign targets to their buckets
+ TORCH_ARG(std::vector<int64_t>, cutoffs);
+
+ /// value used as an exponent to compute sizes of the clusters. Default: 4.0
+ TORCH_ARG(double, div_value) = 4.;
+
+ /// If ``true``, adds a bias term to the 'head' of
+ /// the adaptive softmax. Default: false
+ TORCH_ARG(bool, head_bias) = false;
+};
+
+} // namespace nn
+} // namespace torch
diff --git a/torch/csrc/api/src/nn/modules/adaptive.cpp b/torch/csrc/api/src/nn/modules/adaptive.cpp
new file mode 100644
index 0000000..62e5d00
--- /dev/null
+++ b/torch/csrc/api/src/nn/modules/adaptive.cpp
@@ -0,0 +1,166 @@
+#include <torch/nn/modules/adaptive.h>
+#include <torch/nn/options/activation.h>
+#include <torch/nn/options/linear.h>
+
+namespace F = torch::nn::functional;
+
+using namespace torch::indexing;
+
+namespace torch {
+namespace nn {
+
+ASMoutput::ASMoutput(Tensor output_, double loss_): output(std::move(output_)), loss(loss_) {}
+
+AdaptiveLogSoftmaxWithLossImpl::AdaptiveLogSoftmaxWithLossImpl(AdaptiveLogSoftmaxWithLossOptions options_)
+ : options(std::move(options_)),
+ shortlist_size(0),
+ n_clusters(0),
+ head_size(0) {
+ reset();
+}
+
+void AdaptiveLogSoftmaxWithLossImpl::reset() {
+ TORCH_CHECK( std::is_sorted(options.cutoffs().begin(), options.cutoffs().end()) &&
+ *std::min_element(options.cutoffs().begin(), options.cutoffs().end()) > 0 &&
+ *std::max_element(options.cutoffs().begin(), options.cutoffs().end()) <= (options.n_classes() - 1) &&
+ std::set<int64_t>(options.cutoffs().begin(), options.cutoffs().end()).size() == options.cutoffs().size(),
+ "cutoffs should be a sequence of unique, positive integers sorted in an increasing order, ",
+ "where each value is between 1 and n_classes-1");
+
+ cutoffs = options.cutoffs();
+ cutoffs.push_back(options.n_classes());
+
+ shortlist_size = cutoffs[0];
+ n_clusters = cutoffs.size() - 1;
+ head_size = shortlist_size + n_clusters;
+
+ head = this->register_module("head", Linear(LinearOptions(options.in_features(), head_size).bias(options.head_bias())));
+ tail = this->register_module("tail", ModuleList());
+
+ for (int64_t i = 0; i < n_clusters; i++) {
+ int64_t hsz = options.in_features() / static_cast<int64_t>(std::pow(options.div_value(), (i + 1)));
+ int64_t osz = cutoffs[i + 1] - cutoffs[i];
+
+ Sequential projection(
+ Linear(LinearOptions(options.in_features(), hsz).bias(false)),
+ Linear(LinearOptions(hsz, osz).bias(false)));
+ tail->push_back(projection);
+ }
+}
+
+void AdaptiveLogSoftmaxWithLossImpl::reset_parameters() {
+ head->reset_parameters();
+ for (size_t i = 0; i < tail->size(); ++i) {
+ auto i2h = tail[i]->children()[0]->as<Linear>();
+ auto h2o = tail[i]->children()[1]->as<Linear>();
+ i2h->reset_parameters();
+ h2o->reset_parameters();
+ }
+}
+
+ASMoutput AdaptiveLogSoftmaxWithLossImpl::forward(const Tensor& input, const Tensor& target) {
+ TORCH_CHECK(input.size(0) == target.size(0),
+ "Input and target should have the same size in the batch dimension.");
+
+ int64_t used_rows = 0;
+ const int64_t batch_size = target.size(0);
+
+ Tensor output = input.new_zeros(batch_size);
+ Tensor gather_inds = target.new_empty(batch_size);
+
+ auto cutoff_values = cutoffs;
+ cutoff_values.insert(cutoff_values.begin(), 0);
+
+ for (size_t i = 0; i < cutoff_values.size() - 1; ++i) {
+ int64_t low_idx = cutoff_values[i];
+ int64_t high_idx = cutoff_values[i + 1];
+
+ const Tensor target_mask = (target >= low_idx) * (target < high_idx);
+ const Tensor row_indices = target_mask.nonzero().squeeze();
+
+ if (row_indices.numel() == 0) {
+ continue;
+ }
+
+ if (i == 0) {
+ gather_inds.index_copy_(0, row_indices, target.index({target_mask}));
+ } else {
+ Tensor relative_target = target.index({target_mask}) - low_idx;
+ Tensor input_subset = input.index_select(0, row_indices);
+
+ const Tensor cluster_output = tail[i - 1]->as<Sequential>()->forward(input_subset);
+ int64_t cluster_index = shortlist_size + i - 1;
+
+ gather_inds.index_fill_(0, row_indices, cluster_index);
+
+ const Tensor cluster_logprob = F::log_softmax(cluster_output, 1);
+ const Tensor local_logprob = cluster_logprob.gather(1, relative_target.unsqueeze(1));
+ output.index_copy_(0, row_indices, local_logprob.squeeze(1));
+ }
+
+ used_rows += row_indices.numel();
+ }
+
+ TORCH_CHECK(
+ used_rows == batch_size,
+ "Target values should be in [0, ", options.n_classes() - 1, "], "
+ "but values in range [", target.min().item().toDouble(), ", ", target.max().item().toDouble(), "] "
+ "were found. ");
+
+ const Tensor head_output = head(input);
+ const Tensor head_logprob = F::log_softmax(head_output, 1);
+ output += head_logprob.gather(1, gather_inds.unsqueeze(1)).squeeze();
+ const double loss = (-output).mean().item().toDouble();
+
+ return ASMoutput(output, loss);
+}
+
+Tensor AdaptiveLogSoftmaxWithLossImpl::_get_full_log_prob(const Tensor& input, const Tensor& head_output) {
+ Tensor out = input.new_empty({head_output.size(0), options.n_classes()});
+ const Tensor head_logprob = F::log_softmax(head_output, 1);
+
+ out.index_put_({Slice(), Slice(None, shortlist_size)}, head_logprob.index({Slice(), Slice(None, shortlist_size)}));
+
+ for (size_t i = 0; i < cutoffs.size() - 1; ++i) {
+ int64_t start_idx = cutoffs[i];
+ int64_t stop_idx = cutoffs[i+1];
+ const Tensor cluster_output = tail[i]->as<Sequential>()->forward(input);
+ const Tensor cluster_logprob = F::log_softmax(cluster_output, 1);
+ auto output_logprob = cluster_logprob + head_logprob.index({Slice(), static_cast<int64_t>(shortlist_size + i)}).unsqueeze(1);
+
+ out.index_put_({Slice(), Slice(start_idx, stop_idx)}, output_logprob);
+ }
+ return out;
+}
+
+Tensor AdaptiveLogSoftmaxWithLossImpl::AdaptiveLogSoftmaxWithLossImpl::log_prob(const Tensor& input) {
+ const Tensor head_output = head(input);
+ return _get_full_log_prob(input, head_output);
+}
+
+Tensor AdaptiveLogSoftmaxWithLossImpl::predict(const Tensor& input) {
+ const Tensor head_output = head(input);
+ Tensor output = torch::argmax(head_output, 1);
+ auto not_in_shortlist = (output >= shortlist_size);
+ auto all_in_shortlist = bitwise_not(not_in_shortlist.any());
+
+ if (all_in_shortlist.item().toBool()) {
+ return output;
+ } else if (not_in_shortlist.all().item().toBool()) {
+ const Tensor log_prob = _get_full_log_prob(input, head_output);
+ return torch::argmax(log_prob, 1);
+ } else {
+ const Tensor log_prob = _get_full_log_prob(
+ input.index({not_in_shortlist}),
+ head_output.index({not_in_shortlist}));
+ output.index_put_({not_in_shortlist}, torch::argmax(log_prob, 1));
+ return output;
+ }
+}
+
+void AdaptiveLogSoftmaxWithLossImpl::pretty_print(std::ostream& stream) const {
+ stream << "torch::nn::AdaptiveLogSoftmaxWithLoss";
+}
+
+} // namespace nn
+} // namespace torch
diff --git a/torch/csrc/api/src/nn/options/adaptive.cpp b/torch/csrc/api/src/nn/options/adaptive.cpp
new file mode 100644
index 0000000..bff2e42
--- /dev/null
+++ b/torch/csrc/api/src/nn/options/adaptive.cpp
@@ -0,0 +1,11 @@
+#include <torch/nn/options/adaptive.h>
+
+namespace torch {
+namespace nn {
+
+AdaptiveLogSoftmaxWithLossOptions::AdaptiveLogSoftmaxWithLossOptions(
+ int64_t in_features, int64_t n_classes, std::vector<int64_t> cutoffs)
+ : in_features_(in_features), n_classes_(n_classes), cutoffs_(std::move(cutoffs)) {}
+
+} // namespace nn
+} // namespace torch