Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -37,8 +37,8 @@ def flash_attn_maxseqlen_wrapper(
|
||||
else:
|
||||
from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func
|
||||
|
||||
# if not current_platform.is_rocm() and fa_version is not None:
|
||||
# kwargs["fa_version"] = fa_version
|
||||
if not current_platform.is_rocm() and fa_version is not None:
|
||||
kwargs["fa_version"] = fa_version
|
||||
|
||||
q_len = q.size(1)
|
||||
if cu_seqlens is None:
|
||||
@@ -268,3 +268,91 @@ def vit_torch_sdpa_wrapper(
|
||||
return torch.ops.vllm.torch_sdpa_wrapper(
|
||||
q, k, v, scale, cu_seqlens, enable_gqa=enable_gqa
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_wrapper(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
scale: float,
|
||||
workspace_buffer: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None,
|
||||
sequence_lengths: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache
|
||||
|
||||
is_reshaped = q.dim() == 4
|
||||
|
||||
if is_reshaped:
|
||||
reshape_batch_size = q.shape[0]
|
||||
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||
# cuDNN <= 9.10.2.21 requires q, k to be contiguous
|
||||
# this comes with no cost for ViTs with RoPE because
|
||||
# RoPE has already made q and k contiguous.
|
||||
q, k = q.contiguous(), k.contiguous()
|
||||
|
||||
assert len(cu_seqlens) % 2 == 0, "cu_seqlens must be divisible by 2"
|
||||
cu_seqlength = len(cu_seqlens) // 2
|
||||
batch_offsets_qko = cu_seqlens[:cu_seqlength].view(-1, 1, 1, 1)
|
||||
batch_offsets_v = cu_seqlens[cu_seqlength:].view(-1, 1, 1, 1)
|
||||
sequence_lengths = sequence_lengths.view(-1, 1, 1, 1)
|
||||
max_seqlen = max_seqlen.item()
|
||||
|
||||
output, _ = cudnn_batch_prefill_with_kv_cache(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
scale,
|
||||
workspace_buffer,
|
||||
max_token_per_sequence=max_seqlen,
|
||||
max_sequence_kv=max_seqlen,
|
||||
actual_seq_lens_q=sequence_lengths,
|
||||
actual_seq_lens_kv=sequence_lengths,
|
||||
causal=False,
|
||||
return_lse=False,
|
||||
batch_offsets_q=batch_offsets_qko,
|
||||
batch_offsets_k=batch_offsets_qko,
|
||||
batch_offsets_v=batch_offsets_v,
|
||||
batch_offsets_o=batch_offsets_qko,
|
||||
)
|
||||
|
||||
if is_reshaped:
|
||||
output = einops.rearrange(output, "(b s) h d -> b s h d", b=reshape_batch_size)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def vit_flashinfer_wrapper_fake(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
scale: float,
|
||||
workspace_buffer: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None,
|
||||
sequence_lengths: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(q)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="flashinfer_wrapper",
|
||||
op_func=flashinfer_wrapper,
|
||||
fake_impl=vit_flashinfer_wrapper_fake,
|
||||
)
|
||||
|
||||
|
||||
def vit_flashinfer_wrapper(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
scale: float,
|
||||
workspace_buffer: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None,
|
||||
sequence_lengths: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.flashinfer_wrapper(
|
||||
q, k, v, scale, workspace_buffer, cu_seqlens, max_seqlen, sequence_lengths
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user