[feature]Pooling Features and PCP Adaptation (#4143)
This PR let pooling kv connector support pcp feature - vLLM version: v0.11.2 --------- Signed-off-by: fjw <2270923832@qq.com> Signed-off-by: SlightwindSec <slightwindsec@gmail.com> Co-authored-by: SlightwindSec <slightwindsec@gmail.com>
This commit is contained in:
@@ -1,11 +1,13 @@
|
||||
# Standard
|
||||
import math
|
||||
import threading
|
||||
from typing import Dict, Generator, Optional, Type
|
||||
|
||||
# Third Party
|
||||
import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (get_decode_context_model_parallel_rank,
|
||||
get_decode_context_model_parallel_world_size,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.utils import logger
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
|
||||
@@ -20,6 +22,14 @@ from vllm_ascend.distributed.kvpool.config_data import (
|
||||
from vllm_ascend.distributed.kvpool.kv_transfer import (
|
||||
KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread,
|
||||
KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread)
|
||||
from vllm_ascend.utils import prefill_context_parallel_enable
|
||||
|
||||
if prefill_context_parallel_enable():
|
||||
# isort: off
|
||||
from vllm.distributed import (get_prefill_context_model_parallel_rank,
|
||||
get_prefill_context_model_parallel_world_size
|
||||
)
|
||||
# isort: on
|
||||
|
||||
backend_map: Dict[str, Type[Backend]] = {
|
||||
"mooncake": MooncakeBackend,
|
||||
@@ -44,17 +54,30 @@ class KVPoolWorker:
|
||||
and model_config.use_mla):
|
||||
self.use_mla = True
|
||||
self.use_layerwise = use_layerwize
|
||||
self.tp_rank = parallel_config.rank
|
||||
self.tp_size = parallel_config.tensor_parallel_size
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
self.pcp_size = get_prefill_context_model_parallel_world_size(
|
||||
) if prefill_context_parallel_enable() else 1
|
||||
self.pcp_rank = get_prefill_context_model_parallel_rank(
|
||||
) if self.pcp_size > 1 else 0
|
||||
self.dcp_size = get_decode_context_model_parallel_world_size()
|
||||
self.dcp_rank = get_decode_context_model_parallel_rank(
|
||||
) if self.dcp_size > 1 else 0
|
||||
|
||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"load_async", False)
|
||||
self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"backend", "mooncake")
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
|
||||
if self.pcp_size > 1:
|
||||
self.block_size *= self.pcp_size
|
||||
if self.dcp_size > 1:
|
||||
self.block_size *= self.dcp_size
|
||||
self.current_layer = 0
|
||||
self.num_layers = model_config.get_num_layers(parallel_config)
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
|
||||
if self.use_mla:
|
||||
self.num_kv_head = 1
|
||||
@@ -69,8 +92,10 @@ class KVPoolWorker:
|
||||
self.put_step = 1
|
||||
|
||||
self.metadata = KeyMetadata(
|
||||
model_config.model,
|
||||
model_config.model.split('/')[-1],
|
||||
self.head_or_tp_rank,
|
||||
self.pcp_rank,
|
||||
self.dcp_rank,
|
||||
)
|
||||
|
||||
self.token_database = ChunkedTokenDatabase(self.metadata,
|
||||
@@ -147,12 +172,13 @@ class KVPoolWorker:
|
||||
ready_event_sending = threading.Event()
|
||||
self.kv_send_thread = KVCacheStoreLayerSendingThread(
|
||||
self.m_store, self.token_database, self.tp_rank,
|
||||
self.put_step, ready_event_sending, self.num_layers)
|
||||
self.dcp_size, self.put_step, ready_event_sending,
|
||||
self.num_layers)
|
||||
self.kv_send_thread.start()
|
||||
ready_event = threading.Event()
|
||||
self.kv_recv_thread = KVCacheStoreLayerRecvingThread(
|
||||
self.m_store, self.token_database, self.tp_rank, ready_event,
|
||||
self.get_event)
|
||||
self.m_store, self.token_database, self.tp_rank, self.dcp_size,
|
||||
ready_event, self.get_event)
|
||||
self.kv_recv_thread.start()
|
||||
ready_event.wait()
|
||||
else:
|
||||
@@ -160,13 +186,13 @@ class KVPoolWorker:
|
||||
ready_event_sending = threading.Event()
|
||||
self.kv_send_thread = KVCacheStoreSendingThread(
|
||||
self.m_store, self.token_database, self.tp_rank,
|
||||
self.put_step, ready_event_sending)
|
||||
self.dcp_size, self.put_step, ready_event_sending)
|
||||
self.kv_send_thread.start()
|
||||
if self.load_async:
|
||||
ready_event = threading.Event()
|
||||
self.kv_recv_thread = KVCacheStoreRecvingThread(
|
||||
self.m_store, self.token_database, self.tp_rank,
|
||||
ready_event)
|
||||
self.dcp_size, ready_event)
|
||||
self.kv_recv_thread.start()
|
||||
ready_event.wait()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user