| #include <ATen/ATen.h> |
| #include <ATen/cuda/Exceptions.h> |
| #include <THC/THCTensorMathReduce.cuh> |
| #include <math.h> |
| |
| #include <ATen/native/Distance.h> |
| |
| |
| namespace at { namespace native { |
| |
| namespace { |
| |
| static const int forward_threads = 256; |
| |
| #ifdef __HIP_PLATFORM_HCC__ |
| static const int WARP_SIZE = 64; |
| #else |
| static const int WARP_SIZE = 32; |
| #endif |
| |
| template <typename scalar_t> |
| static __forceinline__ __device__ scalar_t device_sqrt(scalar_t val); |
| |
| template <> |
| __forceinline__ __device__ float device_sqrt(float val) { |
| return ::sqrtf(val); |
| } |
| |
| template <> |
| __forceinline__ __device__ double device_sqrt(double val) { |
| return ::sqrt(val); |
| } |
| |
| template <typename scalar_t> |
| struct dists { |
| |
| static __forceinline__ __device__ scalar_t sign(scalar_t val) { |
| return (0 < val) - (val < 0); |
| } |
| |
| // Zero norm |
| struct zero { |
| static __forceinline__ __device__ void inc(scalar_t& agg, const scalar_t diff, const scalar_t p) { agg += diff != 0.0; } |
| static __forceinline__ __device__ scalar_t finish(const scalar_t agg, const scalar_t p) { return agg; } |
| static __forceinline__ __device__ void agg(scalar_t& update, const scalar_t other) { update += other; } |
| }; |
| |
| // One norm |
| struct one { |
| static __forceinline__ __device__ void inc(scalar_t& agg, const scalar_t diff, const scalar_t p) { agg += diff; } |
| static __forceinline__ __device__ scalar_t finish(const scalar_t agg, const scalar_t p) { return agg; } |
| static __forceinline__ __device__ void agg(scalar_t& update, const scalar_t other) { update += other; } |
| static __forceinline__ __device__ scalar_t backward(const scalar_t diff, const scalar_t grad, const scalar_t dist, const scalar_t p) { return grad * sign(diff); } |
| }; |
| |
| // Special case backward when p is less than two |
| struct lt_two { |
| static __forceinline__ __device__ scalar_t backward(const scalar_t diff, const scalar_t grad, const scalar_t dist, const scalar_t p) { return dist == 0.0 ? 0 : sign(diff) * std::pow(std::abs(diff), p - 1) * grad / std::pow(dist, p - 1); } |
| }; |
| |
| // Two norm |
| struct two { |
| static __forceinline__ __device__ void inc(scalar_t& agg, const scalar_t diff, const scalar_t p) { agg += diff * diff; } |
| static __forceinline__ __device__ scalar_t finish(const scalar_t agg, const scalar_t p) { return device_sqrt<scalar_t>(agg); } |
| static __forceinline__ __device__ void agg(scalar_t& update, const scalar_t other) { update += other; } |
| static __forceinline__ __device__ scalar_t backward(const scalar_t diff, const scalar_t grad, const scalar_t dist, const scalar_t p) { return dist == 0.0 ? 0 : grad * diff / dist; } |
| }; |
| |
| // General p norm |
| struct p { |
| static __forceinline__ __device__ void inc(scalar_t& agg, const scalar_t diff, const scalar_t p) { agg += std::pow(diff, p); } |
| static __forceinline__ __device__ scalar_t finish(const scalar_t agg, const scalar_t p) { return std::pow(agg, static_cast<scalar_t>(1) / p); } |
| static __forceinline__ __device__ void agg(scalar_t& update, const scalar_t other) { update += other; } |
| static __forceinline__ __device__ scalar_t backward(const scalar_t diff, const scalar_t grad, const scalar_t dist, const scalar_t p) { return dist == 0.0 ? 0 : diff * std::pow(std::abs(diff), p - 2) * grad / std::pow(dist, p - 1); } |
| }; |
| |
| // Inf norm |
| struct inf { |
| static __forceinline__ __device__ void inc(scalar_t& agg, const scalar_t diff, const scalar_t p) { if (diff > agg) { agg = diff; } } |
| static __forceinline__ __device__ scalar_t finish(const scalar_t agg, const scalar_t p) { return agg; } |
| static __forceinline__ __device__ void agg(scalar_t& update, const scalar_t other) { if (other > update) { update = other; } } |
| static __forceinline__ __device__ scalar_t backward(const scalar_t diff, const scalar_t grad, const scalar_t dist, const scalar_t p) { return grad * sign(diff) * (std::abs(diff) == dist); } |
| }; |
| |
| }; |
| |
| template <typename scalar_t, typename F> |
| __device__ static inline scalar_t reduce_agg(scalar_t agg) { |
| for (int offset = warpSize / 2; offset > 0; offset /= 2) { |
| F::agg(agg, WARP_SHFL_DOWN(agg, offset)); |
| } |
| |
| __shared__ scalar_t shared[forward_threads]; |
| int lane = threadIdx.x % warpSize; |
| int warp_id = threadIdx.x / warpSize; |
| if (lane == 0) { |
| shared[warp_id] = agg; |
| } |
| |
| __syncthreads(); |
| agg = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0.0; |
| if (warp_id == 0) { |
| for (int offset = blockDim.x / warpSize / 2; offset > 0; offset /= 2) { |
| F::agg(agg, WARP_SHFL_DOWN(agg, offset)); |
| } |
| } |
| return agg; |
| } |
| |
| template <typename scalar_t, typename F> |
| __global__ static void pdist_kernel_cuda_impl(scalar_t * result, const scalar_t * self, const int64_t n, const int64_t m, const scalar_t p, |
| const double n2, const double n2_squared_minus_1) { |
| const int k = blockIdx.x; |
| const int stride = blockDim.x; |
| |
| // The -1 accounts for floating point truncation issues |
| int64_t i = static_cast<int64_t>((n2 - device_sqrt<double>(n2_squared_minus_1 - 2 * k))); |
| int64_t j = k - n * i + i * (i + 1) / 2 + i + 1; |
| |
| const scalar_t * const start = self + i * m; |
| const scalar_t * const end = start + m; |
| const scalar_t * a = start + threadIdx.x; |
| const scalar_t * b = self + j * m + threadIdx.x; |
| scalar_t agg = 0.0; |
| for (; a < end; a += stride, b += stride) { |
| F::inc(agg, std::abs(*a - *b), p); |
| } |
| |
| agg = reduce_agg<scalar_t, F>(agg); |
| if (threadIdx.x == 0) { |
| result[k] = F::finish(agg, p); |
| } |
| } |
| |
| template <typename scalar_t, typename F> |
| __global__ static void cdist_backward_kernel_cuda_impl(scalar_t * buffer, const scalar_t * grad, const scalar_t * x1, const scalar_t * x2, const scalar_t * dist, int64_t gs, |
| const scalar_t p, const int64_t r1, const int64_t r2, const int64_t m, const int64_t count) { |
| const int k = blockIdx.y * blockDim.y + threadIdx.y; |
| const int init = blockIdx.x * blockDim.x + threadIdx.x; |
| const int stride = blockDim.x * gridDim.x; |
| |
| if (k >= count) { |
| return; |
| } |
| |
| int64_t i = k / r2; |
| int64_t j = k % r2; |
| |
| const scalar_t grad_k = grad[k * gs]; |
| const scalar_t dist_k = dist[k]; |
| |
| const scalar_t * const start = x1 + i * m; |
| const scalar_t * const end = start + m; |
| const scalar_t * self_i = start + init; |
| const scalar_t * self_j = x2 + j * m + init; |
| |
| scalar_t * buff_i = buffer + (r1 * j + i) * m + init; |
| |
| for (; self_i < end; self_i += stride, self_j += stride, buff_i += stride) { |
| const scalar_t res = F::backward(*self_i - *self_j, grad_k, dist_k, p); |
| *buff_i = res; |
| } |
| } |
| |
| template <typename scalar_t, typename F> |
| __global__ static void pdist_backward_kernel_cuda_impl(scalar_t * buffer, const scalar_t * grad, const scalar_t * self, const scalar_t * dist, int64_t gs, const int64_t n, const int64_t m, const int64_t combs, const scalar_t p, |
| const double n2, const double n2_squared_minus_1) { |
| const int k = blockIdx.y * blockDim.y + threadIdx.y; |
| const int init = blockIdx.x * blockDim.x + threadIdx.x; |
| const int stride = blockDim.x * gridDim.x; |
| |
| if (k >= combs) { |
| return; |
| } |
| |
| // The -1 accounts for floating point truncation issues |
| int64_t i = static_cast<int64_t>((n2 - device_sqrt<double>(n2_squared_minus_1 - 2 * k))); |
| int64_t j = k - n * i + i * (i + 1) / 2 + i + 1; |
| int64_t ib = j - i - 1; |
| int64_t jb = n - 2 - i; |
| |
| const scalar_t grad_k = grad[k * gs]; |
| const scalar_t dist_k = dist[k]; |
| |
| const scalar_t * const start = self + i * m; |
| const scalar_t * const end = start + m; |
| const scalar_t * self_i = start + init; |
| const scalar_t * self_j = self + j * m + init; |
| scalar_t * buff_i = buffer + (ib * n + i) * m + init; |
| scalar_t * buff_j = buffer + (jb * n + j) * m + init; |
| for (; self_i < end; self_i += stride, self_j += stride, buff_i += stride, buff_j += stride) { |
| const scalar_t res = F::backward(*self_i - *self_j, grad_k, dist_k, p); |
| *buff_i = res; |
| *buff_j = -res; |
| } |
| } |
| |
| template <typename scalar_t, typename F> |
| __global__ static void cdist_kernel_cuda_impl(scalar_t * result, const scalar_t * x1, const scalar_t * x2, const scalar_t p, const int64_t r1, const int64_t r2, const int64_t m) { |
| const int k = blockIdx.x; |
| const int64_t i = k / r2; |
| const int64_t j = k % r2; |
| const int stride = blockDim.x; |
| |
| const scalar_t * const start = x1 + i * m; |
| const scalar_t * const end = start + m; |
| const scalar_t * a = start + threadIdx.x; |
| const scalar_t * b = x2 + j * m + threadIdx.x; |
| |
| scalar_t agg = 0.0; |
| for (; a < end; a += stride, b += stride) { |
| F::inc(agg, std::abs(*a - *b), p); |
| } |
| agg = reduce_agg<scalar_t, F>(agg); |
| if (threadIdx.x == 0) { |
| result[k] = F::finish(agg, p); |
| } |
| } |
| |
| void cdist_kernel_impl(Tensor& result, const Tensor& x1, const Tensor& x2, double p) { |
| int64_t r1 = x1.size(-2); |
| int64_t r2 = x2.size(-2); |
| int64_t m = x1.size(-1); |
| const dim3 grid(r1*r2); |
| const dim3 block(std::min((int64_t)forward_threads, ((m - 1) / WARP_SIZE + 1) * WARP_SIZE)); |
| |
| AT_DISPATCH_FLOATING_TYPES(x1.scalar_type(), "cdist_cuda", [&] { |
| if (p == 0.0) { |
| cdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::zero><<<grid, block>>>(result.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), p, r1, r2, m); |
| } else if (p == 1.0) { |
| cdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::one><<<grid, block>>>(result.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), p, r1, r2, m); |
| } else if (p == 2.0) { |
| cdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::two><<<grid, block>>>(result.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), p, r1, r2, m); |
| } else if (std::isinf(p)) { |
| cdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::inf><<<grid, block>>>(result.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), p, r1, r2, m); |
| } else { |
| cdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::p><<<grid, block>>>(result.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), p, r1, r2, m); |
| } |
| }); |
| AT_CUDA_CHECK(cudaGetLastError()); |
| } |
| |
| void pdist_forward_kernel_impl(Tensor& result, const Tensor& self, double p) { |
| const dim3 grid(result.numel()); |
| const dim3 block(forward_threads); |
| int64_t n = self.size(0); |
| int64_t m = self.size(1); |
| // https://github.com/pytorch/pytorch/issues/15511 demonstrated we need to do |
| // some math in fp64 -- this is just minimizing the amount of fp64 math we do on the device. |
| const double n2 = n - .5; |
| const double n2_squared_minus_1 = n2 * n2 - 1; |
| |
| AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "pdist_cuda", [&] { |
| if (p == 0.0) { |
| pdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::zero><<<grid, block>>>(result.data<scalar_t>(), self.data<scalar_t>(), n, m, p, n2, n2_squared_minus_1); |
| } else if (p == 1.0) { |
| pdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::one><<<grid, block>>>(result.data<scalar_t>(), self.data<scalar_t>(), n, m, p, n2, n2_squared_minus_1); |
| } else if (p == 2.0) { |
| pdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::two><<<grid, block>>>(result.data<scalar_t>(), self.data<scalar_t>(), n, m, p, n2, n2_squared_minus_1); |
| } else if (std::isinf(p)) { |
| pdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::inf><<<grid, block>>>(result.data<scalar_t>(), self.data<scalar_t>(), n, m, p, n2, n2_squared_minus_1); |
| } else { |
| pdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::p><<<grid, block>>>(result.data<scalar_t>(), self.data<scalar_t>(), n, m, p, n2, n2_squared_minus_1); |
| } |
| }); |
| AT_CUDA_CHECK(cudaGetLastError()); |
| } |
| |
| void pdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor& self, const double p, const Tensor& dist) { |
| if (p == 0.0 || grad.numel() == 0 || self.numel() == 0) { |
| result.fill_(0); |
| return; |
| } |
| |
| const int64_t n = result.size(0); |
| int64_t m = self.size(1); |
| const int block_x = 64; |
| // NB: be careful with changing block_y; as it's currently written, grid_y is limited to be 2^16. |
| // From binary search, block_y of 16 gives us max pdist dim0 of 1449, |
| // block_y of 4 gives us max pdist dim0 of 725. |
| const int block_y = 16; |
| const int grid_x = (m + block_x * 8 - 1) / (block_x * 8); |
| const int grid_y = (dist.numel() + block_y - 1) / block_y; |
| const dim3 grid(grid_x, grid_y); |
| const dim3 block(block_x, block_y); |
| // https://github.com/pytorch/pytorch/issues/15511 demonstrated we need to do |
| // some math in fp64 -- this is just minimizing the amount of fp64 math we do on the device. |
| const double n2 = n - .5; |
| const double n2_squared_minus_1 = n2 * n2 - 1; |
| |
| Tensor buffer = at::empty({n - 1, result.size(0), result.size(1)}, result.options()); |
| AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "pdist_cuda_backward", [&] { |
| if (p == 1.0) { |
| pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::one><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), self.data<scalar_t>(), dist.data<scalar_t>(), grad.stride(0), n, m, dist.numel(), p, n2, n2_squared_minus_1); |
| } else if (p < 2.0) { |
| pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::lt_two><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), self.data<scalar_t>(), dist.data<scalar_t>(), grad.stride(0), n, m, dist.numel(), p, n2, n2_squared_minus_1); |
| } else if (p == 2.0) { |
| pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::two><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), self.data<scalar_t>(), dist.data<scalar_t>(), grad.stride(0), n, m, dist.numel(), p, n2, n2_squared_minus_1); |
| } else if (std::isinf(p)) { |
| pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::inf><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), self.data<scalar_t>(), dist.data<scalar_t>(), grad.stride(0), n, m, dist.numel(), p, n2, n2_squared_minus_1); |
| } else { |
| pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::p><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), self.data<scalar_t>(), dist.data<scalar_t>(), grad.stride(0), n, m, dist.numel(), p, n2, n2_squared_minus_1); |
| } |
| }); |
| AT_CUDA_CHECK(cudaGetLastError()); |
| |
| at::sum_out(result, buffer, 0); |
| } |
| |
| void cdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor& x1, const Tensor& x2, const double p, const Tensor& dist) { |
| if (p == 0.0 || grad.numel() == 0 || x1.numel() == 0 || x2.numel() == 0) { |
| result.fill_(0); |
| return; |
| } |
| |
| const int64_t r1 = x1.size(-2); |
| const int64_t r2 = x2.size(-2); |
| const int64_t m = x1.size(-1); |
| const int block_x = 64; |
| const int block_y = 16; |
| const int grid_x = (m + block_x * 8 - 1) / (block_x * 8); |
| const int grid_y = (dist.numel() + block_y - 1) / block_y; |
| |
| const dim3 grid(grid_x, grid_y); |
| const dim3 block(block_x, block_y); |
| |
| const int64_t count = dist.numel(); |
| |
| Tensor buffer = at::empty({r2, r1, m}, result.options()); |
| AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "cdist_cuda_backward", [&] { |
| if (p == 1.0) { |
| cdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::one><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), dist.data<scalar_t>(), grad.stride(1), p, r1, r2, m, count); |
| } else if (p < 2.0) { |
| cdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::lt_two><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), dist.data<scalar_t>(), grad.stride(1), p, r1, r2, m, count); |
| } else if (p == 2.0) { |
| cdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::two><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), dist.data<scalar_t>(), grad.stride(1), p, r1, r2, m, count); |
| } else if (std::isinf(p)) { |
| cdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::inf><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), dist.data<scalar_t>(), grad.stride(1), p, r1, r2, m, count); |
| } else { |
| cdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::p><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), dist.data<scalar_t>(), grad.stride(1), p, r1, r2, m, count); |
| } |
| }); |
| AT_CUDA_CHECK(cudaGetLastError()); |
| |
| at::sum_out(result, buffer, 0); |
| } |
| |
| |
| } // anonymous namespace |
| |
| REGISTER_DISPATCH(pdist_forward_stub, &pdist_forward_kernel_impl); |
| REGISTER_DISPATCH(pdist_backward_stub, &pdist_backward_kernel_impl); |
| REGISTER_DISPATCH(cdist_stub, &cdist_kernel_impl); |
| REGISTER_DISPATCH(cdist_backward_stub, &cdist_backward_kernel_impl); |
| |
| }} // at::native |