Fix input logprob index (#9841)
Co-authored-by: Sheng Shen <sheng.s@berkeley.edu>
This commit is contained in:
@@ -93,20 +93,21 @@ class SchedulerOutputProcessorMixin:
|
|||||||
# This updates radix so others can match
|
# This updates radix so others can match
|
||||||
self.tree_cache.cache_unfinished_req(req)
|
self.tree_cache.cache_unfinished_req(req)
|
||||||
|
|
||||||
if req.return_logprob:
|
if batch.return_logprob:
|
||||||
assert extend_logprob_start_len_per_req is not None
|
assert extend_logprob_start_len_per_req is not None
|
||||||
assert extend_input_len_per_req is not None
|
assert extend_input_len_per_req is not None
|
||||||
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
||||||
extend_input_len = extend_input_len_per_req[i]
|
extend_input_len = extend_input_len_per_req[i]
|
||||||
num_input_logprobs = extend_input_len - extend_logprob_start_len
|
num_input_logprobs = extend_input_len - extend_logprob_start_len
|
||||||
self.add_logprob_return_values(
|
if req.return_logprob:
|
||||||
i,
|
self.add_logprob_return_values(
|
||||||
req,
|
i,
|
||||||
logprob_pt,
|
req,
|
||||||
next_token_ids,
|
logprob_pt,
|
||||||
num_input_logprobs,
|
next_token_ids,
|
||||||
logits_output,
|
num_input_logprobs,
|
||||||
)
|
logits_output,
|
||||||
|
)
|
||||||
logprob_pt += num_input_logprobs
|
logprob_pt += num_input_logprobs
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@@ -146,7 +147,7 @@ class SchedulerOutputProcessorMixin:
|
|||||||
skip_stream_req = req
|
skip_stream_req = req
|
||||||
|
|
||||||
# Incrementally update input logprobs.
|
# Incrementally update input logprobs.
|
||||||
if req.return_logprob:
|
if batch.return_logprob:
|
||||||
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
||||||
extend_input_len = extend_input_len_per_req[i]
|
extend_input_len = extend_input_len_per_req[i]
|
||||||
if extend_logprob_start_len < extend_input_len:
|
if extend_logprob_start_len < extend_input_len:
|
||||||
@@ -154,14 +155,15 @@ class SchedulerOutputProcessorMixin:
|
|||||||
num_input_logprobs = (
|
num_input_logprobs = (
|
||||||
extend_input_len - extend_logprob_start_len
|
extend_input_len - extend_logprob_start_len
|
||||||
)
|
)
|
||||||
self.add_input_logprob_return_values(
|
if req.return_logprob:
|
||||||
i,
|
self.add_input_logprob_return_values(
|
||||||
req,
|
i,
|
||||||
logits_output,
|
req,
|
||||||
logprob_pt,
|
logits_output,
|
||||||
num_input_logprobs,
|
logprob_pt,
|
||||||
last_prefill_chunk=False,
|
num_input_logprobs,
|
||||||
)
|
last_prefill_chunk=False,
|
||||||
|
)
|
||||||
logprob_pt += num_input_logprobs
|
logprob_pt += num_input_logprobs
|
||||||
|
|
||||||
self.set_next_batch_sampling_info_done(batch)
|
self.set_next_batch_sampling_info_done(batch)
|
||||||
|
|||||||
Reference in New Issue
Block a user