[Feat] update op for mla (#4000)

### What this PR does / why we need it?
1、in mla_v1 module, add torch_npu.npu_attention_update op when pcp and dcp

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.11.0
- vLLM main:
83f478bb19

---------

Signed-off-by: LookAround <lixushi@huawei.com>
This commit is contained in:
LookAround0301
2025-11-07 09:48:39 +08:00
committed by GitHub
parent f8610b7d67
commit 79e536d939
2 changed files with 88 additions and 113 deletions

View File

@@ -1,11 +1,10 @@
from dataclasses import dataclass
from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type,
TypeVar)
from typing import (TYPE_CHECKING, ClassVar, List, NamedTuple, Optional, Tuple,
Type, TypeVar)
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torch_npu
from torch import nn
from vllm.attention.backends.abstract import (AttentionBackend,
@@ -140,10 +139,8 @@ class AscendMLADecodeMetadata:
attn_mask: Optional[torch.Tensor] = None
sin: torch.Tensor = None
cos: torch.Tensor = None
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
seq_mask_pcp: torch.Tensor = None
seq_mask_dcp: torch.Tensor = None
cp_seq_len: torch.Tensor = None
batch_seq_mask: torch.Tensor = None
@dataclass
@@ -263,9 +260,10 @@ class AscendMLAMetadataBuilder:
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
self.cos_cache = None
self.sin_cache = None
self.pcp_size = get_prefill_context_model_parallel_world_size(
) if prefill_context_parallel_enable() else 1
self.cp_rank = get_prefill_context_model_parallel_rank(
self.pcp_rank = get_prefill_context_model_parallel_rank(
) if self.pcp_size > 1 else 0
self.dcp_size = get_decode_context_model_parallel_world_size()
self.dcp_rank = get_decode_context_model_parallel_rank(
@@ -273,6 +271,9 @@ class AscendMLAMetadataBuilder:
decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs',
0)
max_num_seqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs)
self.batch_seq_mask_buf = torch.empty(max_num_seqs,
dtype=torch.uint8,
device=device)
self.seq_mask_pcp_buf = torch.empty(max_num_seqs,
self.pcp_size,
dtype=torch.uint8,
@@ -489,36 +490,19 @@ class AscendMLAMetadataBuilder:
num_computed_tokens_of_cp_dcp_array = np.array(
num_computed_tokens_of_pcp_dcp
)[:num_decodes] # [bs, pcp_size, dcp_size]
seq_mask_pcp = torch.where(
torch.tensor(
num_computed_tokens_of_cp_dcp_array.sum(2)) == 0, 0,
1).to(torch.uint8)
self.seq_mask_pcp_buf[:seq_mask_pcp.shape[0], :seq_mask_pcp.
shape[1]].copy_(seq_mask_pcp,
non_blocking=True)
seq_mask_pcp_shape = (seq_mask_pcp.shape[0],
seq_mask_pcp.shape[1])
seq_mask_dcp = torch.where(
torch.tensor(
num_computed_tokens_of_cp_dcp_array[:,
self.cp_rank, :])
== 0, 0, 1).to(torch.uint8)
self.seq_mask_dcp_buf[:seq_mask_dcp.shape[0], :seq_mask_dcp.
shape[1]].copy_(seq_mask_dcp,
non_blocking=True)
seq_mask_dcp_shape = (seq_mask_dcp.shape[0],
seq_mask_dcp.shape[1])
cp_seq_len = num_computed_tokens_of_cp_dcp_array[:,
self.cp_rank,
self.pcp_rank,
self.dcp_rank]
cp_seq_len = torch.tensor(cp_seq_len, dtype=torch.int32)
batch_seq_mask = (cp_seq_len == 0)
self.batch_seq_mask_buf[:batch_seq_mask.shape[0]].copy_(
batch_seq_mask, non_blocking=True)
batch_seq_mask = self.batch_seq_mask_buf[:batch_seq_mask.
shape[0]]
cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len)
else:
seq_mask_pcp_shape = (0, 0)
seq_mask_dcp_shape = (0, 0)
cp_seq_len = None
cp_seq_len, batch_seq_mask = None, None
# TODO: After the fullgraph supports MTP, the if branch needs to deleted
assert self.cos_cache is not None
@@ -541,15 +525,8 @@ class AscendMLAMetadataBuilder:
actual_seq_lengths_q=actual_seq_lengths_q,
sin=sin,
cos=cos,
num_computed_tokens_of_pcp_dcp=
num_computed_tokens_of_pcp_dcp,
seq_mask_pcp=self.
seq_mask_pcp_buf[:seq_mask_pcp_shape[0], :
seq_mask_pcp_shape[1]],
seq_mask_dcp=self.
seq_mask_dcp_buf[:seq_mask_dcp_shape[0], :
seq_mask_dcp_shape[1]],
cp_seq_len=cp_seq_len)
cp_seq_len=cp_seq_len,
batch_seq_mask=batch_seq_mask)
else:
cos[:num_decode_tokens,
...] = self.cos_cache[input_positions].unsqueeze(
@@ -568,15 +545,8 @@ class AscendMLAMetadataBuilder:
actual_seq_lengths_q=actual_seq_lengths_q,
sin=sin[:num_decode_tokens, ...],
cos=cos[:num_decode_tokens, ...],
num_computed_tokens_of_pcp_dcp=
num_computed_tokens_of_pcp_dcp,
seq_mask_pcp=self.
seq_mask_pcp_buf[:seq_mask_pcp_shape[0], :
seq_mask_pcp_shape[1]],
seq_mask_dcp=self.
seq_mask_dcp_buf[:seq_mask_dcp_shape[0], :
seq_mask_dcp_shape[1]],
cp_seq_len=cp_seq_len)
cp_seq_len=cp_seq_len,
batch_seq_mask=batch_seq_mask)
return self.metadata_cls( # type: ignore
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
@@ -1663,8 +1633,6 @@ class AscendMLAImpl(MLAAttentionImpl):
q_nope = q_nope.view(num_tokens, num_heads, -1)
q_pe = q_pe.view(num_tokens, num_heads, -1)
# use pcp & dcp split computed token nums from scheduler to compute actual seq_len and seq_mask
seq_mask_pcp = decode_meta.seq_mask_pcp
seq_mask_dcp = decode_meta.seq_mask_dcp
seq_len = decode_meta.cp_seq_len
common_kwargs = {
@@ -1734,9 +1702,56 @@ class AscendMLAImpl(MLAAttentionImpl):
output=attn_output,
lse=softmax_lse)
# Update out&lse
attn_out_lse_list = self._process_attn_out_lse(attn_output,
softmax_lse,
decode_meta)
attn_output = self._npu_attention_update(attn_out_lse_list)
return self._v_up_proj(attn_output)
def _npu_attention_update(
self, attn_out_lse_list: List[torch.Tensor]) -> torch.Tensor:
attn_out_split_cp = []
attn_lse_split_cp = []
for attn_out_lse in attn_out_lse_list:
attn_out_allgather, attn_lse_allgather = self._out_lse_reshape(
*torch.split(attn_out_lse, [self.kv_lora_rank, 1], dim=-1))
attn_out_split_cp.append(attn_out_allgather)
attn_lse_split_cp.append(attn_lse_allgather)
attn_out, _ = torch_npu.npu_attention_update(attn_lse_split_cp,
attn_out_split_cp, 0)
attn_out = attn_out.view(-1, attn_out_lse_list[0].shape[1],
self.kv_lora_rank)
return attn_out
def _out_lse_reshape(self, attn_out: torch.Tensor,
attn_lse: torch.Tensor) -> torch.Tensor:
attn_out = attn_out.contiguous().view(
attn_out.shape[0] * attn_out.shape[1], attn_out.shape[2])
attn_lse = attn_lse.contiguous().view(
attn_lse.shape[0] * attn_lse.shape[1] * attn_lse.shape[2])
return attn_out, attn_lse
def _process_attn_out_lse(
self,
attn_output: torch.Tensor,
softmax_lse: torch.Tensor,
decode_meta: AscendMLADecodeMetadata,
) -> List[torch.Tensor]:
attn_out_lse_list = []
out_mask = decode_meta.batch_seq_mask[:, None,
None].expand_as(attn_output)
attn_output = torch.where(out_mask, 0, attn_output)
lse_mask = decode_meta.batch_seq_mask[:, None,
None].expand_as(softmax_lse)
softmax_lse = torch.where(lse_mask, -torch.inf, softmax_lse)
softmax_lse = softmax_lse.to(torch.float32)
attn_output = attn_output.to(torch.float32)
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
attn_out_lse = torch.cat([attn_output, softmax_lse], dim=-1)
if self.dcp_size > 1:
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
attn_out_lse = torch.cat([attn_output, softmax_lse], dim=-1)
# permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs]
attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous()
attn_out_lse_all2all = torch.empty_like(attn_out_lse)
@@ -1745,24 +1760,12 @@ class AscendMLAImpl(MLAAttentionImpl):
group=self.dcp_group)
# permute: [num_heads, v_head_dim+1, bs] -> [bs, num_heads, v_head_dim+1]
attn_out_lse_all2all = attn_out_lse_all2all.permute([2, 0, 1])
attn_out_lse_split_on_seq = list(
if self.pcp_size > 1:
attn_out_lse = attn_out_lse_all2all.contiguous()
attn_out_lse_list = list(
torch.chunk(attn_out_lse_all2all, self.dcp_size, dim=1))
# Update out&lse
attn_out_g = None
attn_lse_g = None
for i, attn_out_lse_l in enumerate(attn_out_lse_split_on_seq):
attn_out_l, attn_lse_l = torch.split(attn_out_lse_l,
[self.kv_lora_rank, 1],
dim=-1)
attn_out_g, attn_lse_g = self._update_out_and_lse(
attn_out_g, attn_lse_g, attn_out_l, attn_lse_l,
seq_mask_dcp[:, i])
attn_output = attn_out_g
softmax_lse = attn_lse_g
if self.pcp_size > 1:
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
attn_out_lse = torch.cat([attn_output, softmax_lse], dim=-1)
# AllGather out&lse within PCP group
attn_out_lse_list = [
torch.empty_like(attn_out_lse) for _ in range(self.pcp_size)
@@ -1770,45 +1773,12 @@ class AscendMLAImpl(MLAAttentionImpl):
dist.all_gather(attn_out_lse_list,
attn_out_lse,
group=self.pcp_group)
# Update out&lse
attn_out_g = None
attn_lse_g = None
for i, attn_out_lse_l in enumerate(attn_out_lse_list):
attn_out_l, attn_lse_l = torch.split(attn_out_lse_l,
[self.kv_lora_rank, 1],
dim=-1)
attn_out_g, attn_lse_g = self._update_out_and_lse(
attn_out_g, attn_lse_g, attn_out_l, attn_lse_l,
seq_mask_pcp[:, i])
attn_output = attn_out_g
return self._v_up_proj(attn_output)
if self.dcp_size > 1 and self.pcp_size > 1:
attn_out_lse_list_pcp_dcp = []
for s in attn_out_lse_list:
attn_out_lse_list_split = list(
torch.chunk(s, self.dcp_size, dim=1))
attn_out_lse_list_pcp_dcp += attn_out_lse_list_split
attn_out_lse_list = attn_out_lse_list_pcp_dcp
# TODO use update op to replace this
def _update_out_and_lse(
self,
out: torch.Tensor,
lse: torch.Tensor,
block_out: torch.Tensor,
block_lse: torch.Tensor,
mask: torch.Tensor = None,
):
if out is None:
out = block_out.to(torch.float32)
lse = block_lse
else:
if mask is None:
mask = torch.ones([block_out.size(0)],
dtype=torch.uint8,
device=block_out.device)
out_mask = mask[:, None, None].expand_as(block_out)
lse_mask = mask[:, None, None].expand_as(block_lse)
block_out = block_out.to(torch.float32)
out_without_update = out.clone()
lse_without_update = lse.clone()
out = out - F.sigmoid(block_lse - lse) * (out - block_out)
lse = lse - F.logsigmoid(lse - block_lse)
# mask
out = torch.where(out_mask, out, out_without_update)
lse = torch.where(lse_mask, lse, lse_without_update)
return out, lse
return attn_out_lse_list

View File

@@ -1716,10 +1716,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# We will ignore the sampled tokens from the partial requests.
# TODO: Support prompt logprobs.
spec_decode_metadata = None
logits_indices = torch.from_numpy(
cu_num_tokens
) * self.pcp_size - self.num_pcp_pads[:num_reqs] - 1
logits_indices = logits_indices.to(self.device, non_blocking=True)
if self.pcp_size * self.dcp_size > 1:
logits_indices = torch.from_numpy(
cu_num_tokens
) * self.pcp_size - self.num_pcp_pads[:num_reqs] - 1
logits_indices = logits_indices.to(self.device,
non_blocking=True)
else:
logits_indices = torch.from_numpy(cu_num_tokens - 1).to(
self.device, non_blocking=True)
else:
# pcp not supported now
assert self.pcp_size == 1