Refactor logprob computation to return the real logprob used in sampling (#2664)

This commit is contained in:
Lianmin Zheng
2024-12-30 04:51:38 -08:00
committed by GitHub
parent b02da24a5b
commit 9c6ba2484f
9 changed files with 305 additions and 312 deletions

View File

@@ -974,12 +974,10 @@ class Scheduler:
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
else:
# Move next_token_ids and logprobs to cpu
next_token_ids = next_token_ids.tolist()
if batch.return_logprob:
logits_output.next_token_logprobs = (
logits_output.next_token_logprobs[
torch.arange(len(next_token_ids), device=self.device),
next_token_ids,
].tolist()
logits_output.next_token_logprobs.tolist()
)
logits_output.input_token_logprobs = (
logits_output.input_token_logprobs.tolist()
@@ -987,7 +985,6 @@ class Scheduler:
logits_output.normalized_prompt_logprobs = (
logits_output.normalized_prompt_logprobs.tolist()
)
next_token_ids = next_token_ids.tolist()
# Check finish conditions
logprob_pt = 0
@@ -1064,13 +1061,9 @@ class Scheduler:
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
next_token_logprobs = logits_output.next_token_logprobs
else:
# Move next_token_ids and logprobs to cpu
if batch.return_logprob:
next_token_logprobs = logits_output.next_token_logprobs[
torch.arange(len(next_token_ids), device=self.device),
next_token_ids,
].tolist()
next_token_ids = next_token_ids.tolist()
if batch.return_logprob:
next_token_logprobs = logits_output.next_token_logprobs.tolist()
self.token_to_kv_pool.free_group_begin()
@@ -1095,10 +1088,10 @@ class Scheduler:
req.output_token_logprobs_idx.append(next_token_id)
if req.top_logprobs_num > 0:
req.output_top_logprobs_val.append(
logits_output.output_top_logprobs_val[i]
logits_output.next_token_top_logprobs_val[i]
)
req.output_top_logprobs_idx.append(
logits_output.output_top_logprobs_idx[i]
logits_output.next_token_top_logprobs_idx[i]
)
if req.grammar is not None:
@@ -1200,8 +1193,9 @@ class Scheduler:
req.output_top_logprobs_idx.extend(
output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :]
)
req.output_top_logprobs_val.append(output.output_top_logprobs_val[i])
req.output_top_logprobs_idx.append(output.output_top_logprobs_idx[i])
req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
return num_input_logprobs