| #pragma once |
| #include <ATen/AccumulateType.h> |
| |
| namespace at { |
| namespace native { |
| |
| /* |
| * The following function was converted to CUDA form from code that comes |
| * with the following copyright notice. It has been released under the BSD license. |
| * |
| * Cephes Math Library Release 2.8: June, 2000 |
| * Copyright 1984, 1987, 1992, 2000 by Stephen L. Moshier |
| */ |
| template <typename scalar_t> |
| static inline __host__ __device__ scalar_t calc_digamma(scalar_t in) { |
| using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>; |
| static const double PI_f64 = 3.14159265358979323846; |
| const accscalar_t PSI_10 = 2.25175258906672110764; |
| const accscalar_t A[] = { |
| 8.33333333333333333333E-2, |
| -2.10927960927960927961E-2, |
| 7.57575757575757575758E-3, |
| -4.16666666666666666667E-3, |
| 3.96825396825396825397E-3, |
| -8.33333333333333333333E-3, |
| 8.33333333333333333333E-2, |
| }; |
| |
| accscalar_t x = static_cast<accscalar_t>(in); |
| if (x == 0) { |
| return static_cast<scalar_t>(INFINITY); |
| } |
| |
| bool x_is_integer = x == ::floor(x); |
| accscalar_t result = 0; |
| if (x < 0) { |
| if (x_is_integer) { |
| return static_cast<scalar_t>(INFINITY); |
| } |
| // Rounding errors in tan's input can really affect the output |
| // for extreme values, so we always perform this computation in double. |
| result = static_cast<accscalar_t>(- PI_f64 / ::tan(PI_f64 * static_cast<double>(x))); |
| x = 1 - x; |
| } |
| |
| while (x < 10) { |
| result -= 1 / x; |
| x += 1; |
| } |
| if (x == 10) { |
| return static_cast<scalar_t>(result + PSI_10); |
| } |
| |
| accscalar_t y = 0; |
| if (x < 1.0e17) { |
| accscalar_t z = 1.0 / (x * x); |
| |
| accscalar_t polevl_result = 0; |
| for (int i = 0; i <= 6; i++) { |
| polevl_result = polevl_result * z + A[i]; |
| } |
| y = z * polevl_result; |
| } |
| |
| return static_cast<scalar_t>(::log(x) - (0.5 / x) - y + result); |
| } |
| |
| template <typename scalar_t> |
| static inline __host__ __device__ scalar_t calc_trigamma(scalar_t in) { |
| using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>; |
| const accscalar_t PI = 3.14159265358979323846; |
| accscalar_t x = static_cast<accscalar_t>(in); |
| accscalar_t sign = +1; |
| accscalar_t result = 0; |
| if (x < 0.5f) { |
| sign = -1; |
| accscalar_t sin_pi_x = ::sin(PI * x); |
| result -= (PI * PI) / (sin_pi_x * sin_pi_x); |
| x = 1 - x; |
| } |
| for (int i = 0; i < 6; ++i) { |
| result += 1 / (x * x); |
| x += 1; |
| } |
| const accscalar_t one = static_cast<scalar_t>(1); |
| const accscalar_t ixx = 1 / (x*x); |
| result += (1 + 1 / (2*x) + ixx * (one/6 - ixx * (one/30 - ixx * (one/42)))) / x; |
| return static_cast<scalar_t>(sign * result); |
| } |
| |
| } |
| } |