blob: b4fc8fe248c1c1518059916177f3461adae43c18 [file] [log] [blame]
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <algorithm>
#include <vector>
class BasicSampler {
public:
int64_t sample(const std::vector<float>& logits) {
// Find the token with the highest log probability.
int64_t max_index =
std::max_element(logits.begin(), logits.end()) - logits.begin();
return max_index;
}
};