[Core] Encoder separation for Encode-Prefill-Decode Disaggregation (#4176)

### What this PR does / why we need it?
Support Encoder separation for Encode-Prefill-Decode Disaggregation

- vLLM version: v0.11.2
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2

Signed-off-by: amy-why-3459 <wuhaiyan17@huawei.com>
This commit is contained in:
amy-why-3459
2025-12-03 20:48:45 +08:00
committed by GitHub
parent 6ece6660ec
commit 26e8e58cea
4 changed files with 72 additions and 10 deletions

View File

@@ -18,6 +18,7 @@ import os
import vllm_ascend.patch.platform.patch_config # noqa
import vllm_ascend.patch.platform.patch_distributed # noqa
import vllm_ascend.patch.platform.patch_ec_connector # noqa
import vllm_ascend.patch.platform.patch_mamba_config # noqa
import vllm_ascend.patch.platform.patch_sched_yield # noqa

View File

@@ -0,0 +1,32 @@
import vllm.distributed.ec_transfer.ec_connector.shared_storage_connector
from safetensors.torch import load_file
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata
from vllm.distributed.ec_transfer.ec_connector.shared_storage_connector import (
ECSharedStorageConnector, ECSharedStorageConnectorMetadata)
from vllm.logger import logger
class AscendECSharedStorageConnector(ECSharedStorageConnector):
def start_load_caches(self, encoder_cache, **kwargs) -> None:
metadata: ECConnectorMetadata = self._get_connector_metadata()
assert isinstance(metadata, ECSharedStorageConnectorMetadata)
assert encoder_cache is not None
if metadata is None:
logger.warning((
"In connector.start_load_caches, ",
"but the connector metadata is None",
))
return
# Load the EC for each mm data
for mm_data in metadata.mm_datas:
if mm_data.mm_hash in encoder_cache:
continue
filename = self._generate_filename_debug(mm_data.mm_hash)
ec_cache = load_file(filename)["ec_cache"].npu()
encoder_cache[mm_data.mm_hash] = ec_cache
logger.debug("Success load encoder cache for hash %s",
mm_data.mm_hash)
vllm.distributed.ec_transfer.ec_connector.shared_storage_connector.ECSharedStorageConnector = AscendECSharedStorageConnector

View File

@@ -47,6 +47,7 @@ from vllm.compilation.monitor import set_cudagraph_capturing_enabled
from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig,
get_layers_from_vllm_config)
from vllm.distributed import tensor_model_parallel_all_gather
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
@@ -91,13 +92,16 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
UniformTypeKVCacheSpecs)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
DraftTokenIds, LogprobsTensors, ModelRunnerOutput,
PoolerOutput)
PoolerOutput,
make_empty_encoder_model_runner_output)
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.ec_connector_model_runner_mixin import \
ECConnectorModelRunnerMixin
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.worker.utils import (AttentionGroup, bind_kv_cache,
@@ -268,7 +272,7 @@ class ExecuteModelState(NamedTuple):
positions: torch.Tensor
class NPUModelRunner(LoRAModelRunnerMixin):
class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config
@@ -791,6 +795,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
req_ids_to_add.append(req_id)
# If this rank is an EC transfer producer,
# skip updating the states of KV cache blocks.
if has_ec_transfer() and get_ec_transfer().is_producer:
return
# Update the states of the running/resumed requests.
is_last_rank = get_pp_group().is_last_rank
req_data = scheduler_output.scheduled_cached_reqs
@@ -1072,6 +1081,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
output,
is_embed=pos_info.is_embed,
)
self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash)
def _batch_mm_kwargs_from_scheduler(
self,
@@ -1597,15 +1607,19 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# _prepare_inputs may reorder the batch, so we must gather
# multi-modal outputs after that to ensure the correct order
if self.is_multimodal_model:
# Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output)
with self.maybe_get_ec_connector_output(
scheduler_output,
encoder_cache=self.encoder_cache,
):
# Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output)
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
input_ids = self.input_ids[:total_num_scheduled_tokens]
mm_embeds, is_mm_embed = self._gather_mm_embeddings(
scheduler_output)
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
input_ids = self.input_ids[:total_num_scheduled_tokens]
mm_embeds, is_mm_embed = self._gather_mm_embeddings(
scheduler_output)
inputs_embeds = self.model.embed_input_ids(
input_ids,
@@ -2248,6 +2262,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
with ProfileExecuteDuration().capture_async("prepare input"):
self._update_states(scheduler_output)
if has_ec_transfer() and get_ec_transfer().is_producer:
with self.maybe_get_ec_connector_output(
scheduler_output,
encoder_cache=self.encoder_cache,
):
self._execute_mm_encoder(scheduler_output)
return make_empty_encoder_model_runner_output(
scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
if not has_kv_transfer_group():
logger.debug(
@@ -3741,6 +3764,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
KVCacheSpec: A dictionary mapping layer names to their KV cache
format. Layers that do not need KV cache are not included.
"""
if has_ec_transfer() and get_ec_transfer().is_producer:
return {}
block_size = self.vllm_config.cache_config.block_size
use_mla = self.vllm_config.model_config.use_mla
kv_cache_spec: dict[str, KVCacheSpec] = {}

View File

@@ -30,6 +30,7 @@ from torch_npu.profiler import dynamic_profile as dp
from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import logger
@@ -417,6 +418,7 @@ class NPUWorker(WorkerBase):
self.parallel_config.decode_context_parallel_size)
init_ascend_model_parallel(self.parallel_config)
ensure_kv_transfer_initialized(self.vllm_config)
ensure_ec_transfer_initialized(self.vllm_config)
def _init_profiler(self):
# Torch profiler. Enabled and configured through env vars: