[Feat] update op for mla (#4000)
### What this PR does / why we need it?
1、in mla_v1 module, add torch_npu.npu_attention_update op when pcp and dcp
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
---------
Signed-off-by: LookAround <lixushi@huawei.com>
This commit is contained in:
@@ -1,11 +1,10 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type,
|
||||
TypeVar)
|
||||
from typing import (TYPE_CHECKING, ClassVar, List, NamedTuple, Optional, Tuple,
|
||||
Type, TypeVar)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
from torch import nn
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
@@ -140,10 +139,8 @@ class AscendMLADecodeMetadata:
|
||||
attn_mask: Optional[torch.Tensor] = None
|
||||
sin: torch.Tensor = None
|
||||
cos: torch.Tensor = None
|
||||
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
|
||||
seq_mask_pcp: torch.Tensor = None
|
||||
seq_mask_dcp: torch.Tensor = None
|
||||
cp_seq_len: torch.Tensor = None
|
||||
batch_seq_mask: torch.Tensor = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -263,9 +260,10 @@ class AscendMLAMetadataBuilder:
|
||||
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
||||
self.cos_cache = None
|
||||
self.sin_cache = None
|
||||
|
||||
self.pcp_size = get_prefill_context_model_parallel_world_size(
|
||||
) if prefill_context_parallel_enable() else 1
|
||||
self.cp_rank = get_prefill_context_model_parallel_rank(
|
||||
self.pcp_rank = get_prefill_context_model_parallel_rank(
|
||||
) if self.pcp_size > 1 else 0
|
||||
self.dcp_size = get_decode_context_model_parallel_world_size()
|
||||
self.dcp_rank = get_decode_context_model_parallel_rank(
|
||||
@@ -273,6 +271,9 @@ class AscendMLAMetadataBuilder:
|
||||
decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs',
|
||||
0)
|
||||
max_num_seqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs)
|
||||
self.batch_seq_mask_buf = torch.empty(max_num_seqs,
|
||||
dtype=torch.uint8,
|
||||
device=device)
|
||||
self.seq_mask_pcp_buf = torch.empty(max_num_seqs,
|
||||
self.pcp_size,
|
||||
dtype=torch.uint8,
|
||||
@@ -489,36 +490,19 @@ class AscendMLAMetadataBuilder:
|
||||
num_computed_tokens_of_cp_dcp_array = np.array(
|
||||
num_computed_tokens_of_pcp_dcp
|
||||
)[:num_decodes] # [bs, pcp_size, dcp_size]
|
||||
seq_mask_pcp = torch.where(
|
||||
torch.tensor(
|
||||
num_computed_tokens_of_cp_dcp_array.sum(2)) == 0, 0,
|
||||
1).to(torch.uint8)
|
||||
self.seq_mask_pcp_buf[:seq_mask_pcp.shape[0], :seq_mask_pcp.
|
||||
shape[1]].copy_(seq_mask_pcp,
|
||||
non_blocking=True)
|
||||
seq_mask_pcp_shape = (seq_mask_pcp.shape[0],
|
||||
seq_mask_pcp.shape[1])
|
||||
|
||||
seq_mask_dcp = torch.where(
|
||||
torch.tensor(
|
||||
num_computed_tokens_of_cp_dcp_array[:,
|
||||
self.cp_rank, :])
|
||||
== 0, 0, 1).to(torch.uint8)
|
||||
self.seq_mask_dcp_buf[:seq_mask_dcp.shape[0], :seq_mask_dcp.
|
||||
shape[1]].copy_(seq_mask_dcp,
|
||||
non_blocking=True)
|
||||
seq_mask_dcp_shape = (seq_mask_dcp.shape[0],
|
||||
seq_mask_dcp.shape[1])
|
||||
|
||||
cp_seq_len = num_computed_tokens_of_cp_dcp_array[:,
|
||||
self.cp_rank,
|
||||
self.pcp_rank,
|
||||
self.dcp_rank]
|
||||
cp_seq_len = torch.tensor(cp_seq_len, dtype=torch.int32)
|
||||
batch_seq_mask = (cp_seq_len == 0)
|
||||
self.batch_seq_mask_buf[:batch_seq_mask.shape[0]].copy_(
|
||||
batch_seq_mask, non_blocking=True)
|
||||
batch_seq_mask = self.batch_seq_mask_buf[:batch_seq_mask.
|
||||
shape[0]]
|
||||
cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len)
|
||||
else:
|
||||
seq_mask_pcp_shape = (0, 0)
|
||||
seq_mask_dcp_shape = (0, 0)
|
||||
cp_seq_len = None
|
||||
cp_seq_len, batch_seq_mask = None, None
|
||||
|
||||
# TODO: After the fullgraph supports MTP, the if branch needs to deleted
|
||||
assert self.cos_cache is not None
|
||||
@@ -541,15 +525,8 @@ class AscendMLAMetadataBuilder:
|
||||
actual_seq_lengths_q=actual_seq_lengths_q,
|
||||
sin=sin,
|
||||
cos=cos,
|
||||
num_computed_tokens_of_pcp_dcp=
|
||||
num_computed_tokens_of_pcp_dcp,
|
||||
seq_mask_pcp=self.
|
||||
seq_mask_pcp_buf[:seq_mask_pcp_shape[0], :
|
||||
seq_mask_pcp_shape[1]],
|
||||
seq_mask_dcp=self.
|
||||
seq_mask_dcp_buf[:seq_mask_dcp_shape[0], :
|
||||
seq_mask_dcp_shape[1]],
|
||||
cp_seq_len=cp_seq_len)
|
||||
cp_seq_len=cp_seq_len,
|
||||
batch_seq_mask=batch_seq_mask)
|
||||
else:
|
||||
cos[:num_decode_tokens,
|
||||
...] = self.cos_cache[input_positions].unsqueeze(
|
||||
@@ -568,15 +545,8 @@ class AscendMLAMetadataBuilder:
|
||||
actual_seq_lengths_q=actual_seq_lengths_q,
|
||||
sin=sin[:num_decode_tokens, ...],
|
||||
cos=cos[:num_decode_tokens, ...],
|
||||
num_computed_tokens_of_pcp_dcp=
|
||||
num_computed_tokens_of_pcp_dcp,
|
||||
seq_mask_pcp=self.
|
||||
seq_mask_pcp_buf[:seq_mask_pcp_shape[0], :
|
||||
seq_mask_pcp_shape[1]],
|
||||
seq_mask_dcp=self.
|
||||
seq_mask_dcp_buf[:seq_mask_dcp_shape[0], :
|
||||
seq_mask_dcp_shape[1]],
|
||||
cp_seq_len=cp_seq_len)
|
||||
cp_seq_len=cp_seq_len,
|
||||
batch_seq_mask=batch_seq_mask)
|
||||
|
||||
return self.metadata_cls( # type: ignore
|
||||
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
|
||||
@@ -1663,8 +1633,6 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
q_nope = q_nope.view(num_tokens, num_heads, -1)
|
||||
q_pe = q_pe.view(num_tokens, num_heads, -1)
|
||||
# use pcp & dcp split computed token nums from scheduler to compute actual seq_len and seq_mask
|
||||
seq_mask_pcp = decode_meta.seq_mask_pcp
|
||||
seq_mask_dcp = decode_meta.seq_mask_dcp
|
||||
seq_len = decode_meta.cp_seq_len
|
||||
|
||||
common_kwargs = {
|
||||
@@ -1734,9 +1702,56 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
output=attn_output,
|
||||
lse=softmax_lse)
|
||||
|
||||
# Update out&lse
|
||||
attn_out_lse_list = self._process_attn_out_lse(attn_output,
|
||||
softmax_lse,
|
||||
decode_meta)
|
||||
attn_output = self._npu_attention_update(attn_out_lse_list)
|
||||
return self._v_up_proj(attn_output)
|
||||
|
||||
def _npu_attention_update(
|
||||
self, attn_out_lse_list: List[torch.Tensor]) -> torch.Tensor:
|
||||
attn_out_split_cp = []
|
||||
attn_lse_split_cp = []
|
||||
|
||||
for attn_out_lse in attn_out_lse_list:
|
||||
attn_out_allgather, attn_lse_allgather = self._out_lse_reshape(
|
||||
*torch.split(attn_out_lse, [self.kv_lora_rank, 1], dim=-1))
|
||||
attn_out_split_cp.append(attn_out_allgather)
|
||||
attn_lse_split_cp.append(attn_lse_allgather)
|
||||
attn_out, _ = torch_npu.npu_attention_update(attn_lse_split_cp,
|
||||
attn_out_split_cp, 0)
|
||||
attn_out = attn_out.view(-1, attn_out_lse_list[0].shape[1],
|
||||
self.kv_lora_rank)
|
||||
return attn_out
|
||||
|
||||
def _out_lse_reshape(self, attn_out: torch.Tensor,
|
||||
attn_lse: torch.Tensor) -> torch.Tensor:
|
||||
attn_out = attn_out.contiguous().view(
|
||||
attn_out.shape[0] * attn_out.shape[1], attn_out.shape[2])
|
||||
attn_lse = attn_lse.contiguous().view(
|
||||
attn_lse.shape[0] * attn_lse.shape[1] * attn_lse.shape[2])
|
||||
return attn_out, attn_lse
|
||||
|
||||
def _process_attn_out_lse(
|
||||
self,
|
||||
attn_output: torch.Tensor,
|
||||
softmax_lse: torch.Tensor,
|
||||
decode_meta: AscendMLADecodeMetadata,
|
||||
) -> List[torch.Tensor]:
|
||||
attn_out_lse_list = []
|
||||
out_mask = decode_meta.batch_seq_mask[:, None,
|
||||
None].expand_as(attn_output)
|
||||
attn_output = torch.where(out_mask, 0, attn_output)
|
||||
lse_mask = decode_meta.batch_seq_mask[:, None,
|
||||
None].expand_as(softmax_lse)
|
||||
softmax_lse = torch.where(lse_mask, -torch.inf, softmax_lse)
|
||||
|
||||
softmax_lse = softmax_lse.to(torch.float32)
|
||||
attn_output = attn_output.to(torch.float32)
|
||||
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
|
||||
attn_out_lse = torch.cat([attn_output, softmax_lse], dim=-1)
|
||||
if self.dcp_size > 1:
|
||||
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
|
||||
attn_out_lse = torch.cat([attn_output, softmax_lse], dim=-1)
|
||||
# permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs]
|
||||
attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous()
|
||||
attn_out_lse_all2all = torch.empty_like(attn_out_lse)
|
||||
@@ -1745,24 +1760,12 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
group=self.dcp_group)
|
||||
# permute: [num_heads, v_head_dim+1, bs] -> [bs, num_heads, v_head_dim+1]
|
||||
attn_out_lse_all2all = attn_out_lse_all2all.permute([2, 0, 1])
|
||||
attn_out_lse_split_on_seq = list(
|
||||
if self.pcp_size > 1:
|
||||
attn_out_lse = attn_out_lse_all2all.contiguous()
|
||||
attn_out_lse_list = list(
|
||||
torch.chunk(attn_out_lse_all2all, self.dcp_size, dim=1))
|
||||
# Update out&lse
|
||||
attn_out_g = None
|
||||
attn_lse_g = None
|
||||
for i, attn_out_lse_l in enumerate(attn_out_lse_split_on_seq):
|
||||
attn_out_l, attn_lse_l = torch.split(attn_out_lse_l,
|
||||
[self.kv_lora_rank, 1],
|
||||
dim=-1)
|
||||
attn_out_g, attn_lse_g = self._update_out_and_lse(
|
||||
attn_out_g, attn_lse_g, attn_out_l, attn_lse_l,
|
||||
seq_mask_dcp[:, i])
|
||||
attn_output = attn_out_g
|
||||
softmax_lse = attn_lse_g
|
||||
|
||||
if self.pcp_size > 1:
|
||||
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
|
||||
attn_out_lse = torch.cat([attn_output, softmax_lse], dim=-1)
|
||||
# AllGather out&lse within PCP group
|
||||
attn_out_lse_list = [
|
||||
torch.empty_like(attn_out_lse) for _ in range(self.pcp_size)
|
||||
@@ -1770,45 +1773,12 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
dist.all_gather(attn_out_lse_list,
|
||||
attn_out_lse,
|
||||
group=self.pcp_group)
|
||||
# Update out&lse
|
||||
attn_out_g = None
|
||||
attn_lse_g = None
|
||||
for i, attn_out_lse_l in enumerate(attn_out_lse_list):
|
||||
attn_out_l, attn_lse_l = torch.split(attn_out_lse_l,
|
||||
[self.kv_lora_rank, 1],
|
||||
dim=-1)
|
||||
attn_out_g, attn_lse_g = self._update_out_and_lse(
|
||||
attn_out_g, attn_lse_g, attn_out_l, attn_lse_l,
|
||||
seq_mask_pcp[:, i])
|
||||
attn_output = attn_out_g
|
||||
return self._v_up_proj(attn_output)
|
||||
if self.dcp_size > 1 and self.pcp_size > 1:
|
||||
attn_out_lse_list_pcp_dcp = []
|
||||
for s in attn_out_lse_list:
|
||||
attn_out_lse_list_split = list(
|
||||
torch.chunk(s, self.dcp_size, dim=1))
|
||||
attn_out_lse_list_pcp_dcp += attn_out_lse_list_split
|
||||
attn_out_lse_list = attn_out_lse_list_pcp_dcp
|
||||
|
||||
# TODO use update op to replace this
|
||||
def _update_out_and_lse(
|
||||
self,
|
||||
out: torch.Tensor,
|
||||
lse: torch.Tensor,
|
||||
block_out: torch.Tensor,
|
||||
block_lse: torch.Tensor,
|
||||
mask: torch.Tensor = None,
|
||||
):
|
||||
if out is None:
|
||||
out = block_out.to(torch.float32)
|
||||
lse = block_lse
|
||||
else:
|
||||
if mask is None:
|
||||
mask = torch.ones([block_out.size(0)],
|
||||
dtype=torch.uint8,
|
||||
device=block_out.device)
|
||||
out_mask = mask[:, None, None].expand_as(block_out)
|
||||
lse_mask = mask[:, None, None].expand_as(block_lse)
|
||||
block_out = block_out.to(torch.float32)
|
||||
out_without_update = out.clone()
|
||||
lse_without_update = lse.clone()
|
||||
|
||||
out = out - F.sigmoid(block_lse - lse) * (out - block_out)
|
||||
lse = lse - F.logsigmoid(lse - block_lse)
|
||||
# mask
|
||||
out = torch.where(out_mask, out, out_without_update)
|
||||
lse = torch.where(lse_mask, lse, lse_without_update)
|
||||
return out, lse
|
||||
return attn_out_lse_list
|
||||
|
||||
@@ -1716,10 +1716,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# We will ignore the sampled tokens from the partial requests.
|
||||
# TODO: Support prompt logprobs.
|
||||
spec_decode_metadata = None
|
||||
logits_indices = torch.from_numpy(
|
||||
cu_num_tokens
|
||||
) * self.pcp_size - self.num_pcp_pads[:num_reqs] - 1
|
||||
logits_indices = logits_indices.to(self.device, non_blocking=True)
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
logits_indices = torch.from_numpy(
|
||||
cu_num_tokens
|
||||
) * self.pcp_size - self.num_pcp_pads[:num_reqs] - 1
|
||||
logits_indices = logits_indices.to(self.device,
|
||||
non_blocking=True)
|
||||
else:
|
||||
logits_indices = torch.from_numpy(cu_num_tokens - 1).to(
|
||||
self.device, non_blocking=True)
|
||||
else:
|
||||
# pcp not supported now
|
||||
assert self.pcp_size == 1
|
||||
|
||||
Reference in New Issue
Block a user