Refactor logprob computation to return the real logprob used in sampling (#2664)
This commit is contained in:
@@ -974,12 +974,10 @@ class Scheduler:
|
||||
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
||||
else:
|
||||
# Move next_token_ids and logprobs to cpu
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
if batch.return_logprob:
|
||||
logits_output.next_token_logprobs = (
|
||||
logits_output.next_token_logprobs[
|
||||
torch.arange(len(next_token_ids), device=self.device),
|
||||
next_token_ids,
|
||||
].tolist()
|
||||
logits_output.next_token_logprobs.tolist()
|
||||
)
|
||||
logits_output.input_token_logprobs = (
|
||||
logits_output.input_token_logprobs.tolist()
|
||||
@@ -987,7 +985,6 @@ class Scheduler:
|
||||
logits_output.normalized_prompt_logprobs = (
|
||||
logits_output.normalized_prompt_logprobs.tolist()
|
||||
)
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
|
||||
# Check finish conditions
|
||||
logprob_pt = 0
|
||||
@@ -1064,13 +1061,9 @@ class Scheduler:
|
||||
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
||||
next_token_logprobs = logits_output.next_token_logprobs
|
||||
else:
|
||||
# Move next_token_ids and logprobs to cpu
|
||||
if batch.return_logprob:
|
||||
next_token_logprobs = logits_output.next_token_logprobs[
|
||||
torch.arange(len(next_token_ids), device=self.device),
|
||||
next_token_ids,
|
||||
].tolist()
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
if batch.return_logprob:
|
||||
next_token_logprobs = logits_output.next_token_logprobs.tolist()
|
||||
|
||||
self.token_to_kv_pool.free_group_begin()
|
||||
|
||||
@@ -1095,10 +1088,10 @@ class Scheduler:
|
||||
req.output_token_logprobs_idx.append(next_token_id)
|
||||
if req.top_logprobs_num > 0:
|
||||
req.output_top_logprobs_val.append(
|
||||
logits_output.output_top_logprobs_val[i]
|
||||
logits_output.next_token_top_logprobs_val[i]
|
||||
)
|
||||
req.output_top_logprobs_idx.append(
|
||||
logits_output.output_top_logprobs_idx[i]
|
||||
logits_output.next_token_top_logprobs_idx[i]
|
||||
)
|
||||
|
||||
if req.grammar is not None:
|
||||
@@ -1200,8 +1193,9 @@ class Scheduler:
|
||||
req.output_top_logprobs_idx.extend(
|
||||
output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :]
|
||||
)
|
||||
req.output_top_logprobs_val.append(output.output_top_logprobs_val[i])
|
||||
req.output_top_logprobs_idx.append(output.output_top_logprobs_idx[i])
|
||||
|
||||
req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
|
||||
req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
|
||||
|
||||
return num_input_logprobs
|
||||
|
||||
|
||||
Reference in New Issue
Block a user