[Feature] Support for cross-attention and whisper model (#5592)
### What this PR does / why we need it?
To solve the problem of the
issue:https://github.com/vllm-project/vllm-ascend/issues/2262
- support for cross-attention when the model is encoder-decoder
- support for whisper model
- vLLM version: v0.13.0
- vLLM main:
7157596103
Signed-off-by: gh924 <guihao2@huawei.com>
Co-authored-by: Aoxuan Chen <43376869+chenaoxuan@users.noreply.github.com>
This commit is contained in:
@@ -320,26 +320,3 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
mock_fused_infer_attention_score.assert_called_once()
|
||||
|
||||
assert output.shape == (10, 8, 64)
|
||||
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
def test_forward_raise_error(self, mock_paged_attention):
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
metadata.query_lens = torch.tensor([10])
|
||||
metadata.seq_lens = torch.tensor([10])
|
||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
metadata.num_decodes = 0
|
||||
metadata.num_prefills = 10
|
||||
layer = self.layer_no_quant
|
||||
|
||||
with self.assertRaises(NotImplementedError):
|
||||
self.impl_error.forward(layer, query, key, value, kv_cache,
|
||||
metadata, output)
|
||||
|
||||
Reference in New Issue
Block a user