Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -86,6 +86,26 @@ class AttentionBackend(ABC):
|
||||
) -> tuple[int, ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_kv_cache_block_dim(
|
||||
cls,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> int:
|
||||
"""Discover which tensor dim is the block index, since different
|
||||
backends lay out dims differently."""
|
||||
_S = 1234567
|
||||
shape = cls.get_kv_cache_shape(
|
||||
_S,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
cache_dtype_str=cache_dtype_str,
|
||||
)
|
||||
return shape.index(_S)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_stride_order(
|
||||
include_num_layers_dimension: bool = False,
|
||||
@@ -301,10 +321,13 @@ class CommonAttentionMetadata:
|
||||
|
||||
query_start_loc: torch.Tensor
|
||||
query_start_loc_cpu: torch.Tensor
|
||||
|
||||
"""(batch_size + 1,), the start location of each request in query Tensor"""
|
||||
|
||||
key_start_loc: torch.Tensor
|
||||
"""(batch_size + 1,), the start location of each request in key/valye Tensor(none-crossattention)"""
|
||||
seq_lens: torch.Tensor
|
||||
"""(batch_size,), the number of computed tokens for each request"""
|
||||
seq_lens_np: np.array
|
||||
|
||||
num_reqs: int
|
||||
"""Number of requests"""
|
||||
@@ -394,7 +417,9 @@ class CommonAttentionMetadata:
|
||||
return CommonAttentionMetadata(
|
||||
query_start_loc=self.query_start_loc[: num_actual_reqs + 1],
|
||||
query_start_loc_cpu=self.query_start_loc_cpu[: num_actual_reqs + 1],
|
||||
key_start_loc=self.key_start_loc[: num_actual_reqs + 1],
|
||||
seq_lens=self.seq_lens[:num_actual_reqs],
|
||||
seq_lens_np=self.seq_lens_np[:num_actual_reqs],
|
||||
_seq_lens_cpu=self._seq_lens_cpu[:num_actual_reqs]
|
||||
if self._seq_lens_cpu is not None
|
||||
else None,
|
||||
@@ -811,6 +836,28 @@ class MLAAttentionImpl(AttentionImplBase[T], Generic[T]):
|
||||
"""MQA-style decode forward pass."""
|
||||
raise NotImplementedError
|
||||
|
||||
def do_kv_cache_update(
|
||||
self,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: torch.Tensor,
|
||||
) -> None:
|
||||
if kv_cache.numel() == 0:
|
||||
return
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
ops.concat_and_cache_mla(
|
||||
kv_c_normed,
|
||||
k_pe.squeeze(1),
|
||||
kv_cache,
|
||||
slot_mapping.flatten(),
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
scale=k_scale,
|
||||
)
|
||||
|
||||
|
||||
class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]):
|
||||
"""Sparse MLA attention implementation with only forward_mqa method.
|
||||
@@ -856,6 +903,28 @@ class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]):
|
||||
"""MQA-style decode forward pass."""
|
||||
raise NotImplementedError
|
||||
|
||||
def do_kv_cache_update(
|
||||
self,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: torch.Tensor,
|
||||
) -> None:
|
||||
if kv_cache.numel() == 0:
|
||||
return
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
ops.concat_and_cache_mla(
|
||||
kv_c_normed,
|
||||
k_pe.squeeze(1),
|
||||
kv_cache,
|
||||
slot_mapping.flatten(),
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
scale=k_scale,
|
||||
)
|
||||
|
||||
|
||||
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
|
||||
return kv_cache_dtype.startswith("fp8")
|
||||
|
||||
Reference in New Issue
Block a user