Adjust max prefix len (#980)

This commit is contained in:
Liangsheng Yin
2024-08-07 17:41:26 -07:00
committed by GitHub
parent 7623091d97
commit 2b8257f325
2 changed files with 14 additions and 21 deletions

View File

@@ -163,11 +163,21 @@ class Req:
return self.finished_reason is not None
def adjust_max_prefix_ids(self):
max_prefix_ids = self.input_ids
if self.return_logprob:
max_prefix_ids = self.input_ids[: self.logprob_start_len]
input_len = len(self.input_ids)
max_prefix_len = input_len
return max_prefix_ids
if self.sampling_params.max_new_tokens > 0:
# Need at least one token to compute logits
max_prefix_len = min(max_prefix_len, input_len - 1)
if self.return_logprob:
max_prefix_len = min(max_prefix_len, self.logprob_start_len)
if self.normalized_prompt_logprob is None:
# Need at least two tokens to compute normalized logprob
max_prefix_len = min(max_prefix_len, input_len - 2)
return self.input_ids[:max_prefix_len]
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
def init_incremental_detokenize(self):

View File

@@ -387,23 +387,6 @@ class ModelTpServer:
for req in self.waiting_queue:
# FIXME: Move this code into adjust_max_prefix_len
if req.return_logprob and req.normalized_prompt_logprob is None:
# Need at least two tokens to compute normalized logprob
if req.extend_input_len < 2:
delta = 2 - req.extend_input_len
req.extend_input_len += delta
req.prefix_indices = req.prefix_indices[:-delta]
if req.image_offset is not None:
req.image_offset += delta
if req.extend_input_len == 0 and req.sampling_params.max_new_tokens > 0:
# Need at least one token to compute logits
req.extend_input_len = 1
req.prefix_indices = req.prefix_indices[:-1]
if req.image_offset is not None:
req.image_offset += 1
res = adder.add_one_req(req)
if (
not res