fix runner max seq len (#2688)
Summary:
Pull Request resolved: https://github.com/pytorch/executorch/pull/2688
Max seq len arg when not passed uses the max seq len from the model.
THis means, num tokens generated should be equal to kv cache size.
However, generate loop tries to generate one more token because pos, 0 based
index, is taken for num tokens
Reviewed By: mergennachin, digantdesai
Differential Revision: D55369776
fbshipit-source-id: 7beb38177a23449649e96184b0b0a0bb507c199f
diff --git a/examples/models/llama2/runner/runner.cpp b/examples/models/llama2/runner/runner.cpp
index 4db94b8..cbefd69 100644
--- a/examples/models/llama2/runner/runner.cpp
+++ b/examples/models/llama2/runner/runner.cpp
@@ -238,7 +238,7 @@
}
// create a 1xN int tensor with next as value
- while (pos < seq_len) {
+ while (pos + 1 < seq_len) {
// ET_LOG(Info, "Generating step %d...", pos);
// set the current token in the tensor
std::vector<EValue> inputs;
@@ -339,11 +339,11 @@
timers_.inference_end_ms = util::time_in_ms();
printf("\n");
- if (pos == seq_len) {
+ if (pos + 1 == seq_len) {
ET_LOG(Info, "Sequence length (%i tokens) reached!", seq_len);
}
- timers_.printReport(num_prompt_tokens, pos - num_prompt_tokens);
+ timers_.printReport(num_prompt_tokens, (pos + 1) - num_prompt_tokens);
delete[] prompt_tokens;
return Error::Ok;