[Feature] Support for cross-attention and whisper model (#5592)
### What this PR does / why we need it?
To solve the problem of the
issue:https://github.com/vllm-project/vllm-ascend/issues/2262
- support for cross-attention when the model is encoder-decoder
- support for whisper model
- vLLM version: v0.13.0
- vLLM main:
7157596103
Signed-off-by: gh924 <guihao2@huawei.com>
Co-authored-by: Aoxuan Chen <43376869+chenaoxuan@users.noreply.github.com>
This commit is contained in:
@@ -32,7 +32,7 @@ from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
||||
AttentionMetadataBuilder)
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec
|
||||
|
||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.context_parallel.common_cp import (
|
||||
@@ -256,6 +256,9 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
|
||||
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
||||
|
||||
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
|
||||
if isinstance(self.kv_cache_spec, CrossAttentionSpec):
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
slot_mapping = common_attn_metadata.slot_mapping.to(torch.int32)
|
||||
attn_state = common_attn_metadata.attn_state
|
||||
|
||||
# Get attn_mask and swa_mask from singleton AttentionMaskBuilder
|
||||
@@ -502,6 +505,9 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
block_size = 128
|
||||
block_table = None
|
||||
actual_seq_lengths_kv = attn_metadata.actual_seq_lengths_q
|
||||
if self.attn_type == AttentionType.ENCODER_DECODER:
|
||||
actual_seq_lengths_kv = torch.cumsum(attn_metadata.seq_lens,
|
||||
dim=0).tolist()
|
||||
elif attn_metadata.attn_state == \
|
||||
AscendAttentionState.PrefillCacheHit:
|
||||
batch_size = attn_metadata.seq_lens.shape[0]
|
||||
@@ -583,7 +589,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
= self._get_fia_params(key, value, attn_metadata)
|
||||
num_tokens = attn_metadata.actual_seq_lengths_q[-1]
|
||||
query = query[:num_tokens]
|
||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache and self.attn_type != AttentionType.ENCODER_DECODER:
|
||||
key = key[:num_tokens]
|
||||
value = value[:num_tokens]
|
||||
# Get workspace from cache or calculate it if not present.
|
||||
@@ -675,23 +681,29 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
if self.key_cache is None:
|
||||
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
||||
slots = attn_metadata.slot_mapping
|
||||
encoder_decoder = (self.attn_type == AttentionType.ENCODER_DECODER)
|
||||
if get_ascend_device_type() == AscendDeviceType.A5:
|
||||
# TODO: Once eagle running to here, it may has error because of the 0 dim of slot_mapping.
|
||||
# Should check if the 0 dim of slot_mapping must equal to the 0 dim of key.
|
||||
# If it's necessary, the slots should be sliced.
|
||||
torch_npu.npu_scatter_pa_kv_cache(
|
||||
key=key[:attn_metadata.num_actual_tokens],
|
||||
value=value[:attn_metadata.num_actual_tokens].contiguous(),
|
||||
key=key[:attn_metadata.num_actual_tokens]
|
||||
if not encoder_decoder else key,
|
||||
value=value[:attn_metadata.num_actual_tokens].contiguous()
|
||||
if not encoder_decoder else value,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
slot_mapping=slots)
|
||||
else:
|
||||
torch_npu._npu_reshape_and_cache(
|
||||
key=key[:attn_metadata.num_actual_tokens],
|
||||
value=value[:attn_metadata.num_actual_tokens],
|
||||
key=key[:attn_metadata.num_actual_tokens]
|
||||
if not encoder_decoder else key,
|
||||
value=value[:attn_metadata.num_actual_tokens]
|
||||
if not encoder_decoder else value,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
slot_indices=slots[:attn_metadata.num_actual_tokens])
|
||||
slot_indices=slots[:attn_metadata.num_actual_tokens]
|
||||
if not encoder_decoder else slots)
|
||||
if self.is_kv_producer:
|
||||
attn_metadata.reshape_cache_event.record()
|
||||
return key, value
|
||||
@@ -747,18 +759,12 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
" for AscendAttentionBackendImpl")
|
||||
|
||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||
attn_type = self.attn_type
|
||||
if attn_type not in [
|
||||
AttentionType.DECODER, AttentionType.ENCODER_ONLY
|
||||
]:
|
||||
raise NotImplementedError("Encoder/Decoder cross-attention "
|
||||
"is not implemented for "
|
||||
"PallasAttentionBackendImpl")
|
||||
num_tokens = query.shape[0]
|
||||
if attn_metadata is None:
|
||||
return output.fill_(0)
|
||||
key, value = self.reshape_and_cache(key, value, kv_cache,
|
||||
attn_metadata)
|
||||
if key is not None and value is not None:
|
||||
key, value = self.reshape_and_cache(key, value, kv_cache,
|
||||
attn_metadata)
|
||||
# pooling model branch
|
||||
if attn_metadata.model_runner_type == "pooling":
|
||||
attn_output = self._forward_encoder_attention(
|
||||
|
||||
Reference in New Issue
Block a user