[feature]Pooling Features and PCP Adaptation (#4143)
This PR let pooling kv connector support pcp feature - vLLM version: v0.11.2 --------- Signed-off-by: fjw <2270923832@qq.com> Signed-off-by: SlightwindSec <slightwindsec@gmail.com> Co-authored-by: SlightwindSec <slightwindsec@gmail.com>
This commit is contained in:
@@ -43,8 +43,6 @@ class AscendStoreConnector(KVConnectorBase_V1):
|
|||||||
|
|
||||||
self.kv_caches: dict[str, torch.Tensor] = {}
|
self.kv_caches: dict[str, torch.Tensor] = {}
|
||||||
|
|
||||||
self._block_size = vllm_config.cache_config.block_size
|
|
||||||
|
|
||||||
self.sended_but_unfinished_reqs: set[str] = set()
|
self.sended_but_unfinished_reqs: set[str] = set()
|
||||||
|
|
||||||
if role == KVConnectorRole.SCHEDULER:
|
if role == KVConnectorRole.SCHEDULER:
|
||||||
|
|||||||
@@ -17,6 +17,10 @@ class KeyMetadata:
|
|||||||
model_name: str
|
model_name: str
|
||||||
""" worker id when running under a distributed setting """
|
""" worker id when running under a distributed setting """
|
||||||
head_or_tp_rank: int
|
head_or_tp_rank: int
|
||||||
|
""" Initialize the current prefill context model parallel rank """
|
||||||
|
pcp_rank: int
|
||||||
|
""" Initialize the current decode context model parallel rank """
|
||||||
|
dcp_rank: int
|
||||||
|
|
||||||
|
|
||||||
@dataclass(order=True)
|
@dataclass(order=True)
|
||||||
@@ -28,12 +32,15 @@ class PoolKey:
|
|||||||
return hash((
|
return hash((
|
||||||
self.key_metadata.model_name,
|
self.key_metadata.model_name,
|
||||||
self.key_metadata.head_or_tp_rank,
|
self.key_metadata.head_or_tp_rank,
|
||||||
|
self.key_metadata.pcp_rank,
|
||||||
|
self.key_metadata.dcp_rank,
|
||||||
self.chunk_hash,
|
self.chunk_hash,
|
||||||
))
|
))
|
||||||
|
|
||||||
def to_string(self):
|
def to_string(self):
|
||||||
return (
|
return (
|
||||||
f"{self.key_metadata.model_name}"
|
f"{self.key_metadata.model_name}"
|
||||||
|
f"@pcp{self.key_metadata.pcp_rank}@dcp{self.key_metadata.dcp_rank}"
|
||||||
f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}@{self.chunk_hash}"
|
f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}@{self.chunk_hash}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -60,6 +67,8 @@ class LayerPoolKey(PoolKey):
|
|||||||
return hash((
|
return hash((
|
||||||
self.key_metadata.model_name,
|
self.key_metadata.model_name,
|
||||||
self.key_metadata.head_or_tp_rank,
|
self.key_metadata.head_or_tp_rank,
|
||||||
|
self.key_metadata.pcp_rank,
|
||||||
|
self.key_metadata.dcp_rank,
|
||||||
self.chunk_hash,
|
self.chunk_hash,
|
||||||
self.layer_id,
|
self.layer_id,
|
||||||
))
|
))
|
||||||
@@ -67,6 +76,7 @@ class LayerPoolKey(PoolKey):
|
|||||||
def to_string(self):
|
def to_string(self):
|
||||||
return (
|
return (
|
||||||
f"{self.key_metadata.model_name}"
|
f"{self.key_metadata.model_name}"
|
||||||
|
f"@pcp{self.key_metadata.pcp_rank}@dcp{self.key_metadata.dcp_rank}"
|
||||||
f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}@{self.chunk_hash}@{self.layer_id}"
|
f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}@{self.chunk_hash}@{self.layer_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -19,11 +19,13 @@ from vllm_ascend.distributed.kvpool.config_data import (ChunkedTokenDatabase,
|
|||||||
class KVTransferThread(threading.Thread):
|
class KVTransferThread(threading.Thread):
|
||||||
|
|
||||||
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
||||||
tp_rank: int, ready_event: threading.Event, name: str):
|
tp_rank: int, dcp_size: int, ready_event: threading.Event,
|
||||||
|
name: str):
|
||||||
super().__init__(daemon=True, name=name)
|
super().__init__(daemon=True, name=name)
|
||||||
self.m_store = m_store
|
self.m_store = m_store
|
||||||
self.ready_event = ready_event
|
self.ready_event = ready_event
|
||||||
self.tp_rank = tp_rank
|
self.tp_rank = tp_rank
|
||||||
|
self.dcp_size = dcp_size
|
||||||
self.token_database = token_database
|
self.token_database = token_database
|
||||||
self.done_task_lock = threading.Lock()
|
self.done_task_lock = threading.Lock()
|
||||||
self.request_queue: queue.Queue[Any] = queue.Queue()
|
self.request_queue: queue.Queue[Any] = queue.Queue()
|
||||||
@@ -87,10 +89,12 @@ class KVTransferThread(threading.Thread):
|
|||||||
class KVCacheStoreSendingThread(KVTransferThread):
|
class KVCacheStoreSendingThread(KVTransferThread):
|
||||||
|
|
||||||
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
||||||
tp_rank: int, put_step: int, ready_event: threading.Event):
|
tp_rank: int, dcp_size: int, put_step: int,
|
||||||
|
ready_event: threading.Event):
|
||||||
super().__init__(m_store,
|
super().__init__(m_store,
|
||||||
token_database,
|
token_database,
|
||||||
tp_rank,
|
tp_rank,
|
||||||
|
dcp_size,
|
||||||
ready_event,
|
ready_event,
|
||||||
name="KVCacheSendingThread")
|
name="KVCacheSendingThread")
|
||||||
self.put_step = put_step
|
self.put_step = put_step
|
||||||
@@ -112,12 +116,18 @@ class KVCacheStoreSendingThread(KVTransferThread):
|
|||||||
key_list.append(key.to_string())
|
key_list.append(key.to_string())
|
||||||
addr_list.append(addr)
|
addr_list.append(addr)
|
||||||
size_list.append(size)
|
size_list.append(size)
|
||||||
key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step]
|
if self.dcp_size > 1:
|
||||||
addr_list_tp = addr_list[self.tp_rank % self.put_step::self.put_step]
|
|
||||||
size_list_tp = size_list[self.tp_rank % self.put_step::self.put_step]
|
|
||||||
if key_list_tp:
|
|
||||||
torch.npu.current_stream().synchronize()
|
torch.npu.current_stream().synchronize()
|
||||||
self.m_store.put(key_list_tp, addr_list_tp, size_list_tp)
|
self.m_store.put(key_list, addr_list, size_list)
|
||||||
|
else:
|
||||||
|
key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step]
|
||||||
|
addr_list_tp = addr_list[self.tp_rank %
|
||||||
|
self.put_step::self.put_step]
|
||||||
|
size_list_tp = size_list[self.tp_rank %
|
||||||
|
self.put_step::self.put_step]
|
||||||
|
if key_list_tp:
|
||||||
|
torch.npu.current_stream().synchronize()
|
||||||
|
self.m_store.put(key_list_tp, addr_list_tp, size_list_tp)
|
||||||
if is_last_chunk:
|
if is_last_chunk:
|
||||||
self.set_finished_request(req_id)
|
self.set_finished_request(req_id)
|
||||||
self.request_queue.task_done()
|
self.request_queue.task_done()
|
||||||
@@ -126,10 +136,11 @@ class KVCacheStoreSendingThread(KVTransferThread):
|
|||||||
class KVCacheStoreRecvingThread(KVTransferThread):
|
class KVCacheStoreRecvingThread(KVTransferThread):
|
||||||
|
|
||||||
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
||||||
tp_rank: int, ready_event: threading.Event):
|
tp_rank: int, dcp_size: int, ready_event: threading.Event):
|
||||||
super().__init__(m_store,
|
super().__init__(m_store,
|
||||||
token_database,
|
token_database,
|
||||||
tp_rank,
|
tp_rank,
|
||||||
|
dcp_size,
|
||||||
ready_event,
|
ready_event,
|
||||||
name="KVCacheStoreRecvingThread")
|
name="KVCacheStoreRecvingThread")
|
||||||
|
|
||||||
@@ -166,11 +177,12 @@ class KVCacheStoreRecvingThread(KVTransferThread):
|
|||||||
class KVCacheStoreLayerSendingThread(KVTransferThread):
|
class KVCacheStoreLayerSendingThread(KVTransferThread):
|
||||||
|
|
||||||
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
||||||
tp_rank: int, put_step: int, ready_event: threading.Event,
|
tp_rank: int, dcp_size: int, put_step: int,
|
||||||
num_layers: int):
|
ready_event: threading.Event, num_layers: int):
|
||||||
super().__init__(m_store,
|
super().__init__(m_store,
|
||||||
token_database,
|
token_database,
|
||||||
tp_rank,
|
tp_rank,
|
||||||
|
dcp_size,
|
||||||
ready_event,
|
ready_event,
|
||||||
name="KVCacheStoreLayerSendingThread")
|
name="KVCacheStoreLayerSendingThread")
|
||||||
self.final_layer_id = num_layers - 1
|
self.final_layer_id = num_layers - 1
|
||||||
@@ -192,12 +204,18 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
|
|||||||
key_list.append(key.to_string())
|
key_list.append(key.to_string())
|
||||||
addr_list.append(addr)
|
addr_list.append(addr)
|
||||||
size_list.append(size)
|
size_list.append(size)
|
||||||
key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step]
|
if self.dcp_size > 1:
|
||||||
addr_list_tp = addr_list[self.tp_rank % self.put_step::self.put_step]
|
|
||||||
size_list_tp = size_list[self.tp_rank % self.put_step::self.put_step]
|
|
||||||
if key_list_tp:
|
|
||||||
torch.npu.current_stream().synchronize()
|
torch.npu.current_stream().synchronize()
|
||||||
self.m_store.put(key_list_tp, addr_list_tp, size_list_tp)
|
self.m_store.put(key_list, addr_list, size_list)
|
||||||
|
else:
|
||||||
|
key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step]
|
||||||
|
addr_list_tp = addr_list[self.tp_rank %
|
||||||
|
self.put_step::self.put_step]
|
||||||
|
size_list_tp = size_list[self.tp_rank %
|
||||||
|
self.put_step::self.put_step]
|
||||||
|
if key_list_tp:
|
||||||
|
torch.npu.current_stream().synchronize()
|
||||||
|
self.m_store.put(key_list_tp, addr_list_tp, size_list_tp)
|
||||||
if req_meta.layer_id == self.final_layer_id and req_meta.is_last_chunk:
|
if req_meta.layer_id == self.final_layer_id and req_meta.is_last_chunk:
|
||||||
self.set_finished_request(req_meta.req_id)
|
self.set_finished_request(req_meta.req_id)
|
||||||
self.request_queue.task_done()
|
self.request_queue.task_done()
|
||||||
@@ -206,11 +224,12 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
|
|||||||
class KVCacheStoreLayerRecvingThread(KVTransferThread):
|
class KVCacheStoreLayerRecvingThread(KVTransferThread):
|
||||||
|
|
||||||
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
||||||
tp_rank: int, ready_event: threading.Event,
|
tp_rank: int, dcp_size: int, ready_event: threading.Event,
|
||||||
get_event: threading.Event):
|
get_event: threading.Event):
|
||||||
super().__init__(m_store,
|
super().__init__(m_store,
|
||||||
token_database,
|
token_database,
|
||||||
tp_rank,
|
tp_rank,
|
||||||
|
dcp_size,
|
||||||
ready_event,
|
ready_event,
|
||||||
name="KVCacheStoreLayerRecvingThread")
|
name="KVCacheStoreLayerRecvingThread")
|
||||||
self.get_event = get_event
|
self.get_event = get_event
|
||||||
|
|||||||
@@ -29,7 +29,14 @@ class KVPoolScheduler:
|
|||||||
"load_async", False)
|
"load_async", False)
|
||||||
# request_id -> (vllm cached tokes, kvpool cached tokens)
|
# request_id -> (vllm cached tokes, kvpool cached tokens)
|
||||||
self.load_specs: dict[str, LoadSpec] = {}
|
self.load_specs: dict[str, LoadSpec] = {}
|
||||||
|
self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size
|
||||||
|
self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size
|
||||||
|
|
||||||
self._block_size = vllm_config.cache_config.block_size
|
self._block_size = vllm_config.cache_config.block_size
|
||||||
|
if self.pcp_size > 1:
|
||||||
|
self._block_size *= self.pcp_size
|
||||||
|
if self.dcp_size > 1:
|
||||||
|
self._block_size *= self.dcp_size
|
||||||
# request_id -> full_token_ids
|
# request_id -> full_token_ids
|
||||||
self._request_trackers: dict[str, RequestTracker] = {}
|
self._request_trackers: dict[str, RequestTracker] = {}
|
||||||
# Whether to discard partial chunks
|
# Whether to discard partial chunks
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
# Standard
|
|
||||||
import math
|
import math
|
||||||
import threading
|
import threading
|
||||||
from typing import Dict, Generator, Optional, Type
|
from typing import Dict, Generator, Optional, Type
|
||||||
|
|
||||||
# Third Party
|
|
||||||
import torch
|
import torch
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.distributed import (get_decode_context_model_parallel_rank,
|
||||||
|
get_decode_context_model_parallel_world_size,
|
||||||
|
get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size)
|
||||||
from vllm.utils import logger
|
from vllm.utils import logger
|
||||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||||
|
|
||||||
@@ -20,6 +22,14 @@ from vllm_ascend.distributed.kvpool.config_data import (
|
|||||||
from vllm_ascend.distributed.kvpool.kv_transfer import (
|
from vllm_ascend.distributed.kvpool.kv_transfer import (
|
||||||
KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread,
|
KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread,
|
||||||
KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread)
|
KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread)
|
||||||
|
from vllm_ascend.utils import prefill_context_parallel_enable
|
||||||
|
|
||||||
|
if prefill_context_parallel_enable():
|
||||||
|
# isort: off
|
||||||
|
from vllm.distributed import (get_prefill_context_model_parallel_rank,
|
||||||
|
get_prefill_context_model_parallel_world_size
|
||||||
|
)
|
||||||
|
# isort: on
|
||||||
|
|
||||||
backend_map: Dict[str, Type[Backend]] = {
|
backend_map: Dict[str, Type[Backend]] = {
|
||||||
"mooncake": MooncakeBackend,
|
"mooncake": MooncakeBackend,
|
||||||
@@ -44,17 +54,30 @@ class KVPoolWorker:
|
|||||||
and model_config.use_mla):
|
and model_config.use_mla):
|
||||||
self.use_mla = True
|
self.use_mla = True
|
||||||
self.use_layerwise = use_layerwize
|
self.use_layerwise = use_layerwize
|
||||||
self.tp_rank = parallel_config.rank
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
self.tp_size = parallel_config.tensor_parallel_size
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
|
self.pcp_size = get_prefill_context_model_parallel_world_size(
|
||||||
|
) if prefill_context_parallel_enable() else 1
|
||||||
|
self.pcp_rank = get_prefill_context_model_parallel_rank(
|
||||||
|
) if self.pcp_size > 1 else 0
|
||||||
|
self.dcp_size = get_decode_context_model_parallel_world_size()
|
||||||
|
self.dcp_rank = get_decode_context_model_parallel_rank(
|
||||||
|
) if self.dcp_size > 1 else 0
|
||||||
|
|
||||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||||
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||||
"load_async", False)
|
"load_async", False)
|
||||||
self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||||
"backend", "mooncake")
|
"backend", "mooncake")
|
||||||
self.block_size = vllm_config.cache_config.block_size
|
self.block_size = vllm_config.cache_config.block_size
|
||||||
|
|
||||||
|
if self.pcp_size > 1:
|
||||||
|
self.block_size *= self.pcp_size
|
||||||
|
if self.dcp_size > 1:
|
||||||
|
self.block_size *= self.dcp_size
|
||||||
self.current_layer = 0
|
self.current_layer = 0
|
||||||
self.num_layers = model_config.get_num_layers(parallel_config)
|
self.num_layers = model_config.get_num_layers(parallel_config)
|
||||||
self.block_size = vllm_config.cache_config.block_size
|
|
||||||
|
|
||||||
if self.use_mla:
|
if self.use_mla:
|
||||||
self.num_kv_head = 1
|
self.num_kv_head = 1
|
||||||
@@ -69,8 +92,10 @@ class KVPoolWorker:
|
|||||||
self.put_step = 1
|
self.put_step = 1
|
||||||
|
|
||||||
self.metadata = KeyMetadata(
|
self.metadata = KeyMetadata(
|
||||||
model_config.model,
|
model_config.model.split('/')[-1],
|
||||||
self.head_or_tp_rank,
|
self.head_or_tp_rank,
|
||||||
|
self.pcp_rank,
|
||||||
|
self.dcp_rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.token_database = ChunkedTokenDatabase(self.metadata,
|
self.token_database = ChunkedTokenDatabase(self.metadata,
|
||||||
@@ -147,12 +172,13 @@ class KVPoolWorker:
|
|||||||
ready_event_sending = threading.Event()
|
ready_event_sending = threading.Event()
|
||||||
self.kv_send_thread = KVCacheStoreLayerSendingThread(
|
self.kv_send_thread = KVCacheStoreLayerSendingThread(
|
||||||
self.m_store, self.token_database, self.tp_rank,
|
self.m_store, self.token_database, self.tp_rank,
|
||||||
self.put_step, ready_event_sending, self.num_layers)
|
self.dcp_size, self.put_step, ready_event_sending,
|
||||||
|
self.num_layers)
|
||||||
self.kv_send_thread.start()
|
self.kv_send_thread.start()
|
||||||
ready_event = threading.Event()
|
ready_event = threading.Event()
|
||||||
self.kv_recv_thread = KVCacheStoreLayerRecvingThread(
|
self.kv_recv_thread = KVCacheStoreLayerRecvingThread(
|
||||||
self.m_store, self.token_database, self.tp_rank, ready_event,
|
self.m_store, self.token_database, self.tp_rank, self.dcp_size,
|
||||||
self.get_event)
|
ready_event, self.get_event)
|
||||||
self.kv_recv_thread.start()
|
self.kv_recv_thread.start()
|
||||||
ready_event.wait()
|
ready_event.wait()
|
||||||
else:
|
else:
|
||||||
@@ -160,13 +186,13 @@ class KVPoolWorker:
|
|||||||
ready_event_sending = threading.Event()
|
ready_event_sending = threading.Event()
|
||||||
self.kv_send_thread = KVCacheStoreSendingThread(
|
self.kv_send_thread = KVCacheStoreSendingThread(
|
||||||
self.m_store, self.token_database, self.tp_rank,
|
self.m_store, self.token_database, self.tp_rank,
|
||||||
self.put_step, ready_event_sending)
|
self.dcp_size, self.put_step, ready_event_sending)
|
||||||
self.kv_send_thread.start()
|
self.kv_send_thread.start()
|
||||||
if self.load_async:
|
if self.load_async:
|
||||||
ready_event = threading.Event()
|
ready_event = threading.Event()
|
||||||
self.kv_recv_thread = KVCacheStoreRecvingThread(
|
self.kv_recv_thread = KVCacheStoreRecvingThread(
|
||||||
self.m_store, self.token_database, self.tp_rank,
|
self.m_store, self.token_database, self.tp_rank,
|
||||||
ready_event)
|
self.dcp_size, ready_event)
|
||||||
self.kv_recv_thread.start()
|
self.kv_recv_thread.start()
|
||||||
ready_event.wait()
|
ready_event.wait()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user