From 26e8e58cea8ba5cc5edaef19e72d0ddc4e9f1c1c Mon Sep 17 00:00:00 2001 From: amy-why-3459 Date: Wed, 3 Dec 2025 20:48:45 +0800 Subject: [PATCH] [Core] Encoder separation for Encode-Prefill-Decode Disaggregation (#4176) ### What this PR does / why we need it? Support Encoder separation for Encode-Prefill-Decode Disaggregation - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 Signed-off-by: amy-why-3459 --- vllm_ascend/patch/platform/__init__.py | 1 + .../patch/platform/patch_ec_connector.py | 32 +++++++++++++ vllm_ascend/worker/model_runner_v1.py | 47 +++++++++++++++---- vllm_ascend/worker/worker_v1.py | 2 + 4 files changed, 72 insertions(+), 10 deletions(-) create mode 100644 vllm_ascend/patch/platform/patch_ec_connector.py diff --git a/vllm_ascend/patch/platform/__init__.py b/vllm_ascend/patch/platform/__init__.py index ca24083f..60a54e51 100644 --- a/vllm_ascend/patch/platform/__init__.py +++ b/vllm_ascend/patch/platform/__init__.py @@ -18,6 +18,7 @@ import os import vllm_ascend.patch.platform.patch_config # noqa import vllm_ascend.patch.platform.patch_distributed # noqa +import vllm_ascend.patch.platform.patch_ec_connector # noqa import vllm_ascend.patch.platform.patch_mamba_config # noqa import vllm_ascend.patch.platform.patch_sched_yield # noqa diff --git a/vllm_ascend/patch/platform/patch_ec_connector.py b/vllm_ascend/patch/platform/patch_ec_connector.py new file mode 100644 index 00000000..f0464b75 --- /dev/null +++ b/vllm_ascend/patch/platform/patch_ec_connector.py @@ -0,0 +1,32 @@ +import vllm.distributed.ec_transfer.ec_connector.shared_storage_connector +from safetensors.torch import load_file +from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata +from vllm.distributed.ec_transfer.ec_connector.shared_storage_connector import ( + ECSharedStorageConnector, ECSharedStorageConnectorMetadata) +from vllm.logger import logger + + +class AscendECSharedStorageConnector(ECSharedStorageConnector): + + def start_load_caches(self, encoder_cache, **kwargs) -> None: + metadata: ECConnectorMetadata = self._get_connector_metadata() + assert isinstance(metadata, ECSharedStorageConnectorMetadata) + assert encoder_cache is not None + if metadata is None: + logger.warning(( + "In connector.start_load_caches, ", + "but the connector metadata is None", + )) + return + # Load the EC for each mm data + for mm_data in metadata.mm_datas: + if mm_data.mm_hash in encoder_cache: + continue + filename = self._generate_filename_debug(mm_data.mm_hash) + ec_cache = load_file(filename)["ec_cache"].npu() + encoder_cache[mm_data.mm_hash] = ec_cache + logger.debug("Success load encoder cache for hash %s", + mm_data.mm_hash) + + +vllm.distributed.ec_transfer.ec_connector.shared_storage_connector.ECSharedStorageConnector = AscendECSharedStorageConnector diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 9e28e117..8aff73f1 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -47,6 +47,7 @@ from vllm.compilation.monitor import set_cudagraph_capturing_enabled from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig, get_layers_from_vllm_config) from vllm.distributed import tensor_model_parallel_all_gather +from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 @@ -91,13 +92,16 @@ from vllm.v1.kv_cache_interface import (AttentionSpec, UniformTypeKVCacheSpecs) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, DraftTokenIds, LogprobsTensors, ModelRunnerOutput, - PoolerOutput) + PoolerOutput, + make_empty_encoder_model_runner_output) from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer from vllm.v1.utils import CpuGpuBuffer +from vllm.v1.worker.ec_connector_model_runner_mixin import \ + ECConnectorModelRunnerMixin from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.utils import (AttentionGroup, bind_kv_cache, @@ -268,7 +272,7 @@ class ExecuteModelState(NamedTuple): positions: torch.Tensor -class NPUModelRunner(LoRAModelRunnerMixin): +class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): def __init__(self, vllm_config: VllmConfig, device: torch.device): self.vllm_config = vllm_config @@ -791,6 +795,11 @@ class NPUModelRunner(LoRAModelRunnerMixin): req_ids_to_add.append(req_id) + # If this rank is an EC transfer producer, + # skip updating the states of KV cache blocks. + if has_ec_transfer() and get_ec_transfer().is_producer: + return + # Update the states of the running/resumed requests. is_last_rank = get_pp_group().is_last_rank req_data = scheduler_output.scheduled_cached_reqs @@ -1072,6 +1081,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): output, is_embed=pos_info.is_embed, ) + self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash) def _batch_mm_kwargs_from_scheduler( self, @@ -1597,15 +1607,19 @@ class NPUModelRunner(LoRAModelRunnerMixin): # _prepare_inputs may reorder the batch, so we must gather # multi-modal outputs after that to ensure the correct order if self.is_multimodal_model: - # Run the multimodal encoder if any. - self._execute_mm_encoder(scheduler_output) + with self.maybe_get_ec_connector_output( + scheduler_output, + encoder_cache=self.encoder_cache, + ): + # Run the multimodal encoder if any. + self._execute_mm_encoder(scheduler_output) - # NOTE(woosuk): To unify token ids and soft tokens (vision - # embeddings), we always use embeddings (rather than token ids) - # as input to the multimodal model, even when the input is text. - input_ids = self.input_ids[:total_num_scheduled_tokens] - mm_embeds, is_mm_embed = self._gather_mm_embeddings( - scheduler_output) + # NOTE(woosuk): To unify token ids and soft tokens (vision + # embeddings), we always use embeddings (rather than token ids) + # as input to the multimodal model, even when the input is text. + input_ids = self.input_ids[:total_num_scheduled_tokens] + mm_embeds, is_mm_embed = self._gather_mm_embeddings( + scheduler_output) inputs_embeds = self.model.embed_input_ids( input_ids, @@ -2248,6 +2262,15 @@ class NPUModelRunner(LoRAModelRunnerMixin): with ProfileExecuteDuration().capture_async("prepare input"): self._update_states(scheduler_output) + if has_ec_transfer() and get_ec_transfer().is_producer: + with self.maybe_get_ec_connector_output( + scheduler_output, + encoder_cache=self.encoder_cache, + ): + self._execute_mm_encoder(scheduler_output) + return make_empty_encoder_model_runner_output( + scheduler_output) + if not scheduler_output.total_num_scheduled_tokens: if not has_kv_transfer_group(): logger.debug( @@ -3741,6 +3764,10 @@ class NPUModelRunner(LoRAModelRunnerMixin): KVCacheSpec: A dictionary mapping layer names to their KV cache format. Layers that do not need KV cache are not included. """ + + if has_ec_transfer() and get_ec_transfer().is_producer: + return {} + block_size = self.vllm_config.cache_config.block_size use_mla = self.vllm_config.model_config.use_mla kv_cache_spec: dict[str, KVCacheSpec] = {} diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 849e2654..41b6abb9 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -30,6 +30,7 @@ from torch_npu.profiler import dynamic_profile as dp from vllm.config import VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) +from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import logger @@ -417,6 +418,7 @@ class NPUWorker(WorkerBase): self.parallel_config.decode_context_parallel_size) init_ascend_model_parallel(self.parallel_config) ensure_kv_transfer_initialized(self.vllm_config) + ensure_ec_transfer_initialized(self.vllm_config) def _init_profiler(self): # Torch profiler. Enabled and configured through env vars: