[PD] Transfer hidden states for mtp when disaggregation (#7242)

This commit is contained in:
Atream
2025-06-20 02:22:47 +08:00
committed by GitHub
parent d20a073bc3
commit 4f838c09cd
6 changed files with 43 additions and 6 deletions

View File

@@ -627,6 +627,8 @@ class Scheduler(
)
self.disagg_metadata_buffers = MetadataBuffers(
buffer_size,
hidden_size=self.model_config.hf_text_config.hidden_size,
dtype=self.model_config.dtype,
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
)
@@ -677,6 +679,8 @@ class Scheduler(
)
self.disagg_metadata_buffers = MetadataBuffers(
buffer_size,
hidden_size=self.model_config.hf_text_config.hidden_size,
dtype=self.model_config.dtype,
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
)
@@ -1681,13 +1685,15 @@ class Scheduler(
# These 2 values are needed for processing the output, but the values can be
# modified by overlap schedule. So we have to copy them here so that
# we can use the correct values in output processing.
if batch.return_logprob:
if batch.return_logprob or self.spec_algorithm.is_eagle():
extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
else:
extend_input_len_per_req = None
if batch.return_logprob:
extend_logprob_start_len_per_req = [
req.extend_logprob_start_len for req in batch.reqs
]
else:
extend_input_len_per_req = None
extend_logprob_start_len_per_req = None
ret = GenerationBatchResult(