improve sampling time
Differential Revision: D60742125
Pull Request resolved: https://github.com/pytorch/executorch/pull/4644
diff --git a/extension/llm/sampler/sampler.cpp b/extension/llm/sampler/sampler.cpp
index 047526c..6b0f155 100644
--- a/extension/llm/sampler/sampler.cpp
+++ b/extension/llm/sampler/sampler.cpp
@@ -33,6 +33,7 @@
*/
#include <executorch/extension/llm/sampler/sampler.h>
+#include <algorithm>
namespace torch {
namespace executor {
@@ -67,18 +68,6 @@
}
template <typename T>
-static int32_t compare(const void* a, const void* b) {
- ProbIndex<T>* a_ = (ProbIndex<T>*)a;
- ProbIndex<T>* b_ = (ProbIndex<T>*)b;
- if (a_->prob > b_->prob) {
- return -1;
- } else if (a_->prob < b_->prob) {
- return 1;
- }
- return 0;
-}
-
-template <typename T>
int32_t Sampler::sample_topp(T* probabilities, float coin) {
// top-p sampling (or "nucleus sampling") samples from the smallest set of
// tokens that exceed probability topp. This way we never sample tokens that
@@ -100,7 +89,11 @@
n0++;
}
}
- qsort(probindex.get(), n0, sizeof(ProbIndex<T>), compare<T>);
+
+ auto compare = [](const ProbIndex<T>& a, const ProbIndex<T>& b) {
+ return a.prob > b.prob;
+ };
+ std::sort(probindex.get(), probindex.get() + n0, compare);
// truncate the list where cumulative probability exceeds topp
T cumulative_prob = 0;