diff --git a/vllm_ascend/ops/mm_encoder_attention.py b/vllm_ascend/ops/mm_encoder_attention.py index 85388db0..733cd888 100644 --- a/vllm_ascend/ops/mm_encoder_attention.py +++ b/vllm_ascend/ops/mm_encoder_attention.py @@ -21,8 +21,20 @@ import torch.nn.functional as F import torch_npu from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention # type: ignore -MIN_PAD_SIZE = 64 # min_size to pad weight -MAX_PAD_SIZE = 128 # max_size to pad weight +from vllm_ascend.utils import vllm_version_is + +MIN_PAD_SIZE: int = 64 # min_size to pad weight +MAX_PAD_SIZE: int = 128 # max_size to pad weight + +# Use seq_lens CPU cache to avoid frequent d2h copy. +# AscendMMEncoderAttention will copy the cu_seqlens from NPU to CPU in every +# forward, since the op _npu_flash_attention_unpad() requires CPU cu_seqlens +# (otherwise it will break down). +# Thus, we use seq_lens_cpu_cache to cache this tensor, since it's shared +# between all layers, but may change in different forward step. When the +# current layer_index is 0, we update the cache, otherwise we directly use the +# cache to avoid frequent diff and copy operations, which are costful. +seq_lens_cpu_cache: torch.Tensor = None class AscendMMEncoderAttention(MMEncoderAttention): @@ -52,7 +64,13 @@ class AscendMMEncoderAttention(MMEncoderAttention): prefix=prefix, ) - def reshape_qkv_to_3d( + if not vllm_version_is("0.15.0"): + self.layer_index = int("".join(filter(str.isdigit, prefix))) + + self.enable_pad = self.head_size > MIN_PAD_SIZE and self.head_size < MAX_PAD_SIZE + self.scale_value = self.head_size**-0.5 + + def _reshape_qkv_to_3d( self, query: torch.Tensor, key: torch.Tensor, @@ -88,41 +106,46 @@ class AscendMMEncoderAttention(MMEncoderAttention): kv_len = key.size(1) is_reshaped = query.dim() == 4 + if vllm_version_is("0.15.0"): + if cu_seqlens is None: + cu_seqlens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device="cpu") + seq_lens_cpu = torch.diff(cu_seqlens).to("cpu") + else: + global seq_lens_cpu_cache + if self.layer_index == 0: + if cu_seqlens is None: + cu_seqlens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device="cpu") + # Update seq_lens cpu cache. + seq_lens_cpu_cache = torch.diff(cu_seqlens).to("cpu") + # Directly use seq_lens cpu cache to avoid d2h copy. + seq_lens_cpu = seq_lens_cpu_cache + # q, k, v: [b, s, head, head_dim] -> [b * s, head, head_dim] - q, k, v = self.reshape_qkv_to_3d(query, key, value, bsz, q_len, kv_len) + q, k, v = self._reshape_qkv_to_3d(query, key, value, bsz, q_len, kv_len) - enable_pad = self.head_size > MIN_PAD_SIZE and self.head_size < MAX_PAD_SIZE - - if enable_pad: + if self.enable_pad: origin_shape = q.shape[-1] pad_len = MAX_PAD_SIZE - origin_shape - # Merge qkv to reduce the overhead of launching npu pad operation. - # [3, b*s, head, head_dim] - qkv = torch.stack([q, k, v], dim=0) - # qkv: [3, b * s, head, head_dim] -> [3, b * s, head, MAX_PAD_SIZE] - qkv = F.pad(qkv, (0, pad_len), mode="constant", value=0) - q, k, v = qkv.unbind(dim=0) + # [b * s, head, head_dim] -> [b * s, head, MAX_PAD_SIZE] + q = F.pad(q, (0, pad_len), mode="constant", value=0) + k = F.pad(k, (0, pad_len), mode="constant", value=0) + v = F.pad(v, (0, pad_len), mode="constant", value=0) context_layer = torch.empty_like(q) - if cu_seqlens is None: - cu_seqlens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=query.device) - - cu_seqlens = torch.diff(cu_seqlens).to("cpu") - # operator requires pta version >= 2.5.1 torch_npu._npu_flash_attention_unpad( query=q, key=k, value=v, - seq_len=cu_seqlens, - scale_value=self.head_size**-0.5, + seq_len=seq_lens_cpu, + scale_value=self.scale_value, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, out=context_layer, ) - if enable_pad: + if self.enable_pad: context_layer = context_layer[..., :origin_shape] if is_reshaped: