[KVPOOL]decode save kvcache (#5168)
### What this PR does / why we need it?
kvpool decode save kvcache
now only support mla
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: baxingpiaochong <771405853@qq.com>
Co-authored-by: Chao Lei <leichao139636@163.com>
This commit is contained in:
@@ -34,6 +34,8 @@ class AscendStoreConnector(KVConnectorBase_V1):
|
||||
|
||||
self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"use_layerwise", False)
|
||||
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"consumer_is_to_put", False)
|
||||
|
||||
connector_name = vllm_config.kv_transfer_config.kv_connector
|
||||
if connector_name == "MooncakeConnectorStoreV1":
|
||||
@@ -121,7 +123,7 @@ class AscendStoreConnector(KVConnectorBase_V1):
|
||||
self.connector_worker.save_kv_layer(self._get_connector_metadata())
|
||||
|
||||
def wait_for_save(self):
|
||||
if self.kv_role == "kv_consumer":
|
||||
if self.kv_role == "kv_consumer" and not self.consumer_is_to_put:
|
||||
# Don't do save if the role is kv_consumer
|
||||
return
|
||||
|
||||
@@ -135,7 +137,8 @@ class AscendStoreConnector(KVConnectorBase_V1):
|
||||
"""Get the finished recving and sending requests."""
|
||||
assert self.connector_worker is not None
|
||||
meta = self._get_connector_metadata()
|
||||
done_sending, done_recving = self.connector_worker.get_finished()
|
||||
done_sending, done_recving = self.connector_worker.get_finished(
|
||||
finished_req_ids)
|
||||
sended_and_finished: set[str] = set()
|
||||
for item in list(self.sended_but_unfinished_reqs):
|
||||
if item not in meta.unfinished_request_ids:
|
||||
|
||||
@@ -87,12 +87,14 @@ class LayerPoolKey(PoolKey):
|
||||
|
||||
class ChunkedTokenDatabase():
|
||||
|
||||
def __init__(self, metadata: KeyMetadata, block_size: int, use_mla: bool):
|
||||
def __init__(self, metadata: KeyMetadata, block_size: int, use_mla: bool,
|
||||
partitions: Optional[List[int]]):
|
||||
self.metadata = metadata
|
||||
self.block_size = block_size
|
||||
self.use_mla = use_mla
|
||||
self.kv_caches_base_addr: list[int] = []
|
||||
self.block_len: list[int] = []
|
||||
self.partitions = partitions
|
||||
|
||||
def _make_key_by_hash(self,
|
||||
chunk_hash: str,
|
||||
@@ -188,6 +190,28 @@ class ChunkedTokenDatabase():
|
||||
else:
|
||||
yield start_idx, end_idx, self._make_key_by_hash(hash_val)
|
||||
|
||||
def decode_adaptor_prefill_pp(self, key, addr, size):
|
||||
if self.partitions is None or len(self.partitions) == 1:
|
||||
return key, addr, size
|
||||
|
||||
new_key = []
|
||||
new_addr = []
|
||||
new_size = []
|
||||
|
||||
for i, (addr_list, size_list) in enumerate(zip(addr, size)):
|
||||
start = 0
|
||||
for j, part in enumerate(self.partitions):
|
||||
# part * 2 because addr and size contain both k and v
|
||||
end = len(addr_list) if j == len(
|
||||
self.partitions) - 1 else start + part * 2
|
||||
new_str = key[i].replace( # type: ignore[attr-defined]
|
||||
"@pp_rank:0", f"@pp_rank:{j}", 1)
|
||||
new_key.append(new_str)
|
||||
new_addr.append(addr_list[start:end])
|
||||
new_size.append(size_list[start:end])
|
||||
start = end
|
||||
return new_key, new_addr, new_size
|
||||
|
||||
|
||||
#Parameters related to the connector metadata
|
||||
@dataclass
|
||||
@@ -247,15 +271,12 @@ class RequestTracker:
|
||||
|
||||
def update(
|
||||
self,
|
||||
new_token_ids: list[int],
|
||||
new_block_ids: Union[tuple[list[int], ...], list[int]],
|
||||
) -> None:
|
||||
"""Update the request tracker when a running request is
|
||||
scheduled again
|
||||
"""
|
||||
|
||||
self.token_len = self.token_len + len(new_token_ids)
|
||||
|
||||
if len(new_block_ids) == 0:
|
||||
new_block_ids = []
|
||||
elif isinstance(new_block_ids, tuple):
|
||||
@@ -378,4 +399,4 @@ class LasyerMultiBlockReqMeta:
|
||||
block_ids: list[int]
|
||||
layer_id: int
|
||||
is_last_chunk: Optional[bool] = True
|
||||
current_event: Optional[torch.npu.Event] = None
|
||||
current_event: Optional[torch.npu.Event] = None
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import queue
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
@@ -99,7 +100,7 @@ class KVCacheStoreSendingThread(KVTransferThread):
|
||||
|
||||
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
||||
block_size: int, tp_rank: int, dcp_size: int, put_step: int,
|
||||
ready_event: threading.Event):
|
||||
kv_role: str, ready_event: threading.Event):
|
||||
super().__init__(m_store,
|
||||
token_database,
|
||||
block_size,
|
||||
@@ -108,6 +109,17 @@ class KVCacheStoreSendingThread(KVTransferThread):
|
||||
ready_event,
|
||||
name="KVCacheSendingThread")
|
||||
self.put_step = put_step
|
||||
self.kv_role = kv_role
|
||||
self.stored_requests = defaultdict[str, int](int)
|
||||
|
||||
def add_stored_request(self, req_id: str):
|
||||
with self.done_task_lock:
|
||||
self.stored_requests[req_id] += 1
|
||||
|
||||
def delete_finished_stored_request(self, req_id: str):
|
||||
with self.done_task_lock:
|
||||
if req_id in self.stored_requests:
|
||||
del self.stored_requests[req_id]
|
||||
|
||||
def _handle_request(self, req_meta: ReqMeta):
|
||||
token_len = req_meta.token_len_chunk
|
||||
@@ -154,13 +166,6 @@ class KVCacheStoreSendingThread(KVTransferThread):
|
||||
req_id,
|
||||
)
|
||||
|
||||
addrs = []
|
||||
sizes = []
|
||||
for index, start in enumerate(starts):
|
||||
addr, size, _ = self.token_database.prepare_value(
|
||||
start, ends[index], block_ids)
|
||||
addrs.append(addr)
|
||||
sizes.append(size)
|
||||
if keys:
|
||||
"""
|
||||
Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang.
|
||||
@@ -168,12 +173,24 @@ class KVCacheStoreSendingThread(KVTransferThread):
|
||||
You can manually build the master branch of the project at https://gitcode.com/cann/hixl
|
||||
to resolve this issue before the 8.5.RC1 release.
|
||||
"""
|
||||
addrs = []
|
||||
sizes = []
|
||||
for index, start in enumerate(starts):
|
||||
addr, size, _ = self.token_database.prepare_value(
|
||||
start, ends[index], block_ids)
|
||||
addrs.append(addr)
|
||||
sizes.append(size)
|
||||
|
||||
if self.kv_role == "kv_consumer":
|
||||
keys, addrs, sizes = self.token_database.decode_adaptor_prefill_pp(
|
||||
keys, addrs, sizes)
|
||||
|
||||
if current_event is not None:
|
||||
current_event.synchronize()
|
||||
self.m_store.put(keys, addrs, sizes)
|
||||
|
||||
if is_last_chunk:
|
||||
self.set_finished_request(req_id)
|
||||
with self.done_task_lock:
|
||||
self.stored_requests[req_id] -= 1
|
||||
self.request_queue.task_done()
|
||||
|
||||
|
||||
|
||||
@@ -24,6 +24,8 @@ class KVPoolScheduler:
|
||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||
self.consumer_is_to_load = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"consumer_is_to_load", False)
|
||||
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"consumer_is_to_put", False)
|
||||
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"load_async", False)
|
||||
self.client = LookupKeyClient(vllm_config)
|
||||
@@ -149,7 +151,8 @@ class KVPoolScheduler:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
|
||||
force_skip_save = self.kv_role == "kv_consumer"
|
||||
force_skip_save = (self.kv_role == "kv_consumer"
|
||||
and not self.consumer_is_to_put)
|
||||
|
||||
for finished_req_id in scheduler_output.finished_req_ids:
|
||||
self._request_trackers.pop(finished_req_id, None)
|
||||
@@ -197,6 +200,7 @@ class KVPoolScheduler:
|
||||
num_current_tokens = request_tracker.token_len
|
||||
new_token_ids = request.all_token_ids[
|
||||
num_current_tokens:num_current_tokens + num_new_tokens]
|
||||
request_tracker.token_len += len(new_token_ids)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Request {req_id} is not in _unfinished_requests, "
|
||||
@@ -204,10 +208,7 @@ class KVPoolScheduler:
|
||||
new_block_ids = cached_reqs.new_block_ids[i]
|
||||
if not new_block_ids:
|
||||
continue
|
||||
request_tracker.update(new_token_ids, new_block_ids)
|
||||
# decode not save
|
||||
if request_tracker.token_len > len(request.prompt_token_ids):
|
||||
continue
|
||||
request_tracker.update(new_block_ids)
|
||||
|
||||
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
|
||||
self._block_size * self._block_size)
|
||||
@@ -270,7 +271,7 @@ class KVPoolScheduler:
|
||||
Once a request is finished, determine whether request blocks
|
||||
should be freed now or will be sent asynchronously and freed later.
|
||||
"""
|
||||
if self.kv_role == "kv_consumer":
|
||||
if self.kv_role == "kv_consumer" and not self.consumer_is_to_put:
|
||||
return False, None
|
||||
tracker = self._request_trackers.get(request.request_id)
|
||||
if tracker is not None and tracker.num_saved_tokens <= 0:
|
||||
|
||||
@@ -61,6 +61,8 @@ class KVPoolWorker:
|
||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"load_async", False)
|
||||
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"consumer_is_to_put", False)
|
||||
self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"backend", "mooncake")
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
@@ -92,9 +94,44 @@ class KVPoolWorker:
|
||||
self.pp_rank,
|
||||
)
|
||||
|
||||
partitions = None
|
||||
if self.kv_role == "kv_consumer" and self.consumer_is_to_put:
|
||||
num_hidden_layers = model_config.hf_config.num_hidden_layers
|
||||
partition_list_str = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"prefill_pp_layer_partition", None)
|
||||
prefill_pp_size = int(
|
||||
vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"prefill_pp_size", 1))
|
||||
|
||||
if partition_list_str is not None:
|
||||
try:
|
||||
partitions = [
|
||||
int(layer) for layer in partition_list_str.split(",")
|
||||
]
|
||||
except ValueError as err:
|
||||
raise ValueError("Invalid partition string: {}".format(
|
||||
partition_list_str)) from err
|
||||
if len(partitions) != prefill_pp_size:
|
||||
raise ValueError(
|
||||
f"{len(partitions)=} does not match {prefill_pp_size=}."
|
||||
)
|
||||
if sum(partitions) != num_hidden_layers:
|
||||
raise ValueError(
|
||||
f"{sum(partitions)=} does not match {num_hidden_layers=}."
|
||||
)
|
||||
else:
|
||||
layers_per_partition = num_hidden_layers // prefill_pp_size
|
||||
partitions = [
|
||||
layers_per_partition for _ in range(prefill_pp_size)
|
||||
]
|
||||
|
||||
if remaining_layers := num_hidden_layers % prefill_pp_size:
|
||||
for i in range(2, remaining_layers + 2):
|
||||
partitions[-i] += 1
|
||||
|
||||
self.token_database = ChunkedTokenDatabase(self.metadata,
|
||||
self.block_size,
|
||||
self.use_mla)
|
||||
self.use_mla, partitions)
|
||||
|
||||
real_backend = backend_map.get(self.backend.lower())
|
||||
self.m_store = real_backend( # type: ignore[misc]
|
||||
@@ -103,6 +140,8 @@ class KVPoolWorker:
|
||||
self.kv_send_thread: Optional[KVTransferThread] = None
|
||||
self.kv_recv_thread: Optional[KVTransferThread] = None
|
||||
|
||||
self.finished_store_req: set[str] = set()
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
_, first_kv_cache_tuple = next(iter(kv_caches.items()))
|
||||
first_kv_cache = first_kv_cache_tuple[0]
|
||||
@@ -176,11 +215,12 @@ class KVPoolWorker:
|
||||
self.kv_recv_thread.start()
|
||||
ready_event.wait()
|
||||
else:
|
||||
if self.kv_role in ['kv_producer', 'kv_both']:
|
||||
if self.kv_role in ['kv_producer', 'kv_both'
|
||||
] or self.consumer_is_to_put:
|
||||
ready_event_sending = threading.Event()
|
||||
self.kv_send_thread = KVCacheStoreSendingThread(
|
||||
self.m_store, self.token_database, self.block_size,
|
||||
self.tp_rank, self.dcp_size, self.put_step,
|
||||
self.tp_rank, self.dcp_size, self.put_step, self.kv_role,
|
||||
ready_event_sending)
|
||||
self.kv_send_thread.start()
|
||||
if self.load_async:
|
||||
@@ -289,6 +329,8 @@ class KVPoolWorker:
|
||||
continue
|
||||
|
||||
request.current_event = current_event
|
||||
self.kv_send_thread.add_stored_request( # type: ignore[union-attr]
|
||||
request.req_id)
|
||||
self.kv_send_thread.add_request( # type: ignore[union-attr]
|
||||
request, )
|
||||
|
||||
@@ -413,11 +455,13 @@ class KVPoolWorker:
|
||||
for layer_id in range(self.num_layers):
|
||||
yield
|
||||
|
||||
def get_finished(self) -> tuple[set[str], set[str]]:
|
||||
def get_finished(self,
|
||||
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
||||
done_sending = (
|
||||
self.kv_send_thread.
|
||||
get_and_clear_finished_requests( # type: ignore[union-attr]
|
||||
) if self.kv_role in ['kv_producer', 'kv_both'] else set())
|
||||
self.get_and_clear_finished_requests(
|
||||
finished_req_ids # type: ignore[union-attr]
|
||||
) if self.kv_role in ['kv_producer', 'kv_both']
|
||||
or self.consumer_is_to_put else set())
|
||||
|
||||
done_recving = (
|
||||
self.kv_recv_thread.
|
||||
@@ -430,6 +474,29 @@ class KVPoolWorker:
|
||||
self.tp_rank)
|
||||
return done_sending, done_recving
|
||||
|
||||
def get_and_clear_finished_requests(self, finished_req_ids) -> set[str]:
|
||||
finished_sending = set()
|
||||
for req_id in self.kv_send_thread.stored_requests.copy( # type: ignore[union-attr]
|
||||
):
|
||||
if self.kv_send_thread.stored_requests[ # type: ignore[union-attr]
|
||||
req_id] == 0 and req_id in self.finished_store_req:
|
||||
self.finished_store_req.remove(req_id)
|
||||
finished_sending.add(req_id)
|
||||
self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr]
|
||||
req_id)
|
||||
|
||||
for req_id in finished_req_ids:
|
||||
req_remain_jobs = self.kv_send_thread.stored_requests.get( # type: ignore[union-attr]
|
||||
req_id)
|
||||
if req_remain_jobs == 0:
|
||||
finished_sending.add(req_id)
|
||||
self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr]
|
||||
req_id)
|
||||
elif req_remain_jobs is not None:
|
||||
self.finished_store_req.add(req_id)
|
||||
|
||||
return finished_sending
|
||||
|
||||
def lookup(
|
||||
self,
|
||||
token_len: int,
|
||||
|
||||
Reference in New Issue
Block a user