blob: 114ca93dae26192ca617548642063f908454207b [file] [log] [blame]
#include <ATen/native/Distance.h>
#include <numeric>
#include <iterator>
#include <algorithm>
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/cpu/vml.h>
namespace at { namespace native { namespace {
template<typename scalar_t>
struct Dist {
using Vec = vec256::Vec256<scalar_t>;
// Depending on the value of the pnorm, there are specific implementations
// that are much faster than std::pow(std::abs(a - b), p), but have the same
// standard loop code for how to process the input vector. To reuse the main
// outside loop while still guaranteeing that the compiler inlines every
// different function on p, we break the inner norm logic into structs with
// static functions that represent what's done differently, and template the
// outer loop on those structs.
//
// The four functions are:
// map : This tells how to modify (a - b) to form the component that
// gets summed.
// red : This tells how to sum the result of map up. This is
// separate because the inf norm actually uses max instead of
// sum.
// finish : This tells what to do with the aggregated value to compute
// the norm. Generally this is the result of val ^ (1 / p).
// backward : This is the gradient for that norm. Arguments are pretty
// self explanitory.
//
// There are a few cases where these aren't used. The 0 norm has no backward,
// because it's always 0, so that's shortcircuited earlier. There's a special
// implementation of the general backward pass when p is less than two, so
// there's a struct with only a backward pass for this case.
// TODO This is an inefficient way to compite sign, and can be much faster
// using native SSE instructions that should be added to Vec256.
static inline Vec sign(Vec val) {
return vec256::minimum(vec256::maximum(Vec(0), val.ceil()), Vec(1)) +
vec256::minimum(vec256::maximum(Vec(-1), val.floor()), Vec(0));
}
static inline Vec abs(Vec val) {
return val.abs();
}
static inline scalar_t abs(scalar_t val) {
return std::abs(val);
}
static inline Vec ceil(Vec val) {
return val.ceil();
}
static inline scalar_t ceil(scalar_t val) {
return std::ceil(val);
}
static inline Vec min(Vec val, scalar_t other) {
return vec256::minimum(val, Vec(other));
}
static inline scalar_t min(scalar_t val, scalar_t other) {
return std::min(val, other);
}
static inline Vec max(Vec val, Vec other) {
return vec256::maximum(val, other);
}
static inline scalar_t max(scalar_t val, scalar_t other) {
return std::max(val, other);
}
static inline Vec pow(Vec val, Vec p) {
return val.pow(p);
}
static inline scalar_t pow(scalar_t val, scalar_t p) {
return std::pow(val, p);
}
// Zero norm
template<typename data_t>
struct zdist_calc {
static inline data_t map(const data_t& diff, const data_t& p) { return min(ceil(abs(diff)), 1); }
static inline data_t red(const data_t& agg, const data_t& up) { return agg + up; }
static inline scalar_t finish(const scalar_t agg, const scalar_t p) { return agg; }
};
// One norm
template<typename data_t>
struct odist_calc {
static inline data_t map(const data_t& diff, const data_t& p) { return diff; }
static inline data_t red(const data_t& agg, const data_t& up) { return agg + up; }
static inline scalar_t finish(const scalar_t agg, const scalar_t p) { return agg; }
static inline Vec backward(const Vec& diff, const scalar_t grad, const scalar_t dist, const Vec& p) { return Vec(grad) * sign(diff); }
};
// Special general pnorm derivative if p is less than two
struct lttdist_calc {
static inline Vec backward(const Vec& diff, const scalar_t grad, const scalar_t dist, const Vec& p) { return dist == 0.0 ? Vec(0) : sign(diff) * diff.abs().pow(p - Vec(1)) * Vec(grad) / Vec(dist).pow(p - Vec(1)); }
};
// Two norm
template<typename data_t>
struct tdist_calc {
// TODO This can probably use fused add multiply to get better perf
static inline data_t map(const data_t& diff, const data_t& p) { return diff * diff; }
static inline data_t red(const data_t& agg, const data_t& up) { return agg + up; }
static inline scalar_t finish(const scalar_t agg, const scalar_t p) { return std::sqrt(agg); }
static inline Vec backward(const Vec& diff, const scalar_t grad, const scalar_t dist, const Vec& p) { return dist == 0.0 ? Vec(0) : Vec(grad) * diff / Vec(dist); }
};
// General p norm
template<typename data_t>
struct pdist_calc {
static inline data_t map(const data_t& diff, const data_t& p) { return pow(diff, p); }
static inline data_t red(const data_t& agg, const data_t& up) { return agg + up; }
static inline scalar_t finish(const scalar_t agg, const scalar_t p) { return std::pow(agg, 1.0 / p); }
static inline Vec backward(const Vec& diff, const scalar_t grad, const scalar_t dist, const Vec& p) { return dist == 0.0 ? Vec(0) : diff * diff.abs().pow(p - Vec(2)) * Vec(grad) / Vec(dist).pow(p - Vec(1)); }
};
// Inf norm
template<typename data_t>
struct idist_calc {
static inline data_t map(const data_t& diff, const data_t& p) { return diff; }
static inline data_t red(const data_t& agg, const data_t& up) { return max(agg, up); }
static inline scalar_t finish(const scalar_t agg, const scalar_t p) { return agg; }
// TODO This backward pass uses a very complext expression to compute (diff
// == dist) that could be much faster if using SSE instructions.
static inline Vec backward(const Vec& diff, const scalar_t grad, const scalar_t dist, const Vec& p) { return Vec(grad) * sign(diff) * (Vec(1) - vec256::minimum(Vec(1), (diff.abs() - Vec(dist)).abs().ceil())); }
};
template <typename F>
static void run_parallel_pdist(Tensor& result, const Tensor& self, const scalar_t p) {
const scalar_t * const self_start = self.data_ptr<scalar_t>();
const scalar_t * const self_end = self_start + self.numel();
int64_t n = self.size(0);
int64_t m = self.size(1);
scalar_t * const res_start = result.data_ptr<scalar_t>();
int64_t combs = result.numel(); // n * (n - 1) / 2
// We conceptually iterate over tuples of (i, j, k) where i is the first
// vector from the input, j is the second, and k is the result index. This
// parallelizes over the range of k and infers what i and j are from the
// value of k.
parallel_for(0, combs, internal::GRAIN_SIZE / (16 * m), [p, self_start, self_end, n, m, res_start](int64_t k, int64_t end) {
const Vec pvec(p);
double n2 = n - .5;
// The -1 accounts for floating point truncation issues
int64_t i = static_cast<int64_t>((n2 - std::sqrt(n2 * n2 - 2 * k - 1)));
int64_t j = k - n * i + i * (i + 1) / 2 + i + 1;
const scalar_t * self_i = self_start + i * m;
const scalar_t * self_j = self_start + j * m;
scalar_t * res = res_start + k;
const scalar_t * const res_end = res_start + end;
while (res != res_end) {
*res = F::finish(vec256::map2_reduce_all<scalar_t>(
[&pvec](Vec a, Vec b) { return F::map((a - b).abs(), pvec); },
F::red, self_i, self_j, m), p);
res += 1;
self_j += m;
if (self_j == self_end) {
self_i += m;
self_j = self_i + m;
}
}
});
}
// Assumes self is nonempty, contiguous, and 2D
static void apply_pdist(Tensor& result, const Tensor& self, const scalar_t p) {
if (p == 0.0) {
run_parallel_pdist<zdist_calc<Vec>>(result, self, p);
} else if (p == 1.0) {
run_parallel_pdist<odist_calc<Vec>>(result, self, p);
} else if (p == 2.0) {
run_parallel_pdist<tdist_calc<Vec>>(result, self, p);
} else if (std::isinf(p)) {
run_parallel_pdist<idist_calc<Vec>>(result, self, p);
} else {
run_parallel_pdist<pdist_calc<Vec>>(result, self, p);
}
}
template <typename F>
static void run_parallel_cdist(Tensor& result, const Tensor& t1, const Tensor& t2, const scalar_t p) {
const scalar_t * const t1_start = t1.data_ptr<scalar_t>();
const scalar_t * const t2_start = t2.data_ptr<scalar_t>();
int64_t d = t1.size(0);
int64_t r1 = t1.size(-2);
int64_t r2 = t2.size(-2);
int64_t m = t1.size(-1);
scalar_t * const res_start = result.data_ptr<scalar_t>();
int64_t combs = r1 * r2;
int64_t size1 = r1 * m;
int64_t size2 = r2 * m;
parallel_for(0, combs * d, internal::GRAIN_SIZE / (16 * m), [=](int64_t start, int64_t end) {
scalar_t * res = res_start + start;
const scalar_t * const res_end = res_start + end;
int64_t l = start / combs;
int64_t k = start % combs;
int64_t i = k / r2;
int64_t j = k % r2;
i = i * m;
j = j * m;
while (res != res_end) {
const scalar_t * self_i = t1_start + size1 * l + i;
const scalar_t * self_j = t2_start + size2 * l + j;
scalar_t agg = 0;
for (int x = 0; x < m; x++) {
scalar_t a = *(self_i + x);
scalar_t b = *(self_j + x);
agg = F::red(agg, F::map(std::abs(a-b), p));
}
*res = F::finish(agg, p);
res += 1;
j += m;
if (j == size2) {
j = 0;
i += m;
if (i == size1) {
i = 0;
l += 1;
}
}
}
});
}
static void apply_cdist(Tensor& result, const Tensor& x1, const Tensor& x2, const scalar_t p) {
if (p == 0.0) {
run_parallel_cdist<zdist_calc<scalar_t>>(result, x1, x2, p);
} else if (p == 1.0) {
run_parallel_cdist<odist_calc<scalar_t>>(result, x1, x2, p);
} else if (p == 2.0) {
run_parallel_cdist<tdist_calc<scalar_t>>(result, x1, x2, p);
} else if (std::isinf(p)) {
run_parallel_cdist<idist_calc<scalar_t>>(result, x1, x2, p);
} else {
run_parallel_cdist<pdist_calc<scalar_t>>(result, x1, x2, p);
}
}
// This does a backward pass down a Vec column of the input
template <typename F>
inline static void backward_down_column_pdist(const scalar_t * self_i, scalar_t * res_i, const scalar_t * grad_k, const scalar_t * dist_k, const Vec& pvec, int64_t n, int64_t m, int64_t gs, int64_t count = Vec::size()) {
for (const scalar_t * const self_end = self_i + m * n; self_i != self_end - m; self_i += m, res_i += m) {
const Vec self_vec_i = Vec::loadu(self_i, count);
Vec res_vec_i = Vec::loadu(res_i, count);
const scalar_t * self_j = self_i + m;
scalar_t * res_j = res_i + m;
for (; self_j != self_end; self_j += m, res_j += m, grad_k += gs, dist_k += 1) {
const Vec self_vec_j = Vec::loadu(self_j, count);
Vec res_vec_j = Vec::loadu(res_j, count);
Vec res = F::backward(self_vec_i - self_vec_j, *grad_k, *dist_k, pvec);
res_vec_i = res_vec_i + res;
res_vec_j = res_vec_j - res;
res_vec_j.store(res_j, count);
}
res_vec_i.store(res_i, count);
}
}
template <typename F>
static void run_backward_parallel_pdist(Tensor& result, const Tensor & grad, const Tensor & self, const scalar_t p, const Tensor& dist) {
const int64_t n = self.size(0);
const int64_t m = self.size(1);
const int64_t gs = grad.stride(0);
const scalar_t * const grad_start = grad.data_ptr<scalar_t>();
const scalar_t * const dist_start = dist.data_ptr<scalar_t>();
const scalar_t * const self_start = self.data_ptr<scalar_t>();
scalar_t * const res_start = result.data_ptr<scalar_t>();
// The only way to parallelize and avoid locking requires parallelizing
// over the columns of the input, i.e. we compute the gradient for the
// first section of each vector independentaly of the second section, etc.
at::parallel_for(0, m / Vec::size(), internal::GRAIN_SIZE / (8 * n * n), [p, n, m, gs, grad_start, dist_start, self_start, res_start](int64_t l, int64_t end) {
const Vec pvec(p);
const scalar_t * self_l = self_start + l * Vec::size();
scalar_t * res_l = res_start + l * Vec::size();
for (const scalar_t * const res_end = res_start + end * Vec::size(); res_l != res_end; self_l += Vec::size(), res_l += Vec::size()) {
backward_down_column_pdist<F>(self_l, res_l, grad_start, dist_start, pvec, n, m, gs);
}
});
const int64_t remainder = m % Vec::size();
if (remainder) {
backward_down_column_pdist<F>(self_start + (m - remainder), res_start + (m - remainder), grad_start, dist_start, Vec(p), n, m, gs, remainder);
}
}
// Assumes self is nonempty, contiguous, and 2D and dist is also contiguous
static void apply_backward_pdist(Tensor& result, const Tensor& grad, const Tensor& self, const double p, const Tensor& dist) {
result.fill_(0);
if (p == 0.0) {
} else if (p == 1.0) {
run_backward_parallel_pdist<odist_calc<Vec>>(result, grad, self, p, dist);
} else if (p < 2.0) {
run_backward_parallel_pdist<lttdist_calc>(result, grad, self, p, dist);
} else if (p == 2.0) {
run_backward_parallel_pdist<tdist_calc<Vec>>(result, grad, self, p, dist);
} else if (std::isinf(p)) {
run_backward_parallel_pdist<idist_calc<Vec>>(result, grad, self, p, dist);
} else {
run_backward_parallel_pdist<pdist_calc<Vec>>(result, grad, self, p, dist);
}
}
static void apply_backward_cdist(Tensor& result, const Tensor& grad, const Tensor& x1, const Tensor& x2, const double p, const Tensor& dist) {
result.fill_(0);
if (p == 0.0) {
} else if (p == 1.0) {
run_backward_parallel_cdist<odist_calc<Vec>>(result, grad, x1, x2, p, dist);
} else if (p < 2.0) {
run_backward_parallel_cdist<lttdist_calc>(result, grad, x1, x2, p, dist);
} else if (p == 2.0) {
run_backward_parallel_cdist<tdist_calc<Vec>>(result, grad, x1, x2, p, dist);
} else if (std::isinf(p)) {
run_backward_parallel_cdist<idist_calc<Vec>>(result, grad, x1, x2, p, dist);
} else {
run_backward_parallel_cdist<pdist_calc<Vec>>(result, grad, x1, x2, p, dist);
}
}
template <typename F>
static void run_backward_parallel_cdist(Tensor& result, const Tensor & grad, const Tensor & t1, const Tensor & t2, const scalar_t p, const Tensor& dist) {
const int64_t r1 = t1.size(-2);
const int64_t r2 = t2.size(-2);
const int64_t m = t1.size(-1);
const int64_t d = result.size(0);
const int64_t l1_size = r1 * m;
const int64_t l2_size = r2 * m;
//current implementation supports only tensor that can be collapsed to 1D. However, to avoid checking if grad satisfies this assumption,
//we call .contiguous() on grad before backward, thus stride is guaranteed to be 1
//don't use grad.stride(-1), because if last dimension is 1, stride can be bogus.
const int64_t gs = 1;
const scalar_t * const grad_start = grad.data_ptr<scalar_t>();
const scalar_t * const dist_start = dist.data_ptr<scalar_t>();
const scalar_t * const t1_start = t1.data_ptr<scalar_t>();
const scalar_t * const t2_start = t2.data_ptr<scalar_t>();
scalar_t * const res_start = result.data_ptr<scalar_t>();
at::parallel_for(0, m / Vec::size(), internal::GRAIN_SIZE / (16 * r1), [=](int64_t l, int64_t end) {
const Vec pvec(p);
const scalar_t * i = t1_start + l * Vec::size();
const scalar_t * j = t2_start + l * Vec::size();
scalar_t * res_l = res_start + l * Vec::size();
for (const scalar_t * const res_end = res_start + end * Vec::size(); res_l != res_end; i += Vec::size(), j += Vec::size(), res_l += Vec::size()) {
backward_down_column_cdist<F>(i, j, res_l, grad_start, dist_start, pvec, r1, r2, m, d, gs, l1_size, l2_size);
}
});
const int64_t remainder = m % Vec::size();
if (remainder) {
backward_down_column_cdist<F>(t1_start + (m - remainder), t2_start + (m - remainder), res_start + (m - remainder), grad_start, dist_start, Vec(p), r1, r2, m, d, gs, l1_size, l2_size, remainder);
}
}
template <typename F>
inline static void backward_down_column_cdist(const scalar_t * t1, const scalar_t * t2, scalar_t * res, const scalar_t * grad_k, const scalar_t * dist_k, const Vec& pvec, int64_t r1, int64_t r2, int64_t m, int64_t d, int64_t gs, int64_t l1_size, int64_t l2_size, int64_t count = Vec::size()) {
const scalar_t * t1_end = t1 + l1_size;
const scalar_t * t2_end = t2 + l2_size;
for (int64_t l = 0; l < d; l++) {
for (; t1 != t1_end; t1 += m, res += m) {
const Vec vec_t1 = Vec::loadu(t1, count);
Vec res_vec = Vec::loadu(res, count);
for (const scalar_t * t2_curr = t2; t2_curr != t2_end; t2_curr += m, grad_k += gs, dist_k += 1) {
const Vec vec_t2 = Vec::loadu(t2_curr, count);
Vec res = F::backward(vec_t1 - vec_t2, *grad_k, *dist_k, pvec);
res_vec = res_vec + res;
}
res_vec.store(res, count);
}
t1_end += l1_size;
t2_end += l2_size;
t2 += l2_size;
}
}
};
void pdist_forward_kernel_impl(Tensor& result, const Tensor& self, const double p) {
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "pdist", [&] {
Dist<scalar_t>::apply_pdist(result, self, p);
});
}
static void pdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor& self, const double p, const Tensor& dist) {
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "pdist_backward", [&] {
Dist<scalar_t>::apply_backward_pdist(result, grad, self, p, dist);
});
}
static void cdist_kernel_impl(Tensor& result, const Tensor& x1, const Tensor& x2, const double p) {
AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "cdist", [&] {
Dist<scalar_t>::apply_cdist(result, x1, x2, p);
});
}
static void cdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor& x1, const Tensor& x2, const double p, const Tensor& dist) {
AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "cdist_backward", [&] {
Dist<scalar_t>::apply_backward_cdist(result, grad, x1, x2, p, dist);
});
}
} // anonymous namespace
REGISTER_DISPATCH(pdist_forward_stub, &pdist_forward_kernel_impl);
REGISTER_DISPATCH(pdist_backward_stub, &pdist_backward_kernel_impl);
REGISTER_DISPATCH(cdist_stub, &cdist_kernel_impl);
REGISTER_DISPATCH(cdist_backward_stub, &cdist_backward_kernel_impl);
}} // namespace at::native