[Feature] Support for cross-attention and whisper model (#5592)

### What this PR does / why we need it?
To solve the problem of the
issue:https://github.com/vllm-project/vllm-ascend/issues/2262

- support for cross-attention when the model is encoder-decoder
- support for whisper model

- vLLM version: v0.13.0
- vLLM main:
7157596103

Signed-off-by: gh924 <guihao2@huawei.com>
Co-authored-by: Aoxuan Chen <43376869+chenaoxuan@users.noreply.github.com>
This commit is contained in:
gh924
2026-01-11 11:38:45 +08:00
committed by GitHub
parent db12c1e2c8
commit 6880c1b383
5 changed files with 103 additions and 68 deletions

View File

@@ -21,6 +21,8 @@ import os
import pytest
from modelscope import snapshot_download # type: ignore
from vllm import SamplingParams
from vllm.assets.audio import AudioAsset
from tests.e2e.conftest import VllmRunner
@@ -32,6 +34,10 @@ MINICPM_MODELS = [
"OpenBMB/MiniCPM4-0.5B",
]
WHISPER_MODELS = [
"openai-mirror/whisper-large-v3-turbo",
]
@pytest.mark.parametrize("model", MINICPM_MODELS)
def test_minicpm(model) -> None:
@@ -44,3 +50,26 @@ def test_minicpm(model) -> None:
max_model_len=512,
gpu_memory_utilization=0.7) as runner:
runner.generate_greedy(example_prompts, max_tokens)
@pytest.mark.parametrize("model", WHISPER_MODELS)
def test_whisper(model) -> None:
prompts = ["<|startoftranscript|><|en|><|transcribe|><|notimestamps|>"]
audios = [AudioAsset("mary_had_lamb").audio_and_sample_rate]
sampling_params = SamplingParams(temperature=0.2,
max_tokens=10,
stop_token_ids=None)
with VllmRunner(snapshot_download(model),
max_model_len=448,
max_num_seqs=5,
dtype="bfloat16",
block_size=128,
gpu_memory_utilization=0.9) as runner:
outputs = runner.generate(prompts=prompts,
audios=audios,
sampling_params=sampling_params)
assert outputs is not None, "Generated outputs should not be None."
assert len(outputs) > 0, "Generated outputs should not be empty."

View File

@@ -320,26 +320,3 @@ class TestAscendAttentionBackendImpl(TestBase):
mock_fused_infer_attention_score.assert_called_once()
assert output.shape == (10, 8, 64)
@patch('torch_npu._npu_reshape_and_cache')
def test_forward_raise_error(self, mock_paged_attention):
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
output = torch.empty_like(query)
metadata = self.attn_metadata
metadata.attn_mask = torch.randn(1, 1, 10, 10)
metadata.query_lens = torch.tensor([10])
metadata.seq_lens = torch.tensor([10])
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 0
metadata.num_prefills = 10
layer = self.layer_no_quant
with self.assertRaises(NotImplementedError):
self.impl_error.forward(layer, query, key, value, kv_cache,
metadata, output)