[Feat] update op for mla (#4000)
### What this PR does / why we need it?
1、in mla_v1 module, add torch_npu.npu_attention_update op when pcp and dcp
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
---------
Signed-off-by: LookAround <lixushi@huawei.com>
This commit is contained in:
@@ -1,11 +1,10 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type,
|
from typing import (TYPE_CHECKING, ClassVar, List, NamedTuple, Optional, Tuple,
|
||||||
TypeVar)
|
Type, TypeVar)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch_npu
|
import torch_npu
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||||
@@ -140,10 +139,8 @@ class AscendMLADecodeMetadata:
|
|||||||
attn_mask: Optional[torch.Tensor] = None
|
attn_mask: Optional[torch.Tensor] = None
|
||||||
sin: torch.Tensor = None
|
sin: torch.Tensor = None
|
||||||
cos: torch.Tensor = None
|
cos: torch.Tensor = None
|
||||||
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
|
|
||||||
seq_mask_pcp: torch.Tensor = None
|
|
||||||
seq_mask_dcp: torch.Tensor = None
|
|
||||||
cp_seq_len: torch.Tensor = None
|
cp_seq_len: torch.Tensor = None
|
||||||
|
batch_seq_mask: torch.Tensor = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -263,9 +260,10 @@ class AscendMLAMetadataBuilder:
|
|||||||
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
||||||
self.cos_cache = None
|
self.cos_cache = None
|
||||||
self.sin_cache = None
|
self.sin_cache = None
|
||||||
|
|
||||||
self.pcp_size = get_prefill_context_model_parallel_world_size(
|
self.pcp_size = get_prefill_context_model_parallel_world_size(
|
||||||
) if prefill_context_parallel_enable() else 1
|
) if prefill_context_parallel_enable() else 1
|
||||||
self.cp_rank = get_prefill_context_model_parallel_rank(
|
self.pcp_rank = get_prefill_context_model_parallel_rank(
|
||||||
) if self.pcp_size > 1 else 0
|
) if self.pcp_size > 1 else 0
|
||||||
self.dcp_size = get_decode_context_model_parallel_world_size()
|
self.dcp_size = get_decode_context_model_parallel_world_size()
|
||||||
self.dcp_rank = get_decode_context_model_parallel_rank(
|
self.dcp_rank = get_decode_context_model_parallel_rank(
|
||||||
@@ -273,6 +271,9 @@ class AscendMLAMetadataBuilder:
|
|||||||
decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs',
|
decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs',
|
||||||
0)
|
0)
|
||||||
max_num_seqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs)
|
max_num_seqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs)
|
||||||
|
self.batch_seq_mask_buf = torch.empty(max_num_seqs,
|
||||||
|
dtype=torch.uint8,
|
||||||
|
device=device)
|
||||||
self.seq_mask_pcp_buf = torch.empty(max_num_seqs,
|
self.seq_mask_pcp_buf = torch.empty(max_num_seqs,
|
||||||
self.pcp_size,
|
self.pcp_size,
|
||||||
dtype=torch.uint8,
|
dtype=torch.uint8,
|
||||||
@@ -489,36 +490,19 @@ class AscendMLAMetadataBuilder:
|
|||||||
num_computed_tokens_of_cp_dcp_array = np.array(
|
num_computed_tokens_of_cp_dcp_array = np.array(
|
||||||
num_computed_tokens_of_pcp_dcp
|
num_computed_tokens_of_pcp_dcp
|
||||||
)[:num_decodes] # [bs, pcp_size, dcp_size]
|
)[:num_decodes] # [bs, pcp_size, dcp_size]
|
||||||
seq_mask_pcp = torch.where(
|
|
||||||
torch.tensor(
|
|
||||||
num_computed_tokens_of_cp_dcp_array.sum(2)) == 0, 0,
|
|
||||||
1).to(torch.uint8)
|
|
||||||
self.seq_mask_pcp_buf[:seq_mask_pcp.shape[0], :seq_mask_pcp.
|
|
||||||
shape[1]].copy_(seq_mask_pcp,
|
|
||||||
non_blocking=True)
|
|
||||||
seq_mask_pcp_shape = (seq_mask_pcp.shape[0],
|
|
||||||
seq_mask_pcp.shape[1])
|
|
||||||
|
|
||||||
seq_mask_dcp = torch.where(
|
|
||||||
torch.tensor(
|
|
||||||
num_computed_tokens_of_cp_dcp_array[:,
|
|
||||||
self.cp_rank, :])
|
|
||||||
== 0, 0, 1).to(torch.uint8)
|
|
||||||
self.seq_mask_dcp_buf[:seq_mask_dcp.shape[0], :seq_mask_dcp.
|
|
||||||
shape[1]].copy_(seq_mask_dcp,
|
|
||||||
non_blocking=True)
|
|
||||||
seq_mask_dcp_shape = (seq_mask_dcp.shape[0],
|
|
||||||
seq_mask_dcp.shape[1])
|
|
||||||
|
|
||||||
cp_seq_len = num_computed_tokens_of_cp_dcp_array[:,
|
cp_seq_len = num_computed_tokens_of_cp_dcp_array[:,
|
||||||
self.cp_rank,
|
self.pcp_rank,
|
||||||
self.dcp_rank]
|
self.dcp_rank]
|
||||||
cp_seq_len = torch.tensor(cp_seq_len, dtype=torch.int32)
|
cp_seq_len = torch.tensor(cp_seq_len, dtype=torch.int32)
|
||||||
|
batch_seq_mask = (cp_seq_len == 0)
|
||||||
|
self.batch_seq_mask_buf[:batch_seq_mask.shape[0]].copy_(
|
||||||
|
batch_seq_mask, non_blocking=True)
|
||||||
|
batch_seq_mask = self.batch_seq_mask_buf[:batch_seq_mask.
|
||||||
|
shape[0]]
|
||||||
cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len)
|
cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len)
|
||||||
else:
|
else:
|
||||||
seq_mask_pcp_shape = (0, 0)
|
cp_seq_len, batch_seq_mask = None, None
|
||||||
seq_mask_dcp_shape = (0, 0)
|
|
||||||
cp_seq_len = None
|
|
||||||
|
|
||||||
# TODO: After the fullgraph supports MTP, the if branch needs to deleted
|
# TODO: After the fullgraph supports MTP, the if branch needs to deleted
|
||||||
assert self.cos_cache is not None
|
assert self.cos_cache is not None
|
||||||
@@ -541,15 +525,8 @@ class AscendMLAMetadataBuilder:
|
|||||||
actual_seq_lengths_q=actual_seq_lengths_q,
|
actual_seq_lengths_q=actual_seq_lengths_q,
|
||||||
sin=sin,
|
sin=sin,
|
||||||
cos=cos,
|
cos=cos,
|
||||||
num_computed_tokens_of_pcp_dcp=
|
cp_seq_len=cp_seq_len,
|
||||||
num_computed_tokens_of_pcp_dcp,
|
batch_seq_mask=batch_seq_mask)
|
||||||
seq_mask_pcp=self.
|
|
||||||
seq_mask_pcp_buf[:seq_mask_pcp_shape[0], :
|
|
||||||
seq_mask_pcp_shape[1]],
|
|
||||||
seq_mask_dcp=self.
|
|
||||||
seq_mask_dcp_buf[:seq_mask_dcp_shape[0], :
|
|
||||||
seq_mask_dcp_shape[1]],
|
|
||||||
cp_seq_len=cp_seq_len)
|
|
||||||
else:
|
else:
|
||||||
cos[:num_decode_tokens,
|
cos[:num_decode_tokens,
|
||||||
...] = self.cos_cache[input_positions].unsqueeze(
|
...] = self.cos_cache[input_positions].unsqueeze(
|
||||||
@@ -568,15 +545,8 @@ class AscendMLAMetadataBuilder:
|
|||||||
actual_seq_lengths_q=actual_seq_lengths_q,
|
actual_seq_lengths_q=actual_seq_lengths_q,
|
||||||
sin=sin[:num_decode_tokens, ...],
|
sin=sin[:num_decode_tokens, ...],
|
||||||
cos=cos[:num_decode_tokens, ...],
|
cos=cos[:num_decode_tokens, ...],
|
||||||
num_computed_tokens_of_pcp_dcp=
|
cp_seq_len=cp_seq_len,
|
||||||
num_computed_tokens_of_pcp_dcp,
|
batch_seq_mask=batch_seq_mask)
|
||||||
seq_mask_pcp=self.
|
|
||||||
seq_mask_pcp_buf[:seq_mask_pcp_shape[0], :
|
|
||||||
seq_mask_pcp_shape[1]],
|
|
||||||
seq_mask_dcp=self.
|
|
||||||
seq_mask_dcp_buf[:seq_mask_dcp_shape[0], :
|
|
||||||
seq_mask_dcp_shape[1]],
|
|
||||||
cp_seq_len=cp_seq_len)
|
|
||||||
|
|
||||||
return self.metadata_cls( # type: ignore
|
return self.metadata_cls( # type: ignore
|
||||||
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
|
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
|
||||||
@@ -1663,8 +1633,6 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
q_nope = q_nope.view(num_tokens, num_heads, -1)
|
q_nope = q_nope.view(num_tokens, num_heads, -1)
|
||||||
q_pe = q_pe.view(num_tokens, num_heads, -1)
|
q_pe = q_pe.view(num_tokens, num_heads, -1)
|
||||||
# use pcp & dcp split computed token nums from scheduler to compute actual seq_len and seq_mask
|
# use pcp & dcp split computed token nums from scheduler to compute actual seq_len and seq_mask
|
||||||
seq_mask_pcp = decode_meta.seq_mask_pcp
|
|
||||||
seq_mask_dcp = decode_meta.seq_mask_dcp
|
|
||||||
seq_len = decode_meta.cp_seq_len
|
seq_len = decode_meta.cp_seq_len
|
||||||
|
|
||||||
common_kwargs = {
|
common_kwargs = {
|
||||||
@@ -1734,9 +1702,56 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
output=attn_output,
|
output=attn_output,
|
||||||
lse=softmax_lse)
|
lse=softmax_lse)
|
||||||
|
|
||||||
|
# Update out&lse
|
||||||
|
attn_out_lse_list = self._process_attn_out_lse(attn_output,
|
||||||
|
softmax_lse,
|
||||||
|
decode_meta)
|
||||||
|
attn_output = self._npu_attention_update(attn_out_lse_list)
|
||||||
|
return self._v_up_proj(attn_output)
|
||||||
|
|
||||||
|
def _npu_attention_update(
|
||||||
|
self, attn_out_lse_list: List[torch.Tensor]) -> torch.Tensor:
|
||||||
|
attn_out_split_cp = []
|
||||||
|
attn_lse_split_cp = []
|
||||||
|
|
||||||
|
for attn_out_lse in attn_out_lse_list:
|
||||||
|
attn_out_allgather, attn_lse_allgather = self._out_lse_reshape(
|
||||||
|
*torch.split(attn_out_lse, [self.kv_lora_rank, 1], dim=-1))
|
||||||
|
attn_out_split_cp.append(attn_out_allgather)
|
||||||
|
attn_lse_split_cp.append(attn_lse_allgather)
|
||||||
|
attn_out, _ = torch_npu.npu_attention_update(attn_lse_split_cp,
|
||||||
|
attn_out_split_cp, 0)
|
||||||
|
attn_out = attn_out.view(-1, attn_out_lse_list[0].shape[1],
|
||||||
|
self.kv_lora_rank)
|
||||||
|
return attn_out
|
||||||
|
|
||||||
|
def _out_lse_reshape(self, attn_out: torch.Tensor,
|
||||||
|
attn_lse: torch.Tensor) -> torch.Tensor:
|
||||||
|
attn_out = attn_out.contiguous().view(
|
||||||
|
attn_out.shape[0] * attn_out.shape[1], attn_out.shape[2])
|
||||||
|
attn_lse = attn_lse.contiguous().view(
|
||||||
|
attn_lse.shape[0] * attn_lse.shape[1] * attn_lse.shape[2])
|
||||||
|
return attn_out, attn_lse
|
||||||
|
|
||||||
|
def _process_attn_out_lse(
|
||||||
|
self,
|
||||||
|
attn_output: torch.Tensor,
|
||||||
|
softmax_lse: torch.Tensor,
|
||||||
|
decode_meta: AscendMLADecodeMetadata,
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
attn_out_lse_list = []
|
||||||
|
out_mask = decode_meta.batch_seq_mask[:, None,
|
||||||
|
None].expand_as(attn_output)
|
||||||
|
attn_output = torch.where(out_mask, 0, attn_output)
|
||||||
|
lse_mask = decode_meta.batch_seq_mask[:, None,
|
||||||
|
None].expand_as(softmax_lse)
|
||||||
|
softmax_lse = torch.where(lse_mask, -torch.inf, softmax_lse)
|
||||||
|
|
||||||
|
softmax_lse = softmax_lse.to(torch.float32)
|
||||||
|
attn_output = attn_output.to(torch.float32)
|
||||||
|
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
|
||||||
|
attn_out_lse = torch.cat([attn_output, softmax_lse], dim=-1)
|
||||||
if self.dcp_size > 1:
|
if self.dcp_size > 1:
|
||||||
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
|
|
||||||
attn_out_lse = torch.cat([attn_output, softmax_lse], dim=-1)
|
|
||||||
# permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs]
|
# permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs]
|
||||||
attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous()
|
attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous()
|
||||||
attn_out_lse_all2all = torch.empty_like(attn_out_lse)
|
attn_out_lse_all2all = torch.empty_like(attn_out_lse)
|
||||||
@@ -1745,24 +1760,12 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
group=self.dcp_group)
|
group=self.dcp_group)
|
||||||
# permute: [num_heads, v_head_dim+1, bs] -> [bs, num_heads, v_head_dim+1]
|
# permute: [num_heads, v_head_dim+1, bs] -> [bs, num_heads, v_head_dim+1]
|
||||||
attn_out_lse_all2all = attn_out_lse_all2all.permute([2, 0, 1])
|
attn_out_lse_all2all = attn_out_lse_all2all.permute([2, 0, 1])
|
||||||
attn_out_lse_split_on_seq = list(
|
if self.pcp_size > 1:
|
||||||
|
attn_out_lse = attn_out_lse_all2all.contiguous()
|
||||||
|
attn_out_lse_list = list(
|
||||||
torch.chunk(attn_out_lse_all2all, self.dcp_size, dim=1))
|
torch.chunk(attn_out_lse_all2all, self.dcp_size, dim=1))
|
||||||
# Update out&lse
|
|
||||||
attn_out_g = None
|
|
||||||
attn_lse_g = None
|
|
||||||
for i, attn_out_lse_l in enumerate(attn_out_lse_split_on_seq):
|
|
||||||
attn_out_l, attn_lse_l = torch.split(attn_out_lse_l,
|
|
||||||
[self.kv_lora_rank, 1],
|
|
||||||
dim=-1)
|
|
||||||
attn_out_g, attn_lse_g = self._update_out_and_lse(
|
|
||||||
attn_out_g, attn_lse_g, attn_out_l, attn_lse_l,
|
|
||||||
seq_mask_dcp[:, i])
|
|
||||||
attn_output = attn_out_g
|
|
||||||
softmax_lse = attn_lse_g
|
|
||||||
|
|
||||||
if self.pcp_size > 1:
|
if self.pcp_size > 1:
|
||||||
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
|
|
||||||
attn_out_lse = torch.cat([attn_output, softmax_lse], dim=-1)
|
|
||||||
# AllGather out&lse within PCP group
|
# AllGather out&lse within PCP group
|
||||||
attn_out_lse_list = [
|
attn_out_lse_list = [
|
||||||
torch.empty_like(attn_out_lse) for _ in range(self.pcp_size)
|
torch.empty_like(attn_out_lse) for _ in range(self.pcp_size)
|
||||||
@@ -1770,45 +1773,12 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
dist.all_gather(attn_out_lse_list,
|
dist.all_gather(attn_out_lse_list,
|
||||||
attn_out_lse,
|
attn_out_lse,
|
||||||
group=self.pcp_group)
|
group=self.pcp_group)
|
||||||
# Update out&lse
|
if self.dcp_size > 1 and self.pcp_size > 1:
|
||||||
attn_out_g = None
|
attn_out_lse_list_pcp_dcp = []
|
||||||
attn_lse_g = None
|
for s in attn_out_lse_list:
|
||||||
for i, attn_out_lse_l in enumerate(attn_out_lse_list):
|
attn_out_lse_list_split = list(
|
||||||
attn_out_l, attn_lse_l = torch.split(attn_out_lse_l,
|
torch.chunk(s, self.dcp_size, dim=1))
|
||||||
[self.kv_lora_rank, 1],
|
attn_out_lse_list_pcp_dcp += attn_out_lse_list_split
|
||||||
dim=-1)
|
attn_out_lse_list = attn_out_lse_list_pcp_dcp
|
||||||
attn_out_g, attn_lse_g = self._update_out_and_lse(
|
|
||||||
attn_out_g, attn_lse_g, attn_out_l, attn_lse_l,
|
|
||||||
seq_mask_pcp[:, i])
|
|
||||||
attn_output = attn_out_g
|
|
||||||
return self._v_up_proj(attn_output)
|
|
||||||
|
|
||||||
# TODO use update op to replace this
|
return attn_out_lse_list
|
||||||
def _update_out_and_lse(
|
|
||||||
self,
|
|
||||||
out: torch.Tensor,
|
|
||||||
lse: torch.Tensor,
|
|
||||||
block_out: torch.Tensor,
|
|
||||||
block_lse: torch.Tensor,
|
|
||||||
mask: torch.Tensor = None,
|
|
||||||
):
|
|
||||||
if out is None:
|
|
||||||
out = block_out.to(torch.float32)
|
|
||||||
lse = block_lse
|
|
||||||
else:
|
|
||||||
if mask is None:
|
|
||||||
mask = torch.ones([block_out.size(0)],
|
|
||||||
dtype=torch.uint8,
|
|
||||||
device=block_out.device)
|
|
||||||
out_mask = mask[:, None, None].expand_as(block_out)
|
|
||||||
lse_mask = mask[:, None, None].expand_as(block_lse)
|
|
||||||
block_out = block_out.to(torch.float32)
|
|
||||||
out_without_update = out.clone()
|
|
||||||
lse_without_update = lse.clone()
|
|
||||||
|
|
||||||
out = out - F.sigmoid(block_lse - lse) * (out - block_out)
|
|
||||||
lse = lse - F.logsigmoid(lse - block_lse)
|
|
||||||
# mask
|
|
||||||
out = torch.where(out_mask, out, out_without_update)
|
|
||||||
lse = torch.where(lse_mask, lse, lse_without_update)
|
|
||||||
return out, lse
|
|
||||||
|
|||||||
@@ -1716,10 +1716,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# We will ignore the sampled tokens from the partial requests.
|
# We will ignore the sampled tokens from the partial requests.
|
||||||
# TODO: Support prompt logprobs.
|
# TODO: Support prompt logprobs.
|
||||||
spec_decode_metadata = None
|
spec_decode_metadata = None
|
||||||
logits_indices = torch.from_numpy(
|
if self.pcp_size * self.dcp_size > 1:
|
||||||
cu_num_tokens
|
logits_indices = torch.from_numpy(
|
||||||
) * self.pcp_size - self.num_pcp_pads[:num_reqs] - 1
|
cu_num_tokens
|
||||||
logits_indices = logits_indices.to(self.device, non_blocking=True)
|
) * self.pcp_size - self.num_pcp_pads[:num_reqs] - 1
|
||||||
|
logits_indices = logits_indices.to(self.device,
|
||||||
|
non_blocking=True)
|
||||||
|
else:
|
||||||
|
logits_indices = torch.from_numpy(cu_num_tokens - 1).to(
|
||||||
|
self.device, non_blocking=True)
|
||||||
else:
|
else:
|
||||||
# pcp not supported now
|
# pcp not supported now
|
||||||
assert self.pcp_size == 1
|
assert self.pcp_size == 1
|
||||||
|
|||||||
Reference in New Issue
Block a user