Implement histogramdd on CPU (#65318)

Summary:
Implements `torch.histogramdd` analogous to `numpy.histogramdd`.

Builds on https://github.com/pytorch/pytorch/pull/58780, generalizing the existing `torch.histogram` kernel to handle D-dimensional inputs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/65318

Reviewed By: soulitzer

Differential Revision: D31654555

Pulled By: saketh-are

fbshipit-source-id: 14b781fac0fd3698b052dbd6f0fda46e50d4c5f1
diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h
index ad67144..55f1f40 100644
--- a/aten/src/ATen/core/aten_interned_strings.h
+++ b/aten/src/ATen/core/aten_interned_strings.h
@@ -379,6 +379,7 @@
 _(aten, hinge_embedding_loss) \
 _(aten, histc) \
 _(aten, histogram) \
+_(aten, histogramdd) \
 _(aten, hspmm) \
 _(aten, hsplit) \
 _(aten, hstack) \
diff --git a/aten/src/ATen/native/Histogram.cpp b/aten/src/ATen/native/Histogram.cpp
index 4d678a8..3caca73 100644
--- a/aten/src/ATen/native/Histogram.cpp
+++ b/aten/src/ATen/native/Histogram.cpp
@@ -5,118 +5,198 @@
 #include <ATen/native/Histogram.h>
 #include <ATen/native/Resize.h>
 
+#include <numeric>
 #include <tuple>
+#include <vector>
+#include <functional>
+#include <c10/util/ArrayRef.h>
 #include <c10/core/ScalarType.h>
 #include <c10/core/DefaultDtype.h>
 
-/* Implements a numpy-like histogram function running on cpu
- * https://numpy.org/doc/stable/reference/generated/numpy.histogram.html
+/* Implements a numpy-like histogramdd function running on cpu
+ * https://numpy.org/doc/stable/reference/generated/numpy.histogramdd.html
  *
- * - torch.histogram(input, bins, range=None, weight=None, density=False)
- *   input     - tensor containing the input values. The histogram is computed over the flattened values.
- *   bins      - int or 1D tensor. If int, defines the number of equal-width bins. If tensor, defines the
- *               sequence of bin edges including the rightmost edge.
- *   range     - (float, float), optional. Defines the range of the bins.
- *   weight    - tensor, optional. If provided, weight should have the same shape as input. Each value
- *               in input contributes its associated weight towards its bin's result (instead of 1).
- *   density   - bool, optional. If False, the result will contain the number of samples (or total weight)
- *               in each bin. If True, the result is the value of the probability density function at the
- *               bin, normalized such that the integral over the range is 1.
+ * See the docstr for torch.histogramdd in torch/functional.py for further explanation.
+ *
+ * - torch.histogramdd(input, bins, range=None, weight=None, density=False)
+ *   input     - tensor with shape (M, N). input is interpreted as M coordinates in N-dimensional space.
+ *               If a tensor with more than 2 dimensions is passed, all but the last dimension will be flattened.
+ *   bins      - int[] of length N or tensor list of length N. If int[], defines the number of equal-width bins
+ *               in each dimension. If tensor list, defines the sequences of bin edges, including rightmost edges,
+ *               for each dimension.
+ *   range     - float[] of length 2 * N, optional. If specified, defines the leftmost and rightmost bin edges
+ *               for each dimension.
+ *   weight    - tensor, optional. If provided, weight should have the same shape as input excluding its last dimension.
+ *               Each N-dimensional value in input contributes its associated weight towards its bin's result.
+ *               If weight is not specified, each value has weight 1 by default.
+ *   density   - bool, optional. If false (default), the result will contain the total count (weight) in each bin.
+ *               If True, each count (weight) is divided by the total count (total weight), then divided by the
+ *               volume of its associated bin.
  *
  * Returns:
- *   hist      - 1D tensor containing the values of the histogram.
- *   bin_edges - 1D tensor containing the edges of the histogram bins. Contains hist.numel() + 1 elements.
+ *   hist      - N-dimensional tensor containing the values of the histogram.
+ *   bin_edges - tensor list of length N containing the edges of the histogram bins in each dimension.
  *               Bins include their left edge and exclude their right edge, with the exception of the
- *               rightmost bin which includes both of its edges.
+ *               rightmost bin in each dimension which includes both of its edges.
  *
  * Restrictions are defined in histogram_check_inputs() and in select_outer_bin_edges().
  */
 
 namespace at { namespace native {
 
-DEFINE_DISPATCH(histogram_stub);
-
-DEFINE_DISPATCH(histogram_linear_stub);
+DEFINE_DISPATCH(histogramdd_stub);
+DEFINE_DISPATCH(histogramdd_linear_stub);
 
 namespace {
 
 /* Checks properties of input tensors input, bins, and weight.
  */
-void histogram_check_inputs(const Tensor& input, const Tensor& bins, const c10::optional<Tensor>& weight) {
-    TORCH_CHECK(input.dtype() == bins.dtype(), "torch.histogram: input tensor and bins tensor should",
-            " have the same dtype, but got input ", input.dtype(), " and bins ", bins.dtype());
+void histogramdd_check_inputs(const Tensor& input, const TensorList& bins, const c10::optional<Tensor>& weight) {
+    TORCH_CHECK(input.dim() >= 2, "torch.histogramdd: input tensor should have at least 2 dimensions, but got ",
+                input.dim());
 
-    TORCH_CHECK(bins.dim() == 1, "torch.histogram: bins tensor should have dimension 1,",
-            " but got ", bins.dim(), " dimension");
+    const int64_t N = input.size(-1);
 
-    TORCH_CHECK(bins.numel() > 0, "torch.histogram: bins tensor should have at least 1 element,",
-            " but got ", bins.numel(), " elements");
+    TORCH_CHECK(bins.size() == N, "torch.histogramdd: expected ", N, " sequences of bin edges for a ", N,
+                "-dimensional histogram but got ", bins.size());
+
+    auto input_dtype = input.dtype();
+    for (int64_t dim = 0; dim < N; dim++) {
+        const Tensor& dim_bins = bins[dim];
+
+        auto bins_dtype = dim_bins.dtype();
+        TORCH_CHECK(input_dtype == bins_dtype, "torch.histogramdd: input tensor and bins tensors should",
+                " have the same dtype, but got input with dtype ", input_dtype,
+                " and bins for dimension ", dim, " with dtype ", bins_dtype);
+
+        const int64_t dim_bins_dim = dim_bins.dim();
+        TORCH_CHECK(dim_bins_dim == 1, "torch.histogramdd: bins tensor should have one dimension,",
+                " but got ", dim_bins_dim, " dimensions in the bins tensor for dimension ", dim);
+
+        const int64_t numel = dim_bins.numel();
+        TORCH_CHECK(numel > 0, "torch.histogramdd: bins tensor should have at least 1 element,",
+                " but got ", numel, " elements in the bins tensor for dimension ", dim);
+    }
 
     if (weight.has_value()) {
-        TORCH_CHECK(input.dtype() == weight.value().dtype(), "torch.histogram: if weight tensor is provided,"
+        TORCH_CHECK(input.dtype() == weight.value().dtype(), "torch.histogramdd: if weight tensor is provided,"
                 " input tensor and weight tensor should have the same dtype, but got input(", input.dtype(), ")",
                 ", and weight(", weight.value().dtype(), ")");
 
-        TORCH_CHECK(input.sizes() == weight.value().sizes(), "torch.histogram: if weight tensor is provided,"
-                " input tensor and weight tensor should have the same shape, but got input(", input.sizes(), ")",
-                ", and weight(", weight.value().sizes(), ")");
+        /* If a weight tensor is provided, we expect its shape to match that of
+         * the input tensor excluding its innermost dimension N.
+         */
+        auto input_sizes = input.sizes().vec();
+        input_sizes.pop_back();
+
+        auto weight_sizes = weight.value().sizes().vec();
+        if (weight_sizes.empty()) {
+            // correctly handle scalars
+            weight_sizes = {1};
+        }
+
+        TORCH_CHECK(input_sizes == weight_sizes, "torch.histogramdd: if weight tensor is provided it should have"
+                " the same shape as the input tensor excluding its innermost dimension, but got input with shape ",
+                input.sizes(), " and weight with shape ", weight.value().sizes());
     }
 }
 
 /* Checks properties of output tensors hist and bin_edges, then resizes them.
  */
-void histogram_prepare_out(const Tensor& input, int64_t bin_ct,
-        const Tensor& hist, const Tensor& bin_edges) {
+void histogramdd_prepare_out(const Tensor& input, const std::vector<int64_t>& bin_ct,
+        const Tensor& hist, const TensorList& bin_edges) {
+    const int64_t N = input.size(-1);
+
+    TORCH_INTERNAL_ASSERT((int64_t)bin_ct.size() == N);
+    TORCH_INTERNAL_ASSERT((int64_t)bin_edges.size() == N);
+
     TORCH_CHECK(input.dtype() == hist.dtype(), "torch.histogram: input tensor and hist tensor should",
             " have the same dtype, but got input ", input.dtype(), " and hist ", hist.dtype());
 
-    TORCH_CHECK(input.dtype() == bin_edges.dtype(), "torch.histogram: input tensor and bin_edges tensor should",
-            " have the same dtype, but got input ", input.dtype(), " and bin_edges ", bin_edges.dtype());
+    for (int64_t dim = 0; dim < N; dim++) {
+        TORCH_CHECK(input.dtype() == bin_edges[dim].dtype(), "torch.histogram: input tensor and bin_edges tensor should",
+                " have the same dtype, but got input ", input.dtype(), " and bin_edges ", bin_edges[dim].dtype(),
+                " for dimension ", dim);
 
-    TORCH_CHECK(bin_ct > 0,
-            "torch.histogram(): bins must be > 0, but got ", bin_ct);
+        TORCH_CHECK(bin_ct[dim] > 0,
+                "torch.histogram(): bins must be > 0, but got ", bin_ct[dim], " for dimension ", dim);
+
+        at::native::resize_output(bin_edges[dim], bin_ct[dim] + 1);
+    }
 
     at::native::resize_output(hist, bin_ct);
-
-    at::native::resize_output(bin_edges, bin_ct + 1);
-
-    TORCH_CHECK(hist.is_contiguous(), "torch.histogram: hist tensor must be contiguous");
 }
 
-/* Determines the outermost bin edges.
- */
-std::pair<double, double> select_outer_bin_edges(const Tensor& input, c10::optional<c10::ArrayRef<double>> range) {
-    TORCH_CHECK(!range.has_value() || range.value().size() == 2, "torch.histogram: range should have 2 elements",
-            " if specified, but got ", range.value().size());
+void histogramdd_prepare_out(const Tensor& input, TensorList bins,
+        const Tensor& hist, const TensorList& bin_edges) {
+    std::vector<int64_t> bin_ct(bins.size());
+    std::transform(bins.begin(), bins.end(), bin_ct.begin(), [](Tensor t) { return t.numel() - 1; });
+    histogramdd_prepare_out(input, bin_ct, hist, bin_edges);
+}
 
-    // Default range for empty input matching numpy.histogram's default
-    double leftmost_edge = 0., rightmost_edge = 1.;
+template<typename scalar_t>
+void infer_bin_edges_from_input(const Tensor& input, const int64_t N,
+        std::vector<double> &leftmost_edges, std::vector<double> &rightmost_edges) {
+    // Calls aminmax on input with dim=0, reducing all but the innermost dimension of input.
+    Tensor min, max;
+    std::tie(min, max) = aminmax(input, 0);
+
+    TORCH_INTERNAL_ASSERT(min.is_contiguous() && max.is_contiguous());
+
+    const scalar_t *min_data = min.data_ptr<scalar_t>();
+    std::copy(min_data, min_data + N, leftmost_edges.begin());
+
+    const scalar_t *max_data = max.data_ptr<scalar_t>();
+    std::copy(max_data, max_data + N, rightmost_edges.begin());
+}
+
+/* Determines the outermost bin edges. For simplicity when calling into aminmax,
+ * assumes that input has already been reshaped to (M, N).
+ */
+std::pair<std::vector<double>, std::vector<double>>
+select_outer_bin_edges(const Tensor& input, c10::optional<c10::ArrayRef<double>> range) {
+    TORCH_INTERNAL_ASSERT(input.dim() == 2, "expected input to have shape (M, N)");
+    const int64_t N = input.size(-1);
+
+    // Default ranges for empty input matching numpy.histogram's default
+    std::vector<double> leftmost_edges(N, 0.);
+    std::vector<double> rightmost_edges(N, 1.);
 
     if (range.has_value()) {
         // range is specified
-        leftmost_edge = range.value()[0];
-        rightmost_edge = range.value()[1];
+        TORCH_CHECK((int64_t)range.value().size() == 2 * N, "torch.histogramdd: for a ", N, "-dimensional histogram",
+                " range should have ", 2 * N, " elements, but got ", range.value().size());
+
+        for (int64_t dim = 0; dim < N; dim++) {
+            leftmost_edges[dim] = range.value()[2 * dim];
+            rightmost_edges[dim] = range.value()[2 * dim + 1];
+        }
     } else if (input.numel() > 0) {
         // non-empty input
-        auto extrema = _aminmax(input);
-        leftmost_edge = std::get<0>(extrema).item<double>();
-        rightmost_edge = std::get<1>(extrema).item<double>();
+        AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "histogramdd", [&]() {
+            infer_bin_edges_from_input<scalar_t>(input, N, leftmost_edges, rightmost_edges);
+        });
     }
 
-    TORCH_CHECK(!(std::isinf(leftmost_edge) || std::isinf(rightmost_edge) ||
-            std::isnan(leftmost_edge) || std::isnan(rightmost_edge)),
-            "torch.histogram: range of [", leftmost_edge, ", ", rightmost_edge, "] is not finite");
+    for (int64_t dim = 0; dim < N; dim++) {
+        double leftmost_edge = leftmost_edges[dim];
+        double rightmost_edge = rightmost_edges[dim];
 
-    TORCH_CHECK(leftmost_edge <= rightmost_edge, "torch.histogram: min should not exceed max, but got",
-            " min ", leftmost_edge, " max ", rightmost_edge);
+        TORCH_CHECK(std::isfinite(leftmost_edge) && std::isfinite(rightmost_edge),
+                "torch.histogramdd: dimension ", dim, "'s range [",
+                leftmost_edge, ", ", rightmost_edge, "] is not finite");
 
-    // Expand empty range to match numpy behavior and avoid division by 0 in normalization
-    if (leftmost_edge == rightmost_edge) {
-        leftmost_edge -= 0.5;
-        rightmost_edge += 0.5;
+        TORCH_CHECK(leftmost_edge <= rightmost_edge, "torch.histogramdd: min should not exceed max, but got",
+                " min ", leftmost_edge, " max ", rightmost_edge, " for dimension ", dim);
+
+        // Expand empty range to match numpy behavior and avoid division by 0 in normalization
+        if (leftmost_edge == rightmost_edge) {
+            leftmost_edges[dim] -= 0.5;
+            rightmost_edges[dim] += 0.5;
+        }
     }
 
-    return std::make_pair(leftmost_edge, rightmost_edge);
+    return std::make_pair(leftmost_edges, rightmost_edges);
 }
 
 /* histc's version of the logic for outermost bin edges.
@@ -148,19 +228,118 @@
 
 } // namespace
 
+std::vector<Tensor> allocate_bin_edges_tensors(const Tensor& self) {
+    TORCH_CHECK(self.dim() >= 2, "torch.histogramdd: input tensor should have at least 2 dimensions");
+    const int64_t N = self.size(-1);
+    std::vector<Tensor> bin_edges_out(N);
+    for (int64_t dim = 0; dim < N; dim++) {
+        bin_edges_out[dim] = at::empty({0}, self.options(), MemoryFormat::Contiguous);
+    }
+    return bin_edges_out;
+}
+
+/* Versions of histogramdd in which bins is a Tensor[] defining the sequences of bin edges.
+ */
+Tensor& histogramdd_out_cpu(const Tensor& self, TensorList bins,
+        const c10::optional<Tensor>& weight, bool density,
+        Tensor& hist, TensorList& bin_edges) {
+    histogramdd_check_inputs(self, bins, weight);
+    histogramdd_prepare_out(self, bins, hist, bin_edges);
+
+    for (size_t dim = 0; dim < bins.size(); dim++) {
+        bin_edges[dim].copy_(bins[dim]);
+    }
+
+    histogramdd_stub(self.device().type(), self, weight, density, hist, bin_edges);
+    return hist;
+}
+
+Tensor histogramdd_cpu(const Tensor& self, TensorList bins,
+        const c10::optional<Tensor>& weight, bool density) {
+    Tensor hist = at::empty({0}, self.options(), MemoryFormat::Contiguous);
+    std::vector<Tensor> bin_edges_out = allocate_bin_edges_tensors(self);
+    TensorList bin_edges_out_tl(bin_edges_out);
+
+    histogramdd_out_cpu(self, bins, weight, density, hist, bin_edges_out_tl);
+    return hist;
+}
+
+/* Versions of histogramdd in which bins is an int[]
+ * defining the number of bins in each dimension.
+ */
+std::vector<Tensor>& histogramdd_bin_edges_out_cpu(const Tensor& self, IntArrayRef bin_ct,
+        c10::optional<c10::ArrayRef<double>> range,
+        const c10::optional<Tensor>& weight, bool density,
+        std::vector<Tensor>& bin_edges_out) {
+    TensorList bin_edges_out_tl(bin_edges_out);
+
+    const int64_t N = self.size(-1);
+    const int64_t M = std::accumulate(self.sizes().begin(), self.sizes().end() - 1,
+            (int64_t)1, std::multiplies<int64_t>());
+    Tensor reshaped_self = self.reshape({ M, N });
+
+    auto outer_bin_edges = select_outer_bin_edges(reshaped_self, range);
+
+    for (int64_t dim = 0; dim < N; dim++) {
+        linspace_cpu_out(outer_bin_edges.first[dim], outer_bin_edges.second[dim],
+                bin_ct[dim] + 1, bin_edges_out[dim]);
+    }
+
+    return bin_edges_out;
+}
+
+std::vector<Tensor> histogramdd_bin_edges_cpu(const Tensor& self, IntArrayRef bin_ct,
+        c10::optional<c10::ArrayRef<double>> range,
+        const c10::optional<Tensor>& weight, bool density) {
+    std::vector<Tensor> bin_edges_out = allocate_bin_edges_tensors(self);
+    return histogramdd_bin_edges_out_cpu(self, bin_ct, range, weight, density, bin_edges_out);
+}
+
+Tensor& histogramdd_out_cpu(const Tensor& self, IntArrayRef bin_ct,
+        c10::optional<c10::ArrayRef<double>> range,
+        const c10::optional<Tensor>& weight, bool density,
+        Tensor& hist, TensorList& bin_edges) {
+    std::vector<Tensor> bins = histogramdd_bin_edges_cpu(self, bin_ct, range, weight, density);
+
+    histogramdd_check_inputs(self, bins, weight);
+    histogramdd_prepare_out(self, bins, hist, bin_edges);
+
+    for (size_t dim = 0; dim < bins.size(); dim++) {
+        bin_edges[dim].copy_(bins[dim]);
+    }
+
+    histogramdd_linear_stub(self.device().type(), self, weight, density, hist, bin_edges, true);
+    return hist;
+}
+
+Tensor histogramdd_cpu(const Tensor& self, IntArrayRef bin_ct,
+        c10::optional<c10::ArrayRef<double>> range,
+        const c10::optional<Tensor>& weight, bool density) {
+    Tensor hist = at::empty({0}, self.options(), MemoryFormat::Contiguous);
+    std::vector<Tensor> bin_edges_out = allocate_bin_edges_tensors(self);
+    TensorList bin_edges_out_tl(bin_edges_out);
+
+    histogramdd_out_cpu(self, bin_ct, range, weight, density, hist, bin_edges_out_tl);
+    return hist;
+}
+
 /* Versions of histogram in which bins is a Tensor defining the sequence of bin edges.
  */
 std::tuple<Tensor&, Tensor&>
 histogram_out_cpu(const Tensor& self, const Tensor& bins,
         const c10::optional<Tensor>& weight, bool density,
         Tensor& hist, Tensor& bin_edges) {
-    histogram_check_inputs(self, bins, weight);
-    histogram_prepare_out(self, bins.numel() - 1, hist, bin_edges);
+    Tensor reshaped_self = self.reshape({ self.numel(), 1 });
+    c10::optional<Tensor> reshaped_weight = weight.has_value()
+        ? weight.value().reshape({ weight.value().numel() }) : weight;
+    TensorList bins_in = bins;
+    TensorList bins_out = bin_edges;
 
-    bin_edges.copy_(bins);
-    histogram_stub(self.device().type(), self, weight, density, hist, bin_edges);
+    histogramdd_out_cpu(reshaped_self, bins_in, reshaped_weight, density, hist, bins_out);
+
     return std::forward_as_tuple(hist, bin_edges);
 }
+
 std::tuple<Tensor, Tensor>
 histogram_cpu(const Tensor& self, const Tensor& bins,
         const c10::optional<Tensor>& weight, bool density) {
@@ -175,14 +354,22 @@
 histogram_out_cpu(const Tensor& self, int64_t bin_ct, c10::optional<c10::ArrayRef<double>> range,
         const c10::optional<Tensor>& weight, bool density,
         Tensor& hist, Tensor& bin_edges) {
-    histogram_prepare_out(self, bin_ct, hist, bin_edges);
-    auto outer_bin_edges = select_outer_bin_edges(self, range);
-    linspace_cpu_out(outer_bin_edges.first, outer_bin_edges.second, bin_ct + 1, bin_edges);
-    histogram_check_inputs(self, bin_edges, weight);
+    Tensor reshaped_self = self.reshape({ self.numel(), 1 });
+    c10::optional<Tensor> reshaped_weight = weight.has_value()
+        ? weight.value().reshape({ weight.value().numel() }) : weight;
+    TensorList bins_in = bin_edges;
+    TensorList bins_out = bin_edges;
 
-    histogram_linear_stub(self.device().type(), self, weight, density, hist, bin_edges, true);
+    histogramdd_prepare_out(reshaped_self, std::vector<int64_t>{bin_ct}, hist, bins_out);
+    auto outer_bin_edges = select_outer_bin_edges(reshaped_self, range);
+    linspace_cpu_out(outer_bin_edges.first[0], outer_bin_edges.second[0], bin_ct + 1, bin_edges);
+
+    histogramdd_check_inputs(reshaped_self, bins_in, reshaped_weight);
+
+    histogramdd_linear_stub(reshaped_self.device().type(), reshaped_self, reshaped_weight, density, hist, bin_edges, true);
     return std::forward_as_tuple(hist, bin_edges);
 }
+
 std::tuple<Tensor, Tensor>
 histogram_cpu(const Tensor& self, int64_t bin_ct, c10::optional<c10::ArrayRef<double>> range,
         const c10::optional<Tensor>& weight, bool density) {
@@ -196,15 +383,23 @@
 Tensor& histogram_histc_cpu_out(const Tensor& self, int64_t bin_ct,
         const Scalar& min, const Scalar& max, Tensor& hist) {
     Tensor bin_edges = at::empty({0}, self.options());
-    histogram_prepare_out(self, bin_ct, hist, bin_edges);
+
+    Tensor reshaped = self.reshape({ self.numel(), 1 });
+    TensorList bins_in = bin_edges;
+    TensorList bins_out = bin_edges;
+
+    histogramdd_prepare_out(reshaped, std::vector<int64_t>{bin_ct}, hist, bins_out);
+
     auto outer_bin_edges = histc_select_outer_bin_edges(self, min, max);
     linspace_cpu_out(outer_bin_edges.first, outer_bin_edges.second, bin_ct + 1, bin_edges);
-    histogram_check_inputs(self, bin_edges, {});
 
-    histogram_linear_stub(self.device().type(), self,
+    histogramdd_check_inputs(reshaped, bins_in, {});
+
+    histogramdd_linear_stub(reshaped.device().type(), reshaped,
             c10::optional<Tensor>(), false, hist, bin_edges, false);
     return hist;
 }
+
 Tensor histogram_histc_cpu(const Tensor& self, int64_t bin_ct,
         const Scalar& min, const Scalar& max) {
     Tensor hist = at::empty({0}, self.options(), MemoryFormat::Contiguous);
diff --git a/aten/src/ATen/native/Histogram.h b/aten/src/ATen/native/Histogram.h
index 58d1929..02dbe47 100644
--- a/aten/src/ATen/native/Histogram.h
+++ b/aten/src/ATen/native/Histogram.h
@@ -7,11 +7,10 @@
 
 namespace at { namespace native {
 
-using histogram_fn = void(*)(const Tensor&, const c10::optional<Tensor>&, bool, Tensor&, const Tensor&);
-using histogram_linear_fn = void(*)(const Tensor&, const c10::optional<Tensor>&, bool, Tensor&, const Tensor&, bool);
+using histogramdd_fn = void(*)(const Tensor&, const c10::optional<Tensor>&, bool, Tensor&, const TensorList&);
+using histogramdd_linear_fn = void(*)(const Tensor&, const c10::optional<Tensor>&, bool, Tensor&, const TensorList&, bool);
 
-DECLARE_DISPATCH(histogram_fn, histogram_stub);
-
-DECLARE_DISPATCH(histogram_linear_fn, histogram_linear_stub);
+DECLARE_DISPATCH(histogramdd_fn, histogramdd_stub);
+DECLARE_DISPATCH(histogramdd_linear_fn, histogramdd_linear_stub);
 
 }} // namespace at::native
diff --git a/aten/src/ATen/native/cpu/HistogramKernel.cpp b/aten/src/ATen/native/cpu/HistogramKernel.cpp
index 672aa0d..d6b99da 100644
--- a/aten/src/ATen/native/cpu/HistogramKernel.cpp
+++ b/aten/src/ATen/native/cpu/HistogramKernel.cpp
@@ -7,7 +7,10 @@
 
 #include <algorithm>
 #include <mutex>
+#include <numeric>
 #include <tuple>
+#include <functional>
+#include <ATen/TensorIndexing.h>
 
 namespace at { namespace native {
 
@@ -15,29 +18,34 @@
 
 constexpr int64_t HISTOGRAM_GRAIN_SIZE = 200;
 
-/* The main algorithm. Maps the elements of input into the bins defined by bin_edges.
- * Expects that the elements of bin_edges are increasing; behavior is otherwise undefined.
+/* The main algorithm. Expects that the input tensor has shape (N, D).
+ * Expects that bin_edges contains D one-dimensional tensors, each specifying
+ * an increasing sequences of bin edges.
  *
- * Accepts a template argument of type BIN_SELECTION_ALGORITHM specifying how the
- * elements of input should be mapped into the bins:
+ * Interprets the input as N different D-dimensional coordinates and maps them
+ * into the D-dimensional bins defined by bin_edges, accumulating a D-dimensional
+ * histogram in the hist tensor.
  *
- *     - LINEAR_INTERPOLATION: bin_edges must contain a linear progression.
- *       Elements of input are mapped to bins by computing
+ * Accepts a template argument of type BIN_SELECTION_ALGORITHM specifying how
+ * the scalars in each dimension should be mapped into the dimension's bins:
+ *
+ *     - LINEAR_INTERPOLATION: each bin edge sequence must form a linear progression.
+ *       Scalars are mapped to bins by computing
  *           (element - leftmost_edge)/(rightmost_edge - leftmost_edge) * bin_ct
  *       and truncating the result to an integer.
  *
- *       Results may not be perfectly consistent with the boundaries specified in bin_edges
- *       due to precision issues.
+ *       This is the fastest option, but its results may not be perfectly consistent
+ *       with the boundaries specified in bin_edges due to precision issues.
  *
  *       Used by torch.histc, which doesn't need consistency with bin_edges as it does not
  *       return bin_edges. Additionally, this implementation is identical to the legacy histc
  *       implementation, which was replaced when histogram was implemented.
  *
- *     - LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH: Also expects that bin_edges contains a
- *       linear progression. For each element, if 'pos' is the bin selected by the
+ *     - LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH: Also expects that each bin edge sequence
+ *       forms a linear progression. For each scalar, if 'pos' is the bin selected by the
  *       LINEAR_INTERPOLATION approach, this approach inspects the boundaries in bin_edges to
- *       place the element into pos - 1, pos, or pos + 1. The "local search" over neighboring
- *       bins allows for correction of misclassifications due to precision issues (an element
+ *       place the scalar into pos - 1, pos, or pos + 1. The "local search" over neighboring
+ *       bins allows for correction of misclassifications due to precision issues (a scalar
  *       very close to a bin_edge may be misclassified by LINEAR_INTERPOLATION).
  *
  *       Should produce the same output as the general case BINARY_SEARCH, but run about
@@ -54,8 +62,6 @@
  *
  * See discussion at https://github.com/pytorch/pytorch/pull/58780#discussion_r648604866
  * for further details on relative performance of the bin selection algorithms.
- *
- * Accumulates the total weight in each bin into the hist tensor.
  */
 enum BIN_SELECTION_ALGORITHM {
     LINEAR_INTERPOLATION,
@@ -63,17 +69,28 @@
     BINARY_SEARCH,
 };
 template<typename input_t, BIN_SELECTION_ALGORITHM algorithm>
-void histogram_cpu_contiguous(Tensor& hist, const Tensor& bin_edges,
+void histogramdd_cpu_contiguous(Tensor& hist, const TensorList& bin_edges,
         const Tensor& input, const c10::optional<Tensor>& weight) {
-    TORCH_INTERNAL_ASSERT(hist.is_contiguous());
-    TORCH_INTERNAL_ASSERT(bin_edges.is_contiguous());
-    TORCH_INTERNAL_ASSERT(hist.numel() + 1 == bin_edges.numel());
-    TORCH_INTERNAL_ASSERT(input.dim() == 1);
-    TORCH_INTERNAL_ASSERT(!weight.has_value() || weight.value().dim() == 1);
+    TORCH_INTERNAL_ASSERT(input.dim() == 2);
 
-    const int64_t numel_in = input.numel();
+    const int64_t N = input.size(0);
+    if (weight.has_value()) {
+        TORCH_INTERNAL_ASSERT(weight.value().dim() == 1 && weight.value().numel() == N);
+    }
 
-    TensorAccessor<input_t, 1> accessor_in = input.accessor<input_t, 1>();
+    const int64_t D = input.size(1);
+    TORCH_INTERNAL_ASSERT(int64_t(bin_edges.size()) == D);
+    for (int64_t dim = 0; dim < D; dim++) {
+        TORCH_INTERNAL_ASSERT(bin_edges[dim].is_contiguous());
+        TORCH_INTERNAL_ASSERT(hist.size(dim) + 1 == bin_edges[dim].numel());
+    }
+
+    if (D == 0) {
+        // hist is an empty tensor in this case; nothing to do here
+        return;
+    }
+
+    TensorAccessor<input_t, 2> accessor_in = input.accessor<input_t, 2>();
 
     /* Constructs a c10::optional<TensorAccessor> containing an accessor iff
      * the optional weight tensor has a value.
@@ -82,73 +99,91 @@
             ? c10::optional<TensorAccessor<input_t, 1>>(weight.value().accessor<input_t, 1>())
             : c10::optional<TensorAccessor<input_t, 1>>();
 
-    const int64_t numel_be = bin_edges.numel();
-    const input_t *data_be = bin_edges.data_ptr<input_t>();
+    std::vector<input_t*> bin_seq(D);
+    std::vector<int64_t> num_bin_edges(D);
+    std::vector<input_t> leftmost_edge(D), rightmost_edge(D);
 
-    const input_t leftmost_bin_edge = data_be[0];
-    const input_t rightmost_bin_edge = data_be[numel_be - 1];
+    for (int64_t dim = 0; dim < D; dim++) {
+        bin_seq[dim] = bin_edges[dim].data_ptr<input_t>();
+        num_bin_edges[dim] = bin_edges[dim].numel();
+        leftmost_edge[dim] = bin_seq[dim][0];
+        rightmost_edge[dim] = bin_seq[dim][num_bin_edges[dim] - 1];
+    }
 
-    input_t *data_out = hist.data_ptr<input_t>();
+    int64_t GRAIN_SIZE = std::max(int64_t(1), HISTOGRAM_GRAIN_SIZE / D);
 
     /* Parallelizes processing of input using at::parallel_for.
-     * Each thread accumulates a local result for some range of the input in data_out_local
-     * before locking data_out_mutex and adding its accumulated results to data_out.
+     * Each thread accumulates a local result for some range of the input in hist_local
+     * before locking hist_mutex and adding its accumulated results to the hist tensor.
      */
-    std::mutex data_out_mutex;
-    at::parallel_for(0, numel_in, HISTOGRAM_GRAIN_SIZE, [&](int64_t start, int64_t end) {
-        // Allocates a buffer for the thread's local results
-        std::vector<input_t> data_out_local(numel_be - 1, input_t(0));
+    std::mutex hist_mutex;
+    at::parallel_for(0, N, GRAIN_SIZE, [&](int64_t start, int64_t end) {
+        // Allocates a tensor for the thread's local results
+        Tensor hist_local = at::zeros(hist.sizes(), hist.dtype());
 
+        std::vector<at::indexing::TensorIndex> indices(D, 0);
         for (const auto i : c10::irange(start, end)) {
-            const input_t elt = accessor_in[i];
+            bool skip_elt = false;
 
-            // Skips elements which fall outside the specified bins
-            if (elt < leftmost_bin_edge || rightmost_bin_edge < elt) {
-                continue;
-            }
+            for (int64_t dim = 0; dim < D; dim++) {
+                const input_t elt = accessor_in[i][dim];
 
-            int64_t pos = -1;
-
-            if (algorithm == BINARY_SEARCH) {
-                // Handles the general case via binary search on the bin edges.
-                pos = std::upper_bound(data_be, data_be + numel_be, elt) - data_be - 1;
-            } else if (algorithm == LINEAR_INTERPOLATION
-                    || algorithm == LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH) {
-                /* When bin_edges is known to be a linear progression, maps elt to
-                 * the appropriate bin via simple division.
-                 */
-                pos = static_cast<int64_t>((elt - leftmost_bin_edge)
-                        / (rightmost_bin_edge - leftmost_bin_edge)
-                        * (numel_be - 1));
-
-                /* Ensures consistency with bin_edges by checking the bins to the left and right
-                 * of the selected position. Necessary for cases in which an element very close
-                 * to a bin edge may be misclassified by simple division.
-                 */
-                if (algorithm == LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH) {
-                    int64_t pos_min = std::max(static_cast<int64_t>(0), pos - 1);
-                    int64_t pos_max = std::min(pos + 2, numel_be);
-                    pos = std::upper_bound(data_be + pos_min, data_be + pos_max, elt) - data_be - 1;
+                // Skips elements which fall outside the specified bins
+                if (elt < leftmost_edge[dim] || rightmost_edge[dim] < elt) {
+                    skip_elt = true;
+                    break;
                 }
-            } else {
-                TORCH_INTERNAL_ASSERT(false);
+
+                int64_t pos = -1;
+
+                if (algorithm == BINARY_SEARCH) {
+                    // Handles the general case via binary search on the bin edges.
+                    pos = std::upper_bound(bin_seq[dim], bin_seq[dim] + num_bin_edges[dim], elt)
+                            - bin_seq[dim] - 1;
+                } else if (algorithm == LINEAR_INTERPOLATION
+                        || algorithm == LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH) {
+                    /* When bin_edges is known to be a linear progression, maps elt to
+                     * the appropriate bin via simple division.
+                     */
+                    pos = static_cast<int64_t>((elt - leftmost_edge[dim])
+                            / (rightmost_edge[dim] - leftmost_edge[dim])
+                            * (num_bin_edges[dim] - 1));
+
+                    /* Ensures consistency with bin_edges by checking the bins to the left and right
+                     * of the selected position. Necessary for cases in which an element very close
+                     * to a bin edge may be misclassified by simple division.
+                     */
+                    if (algorithm == LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH) {
+                        int64_t pos_min = std::max(static_cast<int64_t>(0), pos - 1);
+                        int64_t pos_max = std::min(pos + 2, num_bin_edges[dim]);
+                        pos = std::upper_bound(bin_seq[dim] + pos_min, bin_seq[dim] + pos_max, elt)
+                                - bin_seq[dim] - 1;
+                    }
+                } else {
+                    TORCH_INTERNAL_ASSERT(false);
+                }
+
+                // Unlike other bins, the rightmost bin includes its right boundary
+                if (pos == (num_bin_edges[dim] - 1)) {
+                    pos -= 1;
+                }
+
+                indices[dim] = pos;
             }
 
-            // Unlike other bins, the rightmost bin includes its right boundary
-            if (pos == (numel_be - 1)) {
-                pos -= 1;
-            }
+            if (!skip_elt) {
+                // In the unweighted case, the default weight is 1
+                input_t wt = accessor_wt.has_value() ? accessor_wt.value()[i] : static_cast<input_t>(1);
 
-            // In the unweighted case, the default weight is 1
-            input_t wt = accessor_wt.has_value() ? accessor_wt.value()[i] : static_cast<input_t>(1);
-            data_out_local[pos] += wt;
+                input_t cur = hist_local.index(indices).item<input_t>();
+                hist_local.index_put_(indices, cur + wt);
+            }
         }
 
+
         // Locks and updates the common output
-        const std::lock_guard<std::mutex> lock(data_out_mutex);
-        for (int64_t i = 0; i < numel_be - 1; i++) {
-            data_out[i] += data_out_local[i];
-        }
+        const std::lock_guard<std::mutex> lock(hist_mutex);
+        hist.add_(hist_local);
     });
 }
 
@@ -156,64 +191,85 @@
  * Initializes hist to 0, calls into the main algorithm, and normalizes output if necessary.
  */
 template<BIN_SELECTION_ALGORITHM bin_algorithm>
-void histogram_out_cpu_template(const Tensor& self, const c10::optional<Tensor>& weight, bool density,
-        Tensor& hist, const Tensor& bin_edges) {
+void histogramdd_out_cpu_template(const Tensor& self, const c10::optional<Tensor>& weight, bool density,
+        Tensor& hist, const TensorList& bin_edges) {
     hist.fill_(0);
 
-    const int64_t numel_in = self.numel();
-    const Tensor reshaped_input = self.reshape({numel_in});
+    const int64_t N = self.size(-1);
+    const int64_t M = std::accumulate(self.sizes().begin(), self.sizes().end() - 1,
+            (int64_t)1, std::multiplies<int64_t>());
+
+    const Tensor reshaped_input = self.reshape({M, N});
 
     const auto reshaped_weight = weight.has_value()
-            ? c10::optional<Tensor>(weight.value().reshape({numel_in}))
+            ? c10::optional<Tensor>(weight.value().reshape({M}))
             : c10::optional<Tensor>();
 
+    std::vector<Tensor> bin_edges_contig(bin_edges.size());
+    for (size_t dim = 0; dim < bin_edges_contig.size(); dim++) {
+        bin_edges_contig[dim] = bin_edges[dim].contiguous();
+    }
+
     AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "histogram_cpu", [&]() {
-        histogram_cpu_contiguous<scalar_t, bin_algorithm>(
-                hist, bin_edges.contiguous(), reshaped_input, reshaped_weight);
+        histogramdd_cpu_contiguous<scalar_t, bin_algorithm>(
+                hist, bin_edges_contig, reshaped_input, reshaped_weight);
     });
 
-    // Converts the bin totals to a probability density function
+    /* Divides each bin's value by the total count/weight in all bins,
+     * and by the bin's volume.
+     */
     if (density) {
-        auto bin_widths = bin_edges.diff();
-        auto hist_sum = hist.sum().item();
-        hist.div_(bin_widths);
+        const auto hist_sum = hist.sum().item();
         hist.div_(hist_sum);
+
+         /* For each dimension, divides each bin's value
+          * by the bin's length in that dimension.
+          */
+        for (int64_t dim = 0; dim < N; dim++) {
+            const auto bin_lengths = bin_edges[dim].diff();
+
+            // Used to reshape bin_lengths to align with the corresponding dimension of hist.
+            std::vector<int64_t> shape(N, 1);
+            shape[dim] = bin_lengths.numel();
+
+            hist.div_(bin_lengths.reshape(shape));
+        }
     }
 }
 
 /* The general implementation of the histogram kernel. Maps each element of the input tensor
  * to its corresponding bin by performing a binary search over the elements of bin_edges.
  *
- * Refer to histogram_out_cpu_template for more details.
+ * Refer to histogramdd_out_cpu_template for more details.
  */
-static void histogram_kernel_impl(const Tensor& self, const c10::optional<Tensor>& weight, bool density,
-        Tensor& hist, const Tensor& bin_edges) {
-    histogram_out_cpu_template<BINARY_SEARCH>(self, weight, density, hist, bin_edges);
+static void histogramdd_kernel_impl(const Tensor& self, const c10::optional<Tensor>& weight, bool density,
+        Tensor& hist, const TensorList& bin_edges) {
+    histogramdd_out_cpu_template<BINARY_SEARCH>(self, weight, density, hist, bin_edges);
 }
 
 /* A faster version of the histogram kernel for cases in which bin_edges are known
  * to form a linear progression.
  *
- * Refer to histogram_out_cpu_template for more details.
+ * Refer to histogramdd_out_cpu_template for more details.
  */
-static void histogram_linear_kernel_impl(const Tensor& self, const c10::optional<Tensor>& weight,
-        bool density, Tensor& hist, const Tensor& bin_edges, bool local_search) {
+static void histogramdd_linear_kernel_impl(const Tensor& self, const c10::optional<Tensor>& weight,
+        bool density, Tensor& hist, const TensorList& bin_edges, bool local_search) {
     if (local_search) {
-        // histogram codepath: both hist and bin_edges are eventually returned as output,
+        // histogramdd codepath: both hist and bin_edges are eventually returned as output,
         // so we'll keep them consistent
-        histogram_out_cpu_template<LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH>(
+        histogramdd_out_cpu_template<LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH>(
               self, weight, density, hist, bin_edges);
     } else {
         // histc codepath: bin_edges are not returned to the caller
-        histogram_out_cpu_template<LINEAR_INTERPOLATION>(
+        histogramdd_out_cpu_template<LINEAR_INTERPOLATION>(
               self, weight, density, hist, bin_edges);
     }
 }
 
 } // namespace
 
-REGISTER_DISPATCH(histogram_stub, &histogram_kernel_impl);
+REGISTER_DISPATCH(histogramdd_stub, &histogramdd_kernel_impl);
 
-REGISTER_DISPATCH(histogram_linear_stub, &histogram_linear_kernel_impl);
+REGISTER_DISPATCH(histogramdd_linear_stub, &histogramdd_linear_kernel_impl);
 
 }} // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 2caa8c7..7935e99 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -7095,6 +7095,18 @@
   dispatch:
     CPU: histogram_cpu
 
+- func: _histogramdd_bin_edges(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False) -> Tensor[]
+  dispatch:
+    CPU: histogramdd_bin_edges_cpu
+
+- func: _histogramdd_from_bin_cts(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False) -> Tensor
+  dispatch:
+    CPU: histogramdd_cpu
+
+- func: _histogramdd_from_bin_tensors(Tensor self, Tensor[] bins, *, Tensor? weight=None, bool density=False) -> Tensor
+  dispatch:
+    CPU: histogramdd_cpu
+
 - func: fmod.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
   device_check: NoCheck   # TensorIterator
   dispatch:
diff --git a/test/test_fx.py b/test/test_fx.py
index 6a36b7d..416f7a5 100644
--- a/test/test_fx.py
+++ b/test/test_fx.py
@@ -3188,6 +3188,7 @@
                            'expand_as',
                            'fill_',
                            'hstack',
+                           'histogramdd',
                            'igamma',
                            'igammac',
                            'linalg.multi_dot',
diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py
index f4e7751..289c665 100644
--- a/test/test_fx_experimental.py
+++ b/test/test_fx_experimental.py
@@ -1475,6 +1475,7 @@
             "mT",  # Implemented with a lambda
             "mH",  # Implemented with a lambda
             "gradient",
+            "histogramdd",
             "igamma",
             "igammac",
             "index_put",
diff --git a/test/test_reductions.py b/test/test_reductions.py
index a02bc06..3b0e981 100644
--- a/test/test_reductions.py
+++ b/test/test_reductions.py
@@ -2813,8 +2813,23 @@
         self.assertEqual(actual_hist, expected_hist)
         self.assertEqual(actual_bin_edges, expected_bin_edges)
 
+        # Test passing non-contiguous output tensors
+        hist_out = make_tensor(expected_hist.shape, device=expected_hist.device, dtype=expected_hist.dtype,
+                               noncontiguous=True)
+        bin_edges_out = make_tensor(expected_bin_edges.shape, device=expected_bin_edges.device, dtype=expected_bin_edges.dtype,
+                                    noncontiguous=True)
+
+        # Doesn't pass a 'range' kwarg unless necessary because the override of histogram with Tensor bins doesn't accept one
+        if bin_range:
+            torch.histogram(t, bins, range=bin_range, weight=weights, density=density, out=(hist_out, bin_edges_out))
+        else:
+            torch.histogram(t, bins, weight=weights, density=density, out=(hist_out, bin_edges_out))
+
+        self.assertEqual(hist_out, expected_hist)
+        self.assertEqual(bin_edges_out, expected_bin_edges)
+
     @onlyCPU
-    @dtypes(torch.float32, torch.float64)
+    @dtypes(torch.float32)
     def test_histogram(self, device, dtype):
         shapes = (
             (),
@@ -2870,16 +2885,126 @@
             self.assertEqual(actual_hist, expected_hist)
             self.assertEqual(actual_bin_edges, expected_bin_edges)
 
+    """
+    Runs torch.histogramdd and numpy.histogramdd on the specified input parameters
+    and asserts that their output is equal.
+    """
+    def _test_histogramdd_numpy(self, t, bins, bin_range, weights, density):
+        def to_np(t):
+            if type(t) == list:
+                return list(map(to_np, t))
+            if not torch.is_tensor(t):
+                return t
+            return t.cpu().numpy()
+
+        # Wrapper around numpy.histogram performing conversions between torch tensors and numpy arrays.
+        def reference_histogramdd(t, bins, bin_range, weights, density, dtype):
+            (np_t, np_bins, np_weights) = map(to_np, [t, bins, weights])
+
+            # numpy.histogramdd accepts only (N, D) shapes
+            D = np_t.shape[-1]
+            N = np.prod(np_t.shape[:-1])
+            reshaped_t = np.reshape(np_t, (N, D))
+            reshaped_wt = np.reshape(np_weights, (N,)) if np_weights is not None else None
+
+            # numpy.histogramdd throws an error for D=0
+            if D == 0:
+                return (torch.tensor(float('nan') if density else 0.), [])
+
+            # numpy.histogramdd expects range to be specified as a sequence of D (lower, upper) tuples
+            reshaped_range = None if not bin_range else [(bin_range[2 * i], bin_range[2 * i + 1]) for i in range(D)]
+
+            (np_hist, np_bin_edges) = np.histogramdd(reshaped_t, np_bins,
+                                                     range=reshaped_range, weights=reshaped_wt, density=density)
+
+            return (torch.from_numpy(np_hist).to(dtype), [torch.from_numpy(t).to(dtype) for t in np_bin_edges])
+
+        (actual_hist, actual_bin_edges) = torch.histogramdd(t, bins, range=bin_range, weight=weights, density=density)
+        (expected_hist, expected_bin_edges) = reference_histogramdd(t, bins, bin_range, weights, density, actual_hist.dtype)
+
+        D = len(actual_bin_edges)
+        self.assertEqual(D, len(expected_bin_edges))
+
+        """
+        Works around linspace discrepancies by passing torch's constructed bin_edges to numpy.
+        When bin edges are not explicitly defined, histogram uses the linspace operator internally
+        to construct the sequence of bin edges. In some cases, torch.linspace output differs slightly
+        from numpy.linspace output.
+        Issue: https://github.com/pytorch/pytorch/issues/58758
+        """
+        if not torch.is_tensor(bins):
+            for dim in range(D):
+                self.assertEqual(actual_bin_edges[dim], expected_bin_edges[dim], atol=1e-5, rtol=1e-5)
+            # Calls numpy.histogram again, passing torch's actual_bin_edges as the bins argument
+            (expected_hist, expected_bin_edges) = reference_histogramdd(
+                t, actual_bin_edges, bin_range, weights, density, actual_hist.dtype)
+            self.assertEqual(D, len(expected_bin_edges))
+
+        self.assertEqual(actual_hist, expected_hist)
+        for dim in range(D):
+            self.assertEqual(actual_bin_edges[dim], expected_bin_edges[dim])
+
     @onlyCPU
-    @dtypes(torch.float32, torch.float64)
+    @dtypes(torch.float32)
+    def test_histogramdd(self, device, dtype):
+        shapes = (
+            (1, 5),
+            (3, 5),
+            (1, 5, 1),
+            (2, 3, 5),
+            (7, 7, 7, 7),
+            (16, 8, 4, 2),
+            (10, 10, 10),
+            (7, 0, 3),
+            (5, 0),)
+
+        for contig, bins_contig, weighted, density, shape in \
+                product([True, False], [True, False], [True, False], [True, False], shapes):
+            D = shape[-1]
+
+            values = make_tensor(shape, device, dtype, low=-9, high=9, noncontiguous=not contig)
+            weights = make_tensor(shape[:-1], device, dtype, low=0, high=9, noncontiguous=not contig) if weighted else None
+
+            # Tests passing a single bin count
+            bin_ct = random.randint(1, 5)
+            self._test_histogramdd_numpy(values, bin_ct, None, weights, density)
+
+            # Tests passing a bin count for each dimension
+            bin_ct = [random.randint(1, 5) for dim in range(D)]
+            self._test_histogramdd_numpy(values, bin_ct, None, weights, density)
+
+            # Tests with caller-specified histogram range
+            bin_range_tuples = [sorted((random.uniform(-9, 9), random.uniform(-9, 9))) for dim in range(D)]
+            bin_range = [elt for t in bin_range_tuples for elt in t]
+            self._test_histogramdd_numpy(values, bin_ct, bin_range, weights, density)
+
+            # Tests with range min=max
+            for dim in range(D):
+                bin_range[2 * dim + 1] = bin_range[2 * dim]
+            self._test_histogramdd_numpy(values, bin_ct, bin_range, weights, density)
+
+            # Tests with caller-specified bin edges
+            bin_edges = [make_tensor(ct + 1, device, dtype, low=-9, high=9).msort() for ct in bin_ct]
+            if not bins_contig:
+                # Necessary because msort always produces contiguous output
+                bin_edges_noncontig = [make_tensor(ct + 1, device, dtype, noncontiguous=not bins_contig) for ct in bin_ct]
+                for dim in range(D):
+                    bin_edges_noncontig[dim].copy_(bin_edges[dim])
+                bin_edges = bin_edges_noncontig
+            for dim in range(D):
+                self.assertEqual(bin_edges[dim].is_contiguous(), bins_contig)
+            self._test_histogramdd_numpy(values, bin_edges, None, weights, density)
+
+    @onlyCPU
+    @dtypes(torch.float32)
     def test_histogram_error_handling(self, device, dtype):
-        with self.assertRaisesRegex(RuntimeError, '\"histogram_cpu\" not implemented for'):
+        with self.assertRaisesRegex(RuntimeError, 'not implemented for'):
             values = make_tensor((), device, dtype=torch.int32)
             torch.histogram(values, 1)
 
         inconsistent_dtype = torch.float32 if dtype != torch.float32 else torch.float64
 
-        with self.assertRaisesRegex(RuntimeError, 'input tensor and bins tensor should have the same dtype'):
+        with self.assertRaisesRegex(RuntimeError, 'input tensor and bins tensors should have the same dtype'):
             values = make_tensor((), device, dtype=dtype)
             bins = make_tensor((), device, dtype=inconsistent_dtype)
             torch.histogram(values, bins)
@@ -2901,7 +3026,7 @@
             bin_edges = make_tensor((), device, dtype=inconsistent_dtype)
             torch.histogram(values, 1, out=(hist, bin_edges))
 
-        with self.assertRaisesRegex(RuntimeError, 'bins tensor should have dimension 1'):
+        with self.assertRaisesRegex(RuntimeError, 'bins tensor should have one dimension'):
             t = make_tensor((2, 2), device, dtype=dtype)
             torch.histogram(t, t)
 
@@ -2913,17 +3038,12 @@
             values = make_tensor((), device, dtype=dtype)
             torch.histogram(values, -1)
 
-        with self.assertRaisesRegex(RuntimeError, 'input tensor and weight tensor should have the same shape'):
+        with self.assertRaisesRegex(RuntimeError, 'if weight tensor is provided it should have the same shape \
+as the input tensor excluding its innermost dimension'):
             values = make_tensor((2, 2), device, dtype=dtype)
             weight = make_tensor((1), device, dtype=dtype)
             torch.histogram(values, 1, weight=weight)
 
-        with self.assertRaisesRegex(RuntimeError, 'hist tensor must be contiguous'):
-            values = make_tensor((), device, dtype=dtype)
-            hist = make_tensor((2), device, dtype=dtype, noncontiguous=True)
-            bin_edges = make_tensor((), device, dtype=dtype)
-            torch.histogram(values, 2, out=(hist, bin_edges))
-
         with self.assertRaisesRegex(TypeError, 'received an invalid combination of arguments'):
             values = make_tensor((), device, dtype=dtype)
             bin_edges = make_tensor((), device, dtype=dtype)
@@ -2933,7 +3053,7 @@
             values = make_tensor((), device, dtype=dtype)
             torch.histogram(values, 2, range=(1, 0))
 
-        with self.assertRaisesRegex(RuntimeError, r'range of \[nan, nan\] is not finite'):
+        with self.assertRaisesRegex(RuntimeError, r'range \[nan, nan\] is not finite'):
             values = torch.tensor([float("nan")], device=device, dtype=dtype)
             torch.histogram(values, 2)
 
diff --git a/torch/functional.py b/torch/functional.py
index 73e531d..fe127ae 100644
--- a/torch/functional.py
+++ b/torch/functional.py
@@ -1,6 +1,8 @@
 from typing import (
     Tuple, Optional, Union, Any, Sequence, TYPE_CHECKING
 )
+from collections import namedtuple
+import itertools
 
 import torch
 import torch.nn.functional as F
@@ -26,6 +28,7 @@
     'cdist',
     'chain_matmul',
     'einsum',
+    'histogramdd',
     'istft',
     'lu',
     'norm',
@@ -326,6 +329,126 @@
 
     return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
 
+# Wrapper around _histogramdd and _histogramdd_bin_edges needed due to (Tensor, Tensor[]) return type.
+if TYPE_CHECKING:
+    # The JIT doesn't understand Union, so only add type annotation for mypy
+    def histogramdd(input: Tensor,
+                    bins: Union[List[Tensor], List[int], int],
+                    range: Optional[List[float]] = None,
+                    weight: Optional[Tensor] = None,
+                    density: bool = False):
+        pass
+else:
+    def histogramdd(input, bins, range=None, weight=None, density=False):
+        r"""
+        histogramdd(input, bins, *, range=None, weight=None, density=False, out=None) -> (Tensor, Tensor[])
+
+        Computes a multi-dimensional histogram of the values in a tensor.
+
+        Interprets the elements of an input tensor whose innermost dimension has size N
+        as a collection of N-dimensional points. Maps each of the points into a set of
+        N-dimensional bins and returns the number of points (or total weight) in each bin.
+
+        :attr:`input` must be a tensor with at least 2 dimensions.
+        If input has shape (M, N), each of its M rows defines a point in N-dimensional space.
+        If input has three or more dimensions, all but the last dimension are flattened.
+
+        Each dimension is independently associated with its own strictly increasing sequence
+        of bin edges. Bin edges may be specified explicitly by passing a sequence of 1D
+        tensors. Alternatively, bin edges may be constructed automatically by passing a
+        sequence of integers specifying the number of equal-width bins in each dimension.
+
+        For each N-dimensional point in input:
+            - Each of its coordinates is binned independently among the bin edges
+              corresponding to its dimension
+            - Binning results are combined to identify the N-dimensional bin (if any)
+              into which the point falls
+            - If the point falls into a bin, the bin's count (or total weight) is incremented
+            - Points which do not fall into any bin do not contribute to the output
+
+        :attr:`bins` can be a sequence of N 1D tensors, a sequence of N ints, or a single int.
+
+        If :attr:`bins` is a sequence of N 1D tensors, it explicitly specifies the N sequences
+        of bin edges. Each 1D tensor should contain a strictly increasing sequence with at
+        least one element. A sequence of K bin edges defines K-1 bins, explicitly specifying
+        the left and right edges of all bins. Every bin is exclusive of its left edge. Only
+        the rightmost bin is inclusive of its right edge.
+
+        If :attr:`bins` is a sequence of N ints, it specifies the number of equal-width bins
+        in each dimension. By default, the leftmost and rightmost bin edges in each dimension
+        are determined by the minimum and maximum elements of the input tensor in the
+        corresponding dimension. The :attr:`range` argument can be provided to manually
+        specify the leftmost and rightmost bin edges in each dimension.
+
+        If :attr:`bins` is an int, it specifies the number of equal-width bins for all dimensions.
+
+        .. note::
+            See also :func:`torch.histogram`, which specifically computes 1D histograms.
+            While :func:`torch.histogramdd` infers the dimensionality of its bins and
+            binned values from the shape of :attr:`input`, :func:`torch.histogram`
+            accepts and flattens :attr:`input` of any shape.
+
+        Args:
+            {input}
+            bins: Tensor[], int[], or int.
+                  If Tensor[], defines the sequences of bin edges.
+                  If int[], defines the number of equal-width bins in each dimension.
+                  If int, defines the number of equal-width bins for all dimensions.
+        Keyword args:
+            range (sequence of float): Defines the leftmost and rightmost bin edges
+                                       in each dimension.
+            weight (Tensor): By default, each value in the input has weight 1. If a weight
+                             tensor is passed, each N-dimensional coordinate in input
+                             contributes its associated weight towards its bin's result.
+                             The weight tensor should have the same shape as the :attr:`input`
+                             tensor excluding its innermost dimension N.
+            density (bool): If False (default), the result will contain the count (or total weight)
+                            in each bin. If True, each count (weight) is divided by the total count
+                            (total weight), then divided by the volume of its associated bin.
+        Returns:
+            hist (Tensor): N-dimensional Tensor containing the values of the histogram.
+            bin_edges(Tensor[]): sequence of N 1D Tensors containing the bin edges.
+
+        Example::
+            >>> torch.histogramdd(torch.tensor([[0., 1.], [1., 0.], [2., 0.], [2., 2.]]), bins=[3, 3],
+            ...                   weight=torch.tensor([1., 2., 4., 8.]))
+                histogramdd_return_type(hist=tensor([[0., 1., 0.],
+                                                     [2., 0., 0.],
+                                                     [4., 0., 8.]]),
+                                        bin_edges=(tensor([0.0000, 0.6667, 1.3333, 2.0000]),
+                                                   tensor([0.0000, 0.6667, 1.3333, 2.0000])))
+
+            >>> torch.histogramdd(torch.tensor([[0., 0.], [1., 1.], [2., 2.]]), bins=[2, 2],
+            ...                   range=[0., 1., 0., 1.], density=True)
+                histogramdd_return_type(hist=tensor([[2., 0.],
+                                                     [0., 2.]]),
+                                        bin_edges=(tensor([0.0000, 0.5000, 1.0000]),
+                                                   tensor([0.0000, 0.5000, 1.0000])))
+
+        """
+        if isinstance(bins, int):
+            # If a single int is passed, repeat it for all dimensions
+            bins = list(itertools.repeat(bins, input.size()[-1]))
+
+        if bins and isinstance(bins[0], int):
+            """
+            If bins is int[], the histogram kernel runs faster knowing that the bin edges form
+            a linear progression (see comments in aten/src/ATen/native/cpu/HistogramKernel.cpp).
+            However, we end up constructing the bin edge tensors twice because
+            _histogramdd_from_bin_cts cannot pass back (Tensor, Tensor[]).
+            """
+            bin_edges = _VF._histogramdd_bin_edges(input, bins, range=range, weight=weight, density=density)
+            hist = _VF._histogramdd_from_bin_cts(input, bins, range=range, weight=weight, density=density)
+        else:
+            """
+            If bins is Tensor[] we simply return it back.
+            """
+            bin_edges = bins
+            hist = _VF._histogramdd_from_bin_tensors(input, bin_edges, weight=weight, density=density)
+
+        # TODO: figure out how to return torch.return_types.histogramdd
+        histogramdd_return_type = namedtuple('histogramdd_return_type', 'hist bin_edges')
+        return histogramdd_return_type(hist, bin_edges)
 
 # This wrapper exists to support variadic args.
 if TYPE_CHECKING:
diff --git a/torch/overrides.py b/torch/overrides.py
index 4133ca6..28ff321 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -531,6 +531,7 @@
         torch.hinge_embedding_loss: lambda input, target, margin=1.0, size_average=None, reduce=None, reduction='mean': -1,
         torch.histc: lambda input, bins=100, min=0, max=0, out=None: -1,
         torch.histogram: lambda input, bins=100, min=None, max=None, weight=None, density=False, out=None: -1,
+        torch.histogramdd: lambda input, bins, weight=None, density=False: -1,
         torch.linalg.householder_product: lambda input, tau: -1,
         torch.hspmm: lambda mat1, mat2, out=None: -1,
         torch.hsplit: lambda input, indices_or_sections: -1,
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index dafb903..ed8665c 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -2571,6 +2571,27 @@
 
     return sample_inputs
 
+def sample_inputs_histogramdd(op_info, device, dtype, requires_grad):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+
+    sizes = ((S, S), (S, S, S), (S, 1, S), (S, 0, S))
+    bin_ct_patterns = ((1, 1, 1, 1, 1), (2, 3, 2, 3, 2), (3, 2, 3, 2, 3))
+
+    sample_inputs = []
+    for size, bin_ct_pattern, weighted, density in product(sizes, bin_ct_patterns, [False, True], [False, True]):
+        input_tensor = make_arg(size)
+        bin_ct = bin_ct_pattern[:size[-1]]
+        weight_tensor = make_arg(size[:-1]) if weighted else None
+
+        sample_inputs.append(SampleInput(input_tensor, args=(bin_ct,),
+                                         kwargs=dict(weight=weight_tensor, density=density)))
+
+        bins_tensor = [make_arg(ct + 1) for ct in bin_ct]
+        sample_inputs.append(SampleInput(input_tensor, args=(bins_tensor,),
+                                         kwargs=dict(weight=weight_tensor, density=density)))
+
+    return sample_inputs
+
 def sample_inputs_bincount(op_info, device, dtype, requires_grad):
     make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
 
@@ -10003,6 +10024,16 @@
                #                                          ~~~~~~ <--- HERE
                DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
            )),
+    OpInfo('histogramdd',
+           dtypes=_dispatch_dtypes(),  # histogramdd is only implemented on CPU
+           dtypesIfCPU=floating_types(),
+           sample_inputs_func=sample_inputs_histogramdd,
+           supports_autograd=False,
+           skips=(
+               # JIT tests don't work with Tensor keyword arguments
+               # https://github.com/pytorch/pytorch/issues/58507
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+           )),
     OpInfo('bincount',
            dtypes=integral_types_and(),
            sample_inputs_func=sample_inputs_bincount,