[Core] Encoder separation for Encode-Prefill-Decode Disaggregation (#4176)
### What this PR does / why we need it? Support Encoder separation for Encode-Prefill-Decode Disaggregation - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 Signed-off-by: amy-why-3459 <wuhaiyan17@huawei.com>
This commit is contained in:
@@ -18,6 +18,7 @@ import os
|
|||||||
|
|
||||||
import vllm_ascend.patch.platform.patch_config # noqa
|
import vllm_ascend.patch.platform.patch_config # noqa
|
||||||
import vllm_ascend.patch.platform.patch_distributed # noqa
|
import vllm_ascend.patch.platform.patch_distributed # noqa
|
||||||
|
import vllm_ascend.patch.platform.patch_ec_connector # noqa
|
||||||
import vllm_ascend.patch.platform.patch_mamba_config # noqa
|
import vllm_ascend.patch.platform.patch_mamba_config # noqa
|
||||||
import vllm_ascend.patch.platform.patch_sched_yield # noqa
|
import vllm_ascend.patch.platform.patch_sched_yield # noqa
|
||||||
|
|
||||||
|
|||||||
32
vllm_ascend/patch/platform/patch_ec_connector.py
Normal file
32
vllm_ascend/patch/platform/patch_ec_connector.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
import vllm.distributed.ec_transfer.ec_connector.shared_storage_connector
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata
|
||||||
|
from vllm.distributed.ec_transfer.ec_connector.shared_storage_connector import (
|
||||||
|
ECSharedStorageConnector, ECSharedStorageConnectorMetadata)
|
||||||
|
from vllm.logger import logger
|
||||||
|
|
||||||
|
|
||||||
|
class AscendECSharedStorageConnector(ECSharedStorageConnector):
|
||||||
|
|
||||||
|
def start_load_caches(self, encoder_cache, **kwargs) -> None:
|
||||||
|
metadata: ECConnectorMetadata = self._get_connector_metadata()
|
||||||
|
assert isinstance(metadata, ECSharedStorageConnectorMetadata)
|
||||||
|
assert encoder_cache is not None
|
||||||
|
if metadata is None:
|
||||||
|
logger.warning((
|
||||||
|
"In connector.start_load_caches, ",
|
||||||
|
"but the connector metadata is None",
|
||||||
|
))
|
||||||
|
return
|
||||||
|
# Load the EC for each mm data
|
||||||
|
for mm_data in metadata.mm_datas:
|
||||||
|
if mm_data.mm_hash in encoder_cache:
|
||||||
|
continue
|
||||||
|
filename = self._generate_filename_debug(mm_data.mm_hash)
|
||||||
|
ec_cache = load_file(filename)["ec_cache"].npu()
|
||||||
|
encoder_cache[mm_data.mm_hash] = ec_cache
|
||||||
|
logger.debug("Success load encoder cache for hash %s",
|
||||||
|
mm_data.mm_hash)
|
||||||
|
|
||||||
|
|
||||||
|
vllm.distributed.ec_transfer.ec_connector.shared_storage_connector.ECSharedStorageConnector = AscendECSharedStorageConnector
|
||||||
@@ -47,6 +47,7 @@ from vllm.compilation.monitor import set_cudagraph_capturing_enabled
|
|||||||
from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig,
|
from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig,
|
||||||
get_layers_from_vllm_config)
|
get_layers_from_vllm_config)
|
||||||
from vllm.distributed import tensor_model_parallel_all_gather
|
from vllm.distributed import tensor_model_parallel_all_gather
|
||||||
|
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
|
||||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||||
has_kv_transfer_group)
|
has_kv_transfer_group)
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||||
@@ -91,13 +92,16 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
|
|||||||
UniformTypeKVCacheSpecs)
|
UniformTypeKVCacheSpecs)
|
||||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
||||||
DraftTokenIds, LogprobsTensors, ModelRunnerOutput,
|
DraftTokenIds, LogprobsTensors, ModelRunnerOutput,
|
||||||
PoolerOutput)
|
PoolerOutput,
|
||||||
|
make_empty_encoder_model_runner_output)
|
||||||
from vllm.v1.pool.metadata import PoolingMetadata
|
from vllm.v1.pool.metadata import PoolingMetadata
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||||
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
|
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
|
||||||
from vllm.v1.utils import CpuGpuBuffer
|
from vllm.v1.utils import CpuGpuBuffer
|
||||||
|
from vllm.v1.worker.ec_connector_model_runner_mixin import \
|
||||||
|
ECConnectorModelRunnerMixin
|
||||||
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput
|
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput
|
||||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||||
from vllm.v1.worker.utils import (AttentionGroup, bind_kv_cache,
|
from vllm.v1.worker.utils import (AttentionGroup, bind_kv_cache,
|
||||||
@@ -268,7 +272,7 @@ class ExecuteModelState(NamedTuple):
|
|||||||
positions: torch.Tensor
|
positions: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
class NPUModelRunner(LoRAModelRunnerMixin):
|
class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
@@ -791,6 +795,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
req_ids_to_add.append(req_id)
|
req_ids_to_add.append(req_id)
|
||||||
|
|
||||||
|
# If this rank is an EC transfer producer,
|
||||||
|
# skip updating the states of KV cache blocks.
|
||||||
|
if has_ec_transfer() and get_ec_transfer().is_producer:
|
||||||
|
return
|
||||||
|
|
||||||
# Update the states of the running/resumed requests.
|
# Update the states of the running/resumed requests.
|
||||||
is_last_rank = get_pp_group().is_last_rank
|
is_last_rank = get_pp_group().is_last_rank
|
||||||
req_data = scheduler_output.scheduled_cached_reqs
|
req_data = scheduler_output.scheduled_cached_reqs
|
||||||
@@ -1072,6 +1081,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
output,
|
output,
|
||||||
is_embed=pos_info.is_embed,
|
is_embed=pos_info.is_embed,
|
||||||
)
|
)
|
||||||
|
self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash)
|
||||||
|
|
||||||
def _batch_mm_kwargs_from_scheduler(
|
def _batch_mm_kwargs_from_scheduler(
|
||||||
self,
|
self,
|
||||||
@@ -1597,15 +1607,19 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# _prepare_inputs may reorder the batch, so we must gather
|
# _prepare_inputs may reorder the batch, so we must gather
|
||||||
# multi-modal outputs after that to ensure the correct order
|
# multi-modal outputs after that to ensure the correct order
|
||||||
if self.is_multimodal_model:
|
if self.is_multimodal_model:
|
||||||
# Run the multimodal encoder if any.
|
with self.maybe_get_ec_connector_output(
|
||||||
self._execute_mm_encoder(scheduler_output)
|
scheduler_output,
|
||||||
|
encoder_cache=self.encoder_cache,
|
||||||
|
):
|
||||||
|
# Run the multimodal encoder if any.
|
||||||
|
self._execute_mm_encoder(scheduler_output)
|
||||||
|
|
||||||
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
||||||
# embeddings), we always use embeddings (rather than token ids)
|
# embeddings), we always use embeddings (rather than token ids)
|
||||||
# as input to the multimodal model, even when the input is text.
|
# as input to the multimodal model, even when the input is text.
|
||||||
input_ids = self.input_ids[:total_num_scheduled_tokens]
|
input_ids = self.input_ids[:total_num_scheduled_tokens]
|
||||||
mm_embeds, is_mm_embed = self._gather_mm_embeddings(
|
mm_embeds, is_mm_embed = self._gather_mm_embeddings(
|
||||||
scheduler_output)
|
scheduler_output)
|
||||||
|
|
||||||
inputs_embeds = self.model.embed_input_ids(
|
inputs_embeds = self.model.embed_input_ids(
|
||||||
input_ids,
|
input_ids,
|
||||||
@@ -2248,6 +2262,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
with ProfileExecuteDuration().capture_async("prepare input"):
|
with ProfileExecuteDuration().capture_async("prepare input"):
|
||||||
self._update_states(scheduler_output)
|
self._update_states(scheduler_output)
|
||||||
|
if has_ec_transfer() and get_ec_transfer().is_producer:
|
||||||
|
with self.maybe_get_ec_connector_output(
|
||||||
|
scheduler_output,
|
||||||
|
encoder_cache=self.encoder_cache,
|
||||||
|
):
|
||||||
|
self._execute_mm_encoder(scheduler_output)
|
||||||
|
return make_empty_encoder_model_runner_output(
|
||||||
|
scheduler_output)
|
||||||
|
|
||||||
if not scheduler_output.total_num_scheduled_tokens:
|
if not scheduler_output.total_num_scheduled_tokens:
|
||||||
if not has_kv_transfer_group():
|
if not has_kv_transfer_group():
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -3741,6 +3764,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
KVCacheSpec: A dictionary mapping layer names to their KV cache
|
KVCacheSpec: A dictionary mapping layer names to their KV cache
|
||||||
format. Layers that do not need KV cache are not included.
|
format. Layers that do not need KV cache are not included.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if has_ec_transfer() and get_ec_transfer().is_producer:
|
||||||
|
return {}
|
||||||
|
|
||||||
block_size = self.vllm_config.cache_config.block_size
|
block_size = self.vllm_config.cache_config.block_size
|
||||||
use_mla = self.vllm_config.model_config.use_mla
|
use_mla = self.vllm_config.model_config.use_mla
|
||||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from torch_npu.profiler import dynamic_profile as dp
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||||
init_distributed_environment)
|
init_distributed_environment)
|
||||||
|
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
|
||||||
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
||||||
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
@@ -417,6 +418,7 @@ class NPUWorker(WorkerBase):
|
|||||||
self.parallel_config.decode_context_parallel_size)
|
self.parallel_config.decode_context_parallel_size)
|
||||||
init_ascend_model_parallel(self.parallel_config)
|
init_ascend_model_parallel(self.parallel_config)
|
||||||
ensure_kv_transfer_initialized(self.vllm_config)
|
ensure_kv_transfer_initialized(self.vllm_config)
|
||||||
|
ensure_ec_transfer_initialized(self.vllm_config)
|
||||||
|
|
||||||
def _init_profiler(self):
|
def _init_profiler(self):
|
||||||
# Torch profiler. Enabled and configured through env vars:
|
# Torch profiler. Enabled and configured through env vars:
|
||||||
|
|||||||
Reference in New Issue
Block a user