[Core] Encoder separation for Encode-Prefill-Decode Disaggregation (#4176)

### What this PR does / why we need it?
Support Encoder separation for Encode-Prefill-Decode Disaggregation

- vLLM version: v0.11.2
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2

Signed-off-by: amy-why-3459 <wuhaiyan17@huawei.com>
This commit is contained in:
amy-why-3459
2025-12-03 20:48:45 +08:00
committed by GitHub
parent 6ece6660ec
commit 26e8e58cea
4 changed files with 72 additions and 10 deletions

View File

@@ -18,6 +18,7 @@ import os
import vllm_ascend.patch.platform.patch_config # noqa import vllm_ascend.patch.platform.patch_config # noqa
import vllm_ascend.patch.platform.patch_distributed # noqa import vllm_ascend.patch.platform.patch_distributed # noqa
import vllm_ascend.patch.platform.patch_ec_connector # noqa
import vllm_ascend.patch.platform.patch_mamba_config # noqa import vllm_ascend.patch.platform.patch_mamba_config # noqa
import vllm_ascend.patch.platform.patch_sched_yield # noqa import vllm_ascend.patch.platform.patch_sched_yield # noqa

View File

@@ -0,0 +1,32 @@
import vllm.distributed.ec_transfer.ec_connector.shared_storage_connector
from safetensors.torch import load_file
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata
from vllm.distributed.ec_transfer.ec_connector.shared_storage_connector import (
ECSharedStorageConnector, ECSharedStorageConnectorMetadata)
from vllm.logger import logger
class AscendECSharedStorageConnector(ECSharedStorageConnector):
def start_load_caches(self, encoder_cache, **kwargs) -> None:
metadata: ECConnectorMetadata = self._get_connector_metadata()
assert isinstance(metadata, ECSharedStorageConnectorMetadata)
assert encoder_cache is not None
if metadata is None:
logger.warning((
"In connector.start_load_caches, ",
"but the connector metadata is None",
))
return
# Load the EC for each mm data
for mm_data in metadata.mm_datas:
if mm_data.mm_hash in encoder_cache:
continue
filename = self._generate_filename_debug(mm_data.mm_hash)
ec_cache = load_file(filename)["ec_cache"].npu()
encoder_cache[mm_data.mm_hash] = ec_cache
logger.debug("Success load encoder cache for hash %s",
mm_data.mm_hash)
vllm.distributed.ec_transfer.ec_connector.shared_storage_connector.ECSharedStorageConnector = AscendECSharedStorageConnector

View File

@@ -47,6 +47,7 @@ from vllm.compilation.monitor import set_cudagraph_capturing_enabled
from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig, from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig,
get_layers_from_vllm_config) get_layers_from_vllm_config)
from vllm.distributed import tensor_model_parallel_all_gather from vllm.distributed import tensor_model_parallel_all_gather
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
from vllm.distributed.kv_transfer import (get_kv_transfer_group, from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group) has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
@@ -91,13 +92,16 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
UniformTypeKVCacheSpecs) UniformTypeKVCacheSpecs)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
DraftTokenIds, LogprobsTensors, ModelRunnerOutput, DraftTokenIds, LogprobsTensors, ModelRunnerOutput,
PoolerOutput) PoolerOutput,
make_empty_encoder_model_runner_output)
from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
from vllm.v1.utils import CpuGpuBuffer from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.ec_connector_model_runner_mixin import \
ECConnectorModelRunnerMixin
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.worker.utils import (AttentionGroup, bind_kv_cache, from vllm.v1.worker.utils import (AttentionGroup, bind_kv_cache,
@@ -268,7 +272,7 @@ class ExecuteModelState(NamedTuple):
positions: torch.Tensor positions: torch.Tensor
class NPUModelRunner(LoRAModelRunnerMixin): class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
def __init__(self, vllm_config: VllmConfig, device: torch.device): def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config self.vllm_config = vllm_config
@@ -791,6 +795,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
req_ids_to_add.append(req_id) req_ids_to_add.append(req_id)
# If this rank is an EC transfer producer,
# skip updating the states of KV cache blocks.
if has_ec_transfer() and get_ec_transfer().is_producer:
return
# Update the states of the running/resumed requests. # Update the states of the running/resumed requests.
is_last_rank = get_pp_group().is_last_rank is_last_rank = get_pp_group().is_last_rank
req_data = scheduler_output.scheduled_cached_reqs req_data = scheduler_output.scheduled_cached_reqs
@@ -1072,6 +1081,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
output, output,
is_embed=pos_info.is_embed, is_embed=pos_info.is_embed,
) )
self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash)
def _batch_mm_kwargs_from_scheduler( def _batch_mm_kwargs_from_scheduler(
self, self,
@@ -1597,15 +1607,19 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# _prepare_inputs may reorder the batch, so we must gather # _prepare_inputs may reorder the batch, so we must gather
# multi-modal outputs after that to ensure the correct order # multi-modal outputs after that to ensure the correct order
if self.is_multimodal_model: if self.is_multimodal_model:
# Run the multimodal encoder if any. with self.maybe_get_ec_connector_output(
self._execute_mm_encoder(scheduler_output) scheduler_output,
encoder_cache=self.encoder_cache,
):
# Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output)
# NOTE(woosuk): To unify token ids and soft tokens (vision # NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids) # embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text. # as input to the multimodal model, even when the input is text.
input_ids = self.input_ids[:total_num_scheduled_tokens] input_ids = self.input_ids[:total_num_scheduled_tokens]
mm_embeds, is_mm_embed = self._gather_mm_embeddings( mm_embeds, is_mm_embed = self._gather_mm_embeddings(
scheduler_output) scheduler_output)
inputs_embeds = self.model.embed_input_ids( inputs_embeds = self.model.embed_input_ids(
input_ids, input_ids,
@@ -2248,6 +2262,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
with ProfileExecuteDuration().capture_async("prepare input"): with ProfileExecuteDuration().capture_async("prepare input"):
self._update_states(scheduler_output) self._update_states(scheduler_output)
if has_ec_transfer() and get_ec_transfer().is_producer:
with self.maybe_get_ec_connector_output(
scheduler_output,
encoder_cache=self.encoder_cache,
):
self._execute_mm_encoder(scheduler_output)
return make_empty_encoder_model_runner_output(
scheduler_output)
if not scheduler_output.total_num_scheduled_tokens: if not scheduler_output.total_num_scheduled_tokens:
if not has_kv_transfer_group(): if not has_kv_transfer_group():
logger.debug( logger.debug(
@@ -3741,6 +3764,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
KVCacheSpec: A dictionary mapping layer names to their KV cache KVCacheSpec: A dictionary mapping layer names to their KV cache
format. Layers that do not need KV cache are not included. format. Layers that do not need KV cache are not included.
""" """
if has_ec_transfer() and get_ec_transfer().is_producer:
return {}
block_size = self.vllm_config.cache_config.block_size block_size = self.vllm_config.cache_config.block_size
use_mla = self.vllm_config.model_config.use_mla use_mla = self.vllm_config.model_config.use_mla
kv_cache_spec: dict[str, KVCacheSpec] = {} kv_cache_spec: dict[str, KVCacheSpec] = {}

View File

@@ -30,6 +30,7 @@ from torch_npu.profiler import dynamic_profile as dp
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import logger from vllm.logger import logger
@@ -417,6 +418,7 @@ class NPUWorker(WorkerBase):
self.parallel_config.decode_context_parallel_size) self.parallel_config.decode_context_parallel_size)
init_ascend_model_parallel(self.parallel_config) init_ascend_model_parallel(self.parallel_config)
ensure_kv_transfer_initialized(self.vllm_config) ensure_kv_transfer_initialized(self.vllm_config)
ensure_ec_transfer_initialized(self.vllm_config)
def _init_profiler(self): def _init_profiler(self):
# Torch profiler. Enabled and configured through env vars: # Torch profiler. Enabled and configured through env vars: