[feature]Pooling Features and PCP Adaptation (#4143)

This PR let pooling kv connector support pcp feature

- vLLM version: v0.11.2

---------

Signed-off-by: fjw <2270923832@qq.com>
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
Co-authored-by: SlightwindSec <slightwindsec@gmail.com>
This commit is contained in:
DreamerLeader
2025-11-29 22:07:45 +08:00
committed by GitHub
parent 1eb5295a1b
commit 4dbe4fd123
5 changed files with 89 additions and 29 deletions

View File

@@ -43,8 +43,6 @@ class AscendStoreConnector(KVConnectorBase_V1):
self.kv_caches: dict[str, torch.Tensor] = {} self.kv_caches: dict[str, torch.Tensor] = {}
self._block_size = vllm_config.cache_config.block_size
self.sended_but_unfinished_reqs: set[str] = set() self.sended_but_unfinished_reqs: set[str] = set()
if role == KVConnectorRole.SCHEDULER: if role == KVConnectorRole.SCHEDULER:

View File

@@ -17,6 +17,10 @@ class KeyMetadata:
model_name: str model_name: str
""" worker id when running under a distributed setting """ """ worker id when running under a distributed setting """
head_or_tp_rank: int head_or_tp_rank: int
""" Initialize the current prefill context model parallel rank """
pcp_rank: int
""" Initialize the current decode context model parallel rank """
dcp_rank: int
@dataclass(order=True) @dataclass(order=True)
@@ -28,12 +32,15 @@ class PoolKey:
return hash(( return hash((
self.key_metadata.model_name, self.key_metadata.model_name,
self.key_metadata.head_or_tp_rank, self.key_metadata.head_or_tp_rank,
self.key_metadata.pcp_rank,
self.key_metadata.dcp_rank,
self.chunk_hash, self.chunk_hash,
)) ))
def to_string(self): def to_string(self):
return ( return (
f"{self.key_metadata.model_name}" f"{self.key_metadata.model_name}"
f"@pcp{self.key_metadata.pcp_rank}@dcp{self.key_metadata.dcp_rank}"
f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}@{self.chunk_hash}" f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}@{self.chunk_hash}"
) )
@@ -60,6 +67,8 @@ class LayerPoolKey(PoolKey):
return hash(( return hash((
self.key_metadata.model_name, self.key_metadata.model_name,
self.key_metadata.head_or_tp_rank, self.key_metadata.head_or_tp_rank,
self.key_metadata.pcp_rank,
self.key_metadata.dcp_rank,
self.chunk_hash, self.chunk_hash,
self.layer_id, self.layer_id,
)) ))
@@ -67,6 +76,7 @@ class LayerPoolKey(PoolKey):
def to_string(self): def to_string(self):
return ( return (
f"{self.key_metadata.model_name}" f"{self.key_metadata.model_name}"
f"@pcp{self.key_metadata.pcp_rank}@dcp{self.key_metadata.dcp_rank}"
f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}@{self.chunk_hash}@{self.layer_id}" f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}@{self.chunk_hash}@{self.layer_id}"
) )

View File

@@ -19,11 +19,13 @@ from vllm_ascend.distributed.kvpool.config_data import (ChunkedTokenDatabase,
class KVTransferThread(threading.Thread): class KVTransferThread(threading.Thread):
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
tp_rank: int, ready_event: threading.Event, name: str): tp_rank: int, dcp_size: int, ready_event: threading.Event,
name: str):
super().__init__(daemon=True, name=name) super().__init__(daemon=True, name=name)
self.m_store = m_store self.m_store = m_store
self.ready_event = ready_event self.ready_event = ready_event
self.tp_rank = tp_rank self.tp_rank = tp_rank
self.dcp_size = dcp_size
self.token_database = token_database self.token_database = token_database
self.done_task_lock = threading.Lock() self.done_task_lock = threading.Lock()
self.request_queue: queue.Queue[Any] = queue.Queue() self.request_queue: queue.Queue[Any] = queue.Queue()
@@ -87,10 +89,12 @@ class KVTransferThread(threading.Thread):
class KVCacheStoreSendingThread(KVTransferThread): class KVCacheStoreSendingThread(KVTransferThread):
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
tp_rank: int, put_step: int, ready_event: threading.Event): tp_rank: int, dcp_size: int, put_step: int,
ready_event: threading.Event):
super().__init__(m_store, super().__init__(m_store,
token_database, token_database,
tp_rank, tp_rank,
dcp_size,
ready_event, ready_event,
name="KVCacheSendingThread") name="KVCacheSendingThread")
self.put_step = put_step self.put_step = put_step
@@ -112,9 +116,15 @@ class KVCacheStoreSendingThread(KVTransferThread):
key_list.append(key.to_string()) key_list.append(key.to_string())
addr_list.append(addr) addr_list.append(addr)
size_list.append(size) size_list.append(size)
if self.dcp_size > 1:
torch.npu.current_stream().synchronize()
self.m_store.put(key_list, addr_list, size_list)
else:
key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step] key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step]
addr_list_tp = addr_list[self.tp_rank % self.put_step::self.put_step] addr_list_tp = addr_list[self.tp_rank %
size_list_tp = size_list[self.tp_rank % self.put_step::self.put_step] self.put_step::self.put_step]
size_list_tp = size_list[self.tp_rank %
self.put_step::self.put_step]
if key_list_tp: if key_list_tp:
torch.npu.current_stream().synchronize() torch.npu.current_stream().synchronize()
self.m_store.put(key_list_tp, addr_list_tp, size_list_tp) self.m_store.put(key_list_tp, addr_list_tp, size_list_tp)
@@ -126,10 +136,11 @@ class KVCacheStoreSendingThread(KVTransferThread):
class KVCacheStoreRecvingThread(KVTransferThread): class KVCacheStoreRecvingThread(KVTransferThread):
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
tp_rank: int, ready_event: threading.Event): tp_rank: int, dcp_size: int, ready_event: threading.Event):
super().__init__(m_store, super().__init__(m_store,
token_database, token_database,
tp_rank, tp_rank,
dcp_size,
ready_event, ready_event,
name="KVCacheStoreRecvingThread") name="KVCacheStoreRecvingThread")
@@ -166,11 +177,12 @@ class KVCacheStoreRecvingThread(KVTransferThread):
class KVCacheStoreLayerSendingThread(KVTransferThread): class KVCacheStoreLayerSendingThread(KVTransferThread):
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
tp_rank: int, put_step: int, ready_event: threading.Event, tp_rank: int, dcp_size: int, put_step: int,
num_layers: int): ready_event: threading.Event, num_layers: int):
super().__init__(m_store, super().__init__(m_store,
token_database, token_database,
tp_rank, tp_rank,
dcp_size,
ready_event, ready_event,
name="KVCacheStoreLayerSendingThread") name="KVCacheStoreLayerSendingThread")
self.final_layer_id = num_layers - 1 self.final_layer_id = num_layers - 1
@@ -192,9 +204,15 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
key_list.append(key.to_string()) key_list.append(key.to_string())
addr_list.append(addr) addr_list.append(addr)
size_list.append(size) size_list.append(size)
if self.dcp_size > 1:
torch.npu.current_stream().synchronize()
self.m_store.put(key_list, addr_list, size_list)
else:
key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step] key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step]
addr_list_tp = addr_list[self.tp_rank % self.put_step::self.put_step] addr_list_tp = addr_list[self.tp_rank %
size_list_tp = size_list[self.tp_rank % self.put_step::self.put_step] self.put_step::self.put_step]
size_list_tp = size_list[self.tp_rank %
self.put_step::self.put_step]
if key_list_tp: if key_list_tp:
torch.npu.current_stream().synchronize() torch.npu.current_stream().synchronize()
self.m_store.put(key_list_tp, addr_list_tp, size_list_tp) self.m_store.put(key_list_tp, addr_list_tp, size_list_tp)
@@ -206,11 +224,12 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
class KVCacheStoreLayerRecvingThread(KVTransferThread): class KVCacheStoreLayerRecvingThread(KVTransferThread):
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
tp_rank: int, ready_event: threading.Event, tp_rank: int, dcp_size: int, ready_event: threading.Event,
get_event: threading.Event): get_event: threading.Event):
super().__init__(m_store, super().__init__(m_store,
token_database, token_database,
tp_rank, tp_rank,
dcp_size,
ready_event, ready_event,
name="KVCacheStoreLayerRecvingThread") name="KVCacheStoreLayerRecvingThread")
self.get_event = get_event self.get_event = get_event

View File

@@ -29,7 +29,14 @@ class KVPoolScheduler:
"load_async", False) "load_async", False)
# request_id -> (vllm cached tokes, kvpool cached tokens) # request_id -> (vllm cached tokes, kvpool cached tokens)
self.load_specs: dict[str, LoadSpec] = {} self.load_specs: dict[str, LoadSpec] = {}
self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size
self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size
self._block_size = vllm_config.cache_config.block_size self._block_size = vllm_config.cache_config.block_size
if self.pcp_size > 1:
self._block_size *= self.pcp_size
if self.dcp_size > 1:
self._block_size *= self.dcp_size
# request_id -> full_token_ids # request_id -> full_token_ids
self._request_trackers: dict[str, RequestTracker] = {} self._request_trackers: dict[str, RequestTracker] = {}
# Whether to discard partial chunks # Whether to discard partial chunks

View File

@@ -1,11 +1,13 @@
# Standard
import math import math
import threading import threading
from typing import Dict, Generator, Optional, Type from typing import Dict, Generator, Optional, Type
# Third Party
import torch import torch
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import (get_decode_context_model_parallel_rank,
get_decode_context_model_parallel_world_size,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.utils import logger from vllm.utils import logger
from vllm.v1.core.kv_cache_utils import BlockHash from vllm.v1.core.kv_cache_utils import BlockHash
@@ -20,6 +22,14 @@ from vllm_ascend.distributed.kvpool.config_data import (
from vllm_ascend.distributed.kvpool.kv_transfer import ( from vllm_ascend.distributed.kvpool.kv_transfer import (
KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread, KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread,
KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread) KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread)
from vllm_ascend.utils import prefill_context_parallel_enable
if prefill_context_parallel_enable():
# isort: off
from vllm.distributed import (get_prefill_context_model_parallel_rank,
get_prefill_context_model_parallel_world_size
)
# isort: on
backend_map: Dict[str, Type[Backend]] = { backend_map: Dict[str, Type[Backend]] = {
"mooncake": MooncakeBackend, "mooncake": MooncakeBackend,
@@ -44,17 +54,30 @@ class KVPoolWorker:
and model_config.use_mla): and model_config.use_mla):
self.use_mla = True self.use_mla = True
self.use_layerwise = use_layerwize self.use_layerwise = use_layerwize
self.tp_rank = parallel_config.rank self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = parallel_config.tensor_parallel_size self.tp_size = get_tensor_model_parallel_world_size()
self.pcp_size = get_prefill_context_model_parallel_world_size(
) if prefill_context_parallel_enable() else 1
self.pcp_rank = get_prefill_context_model_parallel_rank(
) if self.pcp_size > 1 else 0
self.dcp_size = get_decode_context_model_parallel_world_size()
self.dcp_rank = get_decode_context_model_parallel_rank(
) if self.dcp_size > 1 else 0
self.kv_role = vllm_config.kv_transfer_config.kv_role self.kv_role = vllm_config.kv_transfer_config.kv_role
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get( self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"load_async", False) "load_async", False)
self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get( self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"backend", "mooncake") "backend", "mooncake")
self.block_size = vllm_config.cache_config.block_size self.block_size = vllm_config.cache_config.block_size
if self.pcp_size > 1:
self.block_size *= self.pcp_size
if self.dcp_size > 1:
self.block_size *= self.dcp_size
self.current_layer = 0 self.current_layer = 0
self.num_layers = model_config.get_num_layers(parallel_config) self.num_layers = model_config.get_num_layers(parallel_config)
self.block_size = vllm_config.cache_config.block_size
if self.use_mla: if self.use_mla:
self.num_kv_head = 1 self.num_kv_head = 1
@@ -69,8 +92,10 @@ class KVPoolWorker:
self.put_step = 1 self.put_step = 1
self.metadata = KeyMetadata( self.metadata = KeyMetadata(
model_config.model, model_config.model.split('/')[-1],
self.head_or_tp_rank, self.head_or_tp_rank,
self.pcp_rank,
self.dcp_rank,
) )
self.token_database = ChunkedTokenDatabase(self.metadata, self.token_database = ChunkedTokenDatabase(self.metadata,
@@ -147,12 +172,13 @@ class KVPoolWorker:
ready_event_sending = threading.Event() ready_event_sending = threading.Event()
self.kv_send_thread = KVCacheStoreLayerSendingThread( self.kv_send_thread = KVCacheStoreLayerSendingThread(
self.m_store, self.token_database, self.tp_rank, self.m_store, self.token_database, self.tp_rank,
self.put_step, ready_event_sending, self.num_layers) self.dcp_size, self.put_step, ready_event_sending,
self.num_layers)
self.kv_send_thread.start() self.kv_send_thread.start()
ready_event = threading.Event() ready_event = threading.Event()
self.kv_recv_thread = KVCacheStoreLayerRecvingThread( self.kv_recv_thread = KVCacheStoreLayerRecvingThread(
self.m_store, self.token_database, self.tp_rank, ready_event, self.m_store, self.token_database, self.tp_rank, self.dcp_size,
self.get_event) ready_event, self.get_event)
self.kv_recv_thread.start() self.kv_recv_thread.start()
ready_event.wait() ready_event.wait()
else: else:
@@ -160,13 +186,13 @@ class KVPoolWorker:
ready_event_sending = threading.Event() ready_event_sending = threading.Event()
self.kv_send_thread = KVCacheStoreSendingThread( self.kv_send_thread = KVCacheStoreSendingThread(
self.m_store, self.token_database, self.tp_rank, self.m_store, self.token_database, self.tp_rank,
self.put_step, ready_event_sending) self.dcp_size, self.put_step, ready_event_sending)
self.kv_send_thread.start() self.kv_send_thread.start()
if self.load_async: if self.load_async:
ready_event = threading.Event() ready_event = threading.Event()
self.kv_recv_thread = KVCacheStoreRecvingThread( self.kv_recv_thread = KVCacheStoreRecvingThread(
self.m_store, self.token_database, self.tp_rank, self.m_store, self.token_database, self.tp_rank,
ready_event) self.dcp_size, ready_event)
self.kv_recv_thread.start() self.kv_recv_thread.start()
ready_event.wait() ready_event.wait()