[feat]dcp pcp support aclgraph (#3731)

### What this PR does / why we need it?
dcp pcp support  full aclgraph, including mla attention_v1

- vLLM version: v0.11.0rc3
- vLLM main:
c9461e05a4

Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
weiguihua2
2025-10-27 09:58:23 +08:00
committed by GitHub
parent 8ab8111fde
commit 4312a92a4f
5 changed files with 414 additions and 68 deletions

View File

@@ -110,10 +110,14 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
AscendPrefillContextParallelMetadata)
# yapf: disable
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
set_graph_params,
update_attn_dcp_pcp_params,
update_attn_params,
update_mla_attn_dcp_pcp_params,
update_mla_attn_params)
# yapf: enable
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
from vllm_ascend.eplb.core.eplb_device_transfer_loader import \
D2DExpertWeightLoader
@@ -1649,6 +1653,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
slot_mapping = blk_table.slot_mapping[:slot_mapping_size]
blk_table.slot_mapping[slot_mapping_size:].fill_(0)
if self.pcp_size > 1:
slot_mapping_for_pcp = blk_table.slot_mapping[:
long_seq_metadata
.
num_actual_tokens_pcp_padded]
slot_mapping_for_pcp[slot_mapping_size:].fill_(-1)
assert pcp_unpad_mask is not None
pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[:
pcp_unpad_mask
@@ -1657,10 +1666,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
0]]
pcp_padded_slot_mapping.fill_(-1)
pcp_padded_slot_mapping[
pcp_unpad_mask] = blk_table.slot_mapping[:
slot_mapping_size]
blk_table.slot_mapping[:long_seq_metadata.
num_actual_tokens_pcp_padded] = pcp_padded_slot_mapping
pcp_unpad_mask] = slot_mapping_for_pcp[:
slot_mapping_size]
slot_mapping_for_pcp[:long_seq_metadata.
num_actual_tokens_pcp_padded] = pcp_padded_slot_mapping
slot_mapping = slot_mapping_for_pcp
# Make AscendCommonAttentionMetadata
common_attn_metadata = AscendCommonAttentionMetadata(
@@ -1749,13 +1759,25 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
# TODO: maybe_padded_num_tokens will be removed, use num_input_tokens instead
if self.vllm_config.model_config.use_mla:
# FIXME: Try using `auto_dispatch_capture=True`
update_mla_attn_params(self.update_stream, forward_context,
maybe_padded_num_tokens,
self.speculative_config)
if self.pcp_size * self.dcp_size > 1:
# FIXME: Try using `auto_dispatch_capture=True`
update_mla_attn_dcp_pcp_params(self.update_stream,
forward_context,
maybe_padded_num_tokens,
self.speculative_config)
else:
# FIXME: Try using `auto_dispatch_capture=True`
update_mla_attn_params(self.update_stream, forward_context,
maybe_padded_num_tokens,
self.speculative_config)
else:
update_attn_params(self.update_stream, forward_context,
maybe_padded_num_tokens)
if self.pcp_size * self.dcp_size > 1:
update_attn_dcp_pcp_params(self.update_stream,
forward_context,
maybe_padded_num_tokens)
else:
update_attn_params(self.update_stream, forward_context,
maybe_padded_num_tokens)
if get_forward_context().sp_enabled:
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
@@ -2488,6 +2510,19 @@ class NPUModelRunner(LoRAModelRunnerMixin):
kv_cache_group_id].get_device_tensor()
slot_mapping = self.input_batch.block_table[
kv_cache_group_id].slot_mapping
self.cp_kv_recover_idx = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=self.device)
long_seq_metadata = self._generate_pcp_metadata(
num_tokens, self.seq_lens_cpu)
if long_seq_metadata is not None:
pcp_world_size = get_pcp_group(
).world_size if prefill_context_parallel_enable() else 1
dcp_world_size = get_dcp_group().world_size
num_computed_tokens_of_pcp_dcp = [[
[0] * dcp_world_size for _ in range(pcp_world_size)
] for _ in range(num_tokens)]
long_seq_metadata.num_computed_tokens_of_pcp_dcp = num_computed_tokens_of_pcp_dcp
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=torch.tensor(
[0] + self.actual_seq_lengths_q[:num_reqs],
@@ -2511,6 +2546,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
decode_token_per_req=self.decode_token_per_req,
cos=self.cos,
sin=self.sin,
prefill_context_parallel_metadata=long_seq_metadata,
)
attn_state = AscendAttentionState.DecodeOnly
if self.speculative_config and \
@@ -2540,12 +2576,25 @@ class NPUModelRunner(LoRAModelRunnerMixin):
not forward_context.capturing:
if self.vllm_config.model_config.use_mla:
# FIXME: Try using `auto_dispatch_capture=True`
update_mla_attn_params(self.update_stream, forward_context,
positions.shape[0],
self.speculative_config)
if self.pcp_size * self.dcp_size > 1:
# FIXME: Try using `auto_dispatch_capture=True`
update_mla_attn_dcp_pcp_params(self.update_stream,
forward_context,
positions.shape[0],
self.speculative_config)
else:
# FIXME: Try using `auto_dispatch_capture=True`
update_mla_attn_params(self.update_stream, forward_context,
positions.shape[0],
self.speculative_config)
else:
update_attn_params(self.update_stream, forward_context,
positions.shape[0])
if self.pcp_size * self.dcp_size > 1:
update_attn_dcp_pcp_params(self.update_stream,
forward_context,
positions.shape[0])
else:
update_attn_params(self.update_stream, forward_context,
positions.shape[0])
if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3:
hidden_states, _ = hidden_states