support return logprobs for pipeline (#7356)
Co-authored-by: Zhang Kaihong <zhangkaihong.zkh@alibaba-inc.com>
This commit is contained in:
@@ -812,11 +812,28 @@ class Scheduler(
|
||||
result.next_token_ids,
|
||||
result.bid,
|
||||
)
|
||||
pp_outputs = PPProxyTensors(
|
||||
{
|
||||
"next_token_ids": next_token_ids,
|
||||
}
|
||||
)
|
||||
if self.cur_batch.return_logprob:
|
||||
pp_outputs = PPProxyTensors(
|
||||
{
|
||||
"next_token_ids": next_token_ids,
|
||||
"extend_input_len_per_req": result.extend_input_len_per_req,
|
||||
"extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req,
|
||||
}
|
||||
| (
|
||||
{
|
||||
f"logits_output.{k}": v
|
||||
for k, v in result.logits_output.__dict__.items()
|
||||
}
|
||||
if result.logits_output is not None
|
||||
else {}
|
||||
)
|
||||
)
|
||||
else:
|
||||
pp_outputs = PPProxyTensors(
|
||||
{
|
||||
"next_token_ids": next_token_ids,
|
||||
}
|
||||
)
|
||||
# send the output from the last round to let the next stage worker run post processing
|
||||
self.pp_group.send_tensor_dict(
|
||||
pp_outputs.tensors,
|
||||
@@ -833,12 +850,25 @@ class Scheduler(
|
||||
)
|
||||
)
|
||||
mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
|
||||
logits_output_args = {
|
||||
k[len("logits_output.") :]: v
|
||||
for k, v in next_pp_outputs.tensors.items()
|
||||
if k.startswith("logits_output.")
|
||||
}
|
||||
if len(logits_output_args) > 0:
|
||||
logits_output = LogitsProcessorOutput(**logits_output_args)
|
||||
else:
|
||||
logits_output = None
|
||||
output_result = GenerationBatchResult(
|
||||
logits_output=None,
|
||||
logits_output=logits_output,
|
||||
pp_hidden_states_proxy_tensors=None,
|
||||
next_token_ids=next_pp_outputs["next_token_ids"],
|
||||
extend_input_len_per_req=None,
|
||||
extend_logprob_start_len_per_req=None,
|
||||
extend_input_len_per_req=next_pp_outputs.tensors.get(
|
||||
"extend_input_len_per_req", None
|
||||
),
|
||||
extend_logprob_start_len_per_req=next_pp_outputs.tensors.get(
|
||||
"extend_logprob_start_len_per_req", None
|
||||
),
|
||||
bid=bids[next_mb_id],
|
||||
can_run_cuda_graph=result.can_run_cuda_graph,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user