diff --git a/vllm_ascend/ops/mm_encoder_attention.py b/vllm_ascend/ops/mm_encoder_attention.py index 1ffd8679..2f0e3da7 100644 --- a/vllm_ascend/ops/mm_encoder_attention.py +++ b/vllm_ascend/ops/mm_encoder_attention.py @@ -16,10 +16,12 @@ # import einops +import numpy as np import torch import torch.nn.functional as F import torch_npu from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention # type: ignore +from vllm.v1.attention.backends.registry import AttentionBackendEnum MIN_PAD_SIZE: int = 64 # min_size to pad weight MAX_PAD_SIZE: int = 128 # max_size to pad weight @@ -65,6 +67,21 @@ class AscendMMEncoderAttention(MMEncoderAttention): self.enable_pad = self.head_size > MIN_PAD_SIZE and self.head_size < MAX_PAD_SIZE self.scale_value = self.head_size**-0.5 + @classmethod + def maybe_compute_seq_lens( + cls, + attn_backend: AttentionBackendEnum, + cu_seqlens: np.ndarray, + device: torch.device, + ) -> np.ndarray | None: + if cu_seqlens is None: + return None + + seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] + seq_lens = torch.from_numpy(seq_lens).to("cpu", non_blocking=True) + + return seq_lens + def _reshape_qkv_to_3d( self, query: torch.Tensor, @@ -89,6 +106,21 @@ class AscendMMEncoderAttention(MMEncoderAttention): return query, key, value + def _maybe_compute_cu_seqlens( + self, + bsz: int, + q_len: int, + cu_seqlens: torch.Tensor | None = None, + ) -> torch.Tensor: + if cu_seqlens is not None: + return cu_seqlens + + # If cu_seqlens is not provided, we create a default one assuming all sequences have the same length. + # This is used by models such as Hunyuan-OCR, which always pass None as cu_seqlens and rely on the operator to + # compute it internally. + cu_seqlens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device="cpu") + return cu_seqlens + def forward_oot( self, query: torch.Tensor, @@ -102,10 +134,16 @@ class AscendMMEncoderAttention(MMEncoderAttention): kv_len = key.size(1) is_reshaped = query.dim() == 4 - # Directly use seq_lens cpu cache to avoid d2h copy. - if cu_seqlens is None: - cu_seqlens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device="cpu") - seq_lens_cpu = torch.diff(cu_seqlens).to("cpu") + if sequence_lengths is not None: + # Use pre-compute seq_lens before vision blocks. + if sequence_lengths.device.type != "cpu": + sequence_lengths = sequence_lengths.to("cpu") + seq_lens_cpu = sequence_lengths + else: + # Convert cu_seqlens to seq_lens and move it to CPU, since FA requires CPU seq_lens. + # NOTE: This will considerably hurt performance. + cu_seqlens = self._maybe_compute_cu_seqlens(bsz, q_len, cu_seqlens) + seq_lens_cpu = torch.diff(cu_seqlens).to("cpu") # q, k, v: [b, s, head, head_dim] -> [b * s, head, head_dim] q, k, v = self._reshape_qkv_to_3d(query, key, value, bsz, q_len, kv_len)