clean 0.15.0 support (#6852)

Clean up vllm 0.15.0 related code

- vLLM version: v0.16.0
- vLLM main:
15d76f74e2

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2026-02-28 09:20:57 +08:00
committed by GitHub
parent 84b00695f8
commit 3d563292f3
8 changed files with 17 additions and 36 deletions

View File

@@ -21,8 +21,6 @@ import torch.nn.functional as F
import torch_npu
from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention # type: ignore
from vllm_ascend.utils import vllm_version_is
MIN_PAD_SIZE: int = 64 # min_size to pad weight
MAX_PAD_SIZE: int = 128 # max_size to pad weight
@@ -64,9 +62,7 @@ class AscendMMEncoderAttention(MMEncoderAttention):
prefix=prefix,
)
if not vllm_version_is("0.15.0"):
self.layer_index = int("".join(filter(str.isdigit, prefix)))
self.layer_index = int("".join(filter(str.isdigit, prefix)))
self.enable_pad = self.head_size > MIN_PAD_SIZE and self.head_size < MAX_PAD_SIZE
self.scale_value = self.head_size**-0.5
@@ -106,19 +102,13 @@ class AscendMMEncoderAttention(MMEncoderAttention):
kv_len = key.size(1)
is_reshaped = query.dim() == 4
if vllm_version_is("0.15.0"):
# Directly use seq_lens cpu cache to avoid d2h copy.
global seq_lens_cpu_cache
if self.layer_index == 0:
if cu_seqlens is None:
cu_seqlens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device="cpu")
seq_lens_cpu = torch.diff(cu_seqlens).to("cpu")
else:
global seq_lens_cpu_cache
if self.layer_index == 0:
if cu_seqlens is None:
cu_seqlens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device="cpu")
# Update seq_lens cpu cache.
seq_lens_cpu_cache = torch.diff(cu_seqlens).to("cpu")
# Directly use seq_lens cpu cache to avoid d2h copy.
seq_lens_cpu = seq_lens_cpu_cache
# Update seq_lens cpu cache.
seq_lens_cpu_cache = torch.diff(cu_seqlens).to("cpu")
# q, k, v: [b, s, head, head_dim] -> [b * s, head, head_dim]
q, k, v = self._reshape_qkv_to_3d(query, key, value, bsz, q_len, kv_len)
@@ -138,7 +128,7 @@ class AscendMMEncoderAttention(MMEncoderAttention):
query=q,
key=k,
value=v,
seq_len=seq_lens_cpu,
seq_len=seq_lens_cpu_cache,
scale_value=self.scale_value,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,