[feat]dcp pcp support aclgraph (#3731)

### What this PR does / why we need it?
dcp pcp support  full aclgraph, including mla attention_v1

- vLLM version: v0.11.0rc3
- vLLM main:
c9461e05a4

Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
weiguihua2
2025-10-27 09:58:23 +08:00
committed by GitHub
parent 8ab8111fde
commit 4312a92a4f
5 changed files with 414 additions and 68 deletions

View File

@@ -140,6 +140,9 @@ class AscendMLADecodeMetadata:
cos: torch.Tensor = None
num_computed_tokens_of_pcp_dcp: Optional[list[Optional[list[Optional[
list[int]]]]]] = None
seq_mask_pcp: torch.Tensor = None
seq_mask_dcp: torch.Tensor = None
cp_seq_len: torch.Tensor = None
@dataclass
@@ -259,6 +262,24 @@ class AscendMLAMetadataBuilder:
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
self.cos_cache = None
self.sin_cache = None
self.pcp_size = get_prefill_context_model_parallel_world_size(
) if prefill_context_parallel_enable() else 1
self.cp_rank = get_prefill_context_model_parallel_rank(
) if self.pcp_size > 1 else 0
self.dcp_size = get_decode_context_model_parallel_world_size()
self.dcp_rank = get_decode_context_model_parallel_rank(
) if self.dcp_size > 1 else 0
decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs',
0)
max_num_seqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs)
self.seq_mask_pcp_buf = torch.empty(max_num_seqs,
self.pcp_size,
dtype=torch.uint8,
device=device)
self.seq_mask_dcp_buf = torch.empty(max_num_seqs,
self.dcp_size,
dtype=torch.uint8,
device=device)
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
@@ -463,6 +484,41 @@ class AscendMLAMetadataBuilder:
block_table = block_table[:num_decodes, ...]
seq_lens_list = seq_lens.tolist()
if num_computed_tokens_of_pcp_dcp is not None:
num_computed_tokens_of_cp_dcp_array = np.array(
num_computed_tokens_of_pcp_dcp
)[:num_decodes] # [bs, pcp_size, dcp_size]
seq_mask_pcp = torch.where(
torch.tensor(
num_computed_tokens_of_cp_dcp_array.sum(2)) == 0, 0,
1).to(torch.uint8)
self.seq_mask_pcp_buf[:seq_mask_pcp.shape[0], :seq_mask_pcp.
shape[1]].copy_(seq_mask_pcp,
non_blocking=True)
seq_mask_pcp_shape = (seq_mask_pcp.shape[0],
seq_mask_pcp.shape[1])
seq_mask_dcp = torch.where(
torch.tensor(
num_computed_tokens_of_cp_dcp_array[:,
self.cp_rank, :])
== 0, 0, 1).to(torch.uint8)
self.seq_mask_dcp_buf[:seq_mask_dcp.shape[0], :seq_mask_dcp.
shape[1]].copy_(seq_mask_dcp,
non_blocking=True)
seq_mask_dcp_shape = (seq_mask_dcp.shape[0],
seq_mask_dcp.shape[1])
cp_seq_len = num_computed_tokens_of_cp_dcp_array[:,
self.cp_rank,
self.dcp_rank]
cp_seq_len = torch.tensor(cp_seq_len, dtype=torch.int32)
cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len)
else:
seq_mask_pcp_shape = (0, 0)
seq_mask_dcp_shape = (0, 0)
cp_seq_len = None
# TODO: After the fullgraph supports MTP, the if branch needs to deleted
assert self.cos_cache is not None
assert self.sin_cache is not None
@@ -485,7 +541,14 @@ class AscendMLAMetadataBuilder:
sin=sin,
cos=cos,
num_computed_tokens_of_pcp_dcp=
num_computed_tokens_of_pcp_dcp)
num_computed_tokens_of_pcp_dcp,
seq_mask_pcp=self.
seq_mask_pcp_buf[:seq_mask_pcp_shape[0], :
seq_mask_pcp_shape[1]],
seq_mask_dcp=self.
seq_mask_dcp_buf[:seq_mask_dcp_shape[0], :
seq_mask_dcp_shape[1]],
cp_seq_len=cp_seq_len)
else:
cos[:num_decode_tokens,
...] = self.cos_cache[input_positions].unsqueeze(
@@ -505,7 +568,14 @@ class AscendMLAMetadataBuilder:
sin=sin[:num_decode_tokens, ...],
cos=cos[:num_decode_tokens, ...],
num_computed_tokens_of_pcp_dcp=
num_computed_tokens_of_pcp_dcp)
num_computed_tokens_of_pcp_dcp,
seq_mask_pcp=self.
seq_mask_pcp_buf[:seq_mask_pcp_shape[0], :
seq_mask_pcp_shape[1]],
seq_mask_dcp=self.
seq_mask_dcp_buf[:seq_mask_dcp_shape[0], :
seq_mask_dcp_shape[1]],
cp_seq_len=cp_seq_len)
return self.metadata_cls( # type: ignore
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
@@ -1590,36 +1660,63 @@ class AscendMLAImpl(MLAAttentionImpl):
q_nope = q_nope.view(num_tokens, num_heads, -1)
q_pe = q_pe.view(num_tokens, num_heads, -1)
# use pcp & dcp split computed token nums from scheduler to compute actual seq_len and seq_mask
num_computed_tokens_of_pcp_dcp = np.array(
decode_meta.num_computed_tokens_of_pcp_dcp
)[:attn_metadata.num_decodes] # [bs, pcp_size, dcp_size]
seq_mask_pcp = torch.where(
torch.tensor(num_computed_tokens_of_pcp_dcp.sum(2)) == 0, 0,
1).to(torch.uint8).to(q_pe.device)
seq_mask_dcp = torch.where(
torch.tensor(
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, :]) == 0, 0,
1).to(torch.uint8).to(q_pe.device)
seq_len = num_computed_tokens_of_pcp_dcp[:, self.pcp_rank,
self.dcp_rank]
seq_len = torch.tensor(seq_len, dtype=torch.int32)
# npu_multi_head_latent_attention does not support seq_len = 0,
# update where seq_len == 0 to 1.
# This will not influence result, since we will use seq_mask to update lse.
seq_len = torch.where(seq_len == 0, 1, seq_len)
seq_mask_pcp = decode_meta.seq_mask_pcp
seq_mask_dcp = decode_meta.seq_mask_dcp
seq_len = decode_meta.cp_seq_len
if torch.sum(seq_len).item() == 0:
# Case that no kv_cache has been stored on this rank, no need to do following computation.
attn_output = torch.zeros(
[num_tokens, num_heads, self.kv_lora_rank],
dtype=q_nope.dtype,
device=q_nope.device)
softmax_lse = torch.full((num_tokens, num_heads, 1),
float('-inf'),
dtype=q_nope.dtype,
device=q_nope.device)
common_kwargs = {
"return_lse": True,
"calc_type": "calc_type_ring",
}
graph_params = get_graph_params()
forward_context: ForwardContext = get_forward_context()
if forward_context.capturing:
stream = torch_npu.npu.current_stream()
event = torch.npu.ExternalEvent()
event.wait(stream)
event.reset(stream)
graph_params.events[num_tokens].append(event)
workspace = graph_params.workspaces.get(num_tokens)
if workspace is None:
workspace = torch_npu.atb._npu_multi_head_latent_attention_get_workspace(
q_nope, q_pe, k_nope, k_pe, decode_meta.block_table,
seq_len, num_heads, self.scale, self.num_kv_heads,
**common_kwargs)
update_graph_params_workspaces(num_tokens,
weak_ref_tensors(workspace))
attn_output = torch.empty_like(q_nope)
softmax_lse = torch.empty((num_tokens, num_heads, 1),
dtype=q_nope.dtype,
device=q_nope.device)
graph_params.attn_params[num_tokens].append(
(weak_ref_tensors(q_nope), weak_ref_tensors(q_pe),
weak_ref_tensors(k_nope), weak_ref_tensors(k_pe),
decode_meta.block_table, seq_len, num_heads, self.scale,
self.num_kv_heads, weak_ref_tensors(attn_output),
weak_ref_tensors(softmax_lse)))
torch.npu.graph_task_group_begin(stream)
torch_npu.atb.npu_multi_head_latent_attention(
q_nope,
q_pe,
k_nope,
k_pe,
decode_meta.block_table,
seq_len,
num_heads,
self.scale,
self.num_kv_heads,
**common_kwargs,
workspace=workspace,
output=attn_output,
lse=softmax_lse)
handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)
else:
attn_output, softmax_lse = torch_npu.atb.npu_multi_head_latent_attention(
attn_output = torch.empty_like(q_nope)
softmax_lse = torch.empty((num_tokens, num_heads, 1),
dtype=q_nope.dtype,
device=q_nope.device)
torch_npu.atb.npu_multi_head_latent_attention(
q_nope,
q_pe,
k_nope,
@@ -1630,7 +1727,9 @@ class AscendMLAImpl(MLAAttentionImpl):
self.scale,
self.num_kv_heads,
return_lse=True,
calc_type="calc_type_ring")
calc_type="calc_type_ring",
output=attn_output,
lse=softmax_lse)
if self.dcp_size > 1:
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]