| #pragma once | 
 |  | 
 | #include <cstdint> | 
 |  | 
 | namespace caffe2 { | 
 |  | 
 | // clang-format off | 
 | /** | 
 |  * Embedding lookup with reduction. | 
 |  * | 
 |  * `input` of size data_size * block_size | 
 |  * `indices` of size index_size | 
 |  * `offsets` of size output_size | 
 |  * `weights` nullptr or array of size index_size | 
 |  * `out` of size output_size * block_size | 
 |  * | 
 |  * Behavior is roughly equivalent to pseudocode: | 
 |  * | 
 |  * pos = 0 | 
 |  * for (i = 0..output_size-1) | 
 |  *   for (k = 0..block_size-1) | 
 |  *     out[i*block_size + k] = 0 | 
 |  *   start_offset = offsets[i] | 
 |  *   end_offset = offsets[i+1] | 
 |  *   length = end_offset - start_offset | 
 |  *   for (j = start_offset..end_offset-1) | 
 |  *     for (k = 0..block_size-1) | 
 |  *       out[i*block_size + k] += input[indices[pos]*block_size + k] * | 
 |  *           (weights ? weights[IS_WEIGHT_POSITIONAL ? j - start_offset : pos] : 1.0) | 
 |  *     pos += 1 | 
 |  *   if (normalize_weights && length > 0) | 
 |  *     for (k = 0..block_size-1) | 
 |  *       out[i*block_size + k] /= length | 
 |  * | 
 |  * TODO: make this API also take "offsets" rather than "lengths" to match the | 
 |  *       API for PyTorch's EmbeddingBag | 
 |  */ | 
 | // clang-format on | 
 | template < | 
 |     typename IndexType, | 
 |     typename InType, | 
 |     typename OutType, | 
 |     bool IS_WEIGHT_POSITIONAL = false> | 
 | void EmbeddingLookupIdx( | 
 |     const std::int64_t block_size, | 
 |     const std::int64_t output_size, | 
 |     const std::int64_t index_size, | 
 |     const std::int64_t data_size, | 
 |     const InType* input, | 
 |     const IndexType* indices, | 
 |     const IndexType* offsets, | 
 |     const float* weights, // optional, can be null for non-weighted sum | 
 |     const float* scale_bias, // optional scale & bias params for uint8 input | 
 |     bool normalize_by_lengths, | 
 |     OutType* out); | 
 |  | 
 | } // namespace caffe2 |