Add support for topk metadata transferring for PD (#10616)
Signed-off-by: Shangming Cai <csmthu@gmail.com>
This commit is contained in:
@@ -614,12 +614,16 @@ class DecodeTransferQueue:
|
|||||||
output_token_logprobs_idx,
|
output_token_logprobs_idx,
|
||||||
output_top_logprobs_val,
|
output_top_logprobs_val,
|
||||||
output_top_logprobs_idx,
|
output_top_logprobs_idx,
|
||||||
|
output_topk_p,
|
||||||
|
output_topk_index,
|
||||||
output_hidden_states,
|
output_hidden_states,
|
||||||
) = self.metadata_buffers.get_buf(idx)
|
) = self.metadata_buffers.get_buf(idx)
|
||||||
|
|
||||||
decode_req.req.output_ids.append(output_id[0].item())
|
decode_req.req.output_ids.append(output_id[0].item())
|
||||||
decode_req.req.cached_tokens = cached_tokens[0].item()
|
decode_req.req.cached_tokens = cached_tokens[0].item()
|
||||||
if not self.spec_algorithm.is_none():
|
if not self.spec_algorithm.is_none():
|
||||||
|
decode_req.req.output_topk_p = output_topk_p
|
||||||
|
decode_req.req.output_topk_index = output_topk_index
|
||||||
decode_req.req.hidden_states_tensor = output_hidden_states
|
decode_req.req.hidden_states_tensor = output_hidden_states
|
||||||
if decode_req.req.return_logprob:
|
if decode_req.req.return_logprob:
|
||||||
decode_req.req.output_token_logprobs_val.append(
|
decode_req.req.output_token_logprobs_val.append(
|
||||||
|
|||||||
@@ -125,25 +125,33 @@ class ScheduleBatchDisaggregationDecodeMixin:
|
|||||||
req.grammar.finished = req.finished()
|
req.grammar.finished = req.finished()
|
||||||
self.output_ids = torch.tensor(self.output_ids, device=self.device)
|
self.output_ids = torch.tensor(self.output_ids, device=self.device)
|
||||||
|
|
||||||
# Simulate the eagle run. We add mock data to hidden states for the
|
# Simulate the eagle run.
|
||||||
# ease of implementation now meaning the first token will have acc rate
|
if self.spec_algorithm.is_eagle():
|
||||||
# of 0.
|
|
||||||
if not self.spec_algorithm.is_none():
|
|
||||||
|
|
||||||
b = len(self.reqs)
|
b = len(self.reqs)
|
||||||
topk_p = torch.arange(
|
topk = server_args.speculative_eagle_topk
|
||||||
b * server_args.speculative_eagle_topk,
|
topk_p = torch.stack(
|
||||||
0,
|
[
|
||||||
-1,
|
torch.as_tensor(
|
||||||
|
req.output_topk_p[:topk],
|
||||||
device=self.device,
|
device=self.device,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
topk_p = topk_p.reshape(b, server_args.speculative_eagle_topk)
|
for req in self.reqs
|
||||||
topk_p /= b * server_args.speculative_eagle_topk
|
],
|
||||||
topk_index = torch.arange(
|
dim=0,
|
||||||
b * server_args.speculative_eagle_topk, device=self.device
|
)
|
||||||
|
topk_index = torch.stack(
|
||||||
|
[
|
||||||
|
torch.as_tensor(
|
||||||
|
req.output_topk_index[:topk],
|
||||||
|
device=self.device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
for req in self.reqs
|
||||||
|
],
|
||||||
|
dim=0,
|
||||||
)
|
)
|
||||||
topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk)
|
|
||||||
|
|
||||||
hidden_states_list = [req.hidden_states_tensor for req in self.reqs]
|
hidden_states_list = [req.hidden_states_tensor for req in self.reqs]
|
||||||
hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)
|
hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)
|
||||||
|
|||||||
@@ -421,6 +421,8 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
last_hidden_index = (
|
last_hidden_index = (
|
||||||
hidden_state_offset + extend_input_len_per_req[i] - 1
|
hidden_state_offset + extend_input_len_per_req[i] - 1
|
||||||
)
|
)
|
||||||
|
req.output_topk_p = batch.spec_info.topk_p[i]
|
||||||
|
req.output_topk_index = batch.spec_info.topk_index[i]
|
||||||
if self.spec_algorithm.is_eagle3():
|
if self.spec_algorithm.is_eagle3():
|
||||||
req.hidden_states_tensor = (
|
req.hidden_states_tensor = (
|
||||||
batch.spec_info.hidden_states[i].cpu().clone()
|
batch.spec_info.hidden_states[i].cpu().clone()
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ class MetadataBuffers:
|
|||||||
self,
|
self,
|
||||||
size: int,
|
size: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
dtype: torch.dtype,
|
hidden_states_dtype: torch.dtype,
|
||||||
max_top_logprobs_num: int = 128,
|
max_top_logprobs_num: int = 128,
|
||||||
custom_mem_pool: torch.cuda.MemPool = None,
|
custom_mem_pool: torch.cuda.MemPool = None,
|
||||||
):
|
):
|
||||||
@@ -122,8 +122,15 @@ class MetadataBuffers:
|
|||||||
self.output_top_logprobs_idx = torch.zeros(
|
self.output_top_logprobs_idx = torch.zeros(
|
||||||
(size, max_top_logprobs_num), dtype=torch.int32, device=device
|
(size, max_top_logprobs_num), dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
|
# For PD + spec decode
|
||||||
|
self.output_topk_p = torch.zeros(
|
||||||
|
(size, 16), dtype=torch.float32, device=device
|
||||||
|
)
|
||||||
|
self.output_topk_index = torch.zeros(
|
||||||
|
(size, 16), dtype=torch.int64, device=device
|
||||||
|
)
|
||||||
self.output_hidden_states = torch.zeros(
|
self.output_hidden_states = torch.zeros(
|
||||||
(size, hidden_size), dtype=dtype, device=device
|
(size, hidden_size), dtype=hidden_states_dtype, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_buf_infos(self):
|
def get_buf_infos(self):
|
||||||
@@ -134,6 +141,8 @@ class MetadataBuffers:
|
|||||||
self.output_token_logprobs_idx.data_ptr(),
|
self.output_token_logprobs_idx.data_ptr(),
|
||||||
self.output_top_logprobs_val.data_ptr(),
|
self.output_top_logprobs_val.data_ptr(),
|
||||||
self.output_top_logprobs_idx.data_ptr(),
|
self.output_top_logprobs_idx.data_ptr(),
|
||||||
|
self.output_topk_p.data_ptr(),
|
||||||
|
self.output_topk_index.data_ptr(),
|
||||||
self.output_hidden_states.data_ptr(),
|
self.output_hidden_states.data_ptr(),
|
||||||
]
|
]
|
||||||
data_lens = [
|
data_lens = [
|
||||||
@@ -143,6 +152,8 @@ class MetadataBuffers:
|
|||||||
self.output_token_logprobs_idx.nbytes,
|
self.output_token_logprobs_idx.nbytes,
|
||||||
self.output_top_logprobs_val.nbytes,
|
self.output_top_logprobs_val.nbytes,
|
||||||
self.output_top_logprobs_idx.nbytes,
|
self.output_top_logprobs_idx.nbytes,
|
||||||
|
self.output_topk_p.nbytes,
|
||||||
|
self.output_topk_index.nbytes,
|
||||||
self.output_hidden_states.nbytes,
|
self.output_hidden_states.nbytes,
|
||||||
]
|
]
|
||||||
item_lens = [
|
item_lens = [
|
||||||
@@ -152,6 +163,8 @@ class MetadataBuffers:
|
|||||||
self.output_token_logprobs_idx[0].nbytes,
|
self.output_token_logprobs_idx[0].nbytes,
|
||||||
self.output_top_logprobs_val[0].nbytes,
|
self.output_top_logprobs_val[0].nbytes,
|
||||||
self.output_top_logprobs_idx[0].nbytes,
|
self.output_top_logprobs_idx[0].nbytes,
|
||||||
|
self.output_topk_p[0].nbytes,
|
||||||
|
self.output_topk_index[0].nbytes,
|
||||||
self.output_hidden_states[0].nbytes,
|
self.output_hidden_states[0].nbytes,
|
||||||
]
|
]
|
||||||
return ptrs, data_lens, item_lens
|
return ptrs, data_lens, item_lens
|
||||||
@@ -164,6 +177,8 @@ class MetadataBuffers:
|
|||||||
self.output_token_logprobs_idx[idx],
|
self.output_token_logprobs_idx[idx],
|
||||||
self.output_top_logprobs_val[idx],
|
self.output_top_logprobs_val[idx],
|
||||||
self.output_top_logprobs_idx[idx],
|
self.output_top_logprobs_idx[idx],
|
||||||
|
self.output_topk_p[idx],
|
||||||
|
self.output_topk_index[idx],
|
||||||
self.output_hidden_states[idx],
|
self.output_hidden_states[idx],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -193,8 +208,17 @@ class MetadataBuffers:
|
|||||||
] = torch.tensor(
|
] = torch.tensor(
|
||||||
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
|
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
|
||||||
)
|
)
|
||||||
# for PD + spec decode
|
# For PD + spec decode
|
||||||
if req.hidden_states_tensor is not None:
|
if req.hidden_states_tensor is not None:
|
||||||
|
# speculative_eagle_topk should not be greater than 16 currently
|
||||||
|
topk = req.output_topk_p.size(0)
|
||||||
|
|
||||||
|
self.output_topk_p[req.metadata_buffer_index, :topk].copy_(
|
||||||
|
req.output_topk_p
|
||||||
|
)
|
||||||
|
self.output_topk_index[req.metadata_buffer_index, :topk].copy_(
|
||||||
|
req.output_topk_index
|
||||||
|
)
|
||||||
self.output_hidden_states[req.metadata_buffer_index].copy_(
|
self.output_hidden_states[req.metadata_buffer_index].copy_(
|
||||||
req.hidden_states_tensor
|
req.hidden_states_tensor
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -607,6 +607,8 @@ class Req:
|
|||||||
) = None
|
) = None
|
||||||
self.hidden_states: List[List[float]] = []
|
self.hidden_states: List[List[float]] = []
|
||||||
self.hidden_states_tensor = None # Note: use tensor instead of list to transfer hidden_states when PD + MTP
|
self.hidden_states_tensor = None # Note: use tensor instead of list to transfer hidden_states when PD + MTP
|
||||||
|
self.output_topk_p = None
|
||||||
|
self.output_topk_index = None
|
||||||
|
|
||||||
# Embedding (return values)
|
# Embedding (return values)
|
||||||
self.embedding = None
|
self.embedding = None
|
||||||
|
|||||||
@@ -806,7 +806,7 @@ class Scheduler(
|
|||||||
self.disagg_metadata_buffers = MetadataBuffers(
|
self.disagg_metadata_buffers = MetadataBuffers(
|
||||||
buffer_size,
|
buffer_size,
|
||||||
hidden_size=self.model_config.hf_text_config.hidden_size,
|
hidden_size=self.model_config.hf_text_config.hidden_size,
|
||||||
dtype=self.model_config.dtype,
|
hidden_states_dtype=self.model_config.dtype,
|
||||||
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
|
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -855,7 +855,7 @@ class Scheduler(
|
|||||||
self.disagg_metadata_buffers = MetadataBuffers(
|
self.disagg_metadata_buffers = MetadataBuffers(
|
||||||
buffer_size,
|
buffer_size,
|
||||||
hidden_size=self.model_config.hf_text_config.hidden_size,
|
hidden_size=self.model_config.hf_text_config.hidden_size,
|
||||||
dtype=self.model_config.dtype,
|
hidden_states_dtype=self.model_config.dtype,
|
||||||
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
|
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user