Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -418,6 +418,125 @@ def should_use_deepgemm_for_fp8_linear(
)
def fp8_mqa_logits_torch(
q: torch.Tensor,
kv: tuple[torch.Tensor, torch.Tensor],
weights: torch.Tensor,
cu_seqlen_ks: torch.Tensor,
cu_seqlen_ke: torch.Tensor,
) -> torch.Tensor:
"""Compute FP8 MQA logits for a single sequence without KV paging (CUDA fallback).
This is a pure PyTorch fallback for CUDA when DeepGEMM is not available.
Args:
q: Query tensor of shape [M, H, D]. Casted to
`torch.float8_e4m3fn` by caller.
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
[N, 1]) with dtype `torch.float32`.
weights: weights of shape [M, H], dtype `torch.float32`.
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
shape [M], dtype int32.
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
shape [M], dtype int32.
Returns:
Logits tensor of shape [M, N], dtype `torch.float32`.
"""
kv_fp8, scale = kv
seq_len_kv = kv_fp8.shape[0]
k = kv_fp8.to(torch.bfloat16)
q = q.to(torch.bfloat16)
mask_lo = (
torch.arange(0, seq_len_kv, device=q.device)[None, :] >= cu_seqlen_ks[:, None]
)
mask_hi = (
torch.arange(0, seq_len_kv, device=q.device)[None, :] < cu_seqlen_ke[:, None]
)
mask = mask_lo & mask_hi
score = torch.einsum("mhd,nd->hmn", q, k).float() * scale
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
logits = logits.masked_fill(~mask, float("-inf"))
return logits
def fp8_paged_mqa_logits_torch(
q: torch.Tensor,
kv_cache: torch.Tensor,
weights: torch.Tensor,
context_lens: torch.Tensor,
block_tables: torch.Tensor,
max_model_len: int,
) -> torch.Tensor:
"""Compute FP8 MQA logits using paged KV-cache (CUDA fallback).
This is a pure PyTorch fallback for CUDA when DeepGEMM is not available.
Handles head_dim = 132 (128 + 4 for RoPE).
Args:
q: Query tensor of shape [B, next_n, H, D].
kv_cache: Paged KV-cache in packed FP8+scale layout with shape
[num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last
4 bytes per (block,pos) store the `float` dequant scale.
weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
context_lens: Tensor of shape [B], dtype int32; effective context length
for each batch element.
block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical
block indices to physical blocks in the paged cache.
max_model_len: Maximum sequence length used to size the logits output.
Returns:
Logits tensor of shape [B * next_n, max_model_len], dtype
`torch.float32`.
"""
fp8_dtype = current_platform.fp8_dtype()
batch_size, next_n, heads, dim = q.size()
kv_cache, scale = kv_cache[..., :dim], kv_cache[..., dim:]
scale = scale.contiguous().view(torch.float)
q = q.float()
kv_cache = kv_cache.view(fp8_dtype).float() * scale
num_blocks, block_size, _, dim = kv_cache.size()
logits = torch.full(
[batch_size * next_n, max_model_len],
float("-inf"),
device=q.device,
dtype=torch.float32,
)
for i in range(batch_size):
context_len = context_lens[i].item()
q_offsets = torch.arange(context_len - next_n, context_len, device=q.device)
weight_slice = (
weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous()
)
for block_idx in range(cdiv(context_len, block_size)):
block_id = block_tables[i][block_idx]
qx, kx = q[i], kv_cache[block_id]
k_offsets = torch.arange(
block_idx * block_size, (block_idx + 1) * block_size, device=q.device
)
mask = (k_offsets[None, :] < context_len) & (
k_offsets[None, :] <= q_offsets[:, None]
)
s = torch.where(
mask[None, :, :],
(qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(
logits.dtype
),
float("-inf"),
)
s = torch.relu(s) * weight_slice[..., None]
s = s.sum(dim=0)
logits[
i * next_n : (i + 1) * next_n,
block_idx * block_size : (block_idx + 1) * block_size,
] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf"))
return logits
__all__ = [
"calc_diff",
"DeepGemmQuantScaleFMT",
@@ -425,7 +544,9 @@ __all__ = [
"m_grouped_fp8_gemm_nt_contiguous",
"fp8_m_grouped_gemm_nt_masked",
"fp8_mqa_logits",
"fp8_mqa_logits_torch",
"fp8_paged_mqa_logits",
"fp8_paged_mqa_logits_torch",
"get_paged_mqa_logits_metadata",
"per_block_cast_to_fp8",
"is_deep_gemm_e8m0_used",