diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 12781784e..b31bb77f8 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -541,6 +541,7 @@ class DecodeTransferQueue: self.metadata_buffers = metadata_buffers self.scheduler = scheduler self.tree_cache = tree_cache + self.spec_algorithm = scheduler.spec_algorithm def add(self, decode_req: DecodeRequest) -> None: self.queue.append(decode_req) @@ -582,6 +583,7 @@ class DecodeTransferQueue: idx = decode_req.metadata_buffer_index ( output_id, + output_hidden_states, output_token_logprobs_val, output_token_logprobs_idx, output_top_logprobs_val, @@ -589,7 +591,8 @@ class DecodeTransferQueue: ) = self.metadata_buffers.get_buf(idx) decode_req.req.output_ids.append(output_id[0].item()) - + if not self.spec_algorithm.is_none(): + decode_req.req.hidden_states_tensor = output_hidden_states if decode_req.req.return_logprob: decode_req.req.output_token_logprobs_val.append( output_token_logprobs_val[0].item() diff --git a/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py b/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py index 38e936106..e1d6f61cc 100644 --- a/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py +++ b/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py @@ -126,15 +126,16 @@ class ScheduleBatchDisaggregationDecodeMixin: ) topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk) + hidden_states_list = [req.hidden_states_tensor for req in self.reqs] + hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device) + # local import to avoid circular import from sglang.srt.speculative.eagle_utils import EagleDraftInput spec_info = EagleDraftInput( topk_p=topk_p, topk_index=topk_index, - hidden_states=torch.ones( - (b, model_config.hidden_size), device=self.device - ), + hidden_states=hidden_states, verified_id=self.output_ids, ) spec_info.prepare_for_extend(self) diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 2bb6312eb..3fc27f841 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -393,6 +393,8 @@ class SchedulerDisaggregationPrefillMixin: logits_output.input_token_logprobs = tuple( logits_output.input_token_logprobs.tolist() ) + + hidden_state_offset = 0 for i, (req, next_token_id) in enumerate( zip(batch.reqs, next_token_ids, strict=True) ): @@ -402,6 +404,16 @@ class SchedulerDisaggregationPrefillMixin: req.output_ids.append(next_token_id) self.tree_cache.cache_unfinished_req(req) # update the tree and lock self.disagg_prefill_inflight_queue.append(req) + if logits_output.hidden_states is not None: + last_hidden_index = ( + hidden_state_offset + extend_input_len_per_req[i] - 1 + ) + req.hidden_states_tensor = ( + logits_output.hidden_states[last_hidden_index].cpu().clone() + ) + hidden_state_offset += extend_input_len_per_req[i] + else: + req.hidden_states_tensor = None if req.return_logprob: assert extend_logprob_start_len_per_req is not None assert extend_input_len_per_req is not None diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py index cd46846ff..4a60ba0f9 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -88,6 +88,8 @@ class MetadataBuffers: def __init__( self, size: int, + hidden_size: int, + dtype: torch.dtype, max_top_logprobs_num: int = 128, custom_mem_pool: torch.cuda.MemPool = None, ): @@ -104,6 +106,10 @@ class MetadataBuffers: # We transfer the metadata of first output token to decode # The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device) + + self.output_hidden_states = torch.zeros( + (size, hidden_size), dtype=dtype, device=device + ) self.output_token_logprobs_val = torch.zeros( (size, 16), dtype=torch.float32, device=device ) @@ -120,6 +126,7 @@ class MetadataBuffers: def get_buf_infos(self): ptrs = [ self.output_ids.data_ptr(), + self.output_hidden_states.data_ptr(), # TODO: set None to avoid transfer hidden_states when spec_algorithm is None self.output_token_logprobs_val.data_ptr(), self.output_token_logprobs_idx.data_ptr(), self.output_top_logprobs_val.data_ptr(), @@ -127,6 +134,7 @@ class MetadataBuffers: ] data_lens = [ self.output_ids.nbytes, + self.output_hidden_states.nbytes, self.output_token_logprobs_val.nbytes, self.output_token_logprobs_idx.nbytes, self.output_top_logprobs_val.nbytes, @@ -134,6 +142,7 @@ class MetadataBuffers: ] item_lens = [ self.output_ids[0].nbytes, + self.output_hidden_states[0].nbytes, self.output_token_logprobs_val[0].nbytes, self.output_token_logprobs_idx[0].nbytes, self.output_top_logprobs_val[0].nbytes, @@ -144,6 +153,7 @@ class MetadataBuffers: def get_buf(self, idx: int): return ( self.output_ids[idx], + self.output_hidden_states[idx], self.output_token_logprobs_val[idx], self.output_token_logprobs_idx[idx], self.output_top_logprobs_val[idx], @@ -153,6 +163,10 @@ class MetadataBuffers: def set_buf(self, req: Req): self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0] + if req.hidden_states_tensor is not None: + self.output_hidden_states[req.metadata_buffer_index].copy_( + req.hidden_states_tensor + ) if req.return_logprob: if req.output_token_logprobs_val: # not none or empty list self.output_token_logprobs_val[req.metadata_buffer_index][0] = ( diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 74a396aa3..6143c5575 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -584,6 +584,7 @@ class Req: self.output_token_ids_logprobs_idx ) = None self.hidden_states: List[List[float]] = [] + self.hidden_states_tensor = None # Note: use tensor instead of list to transfer hidden_states when PD + MTP # Embedding (return values) self.embedding = None diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index b2484f090..a0c2997fe 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -627,6 +627,8 @@ class Scheduler( ) self.disagg_metadata_buffers = MetadataBuffers( buffer_size, + hidden_size=self.model_config.hf_text_config.hidden_size, + dtype=self.model_config.dtype, custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(), ) @@ -677,6 +679,8 @@ class Scheduler( ) self.disagg_metadata_buffers = MetadataBuffers( buffer_size, + hidden_size=self.model_config.hf_text_config.hidden_size, + dtype=self.model_config.dtype, custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(), ) @@ -1681,13 +1685,15 @@ class Scheduler( # These 2 values are needed for processing the output, but the values can be # modified by overlap schedule. So we have to copy them here so that # we can use the correct values in output processing. - if batch.return_logprob: + if batch.return_logprob or self.spec_algorithm.is_eagle(): extend_input_len_per_req = [req.extend_input_len for req in batch.reqs] + else: + extend_input_len_per_req = None + if batch.return_logprob: extend_logprob_start_len_per_req = [ req.extend_logprob_start_len for req in batch.reqs ] else: - extend_input_len_per_req = None extend_logprob_start_len_per_req = None ret = GenerationBatchResult(