[KVPOOl]Support pp (#4761)
### What this PR does / why we need it?
Support pp for kv pool
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: baxingpiaochong <771405853@qq.com>
This commit is contained in:
@@ -21,6 +21,8 @@ class KeyMetadata:
|
|||||||
pcp_rank: int
|
pcp_rank: int
|
||||||
""" Initialize the current decode context model parallel rank """
|
""" Initialize the current decode context model parallel rank """
|
||||||
dcp_rank: int
|
dcp_rank: int
|
||||||
|
""" Initialize the current pipeline parallel rank """
|
||||||
|
pp_rank: int
|
||||||
|
|
||||||
|
|
||||||
@dataclass(order=True)
|
@dataclass(order=True)
|
||||||
@@ -34,6 +36,7 @@ class PoolKey:
|
|||||||
self.key_metadata.head_or_tp_rank,
|
self.key_metadata.head_or_tp_rank,
|
||||||
self.key_metadata.pcp_rank,
|
self.key_metadata.pcp_rank,
|
||||||
self.key_metadata.dcp_rank,
|
self.key_metadata.dcp_rank,
|
||||||
|
self.key_metadata.pp_rank,
|
||||||
self.chunk_hash,
|
self.chunk_hash,
|
||||||
))
|
))
|
||||||
|
|
||||||
@@ -41,8 +44,8 @@ class PoolKey:
|
|||||||
return (
|
return (
|
||||||
f"{self.key_metadata.model_name}"
|
f"{self.key_metadata.model_name}"
|
||||||
f"@pcp{self.key_metadata.pcp_rank}@dcp{self.key_metadata.dcp_rank}"
|
f"@pcp{self.key_metadata.pcp_rank}@dcp{self.key_metadata.dcp_rank}"
|
||||||
f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}@{self.chunk_hash}"
|
f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}"
|
||||||
)
|
f"@pp_rank:{self.key_metadata.pp_rank}@{self.chunk_hash}")
|
||||||
|
|
||||||
def split_layers(self, num_layers: int) -> List["LayerPoolKey"]:
|
def split_layers(self, num_layers: int) -> List["LayerPoolKey"]:
|
||||||
"""Split the key into multiple keys for each layer"""
|
"""Split the key into multiple keys for each layer"""
|
||||||
|
|||||||
@@ -48,6 +48,8 @@ class KVPoolWorker:
|
|||||||
self.use_layerwise = use_layerwize
|
self.use_layerwise = use_layerwize
|
||||||
self.tp_rank = get_tensor_model_parallel_rank()
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.pp_size = parallel_config.pipeline_parallel_size
|
||||||
|
self.pp_rank = (parallel_config.rank // self.tp_size) % self.pp_size
|
||||||
|
|
||||||
self.pcp_size = get_pcp_group().world_size
|
self.pcp_size = get_pcp_group().world_size
|
||||||
self.pcp_rank = get_pcp_group(
|
self.pcp_rank = get_pcp_group(
|
||||||
@@ -87,6 +89,7 @@ class KVPoolWorker:
|
|||||||
self.head_or_tp_rank,
|
self.head_or_tp_rank,
|
||||||
self.pcp_rank,
|
self.pcp_rank,
|
||||||
self.dcp_rank,
|
self.dcp_rank,
|
||||||
|
self.pp_rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.token_database = ChunkedTokenDatabase(self.metadata,
|
self.token_database = ChunkedTokenDatabase(self.metadata,
|
||||||
@@ -555,6 +558,12 @@ class KVPoolWorker:
|
|||||||
"@head_or_tp_rank:0", f"@head_or_tp_rank:{i}", 1)
|
"@head_or_tp_rank:0", f"@head_or_tp_rank:{i}", 1)
|
||||||
multi_tp_keys.append(new_str)
|
multi_tp_keys.append(new_str)
|
||||||
|
|
||||||
|
for i in range(1, self.pp_size):
|
||||||
|
for item in keys:
|
||||||
|
new_str = item.replace( # type: ignore[attr-defined]
|
||||||
|
"@pp_rank:0", f"@pp_rank:{i}", 1)
|
||||||
|
multi_tp_keys.append(new_str)
|
||||||
|
|
||||||
res = self.m_store.exists(
|
res = self.m_store.exists(
|
||||||
multi_tp_keys) # type: ignore[assignment]
|
multi_tp_keys) # type: ignore[assignment]
|
||||||
num_block = len(keys)
|
num_block = len(keys)
|
||||||
|
|||||||
@@ -2450,6 +2450,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
|||||||
attn_metadata, self.with_prefill, maybe_padded_num_tokens,
|
attn_metadata, self.with_prefill, maybe_padded_num_tokens,
|
||||||
input_ids, positions, intermediate_tensors, inputs_embeds)
|
input_ids, positions, intermediate_tensors, inputs_embeds)
|
||||||
|
|
||||||
|
self.maybe_wait_for_kv_save()
|
||||||
finished_sending, finished_recving = self.get_finished_kv_transfer(
|
finished_sending, finished_recving = self.get_finished_kv_transfer(
|
||||||
scheduler_output)
|
scheduler_output)
|
||||||
|
|
||||||
@@ -2711,7 +2712,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
|||||||
# ngram and other speculative decoding methods use the sampled
|
# ngram and other speculative decoding methods use the sampled
|
||||||
# tokens on the CPU, so they are run after bookkeeping.
|
# tokens on the CPU, so they are run after bookkeeping.
|
||||||
propose_draft_token_ids(valid_sampled_token_ids)
|
propose_draft_token_ids(valid_sampled_token_ids)
|
||||||
self.maybe_wait_for_kv_save()
|
|
||||||
if has_kv_transfer_group():
|
if has_kv_transfer_group():
|
||||||
get_kv_transfer_group().clear_connector_metadata()
|
get_kv_transfer_group().clear_connector_metadata()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user