[PD][Spec] Fix hidden state transfer for spec decode (#7516)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
@@ -579,11 +579,11 @@ class DecodeTransferQueue:
|
||||
idx = decode_req.metadata_buffer_index
|
||||
(
|
||||
output_id,
|
||||
output_hidden_states,
|
||||
output_token_logprobs_val,
|
||||
output_token_logprobs_idx,
|
||||
output_top_logprobs_val,
|
||||
output_top_logprobs_idx,
|
||||
output_hidden_states,
|
||||
) = self.metadata_buffers.get_buf(idx)
|
||||
|
||||
decode_req.req.output_ids.append(output_id[0].item())
|
||||
|
||||
@@ -291,15 +291,21 @@ class MooncakeKVManager(BaseKVManager):
|
||||
dst_aux_ptrs: list[int],
|
||||
dst_aux_index: int,
|
||||
):
|
||||
aux_item_len = self.kv_args.aux_item_lens[0]
|
||||
prefill_aux_addr = (
|
||||
self.kv_args.aux_data_ptrs[0] + prefill_aux_index * aux_item_len
|
||||
src_addr_list = []
|
||||
dst_addr_list = []
|
||||
length_list = []
|
||||
prefill_aux_ptrs = self.kv_args.aux_data_ptrs
|
||||
prefill_aux_item_lens = self.kv_args.aux_item_lens
|
||||
for i, dst_aux_ptr in enumerate(dst_aux_ptrs):
|
||||
length = prefill_aux_item_lens[i]
|
||||
src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
|
||||
dst_addr = dst_aux_ptrs[i] + length * dst_aux_index
|
||||
src_addr_list.append(src_addr)
|
||||
dst_addr_list.append(dst_addr)
|
||||
length_list.append(length)
|
||||
return self.engine.batch_transfer_sync(
|
||||
mooncake_session_id, src_addr_list, dst_addr_list, length_list
|
||||
)
|
||||
decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
|
||||
status = self.engine.transfer_sync(
|
||||
mooncake_session_id, prefill_aux_addr, decode_aux_addr, aux_item_len
|
||||
)
|
||||
return status
|
||||
|
||||
def sync_status_to_decode_endpoint(
|
||||
self, remote: str, dst_port: int, room: int, status: int
|
||||
|
||||
@@ -107,9 +107,6 @@ class MetadataBuffers:
|
||||
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
|
||||
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
|
||||
|
||||
self.output_hidden_states = torch.zeros(
|
||||
(size, hidden_size), dtype=dtype, device=device
|
||||
)
|
||||
self.output_token_logprobs_val = torch.zeros(
|
||||
(size, 16), dtype=torch.float32, device=device
|
||||
)
|
||||
@@ -122,51 +119,50 @@ class MetadataBuffers:
|
||||
self.output_top_logprobs_idx = torch.zeros(
|
||||
(size, max_top_logprobs_num), dtype=torch.int32, device=device
|
||||
)
|
||||
self.output_hidden_states = torch.zeros(
|
||||
(size, hidden_size), dtype=dtype, device=device
|
||||
)
|
||||
|
||||
def get_buf_infos(self):
|
||||
ptrs = [
|
||||
self.output_ids.data_ptr(),
|
||||
self.output_hidden_states.data_ptr(), # TODO: set None to avoid transfer hidden_states when spec_algorithm is None
|
||||
self.output_token_logprobs_val.data_ptr(),
|
||||
self.output_token_logprobs_idx.data_ptr(),
|
||||
self.output_top_logprobs_val.data_ptr(),
|
||||
self.output_top_logprobs_idx.data_ptr(),
|
||||
self.output_hidden_states.data_ptr(),
|
||||
]
|
||||
data_lens = [
|
||||
self.output_ids.nbytes,
|
||||
self.output_hidden_states.nbytes,
|
||||
self.output_token_logprobs_val.nbytes,
|
||||
self.output_token_logprobs_idx.nbytes,
|
||||
self.output_top_logprobs_val.nbytes,
|
||||
self.output_top_logprobs_idx.nbytes,
|
||||
self.output_hidden_states.nbytes,
|
||||
]
|
||||
item_lens = [
|
||||
self.output_ids[0].nbytes,
|
||||
self.output_hidden_states[0].nbytes,
|
||||
self.output_token_logprobs_val[0].nbytes,
|
||||
self.output_token_logprobs_idx[0].nbytes,
|
||||
self.output_top_logprobs_val[0].nbytes,
|
||||
self.output_top_logprobs_idx[0].nbytes,
|
||||
self.output_hidden_states[0].nbytes,
|
||||
]
|
||||
return ptrs, data_lens, item_lens
|
||||
|
||||
def get_buf(self, idx: int):
|
||||
return (
|
||||
self.output_ids[idx],
|
||||
self.output_hidden_states[idx],
|
||||
self.output_token_logprobs_val[idx],
|
||||
self.output_token_logprobs_idx[idx],
|
||||
self.output_top_logprobs_val[idx],
|
||||
self.output_top_logprobs_idx[idx],
|
||||
self.output_hidden_states[idx],
|
||||
)
|
||||
|
||||
def set_buf(self, req: Req):
|
||||
|
||||
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
|
||||
if req.hidden_states_tensor is not None:
|
||||
self.output_hidden_states[req.metadata_buffer_index].copy_(
|
||||
req.hidden_states_tensor
|
||||
)
|
||||
if req.return_logprob:
|
||||
if req.output_token_logprobs_val: # not none or empty list
|
||||
self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
|
||||
@@ -189,6 +185,11 @@ class MetadataBuffers:
|
||||
] = torch.tensor(
|
||||
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
|
||||
)
|
||||
# for PD + spec decode
|
||||
if req.hidden_states_tensor is not None:
|
||||
self.output_hidden_states[req.metadata_buffer_index].copy_(
|
||||
req.hidden_states_tensor
|
||||
)
|
||||
|
||||
|
||||
#########################
|
||||
|
||||
Reference in New Issue
Block a user