| #pragma once |
| |
| #if defined(__AVX__) && !defined(__NVCC__) && \ |
| (defined(__x86_64__) || defined(_M_X64) || defined(__i386__)) |
| #define CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC |
| #include <immintrin.h> |
| #endif |
| #include <c10/util/Half.h> |
| |
| namespace caffe2 { |
| |
| namespace internal { |
| |
| // The following functions inside internal namespace are inlined because they |
| // are performance critical. |
| |
| template <typename T> |
| static inline void adagrad_update_base_inlined( |
| int N, |
| const T* w, |
| const float* g, |
| const T* h, |
| T* nw, |
| T* nh, |
| float decay, |
| float epsilon, |
| float lr) { |
| for (auto i = 0; i < N; ++i) { |
| float gi = g[i]; |
| float hi = decay * h[i] + gi * gi; |
| nh[i] = hi; |
| nw[i] = w[i] + lr * gi / (std::sqrt(hi) + epsilon); |
| } |
| } |
| |
| // version with prefetching |
| // TODO(msmelyan) |
| // Crux of the computation is computing a / (sqrt(b) + epsilon), |
| // where a and b are vectors and epislon is very small (eg., 10^-5) and does not |
| // change. Today it's computed using two vector sqrt and vector divide simd |
| // instructions. It is slow. We can take advantage of existing fast vector |
| // VRSQRTPS instruction that computes approximate reciprocals of square roots |
| // of the vector. It is 6x faster than vsrt and vdiv combinations. Since the |
| // addition of epislon is just done to avoid division by zero, we approximate a |
| // / (sqrt(b) + epsilon) by a / (sqrt(b + sqrt(epsilon)) If we do that, we can |
| // use VRSQRTPS instead now. VRSQRTPS is not very accurate. Specifically, for |
| // the test on random numbers between 0.1 and 1 the absolute error was about |
| // 10^-3 compared to using slower but more accurate combination of vsqrt and |
| // vdiv. Extend Marat's function with more NR iterations to get more accuracy |
| // for training |
| // TODO(msmelyan) |
| // explore streaming stores, but need to have unique indices (deduplication) |
| inline void adagrad_update_prefetch_inlined( |
| int N, |
| const float* w, |
| #ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC |
| const float* w_n, // prefetch ptr |
| #else |
| const float* /* unused */, |
| #endif |
| |
| const float* g, |
| |
| const float* h, |
| #ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC |
| const float* h_n, // prefetch ptr |
| #else |
| const float* /* unused */, |
| #endif |
| |
| float* nw, |
| #ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC |
| float* nw_n, // prefetch ptr |
| #else |
| float* /* unused */, |
| #endif |
| |
| float* nh, |
| #ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC |
| float* nh_n, // prefetch ptr |
| #else |
| float* /* unused */, |
| #endif |
| |
| float epsilon, |
| float lr) { |
| auto i = 0; |
| |
| #ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC |
| constexpr int kSize = 8; |
| for (; i + kSize <= N; i += kSize) { |
| _mm_prefetch(reinterpret_cast<const char*>(&w_n[i]), _MM_HINT_T0); |
| _mm_prefetch(reinterpret_cast<const char*>(&h_n[i]), _MM_HINT_T0); |
| _mm_prefetch(reinterpret_cast<const char*>(&nw_n[i]), _MM_HINT_T0); |
| _mm_prefetch(reinterpret_cast<const char*>(&nh_n[i]), _MM_HINT_T0); |
| |
| __m256 gi = _mm256_loadu_ps(g + i); |
| __m256 hi = _mm256_loadu_ps(h + i); |
| __m256 wi = _mm256_loadu_ps(w + i); |
| |
| __m256 nhi = _mm256_add_ps(hi, _mm256_mul_ps(gi, gi)); |
| _mm256_storeu_ps(nh + i, nhi); |
| __m256 vtmp = _mm256_div_ps( |
| gi, _mm256_add_ps(_mm256_sqrt_ps(nhi), _mm256_set1_ps(epsilon))); |
| _mm256_storeu_ps( |
| nw + i, _mm256_add_ps(wi, _mm256_mul_ps(_mm256_set1_ps(lr), vtmp))); |
| } |
| #endif |
| |
| adagrad_update_base_inlined( |
| N - i, w + i, g + i, h + i, nw + i, nh + i, 1.0f, epsilon, lr); |
| } |
| |
| inline void rowwise_adagrad_update_inlined( |
| int N, |
| float* w, |
| #ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC |
| float* w_n, // prefetch ptr |
| #else |
| float* /* unused */, |
| #endif |
| |
| const float* g, |
| |
| float* h, |
| #ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC |
| float* h_n, // prefetch ptr |
| #else |
| float* /* unused */, |
| #endif |
| |
| float epsilon, |
| float lr) { |
| auto i = 0; |
| |
| #ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC |
| constexpr int kSize = 8; |
| _mm_prefetch(reinterpret_cast<const char*>(h_n), _MM_HINT_T0); |
| __m256 partial_sum = _mm256_setzero_ps(); |
| for (; i + kSize <= N; i += kSize) { |
| __m256 gi = _mm256_loadu_ps(g + i); |
| partial_sum = _mm256_add_ps(partial_sum, _mm256_mul_ps(gi, gi)); |
| } |
| // Reduce sum to 1 value |
| __m256 partial_sum_2 = _mm256_hadd_ps(partial_sum, partial_sum); |
| __m256 partial_sum_3 = _mm256_hadd_ps(partial_sum_2, partial_sum_2); |
| float final_sum = _mm_cvtss_f32(_mm256_castps256_ps128(partial_sum_3)) + |
| _mm_cvtss_f32(_mm256_extractf128_ps(partial_sum_3, 1)); |
| #else |
| float final_sum = 0.0f; |
| #endif |
| |
| for (; i < N; ++i) { |
| final_sum += g[i] * g[i]; |
| } |
| final_sum /= N; |
| |
| float hi = *h = *h + final_sum; |
| float float_step = lr / (std::sqrt(hi) + epsilon); |
| |
| i = 0; |
| #ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC |
| __m256 step = _mm256_set1_ps(float_step); |
| |
| for (i = 0; i + kSize <= N; i += kSize) { |
| _mm_prefetch(reinterpret_cast<const char*>(&w_n[i]), _MM_HINT_T0); |
| |
| __m256 gi = _mm256_loadu_ps(g + i); |
| __m256 wi = _mm256_loadu_ps(w + i); |
| |
| _mm256_storeu_ps(w + i, _mm256_add_ps(wi, _mm256_mul_ps(gi, step))); |
| } |
| #endif |
| |
| for (; i < N; ++i) { |
| float gi = g[i]; |
| w[i] = w[i] + gi * float_step; |
| } |
| } |
| |
| } // namespace internal |
| |
| // version with prefetching |
| // TODO(msmelyan) |
| // Crux of the computation is computing a / (sqrt(b) + epsilon), |
| // where a and b are vectors and epislon is very small (eg., 10^-5) and does not |
| // change. Today it's computed using two vector sqrt and vector divide simd |
| // instructions. It is slow. We can take advantage of existing fast vector |
| // VRSQRTPS instruction that computes approximate reciprocals of square roots |
| // of the vector. It is 6x faster than vsrt and vdiv combinations. Since the |
| // addition of epislon is just done to avoid division by zero, we approximate a |
| // / (sqrt(b) + epsilon) by a / (sqrt(b + sqrt(epsilon)) If we do that, we can |
| // use VRSQRTPS instead now. VRSQRTPS is not very accurate. Specifically, for |
| // the test on random numbers between 0.1 and 1 the absolute error was about |
| // 10^-3 compared to using slower but more accurate combination of vsqrt and |
| // vdiv. Extend Marat's function with more NR iterations to get more accuracy |
| // for training |
| // TODO(msmelyan) |
| // explore streaming stores, but need to have inuque indices (deduplication) |
| void adagrad_update_prefetch( |
| int N, |
| const float* w, |
| const float* w_n, // prefetch ptr |
| |
| const float* g, |
| |
| const float* h, |
| const float* h_n, // prefetch ptr |
| |
| float* nw, |
| float* nw_n, // prefetch ptr |
| |
| float* nh, |
| float* nh_n, // prefetch ptr |
| |
| float epsilon, |
| float lr); |
| |
| // Version with prefetching for embeddings and |
| // momentum using fp16 |
| void adagrad_fp16_update_prefetch( |
| int N, |
| const at::Half* w, |
| const at::Half* w_n, // prefetch ptr |
| const float* g, |
| const at::Half* h, |
| const at::Half* h_n, // prefetch ptr |
| at::Half* nw, |
| at::Half* nw_n, // prefetch ptr |
| at::Half* nh, |
| at::Half* nh_n, // prefetch ptr |
| float epsilon, |
| float lr); |
| |
| void rowwise_adagrad_update( |
| int N, |
| float* w, |
| float* w_n, // prefetch ptr |
| |
| const float* g, |
| |
| float* h, |
| float* h_n, // prefetch ptr |
| |
| float epsilon, |
| float lr); |
| |
| // version without prefetching |
| void adagrad_update( |
| int N, |
| const float* w, |
| const float* g, |
| const float* h, |
| float* nw, |
| float* nh, |
| float epsilon, |
| float decay, |
| float lr); |
| |
| /** |
| * @return num_rows if succeeds otherwise return the row idx where we pass |
| * the boundary of param_size |
| */ |
| template <typename SIndex> |
| int sparse_adagrad( |
| int num_rows, // number of rows reading |
| int block_size, // number of parameters per rows |
| std::uint64_t param_size, // total number of parameters |
| const float* w, // input parameters |
| const float* g, // input gradients |
| const float* h, // input momentums |
| const SIndex* indices, // indices of each row |
| float* nw, // output parameters |
| float* nh, // output momentums |
| float epsilon, |
| float lr); |
| |
| #define SPARSE_ADAGRAD_SPECIALIZATION(SIndex, ISA) \ |
| int sparse_adagrad_##SIndex##__##ISA( \ |
| int num_rows, \ |
| int block_size, \ |
| std::uint64_t param_size, \ |
| const float* w, \ |
| const float* g, \ |
| const float* h, \ |
| const SIndex* indices, \ |
| float* nw, \ |
| float* nh, \ |
| float epsilon, \ |
| float lr) { \ |
| for (int i = 0; i < num_rows; ++i) { \ |
| std::uint64_t idx = indices[i]; \ |
| auto offsetI = i * block_size; \ |
| auto offsetIdx = idx * block_size; \ |
| \ |
| if (block_size + offsetIdx > param_size) { \ |
| return i; \ |
| } \ |
| \ |
| if (block_size == 1) { \ |
| float gi = g[i]; \ |
| float hi = nh[idx] = h[idx] + gi * gi; \ |
| nw[idx] = w[idx] + lr * gi / (std::sqrt(hi) + epsilon); \ |
| } else { \ |
| const int prefdist_T0 = 16; \ |
| int i_pref = (i < num_rows - prefdist_T0) ? i + prefdist_T0 : i; \ |
| std::uint64_t idx_pref = indices[i_pref]; \ |
| \ |
| adagrad_update_prefetch__##ISA( \ |
| block_size, \ |
| w + offsetIdx, \ |
| &w[idx_pref * block_size], \ |
| g + offsetI, \ |
| h + offsetIdx, \ |
| &h[idx_pref * block_size], \ |
| nw + offsetIdx, \ |
| &nw[idx_pref * block_size], \ |
| nh + offsetIdx, \ |
| &nh[idx_pref * block_size], \ |
| epsilon, \ |
| lr); \ |
| } \ |
| } \ |
| return num_rows; \ |
| }; |
| |
| } // namespace caffe2 |
| |
| #ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC |
| #undef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC |
| #endif |