Add support for topk metadata transferring for PD (#10616)
Signed-off-by: Shangming Cai <csmthu@gmail.com>
This commit is contained in:
@@ -614,12 +614,16 @@ class DecodeTransferQueue:
|
||||
output_token_logprobs_idx,
|
||||
output_top_logprobs_val,
|
||||
output_top_logprobs_idx,
|
||||
output_topk_p,
|
||||
output_topk_index,
|
||||
output_hidden_states,
|
||||
) = self.metadata_buffers.get_buf(idx)
|
||||
|
||||
decode_req.req.output_ids.append(output_id[0].item())
|
||||
decode_req.req.cached_tokens = cached_tokens[0].item()
|
||||
if not self.spec_algorithm.is_none():
|
||||
decode_req.req.output_topk_p = output_topk_p
|
||||
decode_req.req.output_topk_index = output_topk_index
|
||||
decode_req.req.hidden_states_tensor = output_hidden_states
|
||||
if decode_req.req.return_logprob:
|
||||
decode_req.req.output_token_logprobs_val.append(
|
||||
|
||||
@@ -125,25 +125,33 @@ class ScheduleBatchDisaggregationDecodeMixin:
|
||||
req.grammar.finished = req.finished()
|
||||
self.output_ids = torch.tensor(self.output_ids, device=self.device)
|
||||
|
||||
# Simulate the eagle run. We add mock data to hidden states for the
|
||||
# ease of implementation now meaning the first token will have acc rate
|
||||
# of 0.
|
||||
if not self.spec_algorithm.is_none():
|
||||
# Simulate the eagle run.
|
||||
if self.spec_algorithm.is_eagle():
|
||||
|
||||
b = len(self.reqs)
|
||||
topk_p = torch.arange(
|
||||
b * server_args.speculative_eagle_topk,
|
||||
0,
|
||||
-1,
|
||||
device=self.device,
|
||||
dtype=torch.float32,
|
||||
topk = server_args.speculative_eagle_topk
|
||||
topk_p = torch.stack(
|
||||
[
|
||||
torch.as_tensor(
|
||||
req.output_topk_p[:topk],
|
||||
device=self.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
for req in self.reqs
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
topk_p = topk_p.reshape(b, server_args.speculative_eagle_topk)
|
||||
topk_p /= b * server_args.speculative_eagle_topk
|
||||
topk_index = torch.arange(
|
||||
b * server_args.speculative_eagle_topk, device=self.device
|
||||
topk_index = torch.stack(
|
||||
[
|
||||
torch.as_tensor(
|
||||
req.output_topk_index[:topk],
|
||||
device=self.device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
for req in self.reqs
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk)
|
||||
|
||||
hidden_states_list = [req.hidden_states_tensor for req in self.reqs]
|
||||
hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)
|
||||
|
||||
@@ -421,6 +421,8 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
last_hidden_index = (
|
||||
hidden_state_offset + extend_input_len_per_req[i] - 1
|
||||
)
|
||||
req.output_topk_p = batch.spec_info.topk_p[i]
|
||||
req.output_topk_index = batch.spec_info.topk_index[i]
|
||||
if self.spec_algorithm.is_eagle3():
|
||||
req.hidden_states_tensor = (
|
||||
batch.spec_info.hidden_states[i].cpu().clone()
|
||||
|
||||
@@ -85,7 +85,7 @@ class MetadataBuffers:
|
||||
self,
|
||||
size: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
hidden_states_dtype: torch.dtype,
|
||||
max_top_logprobs_num: int = 128,
|
||||
custom_mem_pool: torch.cuda.MemPool = None,
|
||||
):
|
||||
@@ -122,8 +122,15 @@ class MetadataBuffers:
|
||||
self.output_top_logprobs_idx = torch.zeros(
|
||||
(size, max_top_logprobs_num), dtype=torch.int32, device=device
|
||||
)
|
||||
# For PD + spec decode
|
||||
self.output_topk_p = torch.zeros(
|
||||
(size, 16), dtype=torch.float32, device=device
|
||||
)
|
||||
self.output_topk_index = torch.zeros(
|
||||
(size, 16), dtype=torch.int64, device=device
|
||||
)
|
||||
self.output_hidden_states = torch.zeros(
|
||||
(size, hidden_size), dtype=dtype, device=device
|
||||
(size, hidden_size), dtype=hidden_states_dtype, device=device
|
||||
)
|
||||
|
||||
def get_buf_infos(self):
|
||||
@@ -134,6 +141,8 @@ class MetadataBuffers:
|
||||
self.output_token_logprobs_idx.data_ptr(),
|
||||
self.output_top_logprobs_val.data_ptr(),
|
||||
self.output_top_logprobs_idx.data_ptr(),
|
||||
self.output_topk_p.data_ptr(),
|
||||
self.output_topk_index.data_ptr(),
|
||||
self.output_hidden_states.data_ptr(),
|
||||
]
|
||||
data_lens = [
|
||||
@@ -143,6 +152,8 @@ class MetadataBuffers:
|
||||
self.output_token_logprobs_idx.nbytes,
|
||||
self.output_top_logprobs_val.nbytes,
|
||||
self.output_top_logprobs_idx.nbytes,
|
||||
self.output_topk_p.nbytes,
|
||||
self.output_topk_index.nbytes,
|
||||
self.output_hidden_states.nbytes,
|
||||
]
|
||||
item_lens = [
|
||||
@@ -152,6 +163,8 @@ class MetadataBuffers:
|
||||
self.output_token_logprobs_idx[0].nbytes,
|
||||
self.output_top_logprobs_val[0].nbytes,
|
||||
self.output_top_logprobs_idx[0].nbytes,
|
||||
self.output_topk_p[0].nbytes,
|
||||
self.output_topk_index[0].nbytes,
|
||||
self.output_hidden_states[0].nbytes,
|
||||
]
|
||||
return ptrs, data_lens, item_lens
|
||||
@@ -164,6 +177,8 @@ class MetadataBuffers:
|
||||
self.output_token_logprobs_idx[idx],
|
||||
self.output_top_logprobs_val[idx],
|
||||
self.output_top_logprobs_idx[idx],
|
||||
self.output_topk_p[idx],
|
||||
self.output_topk_index[idx],
|
||||
self.output_hidden_states[idx],
|
||||
)
|
||||
|
||||
@@ -193,8 +208,17 @@ class MetadataBuffers:
|
||||
] = torch.tensor(
|
||||
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
|
||||
)
|
||||
# for PD + spec decode
|
||||
# For PD + spec decode
|
||||
if req.hidden_states_tensor is not None:
|
||||
# speculative_eagle_topk should not be greater than 16 currently
|
||||
topk = req.output_topk_p.size(0)
|
||||
|
||||
self.output_topk_p[req.metadata_buffer_index, :topk].copy_(
|
||||
req.output_topk_p
|
||||
)
|
||||
self.output_topk_index[req.metadata_buffer_index, :topk].copy_(
|
||||
req.output_topk_index
|
||||
)
|
||||
self.output_hidden_states[req.metadata_buffer_index].copy_(
|
||||
req.hidden_states_tensor
|
||||
)
|
||||
|
||||
@@ -607,6 +607,8 @@ class Req:
|
||||
) = None
|
||||
self.hidden_states: List[List[float]] = []
|
||||
self.hidden_states_tensor = None # Note: use tensor instead of list to transfer hidden_states when PD + MTP
|
||||
self.output_topk_p = None
|
||||
self.output_topk_index = None
|
||||
|
||||
# Embedding (return values)
|
||||
self.embedding = None
|
||||
|
||||
@@ -806,7 +806,7 @@ class Scheduler(
|
||||
self.disagg_metadata_buffers = MetadataBuffers(
|
||||
buffer_size,
|
||||
hidden_size=self.model_config.hf_text_config.hidden_size,
|
||||
dtype=self.model_config.dtype,
|
||||
hidden_states_dtype=self.model_config.dtype,
|
||||
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
|
||||
)
|
||||
|
||||
@@ -855,7 +855,7 @@ class Scheduler(
|
||||
self.disagg_metadata_buffers = MetadataBuffers(
|
||||
buffer_size,
|
||||
hidden_size=self.model_config.hf_text_config.hidden_size,
|
||||
dtype=self.model_config.dtype,
|
||||
hidden_states_dtype=self.model_config.dtype,
|
||||
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user