add a fast path for EmbeddingBag calling FBGEMM (#36679)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/36679
Test Plan:
Imported from OSS
Unit tests:
python test/run_test.py -i test_nn -- TestNNDeviceTypeCPU.test_EmbeddingBag_per_sample_weights_failures_cpu
python test/run_test.py -i test_nn -- TestNNDeviceTypeCPU.test_EmbeddingBag_per_sample_weights_and_offsets_cpu
python test/run_test.py -i test_nn -- TestNNDeviceTypeCPU.test_EmbeddingBag_per_sample_weights_and_new_offsets_cpu
python test/run_test.py -i test_nn -- TestNNDeviceTypeCPU.test_EmbeddingBag_per_sample_weights_and_no_offsets_cpu
python test/test_nn.py TestNN.test_embeddingbag_from_pretrained
python test/test_nn.py TestNN.test_embeddingbag_from_pretrained_options
Finally run: python test/test_nn.py
Reviewed By: supriyar
Differential Revision: D21058034
Pulled By: xing-liu
fbshipit-source-id: 8fef39078132f63c406976d6b76c51f9ce573f90
diff --git a/aten/src/ATen/native/EmbeddingBag.cpp b/aten/src/ATen/native/EmbeddingBag.cpp
index 29cf2da..2ad6d20 100644
--- a/aten/src/ATen/native/EmbeddingBag.cpp
+++ b/aten/src/ATen/native/EmbeddingBag.cpp
@@ -5,7 +5,11 @@
#include <TH/THBlasUtils.h>
+#ifdef USE_FBGEMM
+#include <fbgemm/Fbgemm.h>
+#else
#include <caffe2/perfkernels/embedding_lookup_idx.h>
+#endif
#include <cstring>
#include <iostream>
@@ -100,8 +104,30 @@
offsets_data = offsets_include_last.data();
}
+#ifdef USE_FBGEMM
+ auto kernel_fp32_i64 =
+ fbgemm::GenerateEmbeddingSpMDM<float, int64_t, int64_t>(
+ /* block_size */ddim,
+ /* has_weight */false,
+ /* normalize_by_lengths */false,
+ /* prefetch */16,
+ /* is_weight_positional */false,
+ /* use_offsets */true
+ );
+#endif
at::parallel_for(
0, output_size, 1, [&](int64_t start_idx, int64_t end_idx) {
+#ifdef USE_FBGEMM
+ kernel_fp32_i64(
+ /* output_size */end_idx - start_idx,
+ /* index_size */offsets_data[end_idx] - offsets_data[start_idx],
+ /* data_size */src.size(0),
+ /* input */src_data,
+ /* indices */select_indices_data + offsets_data[start_idx],
+ /* offsets_or_lengths */offsets_data + start_idx,
+ /* weights */nullptr,
+ /* output */output_data + start_idx * ddim);
+#else
caffe2::EmbeddingLookupIdx(
/*block_size=*/ddim,
/*output_size=*/end_idx - start_idx,
@@ -114,6 +140,7 @@
/*scale_bias=*/nullptr,
/*normalize_by_lengths=*/false,
/*out=*/output_data + start_idx * ddim);
+#endif
});
} else {
AT_ASSERT(select_indices.numel() == add_indices.numel());
@@ -204,8 +231,30 @@
offsets_data = offsets_include_last.data();
}
+#ifdef USE_FBGEMM
+ auto kernel_fp32_i64 =
+ fbgemm::GenerateEmbeddingSpMDM<float, int64_t, int64_t>(
+ /* block_size */ddim,
+ /* has_weight */true,
+ /* normalize_by_lengths */false,
+ /* prefetch */16,
+ /* is_weight_positional */false,
+ /* use_offsets */true
+ );
+#endif
at::parallel_for(
0, output_size, 1, [&](int64_t start_idx, int64_t end_idx) {
+#ifdef USE_FBGEMM
+ kernel_fp32_i64(
+ /* output_size */end_idx - start_idx,
+ /* index_size */offsets_data[end_idx] - offsets_data[start_idx],
+ /* data_size */src.size(0),
+ /* input */src_data,
+ /* indices */select_indices_data + offsets_data[start_idx],
+ /* offsets_or_lengths */offsets_data + start_idx,
+ /* weights */scale_data + offsets_data[start_idx],
+ /* output */output_data + start_idx * ddim);
+#else
caffe2::EmbeddingLookupIdx(
/*block_size=*/ddim,
/*output_size=*/end_idx - start_idx,
@@ -218,6 +267,7 @@
/*scale_bias=*/nullptr,
/*normalize_by_lengths=*/false,
/*out=*/output_data + start_idx * ddim);
+#endif
});
} else {
AT_ASSERT(select_indices.numel() == add_indices.numel());