From ceba0ce4f661722198f6568a54ba20cf06b7e033 Mon Sep 17 00:00:00 2001 From: strgrb Date: Fri, 20 Jun 2025 14:50:45 +0800 Subject: [PATCH] support return logprobs for pipeline (#7356) Co-authored-by: Zhang Kaihong --- python/sglang/srt/managers/scheduler.py | 46 ++++++++++++++++++++----- 1 file changed, 38 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 53e9b02e3..fabb7e16e 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -812,11 +812,28 @@ class Scheduler( result.next_token_ids, result.bid, ) - pp_outputs = PPProxyTensors( - { - "next_token_ids": next_token_ids, - } - ) + if self.cur_batch.return_logprob: + pp_outputs = PPProxyTensors( + { + "next_token_ids": next_token_ids, + "extend_input_len_per_req": result.extend_input_len_per_req, + "extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req, + } + | ( + { + f"logits_output.{k}": v + for k, v in result.logits_output.__dict__.items() + } + if result.logits_output is not None + else {} + ) + ) + else: + pp_outputs = PPProxyTensors( + { + "next_token_ids": next_token_ids, + } + ) # send the output from the last round to let the next stage worker run post processing self.pp_group.send_tensor_dict( pp_outputs.tensors, @@ -833,12 +850,25 @@ class Scheduler( ) ) mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"] + logits_output_args = { + k[len("logits_output.") :]: v + for k, v in next_pp_outputs.tensors.items() + if k.startswith("logits_output.") + } + if len(logits_output_args) > 0: + logits_output = LogitsProcessorOutput(**logits_output_args) + else: + logits_output = None output_result = GenerationBatchResult( - logits_output=None, + logits_output=logits_output, pp_hidden_states_proxy_tensors=None, next_token_ids=next_pp_outputs["next_token_ids"], - extend_input_len_per_req=None, - extend_logprob_start_len_per_req=None, + extend_input_len_per_req=next_pp_outputs.tensors.get( + "extend_input_len_per_req", None + ), + extend_logprob_start_len_per_req=next_pp_outputs.tensors.get( + "extend_logprob_start_len_per_req", None + ), bid=bids[next_mb_id], can_run_cuda_graph=result.can_run_cuda_graph, )