support return logprobs for pipeline (#7356)
Co-authored-by: Zhang Kaihong <zhangkaihong.zkh@alibaba-inc.com>
This commit is contained in:
@@ -812,6 +812,23 @@ class Scheduler(
|
|||||||
result.next_token_ids,
|
result.next_token_ids,
|
||||||
result.bid,
|
result.bid,
|
||||||
)
|
)
|
||||||
|
if self.cur_batch.return_logprob:
|
||||||
|
pp_outputs = PPProxyTensors(
|
||||||
|
{
|
||||||
|
"next_token_ids": next_token_ids,
|
||||||
|
"extend_input_len_per_req": result.extend_input_len_per_req,
|
||||||
|
"extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req,
|
||||||
|
}
|
||||||
|
| (
|
||||||
|
{
|
||||||
|
f"logits_output.{k}": v
|
||||||
|
for k, v in result.logits_output.__dict__.items()
|
||||||
|
}
|
||||||
|
if result.logits_output is not None
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
pp_outputs = PPProxyTensors(
|
pp_outputs = PPProxyTensors(
|
||||||
{
|
{
|
||||||
"next_token_ids": next_token_ids,
|
"next_token_ids": next_token_ids,
|
||||||
@@ -833,12 +850,25 @@ class Scheduler(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
|
mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
|
||||||
|
logits_output_args = {
|
||||||
|
k[len("logits_output.") :]: v
|
||||||
|
for k, v in next_pp_outputs.tensors.items()
|
||||||
|
if k.startswith("logits_output.")
|
||||||
|
}
|
||||||
|
if len(logits_output_args) > 0:
|
||||||
|
logits_output = LogitsProcessorOutput(**logits_output_args)
|
||||||
|
else:
|
||||||
|
logits_output = None
|
||||||
output_result = GenerationBatchResult(
|
output_result = GenerationBatchResult(
|
||||||
logits_output=None,
|
logits_output=logits_output,
|
||||||
pp_hidden_states_proxy_tensors=None,
|
pp_hidden_states_proxy_tensors=None,
|
||||||
next_token_ids=next_pp_outputs["next_token_ids"],
|
next_token_ids=next_pp_outputs["next_token_ids"],
|
||||||
extend_input_len_per_req=None,
|
extend_input_len_per_req=next_pp_outputs.tensors.get(
|
||||||
extend_logprob_start_len_per_req=None,
|
"extend_input_len_per_req", None
|
||||||
|
),
|
||||||
|
extend_logprob_start_len_per_req=next_pp_outputs.tensors.get(
|
||||||
|
"extend_logprob_start_len_per_req", None
|
||||||
|
),
|
||||||
bid=bids[next_mb_id],
|
bid=bids[next_mb_id],
|
||||||
can_run_cuda_graph=result.can_run_cuda_graph,
|
can_run_cuda_graph=result.can_run_cuda_graph,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user