[PD] Transfer hidden states for mtp when disaggregation (#7242)
This commit is contained in:
@@ -627,6 +627,8 @@ class Scheduler(
|
||||
)
|
||||
self.disagg_metadata_buffers = MetadataBuffers(
|
||||
buffer_size,
|
||||
hidden_size=self.model_config.hf_text_config.hidden_size,
|
||||
dtype=self.model_config.dtype,
|
||||
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
|
||||
)
|
||||
|
||||
@@ -677,6 +679,8 @@ class Scheduler(
|
||||
)
|
||||
self.disagg_metadata_buffers = MetadataBuffers(
|
||||
buffer_size,
|
||||
hidden_size=self.model_config.hf_text_config.hidden_size,
|
||||
dtype=self.model_config.dtype,
|
||||
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
|
||||
)
|
||||
|
||||
@@ -1681,13 +1685,15 @@ class Scheduler(
|
||||
# These 2 values are needed for processing the output, but the values can be
|
||||
# modified by overlap schedule. So we have to copy them here so that
|
||||
# we can use the correct values in output processing.
|
||||
if batch.return_logprob:
|
||||
if batch.return_logprob or self.spec_algorithm.is_eagle():
|
||||
extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
|
||||
else:
|
||||
extend_input_len_per_req = None
|
||||
if batch.return_logprob:
|
||||
extend_logprob_start_len_per_req = [
|
||||
req.extend_logprob_start_len for req in batch.reqs
|
||||
]
|
||||
else:
|
||||
extend_input_len_per_req = None
|
||||
extend_logprob_start_len_per_req = None
|
||||
|
||||
ret = GenerationBatchResult(
|
||||
|
||||
Reference in New Issue
Block a user