[KVPOOl]Support pp (#4761)

### What this PR does / why we need it?
Support pp for kv pool

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: baxingpiaochong <771405853@qq.com>
This commit is contained in:
baxingpiaochong
2025-12-09 16:15:26 +08:00
committed by GitHub
parent 9038865261
commit dda027e680
3 changed files with 16 additions and 3 deletions

View File

@@ -48,6 +48,8 @@ class KVPoolWorker:
self.use_layerwise = use_layerwize
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.pp_size = parallel_config.pipeline_parallel_size
self.pp_rank = (parallel_config.rank // self.tp_size) % self.pp_size
self.pcp_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group(
@@ -87,6 +89,7 @@ class KVPoolWorker:
self.head_or_tp_rank,
self.pcp_rank,
self.dcp_rank,
self.pp_rank,
)
self.token_database = ChunkedTokenDatabase(self.metadata,
@@ -555,6 +558,12 @@ class KVPoolWorker:
"@head_or_tp_rank:0", f"@head_or_tp_rank:{i}", 1)
multi_tp_keys.append(new_str)
for i in range(1, self.pp_size):
for item in keys:
new_str = item.replace( # type: ignore[attr-defined]
"@pp_rank:0", f"@pp_rank:{i}", 1)
multi_tp_keys.append(new_str)
res = self.m_store.exists(
multi_tp_keys) # type: ignore[assignment]
num_block = len(keys)