[caffe2] use JIT'ed fp16 SLS (#32432)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32432

Use JIT'ed fp16 SLS in D19477209 from Caffe2 operators

Test Plan: CI

Reviewed By: jianyuh

Differential Revision: D19477208

fbshipit-source-id: ef2ccba10f5f4c475166141bf09c266dedb92d38
diff --git a/caffe2/operators/lengths_reducer_ops.h b/caffe2/operators/lengths_reducer_ops.h
index 3dbbd99..518ced5 100644
--- a/caffe2/operators/lengths_reducer_ops.h
+++ b/caffe2/operators/lengths_reducer_ops.h
@@ -80,32 +80,56 @@
     }
 
 #ifdef USE_FBGEMM
-    if (std::is_same<InputType, float>::value) {
-      // If this is the first call or block size has changed (should never
-      // happen actually), generate a kernel.
-      if (D != last_block_size) {
-        last_block_size = D;
+    // If this is the first call or block size has changed (should never
+    // happen actually), generate a kernel.
+    if (D != last_block_size) {
+      last_block_size = D;
+      if (std::is_same<InputType, float>::value) {
         if (std::is_same<IndexType, std::int32_t>::value) {
-          kernel32_ = fbgemm::GenerateEmbeddingSpMDM<float, std::int32_t>(
-              D,
-              USE_WEIGHT,
-              USE_MEAN,
-              /*prefetch distance*/ 16,
-              USE_POSITIONAL_WEIGHT);
+          kernel_fp32_i32_ =
+              fbgemm::GenerateEmbeddingSpMDM<float, std::int32_t>(
+                  D,
+                  USE_WEIGHT,
+                  USE_MEAN,
+                  /*prefetch distance*/ 16,
+                  USE_POSITIONAL_WEIGHT);
         } else {
           CAFFE_ENFORCE((std::is_same<IndexType, std::int64_t>::value));
-          kernel64_ = fbgemm::GenerateEmbeddingSpMDM<float, std::int64_t>(
-              D,
-              USE_WEIGHT,
-              USE_MEAN,
-              /*prefetch distance*/ 16,
-              USE_POSITIONAL_WEIGHT);
+          kernel_fp32_i64_ =
+              fbgemm::GenerateEmbeddingSpMDM<float, std::int64_t>(
+                  D,
+                  USE_WEIGHT,
+                  USE_MEAN,
+                  /*prefetch distance*/ 16,
+                  USE_POSITIONAL_WEIGHT);
+        }
+      } else {
+        CAFFE_ENFORCE((std::is_same<InputType, at::Half>::value));
+        if (std::is_same<IndexType, std::int32_t>::value) {
+          kernel_fp16_i32_ =
+              fbgemm::GenerateEmbeddingSpMDM<fbgemm::float16, std::int32_t>(
+                  D,
+                  USE_WEIGHT,
+                  USE_MEAN,
+                  /*prefetch distance*/ 16,
+                  USE_POSITIONAL_WEIGHT);
+        } else {
+          CAFFE_ENFORCE((std::is_same<IndexType, std::int64_t>::value));
+          kernel_fp16_i64_ =
+              fbgemm::GenerateEmbeddingSpMDM<fbgemm::float16, std::int64_t>(
+                  D,
+                  USE_WEIGHT,
+                  USE_MEAN,
+                  /*prefetch distance*/ 16,
+                  USE_POSITIONAL_WEIGHT);
         }
       }
+    }
 
-      bool success;
+    bool success;
+    if (std::is_same<InputType, float>::value) {
       if (std::is_same<IndexType, std::int32_t>::value) {
-        success = kernel32_(
+        success = kernel_fp32_i32_(
             M,
             indices_size,
             N,
@@ -115,7 +139,7 @@
             in_weight,
             out_data);
       } else {
-        success = kernel64_(
+        success = kernel_fp32_i64_(
             M,
             indices_size,
             N,
@@ -125,39 +149,61 @@
             in_weight,
             out_data);
       }
-
-      if (success) {
-        return true;
+    } else {
+      if (std::is_same<IndexType, std::int32_t>::value) {
+        success = kernel_fp16_i32_(
+            M,
+            indices_size,
+            N,
+            reinterpret_cast<const fbgemm::float16*>(in_data),
+            indicesInput.template data<std::int32_t>(),
+            lengths,
+            in_weight,
+            out_data);
+      } else {
+        success = kernel_fp16_i64_(
+            M,
+            indices_size,
+            N,
+            reinterpret_cast<const fbgemm::float16*>(in_data),
+            indicesInput.template data<std::int64_t>(),
+            lengths,
+            in_weight,
+            out_data);
       }
-
-      int64_t current = 0;
-      for (int m = 0; m < M; ++m) {
-        for (int i = 0; i < lengths[m]; ++i) {
-          CAFFE_ENFORCE_LT(
-              current,
-              indices_size,
-              "Your input seems to be incorrect: the sum of lengths values "
-              "should be the size of the indices tensor, but it appears not.");
-          IndexType idx = indices[current];
-          CAFFE_ENFORCE(
-              0 <= idx && idx < N,
-              "Index ",
-              current,
-              " is out of bounds: ",
-              idx,
-              ", range 0 to ",
-              N);
-          ++current;
-        }
-      }
-      CAFFE_ENFORCE_EQ(
-          current,
-          indices_size,
-          "Your input seems to be incorrect: the sum of lengths values should be "
-          "the size of the indices tensor, but it appears not.");
-
-      return false;
     }
+
+    if (success) {
+      return true;
+    }
+
+    int64_t current = 0;
+    for (int m = 0; m < M; ++m) {
+      for (int i = 0; i < lengths[m]; ++i) {
+        CAFFE_ENFORCE_LT(
+            current,
+            indices_size,
+            "Your input seems to be incorrect: the sum of lengths values "
+            "should be the size of the indices tensor, but it appears not.");
+        IndexType idx = indices[current];
+        CAFFE_ENFORCE(
+            0 <= idx && idx < N,
+            "Index ",
+            current,
+            " is out of bounds: ",
+            idx,
+            ", range 0 to ",
+            N);
+        ++current;
+      }
+    }
+    CAFFE_ENFORCE_EQ(
+        current,
+        indices_size,
+        "Your input seems to be incorrect: the sum of lengths values should be "
+        "the size of the indices tensor, but it appears not.");
+
+    return false;
 #endif
 
     // delegate work to perfkernel that branches based on architecture
@@ -188,8 +234,14 @@
 #ifdef USE_FBGEMM
  private:
   std::int64_t last_block_size{-1};
-  fbgemm::EmbeddingSpMDMKernelSignature<float, std::int32_t>::Type kernel32_;
-  fbgemm::EmbeddingSpMDMKernelSignature<float, std::int64_t>::Type kernel64_;
+  fbgemm::EmbeddingSpMDMKernelSignature<float, std::int32_t>::Type
+      kernel_fp32_i32_;
+  fbgemm::EmbeddingSpMDMKernelSignature<float, std::int64_t>::Type
+      kernel_fp32_i64_;
+  fbgemm::EmbeddingSpMDMKernelSignature<fbgemm::float16, std::int32_t>::Type
+      kernel_fp16_i32_;
+  fbgemm::EmbeddingSpMDMKernelSignature<fbgemm::float16, std::int64_t>::Type
+      kernel_fp16_i64_;
 #endif
 };