From e23e280e1623c6b16fa0b45f77942c31782b878f Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Sun, 28 Sep 2025 00:09:38 +0800 Subject: [PATCH] Add support for topk metadata transferring for PD (#10616) Signed-off-by: Shangming Cai --- python/sglang/srt/disaggregation/decode.py | 4 ++ .../decode_schedule_batch_mixin.py | 38 +++++++++++-------- python/sglang/srt/disaggregation/prefill.py | 2 + python/sglang/srt/disaggregation/utils.py | 30 +++++++++++++-- python/sglang/srt/managers/schedule_batch.py | 2 + python/sglang/srt/managers/scheduler.py | 4 +- 6 files changed, 60 insertions(+), 20 deletions(-) diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 32128f480..1db475f15 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -614,12 +614,16 @@ class DecodeTransferQueue: output_token_logprobs_idx, output_top_logprobs_val, output_top_logprobs_idx, + output_topk_p, + output_topk_index, output_hidden_states, ) = self.metadata_buffers.get_buf(idx) decode_req.req.output_ids.append(output_id[0].item()) decode_req.req.cached_tokens = cached_tokens[0].item() if not self.spec_algorithm.is_none(): + decode_req.req.output_topk_p = output_topk_p + decode_req.req.output_topk_index = output_topk_index decode_req.req.hidden_states_tensor = output_hidden_states if decode_req.req.return_logprob: decode_req.req.output_token_logprobs_val.append( diff --git a/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py b/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py index be0383eec..e2ae55780 100644 --- a/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py +++ b/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py @@ -125,25 +125,33 @@ class ScheduleBatchDisaggregationDecodeMixin: req.grammar.finished = req.finished() self.output_ids = torch.tensor(self.output_ids, device=self.device) - # Simulate the eagle run. We add mock data to hidden states for the - # ease of implementation now meaning the first token will have acc rate - # of 0. - if not self.spec_algorithm.is_none(): + # Simulate the eagle run. + if self.spec_algorithm.is_eagle(): b = len(self.reqs) - topk_p = torch.arange( - b * server_args.speculative_eagle_topk, - 0, - -1, - device=self.device, - dtype=torch.float32, + topk = server_args.speculative_eagle_topk + topk_p = torch.stack( + [ + torch.as_tensor( + req.output_topk_p[:topk], + device=self.device, + dtype=torch.float32, + ) + for req in self.reqs + ], + dim=0, ) - topk_p = topk_p.reshape(b, server_args.speculative_eagle_topk) - topk_p /= b * server_args.speculative_eagle_topk - topk_index = torch.arange( - b * server_args.speculative_eagle_topk, device=self.device + topk_index = torch.stack( + [ + torch.as_tensor( + req.output_topk_index[:topk], + device=self.device, + dtype=torch.int64, + ) + for req in self.reqs + ], + dim=0, ) - topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk) hidden_states_list = [req.hidden_states_tensor for req in self.reqs] hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device) diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 5b9255e31..3f794ea3a 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -421,6 +421,8 @@ class SchedulerDisaggregationPrefillMixin: last_hidden_index = ( hidden_state_offset + extend_input_len_per_req[i] - 1 ) + req.output_topk_p = batch.spec_info.topk_p[i] + req.output_topk_index = batch.spec_info.topk_index[i] if self.spec_algorithm.is_eagle3(): req.hidden_states_tensor = ( batch.spec_info.hidden_states[i].cpu().clone() diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py index 1ea1cc6c6..fe4e7fb9f 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -85,7 +85,7 @@ class MetadataBuffers: self, size: int, hidden_size: int, - dtype: torch.dtype, + hidden_states_dtype: torch.dtype, max_top_logprobs_num: int = 128, custom_mem_pool: torch.cuda.MemPool = None, ): @@ -122,8 +122,15 @@ class MetadataBuffers: self.output_top_logprobs_idx = torch.zeros( (size, max_top_logprobs_num), dtype=torch.int32, device=device ) + # For PD + spec decode + self.output_topk_p = torch.zeros( + (size, 16), dtype=torch.float32, device=device + ) + self.output_topk_index = torch.zeros( + (size, 16), dtype=torch.int64, device=device + ) self.output_hidden_states = torch.zeros( - (size, hidden_size), dtype=dtype, device=device + (size, hidden_size), dtype=hidden_states_dtype, device=device ) def get_buf_infos(self): @@ -134,6 +141,8 @@ class MetadataBuffers: self.output_token_logprobs_idx.data_ptr(), self.output_top_logprobs_val.data_ptr(), self.output_top_logprobs_idx.data_ptr(), + self.output_topk_p.data_ptr(), + self.output_topk_index.data_ptr(), self.output_hidden_states.data_ptr(), ] data_lens = [ @@ -143,6 +152,8 @@ class MetadataBuffers: self.output_token_logprobs_idx.nbytes, self.output_top_logprobs_val.nbytes, self.output_top_logprobs_idx.nbytes, + self.output_topk_p.nbytes, + self.output_topk_index.nbytes, self.output_hidden_states.nbytes, ] item_lens = [ @@ -152,6 +163,8 @@ class MetadataBuffers: self.output_token_logprobs_idx[0].nbytes, self.output_top_logprobs_val[0].nbytes, self.output_top_logprobs_idx[0].nbytes, + self.output_topk_p[0].nbytes, + self.output_topk_index[0].nbytes, self.output_hidden_states[0].nbytes, ] return ptrs, data_lens, item_lens @@ -164,6 +177,8 @@ class MetadataBuffers: self.output_token_logprobs_idx[idx], self.output_top_logprobs_val[idx], self.output_top_logprobs_idx[idx], + self.output_topk_p[idx], + self.output_topk_index[idx], self.output_hidden_states[idx], ) @@ -193,8 +208,17 @@ class MetadataBuffers: ] = torch.tensor( req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu" ) - # for PD + spec decode + # For PD + spec decode if req.hidden_states_tensor is not None: + # speculative_eagle_topk should not be greater than 16 currently + topk = req.output_topk_p.size(0) + + self.output_topk_p[req.metadata_buffer_index, :topk].copy_( + req.output_topk_p + ) + self.output_topk_index[req.metadata_buffer_index, :topk].copy_( + req.output_topk_index + ) self.output_hidden_states[req.metadata_buffer_index].copy_( req.hidden_states_tensor ) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 3a3a6b06b..6457307c1 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -607,6 +607,8 @@ class Req: ) = None self.hidden_states: List[List[float]] = [] self.hidden_states_tensor = None # Note: use tensor instead of list to transfer hidden_states when PD + MTP + self.output_topk_p = None + self.output_topk_index = None # Embedding (return values) self.embedding = None diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 4404e1fc6..94cd8e16f 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -806,7 +806,7 @@ class Scheduler( self.disagg_metadata_buffers = MetadataBuffers( buffer_size, hidden_size=self.model_config.hf_text_config.hidden_size, - dtype=self.model_config.dtype, + hidden_states_dtype=self.model_config.dtype, custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(), ) @@ -855,7 +855,7 @@ class Scheduler( self.disagg_metadata_buffers = MetadataBuffers( buffer_size, hidden_size=self.model_config.hf_text_config.hidden_size, - dtype=self.model_config.dtype, + hidden_states_dtype=self.model_config.dtype, custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(), )