diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index c8a0067c0..a71631596 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -579,11 +579,11 @@ class DecodeTransferQueue: idx = decode_req.metadata_buffer_index ( output_id, - output_hidden_states, output_token_logprobs_val, output_token_logprobs_idx, output_top_logprobs_val, output_top_logprobs_idx, + output_hidden_states, ) = self.metadata_buffers.get_buf(idx) decode_req.req.output_ids.append(output_id[0].item()) diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index 92e182dfd..faccd9d3d 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -291,15 +291,21 @@ class MooncakeKVManager(BaseKVManager): dst_aux_ptrs: list[int], dst_aux_index: int, ): - aux_item_len = self.kv_args.aux_item_lens[0] - prefill_aux_addr = ( - self.kv_args.aux_data_ptrs[0] + prefill_aux_index * aux_item_len + src_addr_list = [] + dst_addr_list = [] + length_list = [] + prefill_aux_ptrs = self.kv_args.aux_data_ptrs + prefill_aux_item_lens = self.kv_args.aux_item_lens + for i, dst_aux_ptr in enumerate(dst_aux_ptrs): + length = prefill_aux_item_lens[i] + src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index + dst_addr = dst_aux_ptrs[i] + length * dst_aux_index + src_addr_list.append(src_addr) + dst_addr_list.append(dst_addr) + length_list.append(length) + return self.engine.batch_transfer_sync( + mooncake_session_id, src_addr_list, dst_addr_list, length_list ) - decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len - status = self.engine.transfer_sync( - mooncake_session_id, prefill_aux_addr, decode_aux_addr, aux_item_len - ) - return status def sync_status_to_decode_endpoint( self, remote: str, dst_port: int, room: int, status: int diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py index 4a60ba0f9..96110b6cf 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -107,9 +107,6 @@ class MetadataBuffers: # The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device) - self.output_hidden_states = torch.zeros( - (size, hidden_size), dtype=dtype, device=device - ) self.output_token_logprobs_val = torch.zeros( (size, 16), dtype=torch.float32, device=device ) @@ -122,51 +119,50 @@ class MetadataBuffers: self.output_top_logprobs_idx = torch.zeros( (size, max_top_logprobs_num), dtype=torch.int32, device=device ) + self.output_hidden_states = torch.zeros( + (size, hidden_size), dtype=dtype, device=device + ) def get_buf_infos(self): ptrs = [ self.output_ids.data_ptr(), - self.output_hidden_states.data_ptr(), # TODO: set None to avoid transfer hidden_states when spec_algorithm is None self.output_token_logprobs_val.data_ptr(), self.output_token_logprobs_idx.data_ptr(), self.output_top_logprobs_val.data_ptr(), self.output_top_logprobs_idx.data_ptr(), + self.output_hidden_states.data_ptr(), ] data_lens = [ self.output_ids.nbytes, - self.output_hidden_states.nbytes, self.output_token_logprobs_val.nbytes, self.output_token_logprobs_idx.nbytes, self.output_top_logprobs_val.nbytes, self.output_top_logprobs_idx.nbytes, + self.output_hidden_states.nbytes, ] item_lens = [ self.output_ids[0].nbytes, - self.output_hidden_states[0].nbytes, self.output_token_logprobs_val[0].nbytes, self.output_token_logprobs_idx[0].nbytes, self.output_top_logprobs_val[0].nbytes, self.output_top_logprobs_idx[0].nbytes, + self.output_hidden_states[0].nbytes, ] return ptrs, data_lens, item_lens def get_buf(self, idx: int): return ( self.output_ids[idx], - self.output_hidden_states[idx], self.output_token_logprobs_val[idx], self.output_token_logprobs_idx[idx], self.output_top_logprobs_val[idx], self.output_top_logprobs_idx[idx], + self.output_hidden_states[idx], ) def set_buf(self, req: Req): self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0] - if req.hidden_states_tensor is not None: - self.output_hidden_states[req.metadata_buffer_index].copy_( - req.hidden_states_tensor - ) if req.return_logprob: if req.output_token_logprobs_val: # not none or empty list self.output_token_logprobs_val[req.metadata_buffer_index][0] = ( @@ -189,6 +185,11 @@ class MetadataBuffers: ] = torch.tensor( req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu" ) + # for PD + spec decode + if req.hidden_states_tensor is not None: + self.output_hidden_states[req.metadata_buffer_index].copy_( + req.hidden_states_tensor + ) #########################