[PD][Spec] Fix hidden state transfer for spec decode (#7516)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
@@ -579,11 +579,11 @@ class DecodeTransferQueue:
|
|||||||
idx = decode_req.metadata_buffer_index
|
idx = decode_req.metadata_buffer_index
|
||||||
(
|
(
|
||||||
output_id,
|
output_id,
|
||||||
output_hidden_states,
|
|
||||||
output_token_logprobs_val,
|
output_token_logprobs_val,
|
||||||
output_token_logprobs_idx,
|
output_token_logprobs_idx,
|
||||||
output_top_logprobs_val,
|
output_top_logprobs_val,
|
||||||
output_top_logprobs_idx,
|
output_top_logprobs_idx,
|
||||||
|
output_hidden_states,
|
||||||
) = self.metadata_buffers.get_buf(idx)
|
) = self.metadata_buffers.get_buf(idx)
|
||||||
|
|
||||||
decode_req.req.output_ids.append(output_id[0].item())
|
decode_req.req.output_ids.append(output_id[0].item())
|
||||||
|
|||||||
@@ -291,15 +291,21 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
dst_aux_ptrs: list[int],
|
dst_aux_ptrs: list[int],
|
||||||
dst_aux_index: int,
|
dst_aux_index: int,
|
||||||
):
|
):
|
||||||
aux_item_len = self.kv_args.aux_item_lens[0]
|
src_addr_list = []
|
||||||
prefill_aux_addr = (
|
dst_addr_list = []
|
||||||
self.kv_args.aux_data_ptrs[0] + prefill_aux_index * aux_item_len
|
length_list = []
|
||||||
|
prefill_aux_ptrs = self.kv_args.aux_data_ptrs
|
||||||
|
prefill_aux_item_lens = self.kv_args.aux_item_lens
|
||||||
|
for i, dst_aux_ptr in enumerate(dst_aux_ptrs):
|
||||||
|
length = prefill_aux_item_lens[i]
|
||||||
|
src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
|
||||||
|
dst_addr = dst_aux_ptrs[i] + length * dst_aux_index
|
||||||
|
src_addr_list.append(src_addr)
|
||||||
|
dst_addr_list.append(dst_addr)
|
||||||
|
length_list.append(length)
|
||||||
|
return self.engine.batch_transfer_sync(
|
||||||
|
mooncake_session_id, src_addr_list, dst_addr_list, length_list
|
||||||
)
|
)
|
||||||
decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
|
|
||||||
status = self.engine.transfer_sync(
|
|
||||||
mooncake_session_id, prefill_aux_addr, decode_aux_addr, aux_item_len
|
|
||||||
)
|
|
||||||
return status
|
|
||||||
|
|
||||||
def sync_status_to_decode_endpoint(
|
def sync_status_to_decode_endpoint(
|
||||||
self, remote: str, dst_port: int, room: int, status: int
|
self, remote: str, dst_port: int, room: int, status: int
|
||||||
|
|||||||
@@ -107,9 +107,6 @@ class MetadataBuffers:
|
|||||||
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
|
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
|
||||||
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
|
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
|
||||||
|
|
||||||
self.output_hidden_states = torch.zeros(
|
|
||||||
(size, hidden_size), dtype=dtype, device=device
|
|
||||||
)
|
|
||||||
self.output_token_logprobs_val = torch.zeros(
|
self.output_token_logprobs_val = torch.zeros(
|
||||||
(size, 16), dtype=torch.float32, device=device
|
(size, 16), dtype=torch.float32, device=device
|
||||||
)
|
)
|
||||||
@@ -122,51 +119,50 @@ class MetadataBuffers:
|
|||||||
self.output_top_logprobs_idx = torch.zeros(
|
self.output_top_logprobs_idx = torch.zeros(
|
||||||
(size, max_top_logprobs_num), dtype=torch.int32, device=device
|
(size, max_top_logprobs_num), dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
|
self.output_hidden_states = torch.zeros(
|
||||||
|
(size, hidden_size), dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
def get_buf_infos(self):
|
def get_buf_infos(self):
|
||||||
ptrs = [
|
ptrs = [
|
||||||
self.output_ids.data_ptr(),
|
self.output_ids.data_ptr(),
|
||||||
self.output_hidden_states.data_ptr(), # TODO: set None to avoid transfer hidden_states when spec_algorithm is None
|
|
||||||
self.output_token_logprobs_val.data_ptr(),
|
self.output_token_logprobs_val.data_ptr(),
|
||||||
self.output_token_logprobs_idx.data_ptr(),
|
self.output_token_logprobs_idx.data_ptr(),
|
||||||
self.output_top_logprobs_val.data_ptr(),
|
self.output_top_logprobs_val.data_ptr(),
|
||||||
self.output_top_logprobs_idx.data_ptr(),
|
self.output_top_logprobs_idx.data_ptr(),
|
||||||
|
self.output_hidden_states.data_ptr(),
|
||||||
]
|
]
|
||||||
data_lens = [
|
data_lens = [
|
||||||
self.output_ids.nbytes,
|
self.output_ids.nbytes,
|
||||||
self.output_hidden_states.nbytes,
|
|
||||||
self.output_token_logprobs_val.nbytes,
|
self.output_token_logprobs_val.nbytes,
|
||||||
self.output_token_logprobs_idx.nbytes,
|
self.output_token_logprobs_idx.nbytes,
|
||||||
self.output_top_logprobs_val.nbytes,
|
self.output_top_logprobs_val.nbytes,
|
||||||
self.output_top_logprobs_idx.nbytes,
|
self.output_top_logprobs_idx.nbytes,
|
||||||
|
self.output_hidden_states.nbytes,
|
||||||
]
|
]
|
||||||
item_lens = [
|
item_lens = [
|
||||||
self.output_ids[0].nbytes,
|
self.output_ids[0].nbytes,
|
||||||
self.output_hidden_states[0].nbytes,
|
|
||||||
self.output_token_logprobs_val[0].nbytes,
|
self.output_token_logprobs_val[0].nbytes,
|
||||||
self.output_token_logprobs_idx[0].nbytes,
|
self.output_token_logprobs_idx[0].nbytes,
|
||||||
self.output_top_logprobs_val[0].nbytes,
|
self.output_top_logprobs_val[0].nbytes,
|
||||||
self.output_top_logprobs_idx[0].nbytes,
|
self.output_top_logprobs_idx[0].nbytes,
|
||||||
|
self.output_hidden_states[0].nbytes,
|
||||||
]
|
]
|
||||||
return ptrs, data_lens, item_lens
|
return ptrs, data_lens, item_lens
|
||||||
|
|
||||||
def get_buf(self, idx: int):
|
def get_buf(self, idx: int):
|
||||||
return (
|
return (
|
||||||
self.output_ids[idx],
|
self.output_ids[idx],
|
||||||
self.output_hidden_states[idx],
|
|
||||||
self.output_token_logprobs_val[idx],
|
self.output_token_logprobs_val[idx],
|
||||||
self.output_token_logprobs_idx[idx],
|
self.output_token_logprobs_idx[idx],
|
||||||
self.output_top_logprobs_val[idx],
|
self.output_top_logprobs_val[idx],
|
||||||
self.output_top_logprobs_idx[idx],
|
self.output_top_logprobs_idx[idx],
|
||||||
|
self.output_hidden_states[idx],
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_buf(self, req: Req):
|
def set_buf(self, req: Req):
|
||||||
|
|
||||||
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
|
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
|
||||||
if req.hidden_states_tensor is not None:
|
|
||||||
self.output_hidden_states[req.metadata_buffer_index].copy_(
|
|
||||||
req.hidden_states_tensor
|
|
||||||
)
|
|
||||||
if req.return_logprob:
|
if req.return_logprob:
|
||||||
if req.output_token_logprobs_val: # not none or empty list
|
if req.output_token_logprobs_val: # not none or empty list
|
||||||
self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
|
self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
|
||||||
@@ -189,6 +185,11 @@ class MetadataBuffers:
|
|||||||
] = torch.tensor(
|
] = torch.tensor(
|
||||||
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
|
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
|
||||||
)
|
)
|
||||||
|
# for PD + spec decode
|
||||||
|
if req.hidden_states_tensor is not None:
|
||||||
|
self.output_hidden_states[req.metadata_buffer_index].copy_(
|
||||||
|
req.hidden_states_tensor
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
#########################
|
#########################
|
||||||
|
|||||||
Reference in New Issue
Block a user