[feat]dcp pcp support aclgraph (#3731)
### What this PR does / why we need it?
dcp pcp support full aclgraph, including mla attention_v1
- vLLM version: v0.11.0rc3
- vLLM main:
c9461e05a4
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
@@ -110,10 +110,14 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
AscendPrefillContextParallelMetadata)
|
||||
# yapf: disable
|
||||
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
||||
set_graph_params,
|
||||
update_attn_dcp_pcp_params,
|
||||
update_attn_params,
|
||||
update_mla_attn_dcp_pcp_params,
|
||||
update_mla_attn_params)
|
||||
# yapf: enable
|
||||
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
|
||||
from vllm_ascend.eplb.core.eplb_device_transfer_loader import \
|
||||
D2DExpertWeightLoader
|
||||
@@ -1649,6 +1653,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
slot_mapping = blk_table.slot_mapping[:slot_mapping_size]
|
||||
blk_table.slot_mapping[slot_mapping_size:].fill_(0)
|
||||
if self.pcp_size > 1:
|
||||
slot_mapping_for_pcp = blk_table.slot_mapping[:
|
||||
long_seq_metadata
|
||||
.
|
||||
num_actual_tokens_pcp_padded]
|
||||
slot_mapping_for_pcp[slot_mapping_size:].fill_(-1)
|
||||
assert pcp_unpad_mask is not None
|
||||
pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[:
|
||||
pcp_unpad_mask
|
||||
@@ -1657,10 +1666,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
0]]
|
||||
pcp_padded_slot_mapping.fill_(-1)
|
||||
pcp_padded_slot_mapping[
|
||||
pcp_unpad_mask] = blk_table.slot_mapping[:
|
||||
slot_mapping_size]
|
||||
blk_table.slot_mapping[:long_seq_metadata.
|
||||
num_actual_tokens_pcp_padded] = pcp_padded_slot_mapping
|
||||
pcp_unpad_mask] = slot_mapping_for_pcp[:
|
||||
slot_mapping_size]
|
||||
slot_mapping_for_pcp[:long_seq_metadata.
|
||||
num_actual_tokens_pcp_padded] = pcp_padded_slot_mapping
|
||||
slot_mapping = slot_mapping_for_pcp
|
||||
|
||||
# Make AscendCommonAttentionMetadata
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
@@ -1749,13 +1759,25 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
# TODO: maybe_padded_num_tokens will be removed, use num_input_tokens instead
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
# FIXME: Try using `auto_dispatch_capture=True`
|
||||
update_mla_attn_params(self.update_stream, forward_context,
|
||||
maybe_padded_num_tokens,
|
||||
self.speculative_config)
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
# FIXME: Try using `auto_dispatch_capture=True`
|
||||
update_mla_attn_dcp_pcp_params(self.update_stream,
|
||||
forward_context,
|
||||
maybe_padded_num_tokens,
|
||||
self.speculative_config)
|
||||
else:
|
||||
# FIXME: Try using `auto_dispatch_capture=True`
|
||||
update_mla_attn_params(self.update_stream, forward_context,
|
||||
maybe_padded_num_tokens,
|
||||
self.speculative_config)
|
||||
else:
|
||||
update_attn_params(self.update_stream, forward_context,
|
||||
maybe_padded_num_tokens)
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
update_attn_dcp_pcp_params(self.update_stream,
|
||||
forward_context,
|
||||
maybe_padded_num_tokens)
|
||||
else:
|
||||
update_attn_params(self.update_stream, forward_context,
|
||||
maybe_padded_num_tokens)
|
||||
|
||||
if get_forward_context().sp_enabled:
|
||||
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
|
||||
@@ -2488,6 +2510,19 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
kv_cache_group_id].get_device_tensor()
|
||||
slot_mapping = self.input_batch.block_table[
|
||||
kv_cache_group_id].slot_mapping
|
||||
self.cp_kv_recover_idx = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
long_seq_metadata = self._generate_pcp_metadata(
|
||||
num_tokens, self.seq_lens_cpu)
|
||||
if long_seq_metadata is not None:
|
||||
pcp_world_size = get_pcp_group(
|
||||
).world_size if prefill_context_parallel_enable() else 1
|
||||
dcp_world_size = get_dcp_group().world_size
|
||||
num_computed_tokens_of_pcp_dcp = [[
|
||||
[0] * dcp_world_size for _ in range(pcp_world_size)
|
||||
] for _ in range(num_tokens)]
|
||||
long_seq_metadata.num_computed_tokens_of_pcp_dcp = num_computed_tokens_of_pcp_dcp
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=torch.tensor(
|
||||
[0] + self.actual_seq_lengths_q[:num_reqs],
|
||||
@@ -2511,6 +2546,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
decode_token_per_req=self.decode_token_per_req,
|
||||
cos=self.cos,
|
||||
sin=self.sin,
|
||||
prefill_context_parallel_metadata=long_seq_metadata,
|
||||
)
|
||||
attn_state = AscendAttentionState.DecodeOnly
|
||||
if self.speculative_config and \
|
||||
@@ -2540,12 +2576,25 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
not forward_context.capturing:
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
# FIXME: Try using `auto_dispatch_capture=True`
|
||||
update_mla_attn_params(self.update_stream, forward_context,
|
||||
positions.shape[0],
|
||||
self.speculative_config)
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
# FIXME: Try using `auto_dispatch_capture=True`
|
||||
update_mla_attn_dcp_pcp_params(self.update_stream,
|
||||
forward_context,
|
||||
positions.shape[0],
|
||||
self.speculative_config)
|
||||
else:
|
||||
# FIXME: Try using `auto_dispatch_capture=True`
|
||||
update_mla_attn_params(self.update_stream, forward_context,
|
||||
positions.shape[0],
|
||||
self.speculative_config)
|
||||
else:
|
||||
update_attn_params(self.update_stream, forward_context,
|
||||
positions.shape[0])
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
update_attn_dcp_pcp_params(self.update_stream,
|
||||
forward_context,
|
||||
positions.shape[0])
|
||||
else:
|
||||
update_attn_params(self.update_stream, forward_context,
|
||||
positions.shape[0])
|
||||
|
||||
if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3:
|
||||
hidden_states, _ = hidden_states
|
||||
|
||||
Reference in New Issue
Block a user