[PD] Transfer hidden states for mtp when disaggregation (#7242)
This commit is contained in:
@@ -541,6 +541,7 @@ class DecodeTransferQueue:
|
||||
self.metadata_buffers = metadata_buffers
|
||||
self.scheduler = scheduler
|
||||
self.tree_cache = tree_cache
|
||||
self.spec_algorithm = scheduler.spec_algorithm
|
||||
|
||||
def add(self, decode_req: DecodeRequest) -> None:
|
||||
self.queue.append(decode_req)
|
||||
@@ -582,6 +583,7 @@ class DecodeTransferQueue:
|
||||
idx = decode_req.metadata_buffer_index
|
||||
(
|
||||
output_id,
|
||||
output_hidden_states,
|
||||
output_token_logprobs_val,
|
||||
output_token_logprobs_idx,
|
||||
output_top_logprobs_val,
|
||||
@@ -589,7 +591,8 @@ class DecodeTransferQueue:
|
||||
) = self.metadata_buffers.get_buf(idx)
|
||||
|
||||
decode_req.req.output_ids.append(output_id[0].item())
|
||||
|
||||
if not self.spec_algorithm.is_none():
|
||||
decode_req.req.hidden_states_tensor = output_hidden_states
|
||||
if decode_req.req.return_logprob:
|
||||
decode_req.req.output_token_logprobs_val.append(
|
||||
output_token_logprobs_val[0].item()
|
||||
|
||||
@@ -126,15 +126,16 @@ class ScheduleBatchDisaggregationDecodeMixin:
|
||||
)
|
||||
topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk)
|
||||
|
||||
hidden_states_list = [req.hidden_states_tensor for req in self.reqs]
|
||||
hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)
|
||||
|
||||
# local import to avoid circular import
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
||||
|
||||
spec_info = EagleDraftInput(
|
||||
topk_p=topk_p,
|
||||
topk_index=topk_index,
|
||||
hidden_states=torch.ones(
|
||||
(b, model_config.hidden_size), device=self.device
|
||||
),
|
||||
hidden_states=hidden_states,
|
||||
verified_id=self.output_ids,
|
||||
)
|
||||
spec_info.prepare_for_extend(self)
|
||||
|
||||
@@ -393,6 +393,8 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
logits_output.input_token_logprobs = tuple(
|
||||
logits_output.input_token_logprobs.tolist()
|
||||
)
|
||||
|
||||
hidden_state_offset = 0
|
||||
for i, (req, next_token_id) in enumerate(
|
||||
zip(batch.reqs, next_token_ids, strict=True)
|
||||
):
|
||||
@@ -402,6 +404,16 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
req.output_ids.append(next_token_id)
|
||||
self.tree_cache.cache_unfinished_req(req) # update the tree and lock
|
||||
self.disagg_prefill_inflight_queue.append(req)
|
||||
if logits_output.hidden_states is not None:
|
||||
last_hidden_index = (
|
||||
hidden_state_offset + extend_input_len_per_req[i] - 1
|
||||
)
|
||||
req.hidden_states_tensor = (
|
||||
logits_output.hidden_states[last_hidden_index].cpu().clone()
|
||||
)
|
||||
hidden_state_offset += extend_input_len_per_req[i]
|
||||
else:
|
||||
req.hidden_states_tensor = None
|
||||
if req.return_logprob:
|
||||
assert extend_logprob_start_len_per_req is not None
|
||||
assert extend_input_len_per_req is not None
|
||||
|
||||
@@ -88,6 +88,8 @@ class MetadataBuffers:
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
max_top_logprobs_num: int = 128,
|
||||
custom_mem_pool: torch.cuda.MemPool = None,
|
||||
):
|
||||
@@ -104,6 +106,10 @@ class MetadataBuffers:
|
||||
# We transfer the metadata of first output token to decode
|
||||
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
|
||||
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
|
||||
|
||||
self.output_hidden_states = torch.zeros(
|
||||
(size, hidden_size), dtype=dtype, device=device
|
||||
)
|
||||
self.output_token_logprobs_val = torch.zeros(
|
||||
(size, 16), dtype=torch.float32, device=device
|
||||
)
|
||||
@@ -120,6 +126,7 @@ class MetadataBuffers:
|
||||
def get_buf_infos(self):
|
||||
ptrs = [
|
||||
self.output_ids.data_ptr(),
|
||||
self.output_hidden_states.data_ptr(), # TODO: set None to avoid transfer hidden_states when spec_algorithm is None
|
||||
self.output_token_logprobs_val.data_ptr(),
|
||||
self.output_token_logprobs_idx.data_ptr(),
|
||||
self.output_top_logprobs_val.data_ptr(),
|
||||
@@ -127,6 +134,7 @@ class MetadataBuffers:
|
||||
]
|
||||
data_lens = [
|
||||
self.output_ids.nbytes,
|
||||
self.output_hidden_states.nbytes,
|
||||
self.output_token_logprobs_val.nbytes,
|
||||
self.output_token_logprobs_idx.nbytes,
|
||||
self.output_top_logprobs_val.nbytes,
|
||||
@@ -134,6 +142,7 @@ class MetadataBuffers:
|
||||
]
|
||||
item_lens = [
|
||||
self.output_ids[0].nbytes,
|
||||
self.output_hidden_states[0].nbytes,
|
||||
self.output_token_logprobs_val[0].nbytes,
|
||||
self.output_token_logprobs_idx[0].nbytes,
|
||||
self.output_top_logprobs_val[0].nbytes,
|
||||
@@ -144,6 +153,7 @@ class MetadataBuffers:
|
||||
def get_buf(self, idx: int):
|
||||
return (
|
||||
self.output_ids[idx],
|
||||
self.output_hidden_states[idx],
|
||||
self.output_token_logprobs_val[idx],
|
||||
self.output_token_logprobs_idx[idx],
|
||||
self.output_top_logprobs_val[idx],
|
||||
@@ -153,6 +163,10 @@ class MetadataBuffers:
|
||||
def set_buf(self, req: Req):
|
||||
|
||||
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
|
||||
if req.hidden_states_tensor is not None:
|
||||
self.output_hidden_states[req.metadata_buffer_index].copy_(
|
||||
req.hidden_states_tensor
|
||||
)
|
||||
if req.return_logprob:
|
||||
if req.output_token_logprobs_val: # not none or empty list
|
||||
self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
|
||||
|
||||
@@ -584,6 +584,7 @@ class Req:
|
||||
self.output_token_ids_logprobs_idx
|
||||
) = None
|
||||
self.hidden_states: List[List[float]] = []
|
||||
self.hidden_states_tensor = None # Note: use tensor instead of list to transfer hidden_states when PD + MTP
|
||||
|
||||
# Embedding (return values)
|
||||
self.embedding = None
|
||||
|
||||
@@ -627,6 +627,8 @@ class Scheduler(
|
||||
)
|
||||
self.disagg_metadata_buffers = MetadataBuffers(
|
||||
buffer_size,
|
||||
hidden_size=self.model_config.hf_text_config.hidden_size,
|
||||
dtype=self.model_config.dtype,
|
||||
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
|
||||
)
|
||||
|
||||
@@ -677,6 +679,8 @@ class Scheduler(
|
||||
)
|
||||
self.disagg_metadata_buffers = MetadataBuffers(
|
||||
buffer_size,
|
||||
hidden_size=self.model_config.hf_text_config.hidden_size,
|
||||
dtype=self.model_config.dtype,
|
||||
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
|
||||
)
|
||||
|
||||
@@ -1681,13 +1685,15 @@ class Scheduler(
|
||||
# These 2 values are needed for processing the output, but the values can be
|
||||
# modified by overlap schedule. So we have to copy them here so that
|
||||
# we can use the correct values in output processing.
|
||||
if batch.return_logprob:
|
||||
if batch.return_logprob or self.spec_algorithm.is_eagle():
|
||||
extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
|
||||
else:
|
||||
extend_input_len_per_req = None
|
||||
if batch.return_logprob:
|
||||
extend_logprob_start_len_per_req = [
|
||||
req.extend_logprob_start_len for req in batch.reqs
|
||||
]
|
||||
else:
|
||||
extend_input_len_per_req = None
|
||||
extend_logprob_start_len_per_req = None
|
||||
|
||||
ret = GenerationBatchResult(
|
||||
|
||||
Reference in New Issue
Block a user