[CI] Fix broken CI (#6599)
Revert4fb3d5e1b2it breaks E2E Test - vLLM version: v0.15.0 - vLLM main:d7e17aaacd
This commit is contained in:
@@ -19,15 +19,18 @@ import einops
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
from vllm.config import MultiModalConfig
|
||||
from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention # type: ignore
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
|
||||
|
||||
MIN_PAD_SIZE = 64 # min_size to pad weight
|
||||
MAX_PAD_SIZE = 128 # max_size to pad weight
|
||||
|
||||
|
||||
class AscendMMEncoderAttention(MMEncoderAttention):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
@@ -79,12 +82,13 @@ class AscendMMEncoderAttention(MMEncoderAttention):
|
||||
return query, key, value
|
||||
|
||||
def forward_oot(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor
|
||||
| None = None, # Only used for Flash Attention
|
||||
):
|
||||
bsz, q_len = query.size()[:2]
|
||||
kv_len = key.size(1)
|
||||
@@ -93,7 +97,9 @@ class AscendMMEncoderAttention(MMEncoderAttention):
|
||||
# q, k, v: [b, s, head, head_dim] -> [b * s, head, head_dim]
|
||||
q, k, v = self.reshape_qkv_to_3d(query, key, value, bsz, q_len, kv_len)
|
||||
|
||||
enable_pad = envs_ascend.USE_OPTIMIZED_MODEL and self.head_size > MIN_PAD_SIZE and self.head_size < MAX_PAD_SIZE
|
||||
enable_pad = (envs_ascend.USE_OPTIMIZED_MODEL
|
||||
and self.head_size > MIN_PAD_SIZE
|
||||
and self.head_size < MAX_PAD_SIZE)
|
||||
|
||||
if enable_pad:
|
||||
origin_shape = q.shape[-1]
|
||||
@@ -108,7 +114,10 @@ class AscendMMEncoderAttention(MMEncoderAttention):
|
||||
context_layer = torch.empty_like(q)
|
||||
|
||||
if cu_seqlens is None:
|
||||
cu_seqlens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=query.device)
|
||||
cu_seqlens = torch.arange(0, (bsz + 1) * q_len,
|
||||
step=q_len,
|
||||
dtype=torch.int32,
|
||||
device=query.device)
|
||||
|
||||
cu_seqlens = torch.diff(cu_seqlens).to("cpu")
|
||||
|
||||
@@ -128,7 +137,11 @@ class AscendMMEncoderAttention(MMEncoderAttention):
|
||||
context_layer = context_layer[..., :origin_shape]
|
||||
|
||||
if is_reshaped:
|
||||
context_layer = einops.rearrange(context_layer, "(b s) h d -> b s h d", b=bsz).contiguous()
|
||||
context_layer = einops.rearrange(context_layer,
|
||||
"(b s) h d -> b s h d",
|
||||
b=bsz).contiguous()
|
||||
else:
|
||||
context_layer = einops.rearrange(context_layer, "(b s) h d -> b s (h d)", b=bsz).contiguous()
|
||||
context_layer = einops.rearrange(context_layer,
|
||||
"(b s) h d -> b s (h d)",
|
||||
b=bsz).contiguous()
|
||||
return context_layer
|
||||
|
||||
Reference in New Issue
Block a user