[caffe2] use JIT'ed fp32 SLS (#32413)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32413

Use JIT'ed fp32 SLS in Caffe2 operators

Test Plan: CI

Reviewed By: jianyuh

Differential Revision: D19460555

fbshipit-source-id: 4f29d34523efb6ea1e4c324cc8c93c96990c6aad
diff --git a/caffe2/operators/lengths_reducer_ops.h b/caffe2/operators/lengths_reducer_ops.h
index c5f3428..ecd2dd5 100644
--- a/caffe2/operators/lengths_reducer_ops.h
+++ b/caffe2/operators/lengths_reducer_ops.h
@@ -2,6 +2,9 @@
 #include "caffe2/core/context.h"
 #include "caffe2/core/operator.h"
 #include "caffe2/perfkernels/embedding_lookup.h"
+#ifdef USE_FBGEMM
+#include "fbgemm/Fbgemm.h"
+#endif
 
 namespace caffe2 {
 
@@ -9,10 +12,10 @@
 template <
     typename T, // output type
     class InputTypes, // supported input types, such as TensorTypes<float>
-    bool USE_WEIGHT = 0, // Whether it is SparseLengthsWeightedSum
-    bool USE_MEAN = 0, // Whether this is SparseLengthsMean
-    bool USE_POSITIONAL_WEIGHT = 0
-    // USE_WEIGHT = 1 and USE_POSITIONAL_WEIGHT = 1
+    bool USE_WEIGHT = false, // Whether it is SparseLengthsWeightedSum
+    bool USE_MEAN = false, // Whether this is SparseLengthsMean
+    bool USE_POSITIONAL_WEIGHT = false
+    // USE_WEIGHT = true and USE_POSITIONAL_WEIGHT = true
     // -> SparseLengthsPositionalWeightedSum
     >
 class CPUSparseLengthsReductionOp : public Operator<CPUContext> {
@@ -76,6 +79,83 @@
       in_weight = weightInput.template data<T>();
     }
 
+#ifdef USE_FBGEMM
+    if (std::is_same<InputType, float>::value) {
+      // If this is the first call or block size has changed (should never
+      // happen actually), generate a kernel.
+      if (D != last_block_size) {
+        last_block_size = D;
+        if (std::is_same<IndexType, std::int32_t>::value) {
+          kernel32_ = fbgemm::GenerateEmbeddingSpMDM<float, std::int32_t>(
+              D,
+              USE_WEIGHT,
+              USE_MEAN,
+              /*prefetch distance*/ 16,
+              USE_POSITIONAL_WEIGHT);
+        } else {
+          CAFFE_ENFORCE((std::is_same<IndexType, std::int64_t>::value));
+          kernel64_ = fbgemm::GenerateEmbeddingSpMDM<float, std::int64_t>(
+              D,
+              USE_WEIGHT,
+              USE_MEAN,
+              /*prefetch distance*/ 16,
+              USE_POSITIONAL_WEIGHT);
+        }
+      }
+
+      bool success;
+      if (std::is_same<IndexType, std::int32_t>::value) {
+        success = kernel32_(
+            M,
+            indices_size,
+            N,
+            reinterpret_cast<const float*>(in_data),
+            indicesInput.template data<std::int32_t>(),
+            lengths,
+            in_weight,
+            out_data);
+      } else {
+        success = kernel64_(
+            M,
+            indices_size,
+            N,
+            reinterpret_cast<const float*>(in_data),
+            indicesInput.template data<std::int64_t>(),
+            lengths,
+            in_weight,
+            out_data);
+      }
+
+      if (success) {
+        return true;
+      }
+
+      int64_t current = 0;
+      for (int m = 0; m < M; ++m) {
+        for (int i = 0; i < lengths[m]; ++i) {
+          CAFFE_ENFORCE_LT(current, indices_size);
+          IndexType idx = indices[current];
+          CAFFE_ENFORCE(
+              0 <= idx && idx < N,
+              "Index ",
+              current,
+              " is out of bounds: ",
+              idx,
+              ", range 0 to ",
+              N);
+          ++current;
+        }
+      }
+      CAFFE_ENFORCE_EQ(
+          current,
+          indices_size,
+          "Your input seems to be incorrect: the sum of lengths values should be "
+          "the size of the indices tensor, but it appears not.");
+
+      return false;
+    }
+#endif
+
     // delegate work to perfkernel that branches based on architecture
     EmbeddingLookup<IndexType, InputType, T, USE_POSITIONAL_WEIGHT>(
         D,
@@ -100,6 +180,13 @@
     LENGTHS = 2 + USE_WEIGHT, // 2 in SparseLengths[Sum, Mean],
                               // 3 in SparseLengthsWeightedSum
   };
+
+#ifdef USE_FBGEMM
+ private:
+  std::int64_t last_block_size{-1};
+  fbgemm::EmbeddingSpMDMKernelSignature<float, std::int32_t>::Type kernel32_;
+  fbgemm::EmbeddingSpMDMKernelSignature<float, std::int64_t>::Type kernel64_;
+#endif
 };
 
 } // namespace caffe2