diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 9e7e3c32..f2717851 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -1,11 +1,10 @@ from dataclasses import dataclass -from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type, - TypeVar) +from typing import (TYPE_CHECKING, ClassVar, List, NamedTuple, Optional, Tuple, + Type, TypeVar) import numpy as np import torch import torch.distributed as dist -import torch.nn.functional as F import torch_npu from torch import nn from vllm.attention.backends.abstract import (AttentionBackend, @@ -140,10 +139,8 @@ class AscendMLADecodeMetadata: attn_mask: Optional[torch.Tensor] = None sin: torch.Tensor = None cos: torch.Tensor = None - num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None - seq_mask_pcp: torch.Tensor = None - seq_mask_dcp: torch.Tensor = None cp_seq_len: torch.Tensor = None + batch_seq_mask: torch.Tensor = None @dataclass @@ -263,9 +260,10 @@ class AscendMLAMetadataBuilder: self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim self.cos_cache = None self.sin_cache = None + self.pcp_size = get_prefill_context_model_parallel_world_size( ) if prefill_context_parallel_enable() else 1 - self.cp_rank = get_prefill_context_model_parallel_rank( + self.pcp_rank = get_prefill_context_model_parallel_rank( ) if self.pcp_size > 1 else 0 self.dcp_size = get_decode_context_model_parallel_world_size() self.dcp_rank = get_decode_context_model_parallel_rank( @@ -273,6 +271,9 @@ class AscendMLAMetadataBuilder: decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs', 0) max_num_seqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs) + self.batch_seq_mask_buf = torch.empty(max_num_seqs, + dtype=torch.uint8, + device=device) self.seq_mask_pcp_buf = torch.empty(max_num_seqs, self.pcp_size, dtype=torch.uint8, @@ -489,36 +490,19 @@ class AscendMLAMetadataBuilder: num_computed_tokens_of_cp_dcp_array = np.array( num_computed_tokens_of_pcp_dcp )[:num_decodes] # [bs, pcp_size, dcp_size] - seq_mask_pcp = torch.where( - torch.tensor( - num_computed_tokens_of_cp_dcp_array.sum(2)) == 0, 0, - 1).to(torch.uint8) - self.seq_mask_pcp_buf[:seq_mask_pcp.shape[0], :seq_mask_pcp. - shape[1]].copy_(seq_mask_pcp, - non_blocking=True) - seq_mask_pcp_shape = (seq_mask_pcp.shape[0], - seq_mask_pcp.shape[1]) - - seq_mask_dcp = torch.where( - torch.tensor( - num_computed_tokens_of_cp_dcp_array[:, - self.cp_rank, :]) - == 0, 0, 1).to(torch.uint8) - self.seq_mask_dcp_buf[:seq_mask_dcp.shape[0], :seq_mask_dcp. - shape[1]].copy_(seq_mask_dcp, - non_blocking=True) - seq_mask_dcp_shape = (seq_mask_dcp.shape[0], - seq_mask_dcp.shape[1]) cp_seq_len = num_computed_tokens_of_cp_dcp_array[:, - self.cp_rank, + self.pcp_rank, self.dcp_rank] cp_seq_len = torch.tensor(cp_seq_len, dtype=torch.int32) + batch_seq_mask = (cp_seq_len == 0) + self.batch_seq_mask_buf[:batch_seq_mask.shape[0]].copy_( + batch_seq_mask, non_blocking=True) + batch_seq_mask = self.batch_seq_mask_buf[:batch_seq_mask. + shape[0]] cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len) else: - seq_mask_pcp_shape = (0, 0) - seq_mask_dcp_shape = (0, 0) - cp_seq_len = None + cp_seq_len, batch_seq_mask = None, None # TODO: After the fullgraph supports MTP, the if branch needs to deleted assert self.cos_cache is not None @@ -541,15 +525,8 @@ class AscendMLAMetadataBuilder: actual_seq_lengths_q=actual_seq_lengths_q, sin=sin, cos=cos, - num_computed_tokens_of_pcp_dcp= - num_computed_tokens_of_pcp_dcp, - seq_mask_pcp=self. - seq_mask_pcp_buf[:seq_mask_pcp_shape[0], : - seq_mask_pcp_shape[1]], - seq_mask_dcp=self. - seq_mask_dcp_buf[:seq_mask_dcp_shape[0], : - seq_mask_dcp_shape[1]], - cp_seq_len=cp_seq_len) + cp_seq_len=cp_seq_len, + batch_seq_mask=batch_seq_mask) else: cos[:num_decode_tokens, ...] = self.cos_cache[input_positions].unsqueeze( @@ -568,15 +545,8 @@ class AscendMLAMetadataBuilder: actual_seq_lengths_q=actual_seq_lengths_q, sin=sin[:num_decode_tokens, ...], cos=cos[:num_decode_tokens, ...], - num_computed_tokens_of_pcp_dcp= - num_computed_tokens_of_pcp_dcp, - seq_mask_pcp=self. - seq_mask_pcp_buf[:seq_mask_pcp_shape[0], : - seq_mask_pcp_shape[1]], - seq_mask_dcp=self. - seq_mask_dcp_buf[:seq_mask_dcp_shape[0], : - seq_mask_dcp_shape[1]], - cp_seq_len=cp_seq_len) + cp_seq_len=cp_seq_len, + batch_seq_mask=batch_seq_mask) return self.metadata_cls( # type: ignore num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded, @@ -1663,8 +1633,6 @@ class AscendMLAImpl(MLAAttentionImpl): q_nope = q_nope.view(num_tokens, num_heads, -1) q_pe = q_pe.view(num_tokens, num_heads, -1) # use pcp & dcp split computed token nums from scheduler to compute actual seq_len and seq_mask - seq_mask_pcp = decode_meta.seq_mask_pcp - seq_mask_dcp = decode_meta.seq_mask_dcp seq_len = decode_meta.cp_seq_len common_kwargs = { @@ -1734,9 +1702,56 @@ class AscendMLAImpl(MLAAttentionImpl): output=attn_output, lse=softmax_lse) + # Update out&lse + attn_out_lse_list = self._process_attn_out_lse(attn_output, + softmax_lse, + decode_meta) + attn_output = self._npu_attention_update(attn_out_lse_list) + return self._v_up_proj(attn_output) + + def _npu_attention_update( + self, attn_out_lse_list: List[torch.Tensor]) -> torch.Tensor: + attn_out_split_cp = [] + attn_lse_split_cp = [] + + for attn_out_lse in attn_out_lse_list: + attn_out_allgather, attn_lse_allgather = self._out_lse_reshape( + *torch.split(attn_out_lse, [self.kv_lora_rank, 1], dim=-1)) + attn_out_split_cp.append(attn_out_allgather) + attn_lse_split_cp.append(attn_lse_allgather) + attn_out, _ = torch_npu.npu_attention_update(attn_lse_split_cp, + attn_out_split_cp, 0) + attn_out = attn_out.view(-1, attn_out_lse_list[0].shape[1], + self.kv_lora_rank) + return attn_out + + def _out_lse_reshape(self, attn_out: torch.Tensor, + attn_lse: torch.Tensor) -> torch.Tensor: + attn_out = attn_out.contiguous().view( + attn_out.shape[0] * attn_out.shape[1], attn_out.shape[2]) + attn_lse = attn_lse.contiguous().view( + attn_lse.shape[0] * attn_lse.shape[1] * attn_lse.shape[2]) + return attn_out, attn_lse + + def _process_attn_out_lse( + self, + attn_output: torch.Tensor, + softmax_lse: torch.Tensor, + decode_meta: AscendMLADecodeMetadata, + ) -> List[torch.Tensor]: + attn_out_lse_list = [] + out_mask = decode_meta.batch_seq_mask[:, None, + None].expand_as(attn_output) + attn_output = torch.where(out_mask, 0, attn_output) + lse_mask = decode_meta.batch_seq_mask[:, None, + None].expand_as(softmax_lse) + softmax_lse = torch.where(lse_mask, -torch.inf, softmax_lse) + + softmax_lse = softmax_lse.to(torch.float32) + attn_output = attn_output.to(torch.float32) + # Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1] + attn_out_lse = torch.cat([attn_output, softmax_lse], dim=-1) if self.dcp_size > 1: - # Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1] - attn_out_lse = torch.cat([attn_output, softmax_lse], dim=-1) # permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs] attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous() attn_out_lse_all2all = torch.empty_like(attn_out_lse) @@ -1745,24 +1760,12 @@ class AscendMLAImpl(MLAAttentionImpl): group=self.dcp_group) # permute: [num_heads, v_head_dim+1, bs] -> [bs, num_heads, v_head_dim+1] attn_out_lse_all2all = attn_out_lse_all2all.permute([2, 0, 1]) - attn_out_lse_split_on_seq = list( + if self.pcp_size > 1: + attn_out_lse = attn_out_lse_all2all.contiguous() + attn_out_lse_list = list( torch.chunk(attn_out_lse_all2all, self.dcp_size, dim=1)) - # Update out&lse - attn_out_g = None - attn_lse_g = None - for i, attn_out_lse_l in enumerate(attn_out_lse_split_on_seq): - attn_out_l, attn_lse_l = torch.split(attn_out_lse_l, - [self.kv_lora_rank, 1], - dim=-1) - attn_out_g, attn_lse_g = self._update_out_and_lse( - attn_out_g, attn_lse_g, attn_out_l, attn_lse_l, - seq_mask_dcp[:, i]) - attn_output = attn_out_g - softmax_lse = attn_lse_g if self.pcp_size > 1: - # Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1] - attn_out_lse = torch.cat([attn_output, softmax_lse], dim=-1) # AllGather out&lse within PCP group attn_out_lse_list = [ torch.empty_like(attn_out_lse) for _ in range(self.pcp_size) @@ -1770,45 +1773,12 @@ class AscendMLAImpl(MLAAttentionImpl): dist.all_gather(attn_out_lse_list, attn_out_lse, group=self.pcp_group) - # Update out&lse - attn_out_g = None - attn_lse_g = None - for i, attn_out_lse_l in enumerate(attn_out_lse_list): - attn_out_l, attn_lse_l = torch.split(attn_out_lse_l, - [self.kv_lora_rank, 1], - dim=-1) - attn_out_g, attn_lse_g = self._update_out_and_lse( - attn_out_g, attn_lse_g, attn_out_l, attn_lse_l, - seq_mask_pcp[:, i]) - attn_output = attn_out_g - return self._v_up_proj(attn_output) + if self.dcp_size > 1 and self.pcp_size > 1: + attn_out_lse_list_pcp_dcp = [] + for s in attn_out_lse_list: + attn_out_lse_list_split = list( + torch.chunk(s, self.dcp_size, dim=1)) + attn_out_lse_list_pcp_dcp += attn_out_lse_list_split + attn_out_lse_list = attn_out_lse_list_pcp_dcp - # TODO use update op to replace this - def _update_out_and_lse( - self, - out: torch.Tensor, - lse: torch.Tensor, - block_out: torch.Tensor, - block_lse: torch.Tensor, - mask: torch.Tensor = None, - ): - if out is None: - out = block_out.to(torch.float32) - lse = block_lse - else: - if mask is None: - mask = torch.ones([block_out.size(0)], - dtype=torch.uint8, - device=block_out.device) - out_mask = mask[:, None, None].expand_as(block_out) - lse_mask = mask[:, None, None].expand_as(block_lse) - block_out = block_out.to(torch.float32) - out_without_update = out.clone() - lse_without_update = lse.clone() - - out = out - F.sigmoid(block_lse - lse) * (out - block_out) - lse = lse - F.logsigmoid(lse - block_lse) - # mask - out = torch.where(out_mask, out, out_without_update) - lse = torch.where(lse_mask, lse, lse_without_update) - return out, lse + return attn_out_lse_list diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 66868dd6..1cdbd976 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1716,10 +1716,15 @@ class NPUModelRunner(LoRAModelRunnerMixin): # We will ignore the sampled tokens from the partial requests. # TODO: Support prompt logprobs. spec_decode_metadata = None - logits_indices = torch.from_numpy( - cu_num_tokens - ) * self.pcp_size - self.num_pcp_pads[:num_reqs] - 1 - logits_indices = logits_indices.to(self.device, non_blocking=True) + if self.pcp_size * self.dcp_size > 1: + logits_indices = torch.from_numpy( + cu_num_tokens + ) * self.pcp_size - self.num_pcp_pads[:num_reqs] - 1 + logits_indices = logits_indices.to(self.device, + non_blocking=True) + else: + logits_indices = torch.from_numpy(cu_num_tokens - 1).to( + self.device, non_blocking=True) else: # pcp not supported now assert self.pcp_size == 1