[caffe2] use JIT'ed fp32 SLS (#32413)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32413
Use JIT'ed fp32 SLS in Caffe2 operators
Test Plan: CI
Reviewed By: jianyuh
Differential Revision: D19460555
fbshipit-source-id: 4f29d34523efb6ea1e4c324cc8c93c96990c6aad
diff --git a/caffe2/operators/lengths_reducer_ops.h b/caffe2/operators/lengths_reducer_ops.h
index c5f3428..ecd2dd5 100644
--- a/caffe2/operators/lengths_reducer_ops.h
+++ b/caffe2/operators/lengths_reducer_ops.h
@@ -2,6 +2,9 @@
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/perfkernels/embedding_lookup.h"
+#ifdef USE_FBGEMM
+#include "fbgemm/Fbgemm.h"
+#endif
namespace caffe2 {
@@ -9,10 +12,10 @@
template <
typename T, // output type
class InputTypes, // supported input types, such as TensorTypes<float>
- bool USE_WEIGHT = 0, // Whether it is SparseLengthsWeightedSum
- bool USE_MEAN = 0, // Whether this is SparseLengthsMean
- bool USE_POSITIONAL_WEIGHT = 0
- // USE_WEIGHT = 1 and USE_POSITIONAL_WEIGHT = 1
+ bool USE_WEIGHT = false, // Whether it is SparseLengthsWeightedSum
+ bool USE_MEAN = false, // Whether this is SparseLengthsMean
+ bool USE_POSITIONAL_WEIGHT = false
+ // USE_WEIGHT = true and USE_POSITIONAL_WEIGHT = true
// -> SparseLengthsPositionalWeightedSum
>
class CPUSparseLengthsReductionOp : public Operator<CPUContext> {
@@ -76,6 +79,83 @@
in_weight = weightInput.template data<T>();
}
+#ifdef USE_FBGEMM
+ if (std::is_same<InputType, float>::value) {
+ // If this is the first call or block size has changed (should never
+ // happen actually), generate a kernel.
+ if (D != last_block_size) {
+ last_block_size = D;
+ if (std::is_same<IndexType, std::int32_t>::value) {
+ kernel32_ = fbgemm::GenerateEmbeddingSpMDM<float, std::int32_t>(
+ D,
+ USE_WEIGHT,
+ USE_MEAN,
+ /*prefetch distance*/ 16,
+ USE_POSITIONAL_WEIGHT);
+ } else {
+ CAFFE_ENFORCE((std::is_same<IndexType, std::int64_t>::value));
+ kernel64_ = fbgemm::GenerateEmbeddingSpMDM<float, std::int64_t>(
+ D,
+ USE_WEIGHT,
+ USE_MEAN,
+ /*prefetch distance*/ 16,
+ USE_POSITIONAL_WEIGHT);
+ }
+ }
+
+ bool success;
+ if (std::is_same<IndexType, std::int32_t>::value) {
+ success = kernel32_(
+ M,
+ indices_size,
+ N,
+ reinterpret_cast<const float*>(in_data),
+ indicesInput.template data<std::int32_t>(),
+ lengths,
+ in_weight,
+ out_data);
+ } else {
+ success = kernel64_(
+ M,
+ indices_size,
+ N,
+ reinterpret_cast<const float*>(in_data),
+ indicesInput.template data<std::int64_t>(),
+ lengths,
+ in_weight,
+ out_data);
+ }
+
+ if (success) {
+ return true;
+ }
+
+ int64_t current = 0;
+ for (int m = 0; m < M; ++m) {
+ for (int i = 0; i < lengths[m]; ++i) {
+ CAFFE_ENFORCE_LT(current, indices_size);
+ IndexType idx = indices[current];
+ CAFFE_ENFORCE(
+ 0 <= idx && idx < N,
+ "Index ",
+ current,
+ " is out of bounds: ",
+ idx,
+ ", range 0 to ",
+ N);
+ ++current;
+ }
+ }
+ CAFFE_ENFORCE_EQ(
+ current,
+ indices_size,
+ "Your input seems to be incorrect: the sum of lengths values should be "
+ "the size of the indices tensor, but it appears not.");
+
+ return false;
+ }
+#endif
+
// delegate work to perfkernel that branches based on architecture
EmbeddingLookup<IndexType, InputType, T, USE_POSITIONAL_WEIGHT>(
D,
@@ -100,6 +180,13 @@
LENGTHS = 2 + USE_WEIGHT, // 2 in SparseLengths[Sum, Mean],
// 3 in SparseLengthsWeightedSum
};
+
+#ifdef USE_FBGEMM
+ private:
+ std::int64_t last_block_size{-1};
+ fbgemm::EmbeddingSpMDMKernelSignature<float, std::int32_t>::Type kernel32_;
+ fbgemm::EmbeddingSpMDMKernelSignature<float, std::int64_t>::Type kernel64_;
+#endif
};
} // namespace caffe2