[KVPOOL]decode save kvcache (#5168)
### What this PR does / why we need it?
kvpool decode save kvcache
now only support mla
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: baxingpiaochong <771405853@qq.com>
Co-authored-by: Chao Lei <leichao139636@163.com>
This commit is contained in:
@@ -87,12 +87,14 @@ class LayerPoolKey(PoolKey):
|
||||
|
||||
class ChunkedTokenDatabase():
|
||||
|
||||
def __init__(self, metadata: KeyMetadata, block_size: int, use_mla: bool):
|
||||
def __init__(self, metadata: KeyMetadata, block_size: int, use_mla: bool,
|
||||
partitions: Optional[List[int]]):
|
||||
self.metadata = metadata
|
||||
self.block_size = block_size
|
||||
self.use_mla = use_mla
|
||||
self.kv_caches_base_addr: list[int] = []
|
||||
self.block_len: list[int] = []
|
||||
self.partitions = partitions
|
||||
|
||||
def _make_key_by_hash(self,
|
||||
chunk_hash: str,
|
||||
@@ -188,6 +190,28 @@ class ChunkedTokenDatabase():
|
||||
else:
|
||||
yield start_idx, end_idx, self._make_key_by_hash(hash_val)
|
||||
|
||||
def decode_adaptor_prefill_pp(self, key, addr, size):
|
||||
if self.partitions is None or len(self.partitions) == 1:
|
||||
return key, addr, size
|
||||
|
||||
new_key = []
|
||||
new_addr = []
|
||||
new_size = []
|
||||
|
||||
for i, (addr_list, size_list) in enumerate(zip(addr, size)):
|
||||
start = 0
|
||||
for j, part in enumerate(self.partitions):
|
||||
# part * 2 because addr and size contain both k and v
|
||||
end = len(addr_list) if j == len(
|
||||
self.partitions) - 1 else start + part * 2
|
||||
new_str = key[i].replace( # type: ignore[attr-defined]
|
||||
"@pp_rank:0", f"@pp_rank:{j}", 1)
|
||||
new_key.append(new_str)
|
||||
new_addr.append(addr_list[start:end])
|
||||
new_size.append(size_list[start:end])
|
||||
start = end
|
||||
return new_key, new_addr, new_size
|
||||
|
||||
|
||||
#Parameters related to the connector metadata
|
||||
@dataclass
|
||||
@@ -247,15 +271,12 @@ class RequestTracker:
|
||||
|
||||
def update(
|
||||
self,
|
||||
new_token_ids: list[int],
|
||||
new_block_ids: Union[tuple[list[int], ...], list[int]],
|
||||
) -> None:
|
||||
"""Update the request tracker when a running request is
|
||||
scheduled again
|
||||
"""
|
||||
|
||||
self.token_len = self.token_len + len(new_token_ids)
|
||||
|
||||
if len(new_block_ids) == 0:
|
||||
new_block_ids = []
|
||||
elif isinstance(new_block_ids, tuple):
|
||||
@@ -378,4 +399,4 @@ class LasyerMultiBlockReqMeta:
|
||||
block_ids: list[int]
|
||||
layer_id: int
|
||||
is_last_chunk: Optional[bool] = True
|
||||
current_event: Optional[torch.npu.Event] = None
|
||||
current_event: Optional[torch.npu.Event] = None
|
||||
|
||||
Reference in New Issue
Block a user