Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -6,6 +6,7 @@ import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -151,7 +152,34 @@ def flash_mla_with_kvcache_fp8(
|
||||
descale_k,
|
||||
)
|
||||
return out, softmax_lse
|
||||
def flash_mla_sparse_prefill(
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
sm_scale: float,
|
||||
d_v: int = 512,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Sparse attention prefill kernel
|
||||
|
||||
Args:
|
||||
- q: [s_q, h_q, d_qk], bfloat16
|
||||
- kv: [s_kv, h_kv, d_qk], bfloat16
|
||||
- indices: [s_q, h_kv, topk], int32.
|
||||
Invalid indices should be set to -1 or numbers >= s_kv
|
||||
- sm_scale: float
|
||||
- d_v: The dimension of value vectors. Can only be 512
|
||||
|
||||
Returns:
|
||||
- (output, max_logits, lse)
|
||||
About the definition of output,
|
||||
max_logits and lse, please refer to README.md
|
||||
- output: [s_q, h_q, d_v], bfloat16
|
||||
- max_logits: [s_q, h_q], float
|
||||
- lse: [s_q, h_q], float, 2-based log-sum-exp
|
||||
"""
|
||||
results = ops.sparse_prefill_fwd(q, kv, indices,sm_scale, d_v)
|
||||
return results
|
||||
|
||||
#
|
||||
# TODO: Add fake functions
|
||||
|
||||
@@ -37,8 +37,8 @@ def flash_attn_maxseqlen_wrapper(
|
||||
else:
|
||||
from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func
|
||||
|
||||
# if not current_platform.is_rocm() and fa_version is not None:
|
||||
# kwargs["fa_version"] = fa_version
|
||||
if not current_platform.is_rocm() and fa_version is not None:
|
||||
kwargs["fa_version"] = fa_version
|
||||
|
||||
q_len = q.size(1)
|
||||
if cu_seqlens is None:
|
||||
@@ -268,3 +268,91 @@ def vit_torch_sdpa_wrapper(
|
||||
return torch.ops.vllm.torch_sdpa_wrapper(
|
||||
q, k, v, scale, cu_seqlens, enable_gqa=enable_gqa
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_wrapper(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
scale: float,
|
||||
workspace_buffer: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None,
|
||||
sequence_lengths: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache
|
||||
|
||||
is_reshaped = q.dim() == 4
|
||||
|
||||
if is_reshaped:
|
||||
reshape_batch_size = q.shape[0]
|
||||
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||
# cuDNN <= 9.10.2.21 requires q, k to be contiguous
|
||||
# this comes with no cost for ViTs with RoPE because
|
||||
# RoPE has already made q and k contiguous.
|
||||
q, k = q.contiguous(), k.contiguous()
|
||||
|
||||
assert len(cu_seqlens) % 2 == 0, "cu_seqlens must be divisible by 2"
|
||||
cu_seqlength = len(cu_seqlens) // 2
|
||||
batch_offsets_qko = cu_seqlens[:cu_seqlength].view(-1, 1, 1, 1)
|
||||
batch_offsets_v = cu_seqlens[cu_seqlength:].view(-1, 1, 1, 1)
|
||||
sequence_lengths = sequence_lengths.view(-1, 1, 1, 1)
|
||||
max_seqlen = max_seqlen.item()
|
||||
|
||||
output, _ = cudnn_batch_prefill_with_kv_cache(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
scale,
|
||||
workspace_buffer,
|
||||
max_token_per_sequence=max_seqlen,
|
||||
max_sequence_kv=max_seqlen,
|
||||
actual_seq_lens_q=sequence_lengths,
|
||||
actual_seq_lens_kv=sequence_lengths,
|
||||
causal=False,
|
||||
return_lse=False,
|
||||
batch_offsets_q=batch_offsets_qko,
|
||||
batch_offsets_k=batch_offsets_qko,
|
||||
batch_offsets_v=batch_offsets_v,
|
||||
batch_offsets_o=batch_offsets_qko,
|
||||
)
|
||||
|
||||
if is_reshaped:
|
||||
output = einops.rearrange(output, "(b s) h d -> b s h d", b=reshape_batch_size)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def vit_flashinfer_wrapper_fake(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
scale: float,
|
||||
workspace_buffer: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None,
|
||||
sequence_lengths: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(q)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="flashinfer_wrapper",
|
||||
op_func=flashinfer_wrapper,
|
||||
fake_impl=vit_flashinfer_wrapper_fake,
|
||||
)
|
||||
|
||||
|
||||
def vit_flashinfer_wrapper(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
scale: float,
|
||||
workspace_buffer: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None,
|
||||
sequence_lengths: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.flashinfer_wrapper(
|
||||
q, k, v, scale, workspace_buffer, cu_seqlens, max_seqlen, sequence_lengths
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user