[Feature] Support for cross-attention and whisper model (#5592)

### What this PR does / why we need it?
To solve the problem of the
issue:https://github.com/vllm-project/vllm-ascend/issues/2262

- support for cross-attention when the model is encoder-decoder
- support for whisper model

- vLLM version: v0.13.0
- vLLM main:
7157596103

Signed-off-by: gh924 <guihao2@huawei.com>
Co-authored-by: Aoxuan Chen <43376869+chenaoxuan@users.noreply.github.com>
This commit is contained in:
gh924
2026-01-11 11:38:45 +08:00
committed by GitHub
parent db12c1e2c8
commit 6880c1b383
5 changed files with 103 additions and 68 deletions

View File

@@ -32,7 +32,7 @@ from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.context_parallel.common_cp import (
@@ -256,6 +256,9 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
if isinstance(self.kv_cache_spec, CrossAttentionSpec):
seq_lens = common_attn_metadata.seq_lens
slot_mapping = common_attn_metadata.slot_mapping.to(torch.int32)
attn_state = common_attn_metadata.attn_state
# Get attn_mask and swa_mask from singleton AttentionMaskBuilder
@@ -502,6 +505,9 @@ class AscendAttentionBackendImpl(AttentionImpl):
block_size = 128
block_table = None
actual_seq_lengths_kv = attn_metadata.actual_seq_lengths_q
if self.attn_type == AttentionType.ENCODER_DECODER:
actual_seq_lengths_kv = torch.cumsum(attn_metadata.seq_lens,
dim=0).tolist()
elif attn_metadata.attn_state == \
AscendAttentionState.PrefillCacheHit:
batch_size = attn_metadata.seq_lens.shape[0]
@@ -583,7 +589,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
= self._get_fia_params(key, value, attn_metadata)
num_tokens = attn_metadata.actual_seq_lengths_q[-1]
query = query[:num_tokens]
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache and self.attn_type != AttentionType.ENCODER_DECODER:
key = key[:num_tokens]
value = value[:num_tokens]
# Get workspace from cache or calculate it if not present.
@@ -675,23 +681,29 @@ class AscendAttentionBackendImpl(AttentionImpl):
if self.key_cache is None:
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
slots = attn_metadata.slot_mapping
encoder_decoder = (self.attn_type == AttentionType.ENCODER_DECODER)
if get_ascend_device_type() == AscendDeviceType.A5:
# TODO: Once eagle running to here, it may has error because of the 0 dim of slot_mapping.
# Should check if the 0 dim of slot_mapping must equal to the 0 dim of key.
# If it's necessary, the slots should be sliced.
torch_npu.npu_scatter_pa_kv_cache(
key=key[:attn_metadata.num_actual_tokens],
value=value[:attn_metadata.num_actual_tokens].contiguous(),
key=key[:attn_metadata.num_actual_tokens]
if not encoder_decoder else key,
value=value[:attn_metadata.num_actual_tokens].contiguous()
if not encoder_decoder else value,
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_mapping=slots)
else:
torch_npu._npu_reshape_and_cache(
key=key[:attn_metadata.num_actual_tokens],
value=value[:attn_metadata.num_actual_tokens],
key=key[:attn_metadata.num_actual_tokens]
if not encoder_decoder else key,
value=value[:attn_metadata.num_actual_tokens]
if not encoder_decoder else value,
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_indices=slots[:attn_metadata.num_actual_tokens])
slot_indices=slots[:attn_metadata.num_actual_tokens]
if not encoder_decoder else slots)
if self.is_kv_producer:
attn_metadata.reshape_cache_event.record()
return key, value
@@ -747,18 +759,12 @@ class AscendAttentionBackendImpl(AttentionImpl):
" for AscendAttentionBackendImpl")
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
attn_type = self.attn_type
if attn_type not in [
AttentionType.DECODER, AttentionType.ENCODER_ONLY
]:
raise NotImplementedError("Encoder/Decoder cross-attention "
"is not implemented for "
"PallasAttentionBackendImpl")
num_tokens = query.shape[0]
if attn_metadata is None:
return output.fill_(0)
key, value = self.reshape_and_cache(key, value, kv_cache,
attn_metadata)
if key is not None and value is not None:
key, value = self.reshape_and_cache(key, value, kv_cache,
attn_metadata)
# pooling model branch
if attn_metadata.model_runner_type == "pooling":
attn_output = self._forward_encoder_attention(