[KVPOOL]decode save kvcache (#5168)

### What this PR does / why we need it?

kvpool decode save kvcache
now only support mla

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: baxingpiaochong <771405853@qq.com>
Co-authored-by: Chao Lei <leichao139636@163.com>
This commit is contained in:
baxingpiaochong
2026-01-04 22:22:01 +08:00
committed by GitHub
parent 350b95efcf
commit 46c2fc6a3c
6 changed files with 156 additions and 31 deletions

View File

@@ -87,12 +87,14 @@ class LayerPoolKey(PoolKey):
class ChunkedTokenDatabase():
def __init__(self, metadata: KeyMetadata, block_size: int, use_mla: bool):
def __init__(self, metadata: KeyMetadata, block_size: int, use_mla: bool,
partitions: Optional[List[int]]):
self.metadata = metadata
self.block_size = block_size
self.use_mla = use_mla
self.kv_caches_base_addr: list[int] = []
self.block_len: list[int] = []
self.partitions = partitions
def _make_key_by_hash(self,
chunk_hash: str,
@@ -188,6 +190,28 @@ class ChunkedTokenDatabase():
else:
yield start_idx, end_idx, self._make_key_by_hash(hash_val)
def decode_adaptor_prefill_pp(self, key, addr, size):
if self.partitions is None or len(self.partitions) == 1:
return key, addr, size
new_key = []
new_addr = []
new_size = []
for i, (addr_list, size_list) in enumerate(zip(addr, size)):
start = 0
for j, part in enumerate(self.partitions):
# part * 2 because addr and size contain both k and v
end = len(addr_list) if j == len(
self.partitions) - 1 else start + part * 2
new_str = key[i].replace( # type: ignore[attr-defined]
"@pp_rank:0", f"@pp_rank:{j}", 1)
new_key.append(new_str)
new_addr.append(addr_list[start:end])
new_size.append(size_list[start:end])
start = end
return new_key, new_addr, new_size
#Parameters related to the connector metadata
@dataclass
@@ -247,15 +271,12 @@ class RequestTracker:
def update(
self,
new_token_ids: list[int],
new_block_ids: Union[tuple[list[int], ...], list[int]],
) -> None:
"""Update the request tracker when a running request is
scheduled again
"""
self.token_len = self.token_len + len(new_token_ids)
if len(new_block_ids) == 0:
new_block_ids = []
elif isinstance(new_block_ids, tuple):
@@ -378,4 +399,4 @@ class LasyerMultiBlockReqMeta:
block_ids: list[int]
layer_id: int
is_last_chunk: Optional[bool] = True
current_event: Optional[torch.npu.Event] = None
current_event: Optional[torch.npu.Event] = None