diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 53e9b02e3..fabb7e16e 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -812,11 +812,28 @@ class Scheduler( result.next_token_ids, result.bid, ) - pp_outputs = PPProxyTensors( - { - "next_token_ids": next_token_ids, - } - ) + if self.cur_batch.return_logprob: + pp_outputs = PPProxyTensors( + { + "next_token_ids": next_token_ids, + "extend_input_len_per_req": result.extend_input_len_per_req, + "extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req, + } + | ( + { + f"logits_output.{k}": v + for k, v in result.logits_output.__dict__.items() + } + if result.logits_output is not None + else {} + ) + ) + else: + pp_outputs = PPProxyTensors( + { + "next_token_ids": next_token_ids, + } + ) # send the output from the last round to let the next stage worker run post processing self.pp_group.send_tensor_dict( pp_outputs.tensors, @@ -833,12 +850,25 @@ class Scheduler( ) ) mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"] + logits_output_args = { + k[len("logits_output.") :]: v + for k, v in next_pp_outputs.tensors.items() + if k.startswith("logits_output.") + } + if len(logits_output_args) > 0: + logits_output = LogitsProcessorOutput(**logits_output_args) + else: + logits_output = None output_result = GenerationBatchResult( - logits_output=None, + logits_output=logits_output, pp_hidden_states_proxy_tensors=None, next_token_ids=next_pp_outputs["next_token_ids"], - extend_input_len_per_req=None, - extend_logprob_start_len_per_req=None, + extend_input_len_per_req=next_pp_outputs.tensors.get( + "extend_input_len_per_req", None + ), + extend_logprob_start_len_per_req=next_pp_outputs.tensors.get( + "extend_logprob_start_len_per_req", None + ), bid=bids[next_mb_id], can_run_cuda_graph=result.can_run_cuda_graph, )