blob: e640aef8dde26174d78a5188635b077371a30045 [file] [log] [blame]
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/Histogram.h>
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <c10/util/irange.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/sum.h>
#include <ATen/ops/zeros.h>
#include <ATen/ops/zeros_like_ops.h>
#endif
#include <algorithm>
#include <numeric>
#include <functional>
namespace at::native {
namespace {
constexpr int64_t HISTOGRAM_GRAIN_SIZE = 200;
/* The main algorithm. Expects that the input tensor has shape (N, D).
* Expects that bin_edges contains D one-dimensional tensors, each specifying
* an increasing sequences of bin edges.
*
* Interprets the input as N different D-dimensional coordinates and maps them
* into the D-dimensional bins defined by bin_edges, accumulating a D-dimensional
* histogram in the hist tensor.
*
* Accepts a template argument of type BIN_SELECTION_ALGORITHM specifying how
* the scalars in each dimension should be mapped into the dimension's bins:
*
* - LINEAR_INTERPOLATION: each bin edge sequence must form a linear progression.
* Scalars are mapped to bins by computing
* (element - leftmost_edge)/(rightmost_edge - leftmost_edge) * bin_ct
* and truncating the result to an integer.
*
* This is the fastest option, but its results may not be perfectly consistent
* with the boundaries specified in bin_edges due to precision issues.
*
* Used by torch.histc, which doesn't need consistency with bin_edges as it does not
* return bin_edges. Additionally, this implementation is identical to the legacy histc
* implementation, which was replaced when histogram was implemented.
*
* - LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH: Also expects that each bin edge sequence
* forms a linear progression. For each scalar, if 'pos' is the bin selected by the
* LINEAR_INTERPOLATION approach, this approach inspects the boundaries in bin_edges to
* place the scalar into pos - 1, pos, or pos + 1. The "local search" over neighboring
* bins allows for correction of misclassifications due to precision issues (a scalar
* very close to a bin_edge may be misclassified by LINEAR_INTERPOLATION).
*
* Should produce the same output as the general case BINARY_SEARCH, but run about
* 3x faster asymptotically.
*
* Used by torch.histogram for cases in which bin_edges is constructed using
* torch.linspace. The behavior of LINEAR_INTERPOLATION may not perfectly align
* with linspace bin_edges due to precision issues. torch.histogram returns both
* the hist and bin_edges tensors as output, so the "local search" is needed to
* keep its output internally consistent.
*
* - BINARY_SEARCH: Handles torch.histogram's general case by by searching over the
* elements of bin_edges. Implemented using std::upper_bound.
*
* See discussion at https://github.com/pytorch/pytorch/pull/58780#discussion_r648604866
* for further details on relative performance of the bin selection algorithms.
*/
enum BIN_SELECTION_ALGORITHM {
LINEAR_INTERPOLATION,
LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH,
BINARY_SEARCH,
};
template<typename input_t, BIN_SELECTION_ALGORITHM algorithm>
void histogramdd_cpu_contiguous(Tensor& hist, const TensorList& bin_edges,
const Tensor& input, const c10::optional<Tensor>& weight) {
TORCH_INTERNAL_ASSERT(input.dim() == 2);
const int64_t N = input.size(0);
if (weight.has_value()) {
TORCH_INTERNAL_ASSERT(weight.value().dim() == 1 && weight.value().numel() == N);
}
const int64_t D = input.size(1);
TORCH_INTERNAL_ASSERT(int64_t(bin_edges.size()) == D);
for (const auto dim : c10::irange(D)) {
TORCH_INTERNAL_ASSERT(bin_edges[dim].is_contiguous());
TORCH_INTERNAL_ASSERT(hist.size(dim) + 1 == bin_edges[dim].numel());
}
if (D == 0) {
// hist is an empty tensor in this case; nothing to do here
return;
}
TensorAccessor<input_t, 2> accessor_in = input.accessor<input_t, 2>();
/* Constructs a c10::optional<TensorAccessor> containing an accessor iff
* the optional weight tensor has a value.
*/
const auto accessor_wt = weight.has_value()
? c10::optional<TensorAccessor<input_t, 1>>(weight.value().accessor<input_t, 1>())
: c10::optional<TensorAccessor<input_t, 1>>();
std::vector<input_t*> bin_seq(D);
std::vector<int64_t> num_bin_edges(D);
std::vector<input_t> leftmost_edge(D), rightmost_edge(D);
for (const auto dim : c10::irange(D)) {
bin_seq[dim] = bin_edges[dim].data_ptr<input_t>();
num_bin_edges[dim] = bin_edges[dim].numel();
leftmost_edge[dim] = bin_seq[dim][0];
rightmost_edge[dim] = bin_seq[dim][num_bin_edges[dim] - 1];
}
int64_t GRAIN_SIZE = std::max(int64_t(1), HISTOGRAM_GRAIN_SIZE / D);
/* Parallelizes processing of input using at::parallel_for.
* Each thread accumulates a local result into their own slice of
* thread_histograms which get summed together at the end.
*/
const auto num_threads = at::get_num_threads();
const auto hist_sizes = hist.sizes();
DimVector thread_hist_sizes(hist_sizes.size() + 1);
thread_hist_sizes[0] = num_threads;
std::copy(hist_sizes.begin(), hist_sizes.end(),
thread_hist_sizes.begin() + 1);
Tensor thread_histograms = at::zeros(thread_hist_sizes, hist.dtype());
TORCH_INTERNAL_ASSERT(thread_histograms.is_contiguous());
at::parallel_for(0, N, GRAIN_SIZE, [&](int64_t start, int64_t end) {
const auto tid = at::get_thread_num();
auto hist_strides = thread_histograms.strides();
input_t *hist_local_data = thread_histograms.data_ptr<input_t>();
// View only this thread's local results
hist_local_data += hist_strides[0] * tid;
hist_strides = hist_strides.slice(1);
for (const auto i : c10::irange(start, end)) {
bool skip_elt = false;
int64_t hist_index = 0;
for (const auto dim : c10::irange(D)) {
const input_t elt = accessor_in[i][dim];
// Skips elements which fall outside the specified bins and NaN elements
if (!(elt >= leftmost_edge[dim] && elt <= rightmost_edge[dim])) {
skip_elt = true;
break;
}
int64_t pos = -1;
if (algorithm == BINARY_SEARCH) {
// Handles the general case via binary search on the bin edges.
pos = std::upper_bound(bin_seq[dim], bin_seq[dim] + num_bin_edges[dim], elt)
- bin_seq[dim] - 1;
} else if (algorithm == LINEAR_INTERPOLATION
|| algorithm == LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH) {
/* When bin_edges is known to be a linear progression, maps elt to
* the appropriate bin via simple division.
*/
pos = static_cast<int64_t>((elt - leftmost_edge[dim])
* (num_bin_edges[dim] - 1)
/ (rightmost_edge[dim] - leftmost_edge[dim]));
/* Ensures consistency with bin_edges by checking the bins to the left and right
* of the selected position. Necessary for cases in which an element very close
* to a bin edge may be misclassified by simple division.
*/
if (algorithm == LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH) {
int64_t pos_min = std::max(static_cast<int64_t>(0), pos - 1);
int64_t pos_max = std::min(pos + 2, num_bin_edges[dim]);
pos = std::upper_bound(bin_seq[dim] + pos_min, bin_seq[dim] + pos_max, elt)
- bin_seq[dim] - 1;
}
} else {
TORCH_INTERNAL_ASSERT(false);
}
// Unlike other bins, the rightmost bin includes its right boundary
if (pos == (num_bin_edges[dim] - 1)) {
pos -= 1;
}
hist_index += hist_strides[dim] * pos;
}
if (!skip_elt) {
// In the unweighted case, the default weight is 1
input_t wt = accessor_wt.has_value() ? accessor_wt.value()[i] : static_cast<input_t>(1);
hist_local_data[hist_index] += wt;
}
}
});
at::sum_out(hist, thread_histograms, /*dim=*/{0});
}
/* Some pre- and post- processing steps for the main algorithm.
* Initializes hist to 0, calls into the main algorithm, and normalizes output if necessary.
*/
template<BIN_SELECTION_ALGORITHM bin_algorithm>
void histogramdd_out_cpu_template(const Tensor& self, const c10::optional<Tensor>& weight, bool density,
Tensor& hist, const TensorList& bin_edges) {
hist.fill_(0);
const int64_t N = self.size(-1);
const int64_t M = std::accumulate(self.sizes().begin(), self.sizes().end() - 1,
(int64_t)1, std::multiplies<int64_t>());
const Tensor reshaped_input = self.reshape({M, N});
const auto reshaped_weight = weight.has_value()
? c10::optional<Tensor>(weight.value().reshape({M}))
: c10::optional<Tensor>();
std::vector<Tensor> bin_edges_contig(bin_edges.size());
for (const auto dim : c10::irange(bin_edges_contig.size())) {
bin_edges_contig[dim] = bin_edges[dim].contiguous();
}
AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, self.scalar_type(), "histogram_cpu", [&]() {
histogramdd_cpu_contiguous<scalar_t, bin_algorithm>(
hist, bin_edges_contig, reshaped_input, reshaped_weight);
});
/* Divides each bin's value by the total count/weight in all bins,
* and by the bin's volume.
*/
if (density) {
const auto hist_sum = hist.sum().item();
hist.div_(hist_sum);
/* For each dimension, divides each bin's value
* by the bin's length in that dimension.
*/
for (const auto dim : c10::irange(N)) {
const auto bin_lengths = bin_edges[dim].diff();
// Used to reshape bin_lengths to align with the corresponding dimension of hist.
std::vector<int64_t> shape(N, 1);
shape[dim] = bin_lengths.numel();
hist.div_(bin_lengths.reshape(shape));
}
}
}
/* The general implementation of the histogram kernel. Maps each element of the input tensor
* to its corresponding bin by performing a binary search over the elements of bin_edges.
*
* Refer to histogramdd_out_cpu_template for more details.
*/
static void histogramdd_kernel_impl(const Tensor& self, const c10::optional<Tensor>& weight, bool density,
Tensor& hist, const TensorList& bin_edges) {
histogramdd_out_cpu_template<BINARY_SEARCH>(self, weight, density, hist, bin_edges);
}
/* A faster version of the histogram kernel for cases in which bin_edges are known
* to form a linear progression.
*
* Refer to histogramdd_out_cpu_template for more details.
*/
static void histogramdd_linear_kernel_impl(const Tensor& self, const c10::optional<Tensor>& weight,
bool density, Tensor& hist, const TensorList& bin_edges, bool local_search) {
if (local_search) {
// histogramdd codepath: both hist and bin_edges are eventually returned as output,
// so we'll keep them consistent
histogramdd_out_cpu_template<LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH>(
self, weight, density, hist, bin_edges);
} else {
// histc codepath: bin_edges are not returned to the caller
histogramdd_out_cpu_template<LINEAR_INTERPOLATION>(
self, weight, density, hist, bin_edges);
}
}
} // namespace
REGISTER_DISPATCH(histogramdd_stub, &histogramdd_kernel_impl);
REGISTER_DISPATCH(histogramdd_linear_stub, &histogramdd_linear_kernel_impl);
} // namespace at::native