Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -86,6 +86,26 @@ class AttentionBackend(ABC):
) -> tuple[int, ...]:
raise NotImplementedError
@classmethod
def get_kv_cache_block_dim(
cls,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> int:
"""Discover which tensor dim is the block index, since different
backends lay out dims differently."""
_S = 1234567
shape = cls.get_kv_cache_shape(
_S,
block_size,
num_kv_heads,
head_size,
cache_dtype_str=cache_dtype_str,
)
return shape.index(_S)
@staticmethod
def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False,
@@ -301,10 +321,13 @@ class CommonAttentionMetadata:
query_start_loc: torch.Tensor
query_start_loc_cpu: torch.Tensor
"""(batch_size + 1,), the start location of each request in query Tensor"""
key_start_loc: torch.Tensor
"""(batch_size + 1,), the start location of each request in key/valye Tensor(none-crossattention)"""
seq_lens: torch.Tensor
"""(batch_size,), the number of computed tokens for each request"""
seq_lens_np: np.array
num_reqs: int
"""Number of requests"""
@@ -394,7 +417,9 @@ class CommonAttentionMetadata:
return CommonAttentionMetadata(
query_start_loc=self.query_start_loc[: num_actual_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[: num_actual_reqs + 1],
key_start_loc=self.key_start_loc[: num_actual_reqs + 1],
seq_lens=self.seq_lens[:num_actual_reqs],
seq_lens_np=self.seq_lens_np[:num_actual_reqs],
_seq_lens_cpu=self._seq_lens_cpu[:num_actual_reqs]
if self._seq_lens_cpu is not None
else None,
@@ -811,6 +836,28 @@ class MLAAttentionImpl(AttentionImplBase[T], Generic[T]):
"""MQA-style decode forward pass."""
raise NotImplementedError
def do_kv_cache_update(
self,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: torch.Tensor,
) -> None:
if kv_cache.numel() == 0:
return
from vllm import _custom_ops as ops
ops.concat_and_cache_mla(
kv_c_normed,
k_pe.squeeze(1),
kv_cache,
slot_mapping.flatten(),
kv_cache_dtype=kv_cache_dtype,
scale=k_scale,
)
class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]):
"""Sparse MLA attention implementation with only forward_mqa method.
@@ -856,6 +903,28 @@ class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]):
"""MQA-style decode forward pass."""
raise NotImplementedError
def do_kv_cache_update(
self,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: torch.Tensor,
) -> None:
if kv_cache.numel() == 0:
return
from vllm import _custom_ops as ops
ops.concat_and_cache_mla(
kv_c_normed,
k_pe.squeeze(1),
kv_cache,
slot_mapping.flatten(),
kv_cache_dtype=kv_cache_dtype,
scale=k_scale,
)
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
return kv_cache_dtype.startswith("fp8")