add a fast path for EmbeddingBag calling FBGEMM (#36679)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/36679

Test Plan:
Imported from OSS

Unit tests:
python test/run_test.py -i test_nn -- TestNNDeviceTypeCPU.test_EmbeddingBag_per_sample_weights_failures_cpu
python test/run_test.py -i test_nn -- TestNNDeviceTypeCPU.test_EmbeddingBag_per_sample_weights_and_offsets_cpu
python test/run_test.py -i test_nn -- TestNNDeviceTypeCPU.test_EmbeddingBag_per_sample_weights_and_new_offsets_cpu
python test/run_test.py -i test_nn -- TestNNDeviceTypeCPU.test_EmbeddingBag_per_sample_weights_and_no_offsets_cpu
python test/test_nn.py TestNN.test_embeddingbag_from_pretrained
python test/test_nn.py TestNN.test_embeddingbag_from_pretrained_options

Finally run: python test/test_nn.py

Reviewed By: supriyar

Differential Revision: D21058034

Pulled By: xing-liu

fbshipit-source-id: 8fef39078132f63c406976d6b76c51f9ce573f90
diff --git a/aten/src/ATen/native/EmbeddingBag.cpp b/aten/src/ATen/native/EmbeddingBag.cpp
index 29cf2da..2ad6d20 100644
--- a/aten/src/ATen/native/EmbeddingBag.cpp
+++ b/aten/src/ATen/native/EmbeddingBag.cpp
@@ -5,7 +5,11 @@
 
 #include <TH/THBlasUtils.h>
 
+#ifdef USE_FBGEMM
+#include <fbgemm/Fbgemm.h>
+#else
 #include <caffe2/perfkernels/embedding_lookup_idx.h>
+#endif
 
 #include <cstring>
 #include <iostream>
@@ -100,8 +104,30 @@
       offsets_data = offsets_include_last.data();
     }
 
+#ifdef USE_FBGEMM
+    auto kernel_fp32_i64 =
+      fbgemm::GenerateEmbeddingSpMDM<float, int64_t, int64_t>(
+        /* block_size */ddim,
+        /* has_weight */false,
+        /* normalize_by_lengths */false,
+        /* prefetch */16,
+        /* is_weight_positional */false,
+        /* use_offsets */true
+      );
+#endif
     at::parallel_for(
         0, output_size, 1, [&](int64_t start_idx, int64_t end_idx) {
+#ifdef USE_FBGEMM
+          kernel_fp32_i64(
+            /* output_size */end_idx - start_idx,
+            /* index_size */offsets_data[end_idx] - offsets_data[start_idx],
+            /* data_size */src.size(0),
+            /* input */src_data,
+            /* indices */select_indices_data + offsets_data[start_idx],
+            /* offsets_or_lengths */offsets_data + start_idx,
+            /* weights */nullptr,
+            /* output */output_data + start_idx * ddim);
+#else
           caffe2::EmbeddingLookupIdx(
               /*block_size=*/ddim,
               /*output_size=*/end_idx - start_idx,
@@ -114,6 +140,7 @@
               /*scale_bias=*/nullptr,
               /*normalize_by_lengths=*/false,
               /*out=*/output_data + start_idx * ddim);
+#endif
         });
   } else {
     AT_ASSERT(select_indices.numel() == add_indices.numel());
@@ -204,8 +231,30 @@
       offsets_data = offsets_include_last.data();
     }
 
+#ifdef USE_FBGEMM
+    auto kernel_fp32_i64 =
+      fbgemm::GenerateEmbeddingSpMDM<float, int64_t, int64_t>(
+        /* block_size */ddim,
+        /* has_weight */true,
+        /* normalize_by_lengths */false,
+        /* prefetch */16,
+        /* is_weight_positional */false,
+        /* use_offsets */true
+      );
+#endif
     at::parallel_for(
         0, output_size, 1, [&](int64_t start_idx, int64_t end_idx) {
+#ifdef USE_FBGEMM
+          kernel_fp32_i64(
+            /* output_size */end_idx - start_idx,
+            /* index_size */offsets_data[end_idx] - offsets_data[start_idx],
+            /* data_size */src.size(0),
+            /* input */src_data,
+            /* indices */select_indices_data + offsets_data[start_idx],
+            /* offsets_or_lengths */offsets_data + start_idx,
+            /* weights */scale_data + offsets_data[start_idx],
+            /* output */output_data + start_idx * ddim);
+#else
           caffe2::EmbeddingLookupIdx(
               /*block_size=*/ddim,
               /*output_size=*/end_idx - start_idx,
@@ -218,6 +267,7 @@
               /*scale_bias=*/nullptr,
               /*normalize_by_lengths=*/false,
               /*out=*/output_data + start_idx * ddim);
+#endif
         });
   } else {
     AT_ASSERT(select_indices.numel() == add_indices.numel());