Rename prefill_token_logprobs -> input_token_logprobs; decode_token_logprobs -> output_token_logprobs (#776)

This commit is contained in:
Lianmin Zheng
2024-07-27 19:50:34 -07:00
committed by GitHub
parent 0a409bd438
commit 30db99b3d9
16 changed files with 188 additions and 184 deletions

View File

@@ -455,7 +455,7 @@ class ModelTpServer:
torch.arange(len(next_token_ids), device=next_token_ids.device),
next_token_ids,
].tolist()
output.prefill_token_logprobs = output.prefill_token_logprobs.tolist()
output.input_token_logprobs = output.input_token_logprobs.tolist()
output.normalized_prompt_logprobs = (
output.normalized_prompt_logprobs.tolist()
)
@@ -481,24 +481,24 @@ class ModelTpServer:
if req.normalized_prompt_logprob is None:
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
if req.prefill_token_logprobs is None:
if req.input_token_logprobs is None:
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
req.prefill_token_logprobs = list(
req.input_token_logprobs = list(
zip(
output.prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
output.input_token_logprobs[pt : pt + req.extend_input_len - 1],
req.input_ids[-req.extend_input_len + 1 :],
)
)
if req.logprob_start_len == 0:
req.prefill_token_logprobs = [
req.input_token_logprobs = [
(None, req.input_ids[0])
] + req.prefill_token_logprobs
] + req.input_token_logprobs
if req.last_update_decode_tokens != 0:
req.decode_token_logprobs.extend(
req.output_token_logprobs.extend(
list(
zip(
output.prefill_token_logprobs[
output.input_token_logprobs[
pt
+ req.extend_input_len
- req.last_update_decode_tokens : pt
@@ -510,21 +510,21 @@ class ModelTpServer:
)
)
req.decode_token_logprobs.append(
req.output_token_logprobs.append(
(output.next_token_logprobs[i], next_token_ids[i])
)
if req.top_logprobs_num > 0:
if req.prefill_top_logprobs is None:
req.prefill_top_logprobs = output.prefill_top_logprobs[i]
if req.input_top_logprobs is None:
req.input_top_logprobs = output.input_top_logprobs[i]
if req.logprob_start_len == 0:
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
req.input_top_logprobs = [None] + req.input_top_logprobs
if req.last_update_decode_tokens != 0:
req.decode_top_logprobs.extend(
output.prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
req.output_top_logprobs.extend(
output.input_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
)
req.decode_top_logprobs.append(output.decode_top_logprobs[i])
req.output_top_logprobs.append(output.output_top_logprobs[i])
def cache_filled_batch(self, batch: Batch):
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
@@ -589,11 +589,11 @@ class ModelTpServer:
req.check_finished()
if req.return_logprob:
req.decode_token_logprobs.append(
req.output_token_logprobs.append(
(next_token_logprobs[i], next_token_id)
)
if req.top_logprobs_num > 0:
req.decode_top_logprobs.append(output.decode_top_logprobs[i])
req.output_top_logprobs.append(output.output_top_logprobs[i])
self.handle_finished_requests(batch)
@@ -645,16 +645,16 @@ class ModelTpServer:
}
if req.return_logprob:
(
meta_info["prefill_token_logprobs"],
meta_info["decode_token_logprobs"],
meta_info["prefill_top_logprobs"],
meta_info["decode_top_logprobs"],
meta_info["input_token_logprobs"],
meta_info["output_token_logprobs"],
meta_info["input_top_logprobs"],
meta_info["output_top_logprobs"],
meta_info["normalized_prompt_logprob"],
) = (
req.prefill_token_logprobs,
req.decode_token_logprobs,
req.prefill_top_logprobs,
req.decode_top_logprobs,
req.input_token_logprobs,
req.output_token_logprobs,
req.input_top_logprobs,
req.output_top_logprobs,
req.normalized_prompt_logprob,
)
output_meta_info.append(meta_info)