Rename prefill_token_logprobs -> input_token_logprobs; decode_token_logprobs -> output_token_logprobs (#776)
This commit is contained in:
@@ -455,7 +455,7 @@ class ModelTpServer:
|
||||
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
||||
next_token_ids,
|
||||
].tolist()
|
||||
output.prefill_token_logprobs = output.prefill_token_logprobs.tolist()
|
||||
output.input_token_logprobs = output.input_token_logprobs.tolist()
|
||||
output.normalized_prompt_logprobs = (
|
||||
output.normalized_prompt_logprobs.tolist()
|
||||
)
|
||||
@@ -481,24 +481,24 @@ class ModelTpServer:
|
||||
if req.normalized_prompt_logprob is None:
|
||||
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
|
||||
|
||||
if req.prefill_token_logprobs is None:
|
||||
if req.input_token_logprobs is None:
|
||||
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
||||
req.prefill_token_logprobs = list(
|
||||
req.input_token_logprobs = list(
|
||||
zip(
|
||||
output.prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
|
||||
output.input_token_logprobs[pt : pt + req.extend_input_len - 1],
|
||||
req.input_ids[-req.extend_input_len + 1 :],
|
||||
)
|
||||
)
|
||||
if req.logprob_start_len == 0:
|
||||
req.prefill_token_logprobs = [
|
||||
req.input_token_logprobs = [
|
||||
(None, req.input_ids[0])
|
||||
] + req.prefill_token_logprobs
|
||||
] + req.input_token_logprobs
|
||||
|
||||
if req.last_update_decode_tokens != 0:
|
||||
req.decode_token_logprobs.extend(
|
||||
req.output_token_logprobs.extend(
|
||||
list(
|
||||
zip(
|
||||
output.prefill_token_logprobs[
|
||||
output.input_token_logprobs[
|
||||
pt
|
||||
+ req.extend_input_len
|
||||
- req.last_update_decode_tokens : pt
|
||||
@@ -510,21 +510,21 @@ class ModelTpServer:
|
||||
)
|
||||
)
|
||||
|
||||
req.decode_token_logprobs.append(
|
||||
req.output_token_logprobs.append(
|
||||
(output.next_token_logprobs[i], next_token_ids[i])
|
||||
)
|
||||
|
||||
if req.top_logprobs_num > 0:
|
||||
if req.prefill_top_logprobs is None:
|
||||
req.prefill_top_logprobs = output.prefill_top_logprobs[i]
|
||||
if req.input_top_logprobs is None:
|
||||
req.input_top_logprobs = output.input_top_logprobs[i]
|
||||
if req.logprob_start_len == 0:
|
||||
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
|
||||
req.input_top_logprobs = [None] + req.input_top_logprobs
|
||||
|
||||
if req.last_update_decode_tokens != 0:
|
||||
req.decode_top_logprobs.extend(
|
||||
output.prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
|
||||
req.output_top_logprobs.extend(
|
||||
output.input_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
|
||||
)
|
||||
req.decode_top_logprobs.append(output.decode_top_logprobs[i])
|
||||
req.output_top_logprobs.append(output.output_top_logprobs[i])
|
||||
|
||||
def cache_filled_batch(self, batch: Batch):
|
||||
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
|
||||
@@ -589,11 +589,11 @@ class ModelTpServer:
|
||||
req.check_finished()
|
||||
|
||||
if req.return_logprob:
|
||||
req.decode_token_logprobs.append(
|
||||
req.output_token_logprobs.append(
|
||||
(next_token_logprobs[i], next_token_id)
|
||||
)
|
||||
if req.top_logprobs_num > 0:
|
||||
req.decode_top_logprobs.append(output.decode_top_logprobs[i])
|
||||
req.output_top_logprobs.append(output.output_top_logprobs[i])
|
||||
|
||||
self.handle_finished_requests(batch)
|
||||
|
||||
@@ -645,16 +645,16 @@ class ModelTpServer:
|
||||
}
|
||||
if req.return_logprob:
|
||||
(
|
||||
meta_info["prefill_token_logprobs"],
|
||||
meta_info["decode_token_logprobs"],
|
||||
meta_info["prefill_top_logprobs"],
|
||||
meta_info["decode_top_logprobs"],
|
||||
meta_info["input_token_logprobs"],
|
||||
meta_info["output_token_logprobs"],
|
||||
meta_info["input_top_logprobs"],
|
||||
meta_info["output_top_logprobs"],
|
||||
meta_info["normalized_prompt_logprob"],
|
||||
) = (
|
||||
req.prefill_token_logprobs,
|
||||
req.decode_token_logprobs,
|
||||
req.prefill_top_logprobs,
|
||||
req.decode_top_logprobs,
|
||||
req.input_token_logprobs,
|
||||
req.output_token_logprobs,
|
||||
req.input_top_logprobs,
|
||||
req.output_top_logprobs,
|
||||
req.normalized_prompt_logprob,
|
||||
)
|
||||
output_meta_info.append(meta_info)
|
||||
|
||||
Reference in New Issue
Block a user