[KVPOOL]decode save kvcache (#5168)
### What this PR does / why we need it?
kvpool decode save kvcache
now only support mla
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: baxingpiaochong <771405853@qq.com>
Co-authored-by: Chao Lei <leichao139636@163.com>
This commit is contained in:
@@ -257,7 +257,23 @@ python3 -m vllm.entrypoints.openai.api_server \
|
|||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 2.Start proxy_server.
|
Currently, the key-value pool in PD Disaggregate only stores the kv cache generated by the Prefill node by default. In models using MLA, it is now supported that the Decode node stores the kv cache for use by the Prefill node, enabled by adding `consumer_is_to_put: true` to the AscendStoreConnector. If the Prefill node enables PP, `prefill_pp_size` or `prefill_pp_layer_partition` also needs to be set. Example as follows:
|
||||||
|
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"kv_connector": "AscendStoreConnector",
|
||||||
|
"kv_role": "kv_consumer",
|
||||||
|
"kv_connector_extra_config": {
|
||||||
|
"lookup_rpc_port":"0",
|
||||||
|
"backend": "mooncake"
|
||||||
|
"consumer_is_to_put": true,
|
||||||
|
"prefill_pp_size": 2
|
||||||
|
"prefill_pp_layer_partition": "30,31"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2、Start proxy_server.
|
||||||
|
|
||||||
```
|
```
|
||||||
python vllm-ascend/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py \
|
python vllm-ascend/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py \
|
||||||
|
|||||||
@@ -34,6 +34,8 @@ class AscendStoreConnector(KVConnectorBase_V1):
|
|||||||
|
|
||||||
self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||||
"use_layerwise", False)
|
"use_layerwise", False)
|
||||||
|
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||||
|
"consumer_is_to_put", False)
|
||||||
|
|
||||||
connector_name = vllm_config.kv_transfer_config.kv_connector
|
connector_name = vllm_config.kv_transfer_config.kv_connector
|
||||||
if connector_name == "MooncakeConnectorStoreV1":
|
if connector_name == "MooncakeConnectorStoreV1":
|
||||||
@@ -121,7 +123,7 @@ class AscendStoreConnector(KVConnectorBase_V1):
|
|||||||
self.connector_worker.save_kv_layer(self._get_connector_metadata())
|
self.connector_worker.save_kv_layer(self._get_connector_metadata())
|
||||||
|
|
||||||
def wait_for_save(self):
|
def wait_for_save(self):
|
||||||
if self.kv_role == "kv_consumer":
|
if self.kv_role == "kv_consumer" and not self.consumer_is_to_put:
|
||||||
# Don't do save if the role is kv_consumer
|
# Don't do save if the role is kv_consumer
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -135,7 +137,8 @@ class AscendStoreConnector(KVConnectorBase_V1):
|
|||||||
"""Get the finished recving and sending requests."""
|
"""Get the finished recving and sending requests."""
|
||||||
assert self.connector_worker is not None
|
assert self.connector_worker is not None
|
||||||
meta = self._get_connector_metadata()
|
meta = self._get_connector_metadata()
|
||||||
done_sending, done_recving = self.connector_worker.get_finished()
|
done_sending, done_recving = self.connector_worker.get_finished(
|
||||||
|
finished_req_ids)
|
||||||
sended_and_finished: set[str] = set()
|
sended_and_finished: set[str] = set()
|
||||||
for item in list(self.sended_but_unfinished_reqs):
|
for item in list(self.sended_but_unfinished_reqs):
|
||||||
if item not in meta.unfinished_request_ids:
|
if item not in meta.unfinished_request_ids:
|
||||||
|
|||||||
@@ -87,12 +87,14 @@ class LayerPoolKey(PoolKey):
|
|||||||
|
|
||||||
class ChunkedTokenDatabase():
|
class ChunkedTokenDatabase():
|
||||||
|
|
||||||
def __init__(self, metadata: KeyMetadata, block_size: int, use_mla: bool):
|
def __init__(self, metadata: KeyMetadata, block_size: int, use_mla: bool,
|
||||||
|
partitions: Optional[List[int]]):
|
||||||
self.metadata = metadata
|
self.metadata = metadata
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
self.use_mla = use_mla
|
self.use_mla = use_mla
|
||||||
self.kv_caches_base_addr: list[int] = []
|
self.kv_caches_base_addr: list[int] = []
|
||||||
self.block_len: list[int] = []
|
self.block_len: list[int] = []
|
||||||
|
self.partitions = partitions
|
||||||
|
|
||||||
def _make_key_by_hash(self,
|
def _make_key_by_hash(self,
|
||||||
chunk_hash: str,
|
chunk_hash: str,
|
||||||
@@ -188,6 +190,28 @@ class ChunkedTokenDatabase():
|
|||||||
else:
|
else:
|
||||||
yield start_idx, end_idx, self._make_key_by_hash(hash_val)
|
yield start_idx, end_idx, self._make_key_by_hash(hash_val)
|
||||||
|
|
||||||
|
def decode_adaptor_prefill_pp(self, key, addr, size):
|
||||||
|
if self.partitions is None or len(self.partitions) == 1:
|
||||||
|
return key, addr, size
|
||||||
|
|
||||||
|
new_key = []
|
||||||
|
new_addr = []
|
||||||
|
new_size = []
|
||||||
|
|
||||||
|
for i, (addr_list, size_list) in enumerate(zip(addr, size)):
|
||||||
|
start = 0
|
||||||
|
for j, part in enumerate(self.partitions):
|
||||||
|
# part * 2 because addr and size contain both k and v
|
||||||
|
end = len(addr_list) if j == len(
|
||||||
|
self.partitions) - 1 else start + part * 2
|
||||||
|
new_str = key[i].replace( # type: ignore[attr-defined]
|
||||||
|
"@pp_rank:0", f"@pp_rank:{j}", 1)
|
||||||
|
new_key.append(new_str)
|
||||||
|
new_addr.append(addr_list[start:end])
|
||||||
|
new_size.append(size_list[start:end])
|
||||||
|
start = end
|
||||||
|
return new_key, new_addr, new_size
|
||||||
|
|
||||||
|
|
||||||
#Parameters related to the connector metadata
|
#Parameters related to the connector metadata
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -247,15 +271,12 @@ class RequestTracker:
|
|||||||
|
|
||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
new_token_ids: list[int],
|
|
||||||
new_block_ids: Union[tuple[list[int], ...], list[int]],
|
new_block_ids: Union[tuple[list[int], ...], list[int]],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update the request tracker when a running request is
|
"""Update the request tracker when a running request is
|
||||||
scheduled again
|
scheduled again
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.token_len = self.token_len + len(new_token_ids)
|
|
||||||
|
|
||||||
if len(new_block_ids) == 0:
|
if len(new_block_ids) == 0:
|
||||||
new_block_ids = []
|
new_block_ids = []
|
||||||
elif isinstance(new_block_ids, tuple):
|
elif isinstance(new_block_ids, tuple):
|
||||||
@@ -378,4 +399,4 @@ class LasyerMultiBlockReqMeta:
|
|||||||
block_ids: list[int]
|
block_ids: list[int]
|
||||||
layer_id: int
|
layer_id: int
|
||||||
is_last_chunk: Optional[bool] = True
|
is_last_chunk: Optional[bool] = True
|
||||||
current_event: Optional[torch.npu.Event] = None
|
current_event: Optional[torch.npu.Event] = None
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import queue
|
import queue
|
||||||
import threading
|
import threading
|
||||||
|
from collections import defaultdict
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -99,7 +100,7 @@ class KVCacheStoreSendingThread(KVTransferThread):
|
|||||||
|
|
||||||
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
||||||
block_size: int, tp_rank: int, dcp_size: int, put_step: int,
|
block_size: int, tp_rank: int, dcp_size: int, put_step: int,
|
||||||
ready_event: threading.Event):
|
kv_role: str, ready_event: threading.Event):
|
||||||
super().__init__(m_store,
|
super().__init__(m_store,
|
||||||
token_database,
|
token_database,
|
||||||
block_size,
|
block_size,
|
||||||
@@ -108,6 +109,17 @@ class KVCacheStoreSendingThread(KVTransferThread):
|
|||||||
ready_event,
|
ready_event,
|
||||||
name="KVCacheSendingThread")
|
name="KVCacheSendingThread")
|
||||||
self.put_step = put_step
|
self.put_step = put_step
|
||||||
|
self.kv_role = kv_role
|
||||||
|
self.stored_requests = defaultdict[str, int](int)
|
||||||
|
|
||||||
|
def add_stored_request(self, req_id: str):
|
||||||
|
with self.done_task_lock:
|
||||||
|
self.stored_requests[req_id] += 1
|
||||||
|
|
||||||
|
def delete_finished_stored_request(self, req_id: str):
|
||||||
|
with self.done_task_lock:
|
||||||
|
if req_id in self.stored_requests:
|
||||||
|
del self.stored_requests[req_id]
|
||||||
|
|
||||||
def _handle_request(self, req_meta: ReqMeta):
|
def _handle_request(self, req_meta: ReqMeta):
|
||||||
token_len = req_meta.token_len_chunk
|
token_len = req_meta.token_len_chunk
|
||||||
@@ -154,13 +166,6 @@ class KVCacheStoreSendingThread(KVTransferThread):
|
|||||||
req_id,
|
req_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
addrs = []
|
|
||||||
sizes = []
|
|
||||||
for index, start in enumerate(starts):
|
|
||||||
addr, size, _ = self.token_database.prepare_value(
|
|
||||||
start, ends[index], block_ids)
|
|
||||||
addrs.append(addr)
|
|
||||||
sizes.append(size)
|
|
||||||
if keys:
|
if keys:
|
||||||
"""
|
"""
|
||||||
Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang.
|
Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang.
|
||||||
@@ -168,12 +173,24 @@ class KVCacheStoreSendingThread(KVTransferThread):
|
|||||||
You can manually build the master branch of the project at https://gitcode.com/cann/hixl
|
You can manually build the master branch of the project at https://gitcode.com/cann/hixl
|
||||||
to resolve this issue before the 8.5.RC1 release.
|
to resolve this issue before the 8.5.RC1 release.
|
||||||
"""
|
"""
|
||||||
|
addrs = []
|
||||||
|
sizes = []
|
||||||
|
for index, start in enumerate(starts):
|
||||||
|
addr, size, _ = self.token_database.prepare_value(
|
||||||
|
start, ends[index], block_ids)
|
||||||
|
addrs.append(addr)
|
||||||
|
sizes.append(size)
|
||||||
|
|
||||||
|
if self.kv_role == "kv_consumer":
|
||||||
|
keys, addrs, sizes = self.token_database.decode_adaptor_prefill_pp(
|
||||||
|
keys, addrs, sizes)
|
||||||
|
|
||||||
if current_event is not None:
|
if current_event is not None:
|
||||||
current_event.synchronize()
|
current_event.synchronize()
|
||||||
self.m_store.put(keys, addrs, sizes)
|
self.m_store.put(keys, addrs, sizes)
|
||||||
|
|
||||||
if is_last_chunk:
|
with self.done_task_lock:
|
||||||
self.set_finished_request(req_id)
|
self.stored_requests[req_id] -= 1
|
||||||
self.request_queue.task_done()
|
self.request_queue.task_done()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ class KVPoolScheduler:
|
|||||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||||
self.consumer_is_to_load = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
self.consumer_is_to_load = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||||
"consumer_is_to_load", False)
|
"consumer_is_to_load", False)
|
||||||
|
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||||
|
"consumer_is_to_put", False)
|
||||||
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||||
"load_async", False)
|
"load_async", False)
|
||||||
self.client = LookupKeyClient(vllm_config)
|
self.client = LookupKeyClient(vllm_config)
|
||||||
@@ -149,7 +151,8 @@ class KVPoolScheduler:
|
|||||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
force_skip_save = self.kv_role == "kv_consumer"
|
force_skip_save = (self.kv_role == "kv_consumer"
|
||||||
|
and not self.consumer_is_to_put)
|
||||||
|
|
||||||
for finished_req_id in scheduler_output.finished_req_ids:
|
for finished_req_id in scheduler_output.finished_req_ids:
|
||||||
self._request_trackers.pop(finished_req_id, None)
|
self._request_trackers.pop(finished_req_id, None)
|
||||||
@@ -197,6 +200,7 @@ class KVPoolScheduler:
|
|||||||
num_current_tokens = request_tracker.token_len
|
num_current_tokens = request_tracker.token_len
|
||||||
new_token_ids = request.all_token_ids[
|
new_token_ids = request.all_token_ids[
|
||||||
num_current_tokens:num_current_tokens + num_new_tokens]
|
num_current_tokens:num_current_tokens + num_new_tokens]
|
||||||
|
request_tracker.token_len += len(new_token_ids)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Request {req_id} is not in _unfinished_requests, "
|
f"Request {req_id} is not in _unfinished_requests, "
|
||||||
@@ -204,10 +208,7 @@ class KVPoolScheduler:
|
|||||||
new_block_ids = cached_reqs.new_block_ids[i]
|
new_block_ids = cached_reqs.new_block_ids[i]
|
||||||
if not new_block_ids:
|
if not new_block_ids:
|
||||||
continue
|
continue
|
||||||
request_tracker.update(new_token_ids, new_block_ids)
|
request_tracker.update(new_block_ids)
|
||||||
# decode not save
|
|
||||||
if request_tracker.token_len > len(request.prompt_token_ids):
|
|
||||||
continue
|
|
||||||
|
|
||||||
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
|
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
|
||||||
self._block_size * self._block_size)
|
self._block_size * self._block_size)
|
||||||
@@ -270,7 +271,7 @@ class KVPoolScheduler:
|
|||||||
Once a request is finished, determine whether request blocks
|
Once a request is finished, determine whether request blocks
|
||||||
should be freed now or will be sent asynchronously and freed later.
|
should be freed now or will be sent asynchronously and freed later.
|
||||||
"""
|
"""
|
||||||
if self.kv_role == "kv_consumer":
|
if self.kv_role == "kv_consumer" and not self.consumer_is_to_put:
|
||||||
return False, None
|
return False, None
|
||||||
tracker = self._request_trackers.get(request.request_id)
|
tracker = self._request_trackers.get(request.request_id)
|
||||||
if tracker is not None and tracker.num_saved_tokens <= 0:
|
if tracker is not None and tracker.num_saved_tokens <= 0:
|
||||||
|
|||||||
@@ -61,6 +61,8 @@ class KVPoolWorker:
|
|||||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||||
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||||
"load_async", False)
|
"load_async", False)
|
||||||
|
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||||
|
"consumer_is_to_put", False)
|
||||||
self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||||
"backend", "mooncake")
|
"backend", "mooncake")
|
||||||
self.block_size = vllm_config.cache_config.block_size
|
self.block_size = vllm_config.cache_config.block_size
|
||||||
@@ -92,9 +94,44 @@ class KVPoolWorker:
|
|||||||
self.pp_rank,
|
self.pp_rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
partitions = None
|
||||||
|
if self.kv_role == "kv_consumer" and self.consumer_is_to_put:
|
||||||
|
num_hidden_layers = model_config.hf_config.num_hidden_layers
|
||||||
|
partition_list_str = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||||
|
"prefill_pp_layer_partition", None)
|
||||||
|
prefill_pp_size = int(
|
||||||
|
vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||||
|
"prefill_pp_size", 1))
|
||||||
|
|
||||||
|
if partition_list_str is not None:
|
||||||
|
try:
|
||||||
|
partitions = [
|
||||||
|
int(layer) for layer in partition_list_str.split(",")
|
||||||
|
]
|
||||||
|
except ValueError as err:
|
||||||
|
raise ValueError("Invalid partition string: {}".format(
|
||||||
|
partition_list_str)) from err
|
||||||
|
if len(partitions) != prefill_pp_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"{len(partitions)=} does not match {prefill_pp_size=}."
|
||||||
|
)
|
||||||
|
if sum(partitions) != num_hidden_layers:
|
||||||
|
raise ValueError(
|
||||||
|
f"{sum(partitions)=} does not match {num_hidden_layers=}."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
layers_per_partition = num_hidden_layers // prefill_pp_size
|
||||||
|
partitions = [
|
||||||
|
layers_per_partition for _ in range(prefill_pp_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
if remaining_layers := num_hidden_layers % prefill_pp_size:
|
||||||
|
for i in range(2, remaining_layers + 2):
|
||||||
|
partitions[-i] += 1
|
||||||
|
|
||||||
self.token_database = ChunkedTokenDatabase(self.metadata,
|
self.token_database = ChunkedTokenDatabase(self.metadata,
|
||||||
self.block_size,
|
self.block_size,
|
||||||
self.use_mla)
|
self.use_mla, partitions)
|
||||||
|
|
||||||
real_backend = backend_map.get(self.backend.lower())
|
real_backend = backend_map.get(self.backend.lower())
|
||||||
self.m_store = real_backend( # type: ignore[misc]
|
self.m_store = real_backend( # type: ignore[misc]
|
||||||
@@ -103,6 +140,8 @@ class KVPoolWorker:
|
|||||||
self.kv_send_thread: Optional[KVTransferThread] = None
|
self.kv_send_thread: Optional[KVTransferThread] = None
|
||||||
self.kv_recv_thread: Optional[KVTransferThread] = None
|
self.kv_recv_thread: Optional[KVTransferThread] = None
|
||||||
|
|
||||||
|
self.finished_store_req: set[str] = set()
|
||||||
|
|
||||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||||
_, first_kv_cache_tuple = next(iter(kv_caches.items()))
|
_, first_kv_cache_tuple = next(iter(kv_caches.items()))
|
||||||
first_kv_cache = first_kv_cache_tuple[0]
|
first_kv_cache = first_kv_cache_tuple[0]
|
||||||
@@ -176,11 +215,12 @@ class KVPoolWorker:
|
|||||||
self.kv_recv_thread.start()
|
self.kv_recv_thread.start()
|
||||||
ready_event.wait()
|
ready_event.wait()
|
||||||
else:
|
else:
|
||||||
if self.kv_role in ['kv_producer', 'kv_both']:
|
if self.kv_role in ['kv_producer', 'kv_both'
|
||||||
|
] or self.consumer_is_to_put:
|
||||||
ready_event_sending = threading.Event()
|
ready_event_sending = threading.Event()
|
||||||
self.kv_send_thread = KVCacheStoreSendingThread(
|
self.kv_send_thread = KVCacheStoreSendingThread(
|
||||||
self.m_store, self.token_database, self.block_size,
|
self.m_store, self.token_database, self.block_size,
|
||||||
self.tp_rank, self.dcp_size, self.put_step,
|
self.tp_rank, self.dcp_size, self.put_step, self.kv_role,
|
||||||
ready_event_sending)
|
ready_event_sending)
|
||||||
self.kv_send_thread.start()
|
self.kv_send_thread.start()
|
||||||
if self.load_async:
|
if self.load_async:
|
||||||
@@ -289,6 +329,8 @@ class KVPoolWorker:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
request.current_event = current_event
|
request.current_event = current_event
|
||||||
|
self.kv_send_thread.add_stored_request( # type: ignore[union-attr]
|
||||||
|
request.req_id)
|
||||||
self.kv_send_thread.add_request( # type: ignore[union-attr]
|
self.kv_send_thread.add_request( # type: ignore[union-attr]
|
||||||
request, )
|
request, )
|
||||||
|
|
||||||
@@ -413,11 +455,13 @@ class KVPoolWorker:
|
|||||||
for layer_id in range(self.num_layers):
|
for layer_id in range(self.num_layers):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
def get_finished(self) -> tuple[set[str], set[str]]:
|
def get_finished(self,
|
||||||
|
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
||||||
done_sending = (
|
done_sending = (
|
||||||
self.kv_send_thread.
|
self.get_and_clear_finished_requests(
|
||||||
get_and_clear_finished_requests( # type: ignore[union-attr]
|
finished_req_ids # type: ignore[union-attr]
|
||||||
) if self.kv_role in ['kv_producer', 'kv_both'] else set())
|
) if self.kv_role in ['kv_producer', 'kv_both']
|
||||||
|
or self.consumer_is_to_put else set())
|
||||||
|
|
||||||
done_recving = (
|
done_recving = (
|
||||||
self.kv_recv_thread.
|
self.kv_recv_thread.
|
||||||
@@ -430,6 +474,29 @@ class KVPoolWorker:
|
|||||||
self.tp_rank)
|
self.tp_rank)
|
||||||
return done_sending, done_recving
|
return done_sending, done_recving
|
||||||
|
|
||||||
|
def get_and_clear_finished_requests(self, finished_req_ids) -> set[str]:
|
||||||
|
finished_sending = set()
|
||||||
|
for req_id in self.kv_send_thread.stored_requests.copy( # type: ignore[union-attr]
|
||||||
|
):
|
||||||
|
if self.kv_send_thread.stored_requests[ # type: ignore[union-attr]
|
||||||
|
req_id] == 0 and req_id in self.finished_store_req:
|
||||||
|
self.finished_store_req.remove(req_id)
|
||||||
|
finished_sending.add(req_id)
|
||||||
|
self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr]
|
||||||
|
req_id)
|
||||||
|
|
||||||
|
for req_id in finished_req_ids:
|
||||||
|
req_remain_jobs = self.kv_send_thread.stored_requests.get( # type: ignore[union-attr]
|
||||||
|
req_id)
|
||||||
|
if req_remain_jobs == 0:
|
||||||
|
finished_sending.add(req_id)
|
||||||
|
self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr]
|
||||||
|
req_id)
|
||||||
|
elif req_remain_jobs is not None:
|
||||||
|
self.finished_store_req.add(req_id)
|
||||||
|
|
||||||
|
return finished_sending
|
||||||
|
|
||||||
def lookup(
|
def lookup(
|
||||||
self,
|
self,
|
||||||
token_len: int,
|
token_len: int,
|
||||||
|
|||||||
Reference in New Issue
Block a user