Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -37,8 +37,8 @@ def flash_attn_maxseqlen_wrapper(
else:
from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func
# if not current_platform.is_rocm() and fa_version is not None:
# kwargs["fa_version"] = fa_version
if not current_platform.is_rocm() and fa_version is not None:
kwargs["fa_version"] = fa_version
q_len = q.size(1)
if cu_seqlens is None:
@@ -268,3 +268,91 @@ def vit_torch_sdpa_wrapper(
return torch.ops.vllm.torch_sdpa_wrapper(
q, k, v, scale, cu_seqlens, enable_gqa=enable_gqa
)
def flashinfer_wrapper(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
scale: float,
workspace_buffer: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None,
sequence_lengths: torch.Tensor | None = None,
) -> torch.Tensor:
from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache
is_reshaped = q.dim() == 4
if is_reshaped:
reshape_batch_size = q.shape[0]
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
# cuDNN <= 9.10.2.21 requires q, k to be contiguous
# this comes with no cost for ViTs with RoPE because
# RoPE has already made q and k contiguous.
q, k = q.contiguous(), k.contiguous()
assert len(cu_seqlens) % 2 == 0, "cu_seqlens must be divisible by 2"
cu_seqlength = len(cu_seqlens) // 2
batch_offsets_qko = cu_seqlens[:cu_seqlength].view(-1, 1, 1, 1)
batch_offsets_v = cu_seqlens[cu_seqlength:].view(-1, 1, 1, 1)
sequence_lengths = sequence_lengths.view(-1, 1, 1, 1)
max_seqlen = max_seqlen.item()
output, _ = cudnn_batch_prefill_with_kv_cache(
q,
k,
v,
scale,
workspace_buffer,
max_token_per_sequence=max_seqlen,
max_sequence_kv=max_seqlen,
actual_seq_lens_q=sequence_lengths,
actual_seq_lens_kv=sequence_lengths,
causal=False,
return_lse=False,
batch_offsets_q=batch_offsets_qko,
batch_offsets_k=batch_offsets_qko,
batch_offsets_v=batch_offsets_v,
batch_offsets_o=batch_offsets_qko,
)
if is_reshaped:
output = einops.rearrange(output, "(b s) h d -> b s h d", b=reshape_batch_size)
return output
def vit_flashinfer_wrapper_fake(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
scale: float,
workspace_buffer: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None,
sequence_lengths: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.empty_like(q)
direct_register_custom_op(
op_name="flashinfer_wrapper",
op_func=flashinfer_wrapper,
fake_impl=vit_flashinfer_wrapper_fake,
)
def vit_flashinfer_wrapper(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
scale: float,
workspace_buffer: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None,
sequence_lengths: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.ops.vllm.flashinfer_wrapper(
q, k, v, scale, workspace_buffer, cu_seqlens, max_seqlen, sequence_lengths
)