Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -182,7 +182,7 @@ class RocmAttentionBackend(AttentionBackend):
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
return [32, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
@@ -205,6 +205,16 @@ class RocmAttentionBackend(AttentionBackend):
|
||||
def get_impl_cls() -> type["RocmAttentionImpl"]:
|
||||
return RocmAttentionImpl
|
||||
|
||||
@classmethod
|
||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||
"""RocmAttention supports all attention types."""
|
||||
return attn_type in (
|
||||
AttentionType.DECODER,
|
||||
AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY,
|
||||
AttentionType.ENCODER_DECODER,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
@@ -244,6 +254,7 @@ class RocmAttentionImpl(AttentionImpl):
|
||||
kv_sharing_target_layer_name: int | None = None,
|
||||
sinks: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
self.attn_type = attn_type
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
@@ -266,11 +277,6 @@ class RocmAttentionImpl(AttentionImpl):
|
||||
|
||||
RocmAttentionBackend.validate_head_size(head_size)
|
||||
|
||||
if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention is not implemented for RocmAttentionImpl"
|
||||
)
|
||||
|
||||
self.fp8_dtype = current_platform.fp8_dtype()
|
||||
|
||||
self.sinks = sinks
|
||||
@@ -281,6 +287,54 @@ class RocmAttentionImpl(AttentionImpl):
|
||||
f"num_heads: {num_heads}."
|
||||
)
|
||||
|
||||
def _forward_encoder_attention(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
layer: torch.nn.Module,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass for encoder attention without KV cache.
|
||||
|
||||
Args:
|
||||
query: shape = [num_encoder_tokens, num_heads, head_size]
|
||||
key: shape = [num_encoder_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_encoder_tokens, num_kv_heads, head_size]
|
||||
output: shape = [num_encoder_tokens, num_heads, head_size]
|
||||
attn_metadata: Encoder attention metadata
|
||||
layer: The attention layer
|
||||
"""
|
||||
# For encoder attention, process FP8 quantization if needed
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError(
|
||||
"quantization is not supported for encoder attention"
|
||||
)
|
||||
|
||||
# Use encoder-specific metadata for sequence information
|
||||
query_start_loc = attn_metadata.query_start_loc
|
||||
seq_lens = attn_metadata.seq_lens
|
||||
max_query_len = attn_metadata.max_query_len
|
||||
|
||||
# Call flash attention directly on Q, K, V tensors
|
||||
from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd
|
||||
|
||||
context_attention_fwd(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
o=output,
|
||||
b_start_loc=query_start_loc,
|
||||
b_seq_len=seq_lens,
|
||||
max_input_len=max_query_len,
|
||||
is_causal=False,
|
||||
softmax_scale=self.scale,
|
||||
sliding_window_q=self.sliding_window[0],
|
||||
sliding_window_k=self.sliding_window[1],
|
||||
)
|
||||
return output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -330,6 +384,16 @@ class RocmAttentionImpl(AttentionImpl):
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
|
||||
return self._forward_encoder_attention(
|
||||
query[:num_actual_tokens],
|
||||
key[:num_actual_tokens],
|
||||
value[:num_actual_tokens],
|
||||
output[:num_actual_tokens],
|
||||
attn_metadata,
|
||||
layer,
|
||||
)
|
||||
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size
|
||||
)
|
||||
@@ -380,6 +444,8 @@ class RocmAttentionImpl(AttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
):
|
||||
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
|
||||
return
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size
|
||||
)
|
||||
@@ -432,6 +498,8 @@ class RocmAttentionImpl(AttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
layer_slot_mapping: torch.Tensor,
|
||||
):
|
||||
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
|
||||
return
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache,
|
||||
layer.num_kv_heads, # type: ignore[attr-defined]
|
||||
|
||||
Reference in New Issue
Block a user