[caffe2] use JIT'ed fp16 SLS (#32432)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32432
Use JIT'ed fp16 SLS in D19477209 from Caffe2 operators
Test Plan: CI
Reviewed By: jianyuh
Differential Revision: D19477208
fbshipit-source-id: ef2ccba10f5f4c475166141bf09c266dedb92d38
diff --git a/caffe2/operators/lengths_reducer_ops.h b/caffe2/operators/lengths_reducer_ops.h
index 3dbbd99..518ced5 100644
--- a/caffe2/operators/lengths_reducer_ops.h
+++ b/caffe2/operators/lengths_reducer_ops.h
@@ -80,32 +80,56 @@
}
#ifdef USE_FBGEMM
- if (std::is_same<InputType, float>::value) {
- // If this is the first call or block size has changed (should never
- // happen actually), generate a kernel.
- if (D != last_block_size) {
- last_block_size = D;
+ // If this is the first call or block size has changed (should never
+ // happen actually), generate a kernel.
+ if (D != last_block_size) {
+ last_block_size = D;
+ if (std::is_same<InputType, float>::value) {
if (std::is_same<IndexType, std::int32_t>::value) {
- kernel32_ = fbgemm::GenerateEmbeddingSpMDM<float, std::int32_t>(
- D,
- USE_WEIGHT,
- USE_MEAN,
- /*prefetch distance*/ 16,
- USE_POSITIONAL_WEIGHT);
+ kernel_fp32_i32_ =
+ fbgemm::GenerateEmbeddingSpMDM<float, std::int32_t>(
+ D,
+ USE_WEIGHT,
+ USE_MEAN,
+ /*prefetch distance*/ 16,
+ USE_POSITIONAL_WEIGHT);
} else {
CAFFE_ENFORCE((std::is_same<IndexType, std::int64_t>::value));
- kernel64_ = fbgemm::GenerateEmbeddingSpMDM<float, std::int64_t>(
- D,
- USE_WEIGHT,
- USE_MEAN,
- /*prefetch distance*/ 16,
- USE_POSITIONAL_WEIGHT);
+ kernel_fp32_i64_ =
+ fbgemm::GenerateEmbeddingSpMDM<float, std::int64_t>(
+ D,
+ USE_WEIGHT,
+ USE_MEAN,
+ /*prefetch distance*/ 16,
+ USE_POSITIONAL_WEIGHT);
+ }
+ } else {
+ CAFFE_ENFORCE((std::is_same<InputType, at::Half>::value));
+ if (std::is_same<IndexType, std::int32_t>::value) {
+ kernel_fp16_i32_ =
+ fbgemm::GenerateEmbeddingSpMDM<fbgemm::float16, std::int32_t>(
+ D,
+ USE_WEIGHT,
+ USE_MEAN,
+ /*prefetch distance*/ 16,
+ USE_POSITIONAL_WEIGHT);
+ } else {
+ CAFFE_ENFORCE((std::is_same<IndexType, std::int64_t>::value));
+ kernel_fp16_i64_ =
+ fbgemm::GenerateEmbeddingSpMDM<fbgemm::float16, std::int64_t>(
+ D,
+ USE_WEIGHT,
+ USE_MEAN,
+ /*prefetch distance*/ 16,
+ USE_POSITIONAL_WEIGHT);
}
}
+ }
- bool success;
+ bool success;
+ if (std::is_same<InputType, float>::value) {
if (std::is_same<IndexType, std::int32_t>::value) {
- success = kernel32_(
+ success = kernel_fp32_i32_(
M,
indices_size,
N,
@@ -115,7 +139,7 @@
in_weight,
out_data);
} else {
- success = kernel64_(
+ success = kernel_fp32_i64_(
M,
indices_size,
N,
@@ -125,39 +149,61 @@
in_weight,
out_data);
}
-
- if (success) {
- return true;
+ } else {
+ if (std::is_same<IndexType, std::int32_t>::value) {
+ success = kernel_fp16_i32_(
+ M,
+ indices_size,
+ N,
+ reinterpret_cast<const fbgemm::float16*>(in_data),
+ indicesInput.template data<std::int32_t>(),
+ lengths,
+ in_weight,
+ out_data);
+ } else {
+ success = kernel_fp16_i64_(
+ M,
+ indices_size,
+ N,
+ reinterpret_cast<const fbgemm::float16*>(in_data),
+ indicesInput.template data<std::int64_t>(),
+ lengths,
+ in_weight,
+ out_data);
}
-
- int64_t current = 0;
- for (int m = 0; m < M; ++m) {
- for (int i = 0; i < lengths[m]; ++i) {
- CAFFE_ENFORCE_LT(
- current,
- indices_size,
- "Your input seems to be incorrect: the sum of lengths values "
- "should be the size of the indices tensor, but it appears not.");
- IndexType idx = indices[current];
- CAFFE_ENFORCE(
- 0 <= idx && idx < N,
- "Index ",
- current,
- " is out of bounds: ",
- idx,
- ", range 0 to ",
- N);
- ++current;
- }
- }
- CAFFE_ENFORCE_EQ(
- current,
- indices_size,
- "Your input seems to be incorrect: the sum of lengths values should be "
- "the size of the indices tensor, but it appears not.");
-
- return false;
}
+
+ if (success) {
+ return true;
+ }
+
+ int64_t current = 0;
+ for (int m = 0; m < M; ++m) {
+ for (int i = 0; i < lengths[m]; ++i) {
+ CAFFE_ENFORCE_LT(
+ current,
+ indices_size,
+ "Your input seems to be incorrect: the sum of lengths values "
+ "should be the size of the indices tensor, but it appears not.");
+ IndexType idx = indices[current];
+ CAFFE_ENFORCE(
+ 0 <= idx && idx < N,
+ "Index ",
+ current,
+ " is out of bounds: ",
+ idx,
+ ", range 0 to ",
+ N);
+ ++current;
+ }
+ }
+ CAFFE_ENFORCE_EQ(
+ current,
+ indices_size,
+ "Your input seems to be incorrect: the sum of lengths values should be "
+ "the size of the indices tensor, but it appears not.");
+
+ return false;
#endif
// delegate work to perfkernel that branches based on architecture
@@ -188,8 +234,14 @@
#ifdef USE_FBGEMM
private:
std::int64_t last_block_size{-1};
- fbgemm::EmbeddingSpMDMKernelSignature<float, std::int32_t>::Type kernel32_;
- fbgemm::EmbeddingSpMDMKernelSignature<float, std::int64_t>::Type kernel64_;
+ fbgemm::EmbeddingSpMDMKernelSignature<float, std::int32_t>::Type
+ kernel_fp32_i32_;
+ fbgemm::EmbeddingSpMDMKernelSignature<float, std::int64_t>::Type
+ kernel_fp32_i64_;
+ fbgemm::EmbeddingSpMDMKernelSignature<fbgemm::float16, std::int32_t>::Type
+ kernel_fp16_i32_;
+ fbgemm::EmbeddingSpMDMKernelSignature<fbgemm::float16, std::int64_t>::Type
+ kernel_fp16_i64_;
#endif
};