[KVPOOL]decode save kvcache (#5168)
### What this PR does / why we need it?
kvpool decode save kvcache
now only support mla
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: baxingpiaochong <771405853@qq.com>
Co-authored-by: Chao Lei <leichao139636@163.com>
This commit is contained in:
@@ -61,6 +61,8 @@ class KVPoolWorker:
|
||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"load_async", False)
|
||||
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"consumer_is_to_put", False)
|
||||
self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"backend", "mooncake")
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
@@ -92,9 +94,44 @@ class KVPoolWorker:
|
||||
self.pp_rank,
|
||||
)
|
||||
|
||||
partitions = None
|
||||
if self.kv_role == "kv_consumer" and self.consumer_is_to_put:
|
||||
num_hidden_layers = model_config.hf_config.num_hidden_layers
|
||||
partition_list_str = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"prefill_pp_layer_partition", None)
|
||||
prefill_pp_size = int(
|
||||
vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"prefill_pp_size", 1))
|
||||
|
||||
if partition_list_str is not None:
|
||||
try:
|
||||
partitions = [
|
||||
int(layer) for layer in partition_list_str.split(",")
|
||||
]
|
||||
except ValueError as err:
|
||||
raise ValueError("Invalid partition string: {}".format(
|
||||
partition_list_str)) from err
|
||||
if len(partitions) != prefill_pp_size:
|
||||
raise ValueError(
|
||||
f"{len(partitions)=} does not match {prefill_pp_size=}."
|
||||
)
|
||||
if sum(partitions) != num_hidden_layers:
|
||||
raise ValueError(
|
||||
f"{sum(partitions)=} does not match {num_hidden_layers=}."
|
||||
)
|
||||
else:
|
||||
layers_per_partition = num_hidden_layers // prefill_pp_size
|
||||
partitions = [
|
||||
layers_per_partition for _ in range(prefill_pp_size)
|
||||
]
|
||||
|
||||
if remaining_layers := num_hidden_layers % prefill_pp_size:
|
||||
for i in range(2, remaining_layers + 2):
|
||||
partitions[-i] += 1
|
||||
|
||||
self.token_database = ChunkedTokenDatabase(self.metadata,
|
||||
self.block_size,
|
||||
self.use_mla)
|
||||
self.use_mla, partitions)
|
||||
|
||||
real_backend = backend_map.get(self.backend.lower())
|
||||
self.m_store = real_backend( # type: ignore[misc]
|
||||
@@ -103,6 +140,8 @@ class KVPoolWorker:
|
||||
self.kv_send_thread: Optional[KVTransferThread] = None
|
||||
self.kv_recv_thread: Optional[KVTransferThread] = None
|
||||
|
||||
self.finished_store_req: set[str] = set()
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
_, first_kv_cache_tuple = next(iter(kv_caches.items()))
|
||||
first_kv_cache = first_kv_cache_tuple[0]
|
||||
@@ -176,11 +215,12 @@ class KVPoolWorker:
|
||||
self.kv_recv_thread.start()
|
||||
ready_event.wait()
|
||||
else:
|
||||
if self.kv_role in ['kv_producer', 'kv_both']:
|
||||
if self.kv_role in ['kv_producer', 'kv_both'
|
||||
] or self.consumer_is_to_put:
|
||||
ready_event_sending = threading.Event()
|
||||
self.kv_send_thread = KVCacheStoreSendingThread(
|
||||
self.m_store, self.token_database, self.block_size,
|
||||
self.tp_rank, self.dcp_size, self.put_step,
|
||||
self.tp_rank, self.dcp_size, self.put_step, self.kv_role,
|
||||
ready_event_sending)
|
||||
self.kv_send_thread.start()
|
||||
if self.load_async:
|
||||
@@ -289,6 +329,8 @@ class KVPoolWorker:
|
||||
continue
|
||||
|
||||
request.current_event = current_event
|
||||
self.kv_send_thread.add_stored_request( # type: ignore[union-attr]
|
||||
request.req_id)
|
||||
self.kv_send_thread.add_request( # type: ignore[union-attr]
|
||||
request, )
|
||||
|
||||
@@ -413,11 +455,13 @@ class KVPoolWorker:
|
||||
for layer_id in range(self.num_layers):
|
||||
yield
|
||||
|
||||
def get_finished(self) -> tuple[set[str], set[str]]:
|
||||
def get_finished(self,
|
||||
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
||||
done_sending = (
|
||||
self.kv_send_thread.
|
||||
get_and_clear_finished_requests( # type: ignore[union-attr]
|
||||
) if self.kv_role in ['kv_producer', 'kv_both'] else set())
|
||||
self.get_and_clear_finished_requests(
|
||||
finished_req_ids # type: ignore[union-attr]
|
||||
) if self.kv_role in ['kv_producer', 'kv_both']
|
||||
or self.consumer_is_to_put else set())
|
||||
|
||||
done_recving = (
|
||||
self.kv_recv_thread.
|
||||
@@ -430,6 +474,29 @@ class KVPoolWorker:
|
||||
self.tp_rank)
|
||||
return done_sending, done_recving
|
||||
|
||||
def get_and_clear_finished_requests(self, finished_req_ids) -> set[str]:
|
||||
finished_sending = set()
|
||||
for req_id in self.kv_send_thread.stored_requests.copy( # type: ignore[union-attr]
|
||||
):
|
||||
if self.kv_send_thread.stored_requests[ # type: ignore[union-attr]
|
||||
req_id] == 0 and req_id in self.finished_store_req:
|
||||
self.finished_store_req.remove(req_id)
|
||||
finished_sending.add(req_id)
|
||||
self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr]
|
||||
req_id)
|
||||
|
||||
for req_id in finished_req_ids:
|
||||
req_remain_jobs = self.kv_send_thread.stored_requests.get( # type: ignore[union-attr]
|
||||
req_id)
|
||||
if req_remain_jobs == 0:
|
||||
finished_sending.add(req_id)
|
||||
self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr]
|
||||
req_id)
|
||||
elif req_remain_jobs is not None:
|
||||
self.finished_store_req.add(req_id)
|
||||
|
||||
return finished_sending
|
||||
|
||||
def lookup(
|
||||
self,
|
||||
token_len: int,
|
||||
|
||||
Reference in New Issue
Block a user