diff --git a/vllm_ascend/distributed/kvpool/config_data.py b/vllm_ascend/distributed/kvpool/config_data.py index 8b45b291..6828e826 100644 --- a/vllm_ascend/distributed/kvpool/config_data.py +++ b/vllm_ascend/distributed/kvpool/config_data.py @@ -21,6 +21,8 @@ class KeyMetadata: pcp_rank: int """ Initialize the current decode context model parallel rank """ dcp_rank: int + """ Initialize the current pipeline parallel rank """ + pp_rank: int @dataclass(order=True) @@ -34,6 +36,7 @@ class PoolKey: self.key_metadata.head_or_tp_rank, self.key_metadata.pcp_rank, self.key_metadata.dcp_rank, + self.key_metadata.pp_rank, self.chunk_hash, )) @@ -41,8 +44,8 @@ class PoolKey: return ( f"{self.key_metadata.model_name}" f"@pcp{self.key_metadata.pcp_rank}@dcp{self.key_metadata.dcp_rank}" - f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}@{self.chunk_hash}" - ) + f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}" + f"@pp_rank:{self.key_metadata.pp_rank}@{self.chunk_hash}") def split_layers(self, num_layers: int) -> List["LayerPoolKey"]: """Split the key into multiple keys for each layer""" diff --git a/vllm_ascend/distributed/kvpool/pool_worker.py b/vllm_ascend/distributed/kvpool/pool_worker.py index 09cf94be..a1e0a900 100644 --- a/vllm_ascend/distributed/kvpool/pool_worker.py +++ b/vllm_ascend/distributed/kvpool/pool_worker.py @@ -48,6 +48,8 @@ class KVPoolWorker: self.use_layerwise = use_layerwize self.tp_rank = get_tensor_model_parallel_rank() self.tp_size = get_tensor_model_parallel_world_size() + self.pp_size = parallel_config.pipeline_parallel_size + self.pp_rank = (parallel_config.rank // self.tp_size) % self.pp_size self.pcp_size = get_pcp_group().world_size self.pcp_rank = get_pcp_group( @@ -87,6 +89,7 @@ class KVPoolWorker: self.head_or_tp_rank, self.pcp_rank, self.dcp_rank, + self.pp_rank, ) self.token_database = ChunkedTokenDatabase(self.metadata, @@ -555,6 +558,12 @@ class KVPoolWorker: "@head_or_tp_rank:0", f"@head_or_tp_rank:{i}", 1) multi_tp_keys.append(new_str) + for i in range(1, self.pp_size): + for item in keys: + new_str = item.replace( # type: ignore[attr-defined] + "@pp_rank:0", f"@pp_rank:{i}", 1) + multi_tp_keys.append(new_str) + res = self.m_store.exists( multi_tp_keys) # type: ignore[assignment] num_block = len(keys) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 7c096f16..56cef32c 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2450,6 +2450,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): attn_metadata, self.with_prefill, maybe_padded_num_tokens, input_ids, positions, intermediate_tensors, inputs_embeds) + self.maybe_wait_for_kv_save() finished_sending, finished_recving = self.get_finished_kv_transfer( scheduler_output) @@ -2711,7 +2712,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): # ngram and other speculative decoding methods use the sampled # tokens on the CPU, so they are run after bookkeeping. propose_draft_token_ids(valid_sampled_token_ids) - self.maybe_wait_for_kv_save() + if has_kv_transfer_group(): get_kv_transfer_group().clear_connector_metadata()