Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -55,6 +55,16 @@ class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
|
||||
def get_builder_cls() -> type["RocmAttentionMetadataBuilder"]:
|
||||
return RocmAttentionMetadataBuilder
|
||||
|
||||
@classmethod
|
||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||
"""RocmAiterUnifiedAttention supports all attention types."""
|
||||
return attn_type in (
|
||||
AttentionType.DECODER,
|
||||
AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY,
|
||||
AttentionType.ENCODER_DECODER,
|
||||
)
|
||||
|
||||
|
||||
class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
|
||||
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||
@@ -143,6 +153,19 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
# Handle encoder attention differently - no KV cache needed
|
||||
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
|
||||
# For encoder attention,
|
||||
# we use direct Q, K, V tensors without caching
|
||||
return self._forward_encoder_attention(
|
||||
query[:num_actual_tokens],
|
||||
key[:num_actual_tokens],
|
||||
value[:num_actual_tokens],
|
||||
output[:num_actual_tokens],
|
||||
attn_metadata,
|
||||
layer,
|
||||
)
|
||||
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
@@ -195,6 +218,10 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
):
|
||||
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
|
||||
# For encoder attention,
|
||||
# we use direct Q, K, V tensors without caching
|
||||
return
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
@@ -224,6 +251,10 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
layer_slot_mapping: torch.Tensor,
|
||||
):
|
||||
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
|
||||
# For encoder attention,
|
||||
# we use direct Q, K, V tensors without caching
|
||||
return
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
flash_layout = True
|
||||
|
||||
|
||||
Reference in New Issue
Block a user