[feature]Pooling Features and PCP Adaptation (#4143)

This PR let pooling kv connector support pcp feature

- vLLM version: v0.11.2

---------

Signed-off-by: fjw <2270923832@qq.com>
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
Co-authored-by: SlightwindSec <slightwindsec@gmail.com>
This commit is contained in:
DreamerLeader
2025-11-29 22:07:45 +08:00
committed by GitHub
parent 1eb5295a1b
commit 4dbe4fd123
5 changed files with 89 additions and 29 deletions

View File

@@ -1,11 +1,13 @@
# Standard
import math
import threading
from typing import Dict, Generator, Optional, Type
# Third Party
import torch
from vllm.config import VllmConfig
from vllm.distributed import (get_decode_context_model_parallel_rank,
get_decode_context_model_parallel_world_size,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.utils import logger
from vllm.v1.core.kv_cache_utils import BlockHash
@@ -20,6 +22,14 @@ from vllm_ascend.distributed.kvpool.config_data import (
from vllm_ascend.distributed.kvpool.kv_transfer import (
KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread,
KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread)
from vllm_ascend.utils import prefill_context_parallel_enable
if prefill_context_parallel_enable():
# isort: off
from vllm.distributed import (get_prefill_context_model_parallel_rank,
get_prefill_context_model_parallel_world_size
)
# isort: on
backend_map: Dict[str, Type[Backend]] = {
"mooncake": MooncakeBackend,
@@ -44,17 +54,30 @@ class KVPoolWorker:
and model_config.use_mla):
self.use_mla = True
self.use_layerwise = use_layerwize
self.tp_rank = parallel_config.rank
self.tp_size = parallel_config.tensor_parallel_size
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.pcp_size = get_prefill_context_model_parallel_world_size(
) if prefill_context_parallel_enable() else 1
self.pcp_rank = get_prefill_context_model_parallel_rank(
) if self.pcp_size > 1 else 0
self.dcp_size = get_decode_context_model_parallel_world_size()
self.dcp_rank = get_decode_context_model_parallel_rank(
) if self.dcp_size > 1 else 0
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"load_async", False)
self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"backend", "mooncake")
self.block_size = vllm_config.cache_config.block_size
if self.pcp_size > 1:
self.block_size *= self.pcp_size
if self.dcp_size > 1:
self.block_size *= self.dcp_size
self.current_layer = 0
self.num_layers = model_config.get_num_layers(parallel_config)
self.block_size = vllm_config.cache_config.block_size
if self.use_mla:
self.num_kv_head = 1
@@ -69,8 +92,10 @@ class KVPoolWorker:
self.put_step = 1
self.metadata = KeyMetadata(
model_config.model,
model_config.model.split('/')[-1],
self.head_or_tp_rank,
self.pcp_rank,
self.dcp_rank,
)
self.token_database = ChunkedTokenDatabase(self.metadata,
@@ -147,12 +172,13 @@ class KVPoolWorker:
ready_event_sending = threading.Event()
self.kv_send_thread = KVCacheStoreLayerSendingThread(
self.m_store, self.token_database, self.tp_rank,
self.put_step, ready_event_sending, self.num_layers)
self.dcp_size, self.put_step, ready_event_sending,
self.num_layers)
self.kv_send_thread.start()
ready_event = threading.Event()
self.kv_recv_thread = KVCacheStoreLayerRecvingThread(
self.m_store, self.token_database, self.tp_rank, ready_event,
self.get_event)
self.m_store, self.token_database, self.tp_rank, self.dcp_size,
ready_event, self.get_event)
self.kv_recv_thread.start()
ready_event.wait()
else:
@@ -160,13 +186,13 @@ class KVPoolWorker:
ready_event_sending = threading.Event()
self.kv_send_thread = KVCacheStoreSendingThread(
self.m_store, self.token_database, self.tp_rank,
self.put_step, ready_event_sending)
self.dcp_size, self.put_step, ready_event_sending)
self.kv_send_thread.start()
if self.load_async:
ready_event = threading.Event()
self.kv_recv_thread = KVCacheStoreRecvingThread(
self.m_store, self.token_database, self.tp_rank,
ready_event)
self.dcp_size, ready_event)
self.kv_recv_thread.start()
ready_event.wait()