Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -9,58 +9,108 @@ from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
from vllm.utils.import_utils import has_deep_gemm
from vllm.utils.deep_gemm import (
fp8_mqa_logits,
fp8_mqa_logits_torch,
fp8_paged_mqa_logits,
fp8_paged_mqa_logits_torch,
is_deep_gemm_supported,
)
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.mla.indexer import (
DeepseekV32IndexerMetadata,
)
from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton
from vllm.v1.worker.workspace import current_workspace_manager
from vllm.utils.math_utils import cdiv
if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops
elif current_platform.is_xpu():
from vllm._xpu_ops import xpu_ops as ops
import ixformer.inference.functions as ixfops
logger = init_logger(__name__)
@torch.inference_mode()
def cp_gather_indexer_k_quant_cache(
kv_cache, # [num_blocks, block_size, head_dim]
dst_value, # [cu_seq_lens[-1], head_dim]
block_table, # [batch_size, num_blocks]
cu_seq_lens, # [batch_size + 1, ]
batch_size,
):
num_blocks, block_size, _ = kv_cache.shape
head_dim = dst_value.shape[-1]
kv_cache = kv_cache.view(num_blocks, -1)
expected_value = []
# expected_scale = []
for b in range(batch_size):
s = cu_seq_lens[b + 1] - cu_seq_lens[b]
if s == 0:
continue
tot = cdiv(s, block_size)
blocks = block_table[b, :tot]
value = []
scale = []
full_block = torch.arange(tot - 1,
device=kv_cache.device,
dtype=torch.int32)
non_remaining_value = kv_cache[blocks[full_block], :block_size *
head_dim].view(-1, head_dim)
# non_remaining_scale = kv_cache[blocks[full_block],
# block_size * head_dim:].view(-1, 4)
remaining = s - (tot - 1) * block_size
value = torch.cat([
non_remaining_value,
kv_cache[blocks[-1], :remaining * head_dim].view(-1, head_dim)
],
dim=0)
# scale = torch.cat([
# non_remaining_scale,
# kv_cache[blocks[-1], block_size * head_dim:block_size * head_dim +
# remaining * 4].view(-1, 4)
# ],
# dim=0)
expected_value.append(value)
# expected_scale.append(scale)
gather_value = torch.cat(expected_value, dim=0).view(-1, head_dim)
# gather_scale = torch.cat(expected_scale, dim=0).view(-1, 4)
gather_value = gather_value.view(torch.bfloat16)
# gather_scale = gather_scale.view(torch.float32)
dst_value.copy_(gather_value)
# dst_scale.copy_(gather_scale)
def sparse_attn_indexer(
hidden_states: torch.Tensor,
k_cache_prefix: str,
kv_cache: torch.Tensor,
q_fp8: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
weights: torch.Tensor,
quant_block_size: int,
scale_fmt: str | None,
topk_tokens: int,
head_dim: int,
max_model_len: int,
total_seq_lens: int,
topk_indices_buffer: torch.Tensor,
topk_indices_buffer: torch.Tensor | None,
) -> torch.Tensor:
# careful! this will be None in dummy run
attn_metadata = get_forward_context().attn_metadata
fp8_dtype = current_platform.fp8_dtype()
# assert isinstance(attn_metadata, dict)
if not isinstance(attn_metadata, dict):
# Reserve workspace for indexer during profiling run
current_workspace_manager().get_simultaneous(
((total_seq_lens, head_dim), torch.float8_e4m3fn),
((total_seq_lens, 4), torch.uint8),
)
return sparse_attn_indexer_fake(
hidden_states,
k_cache_prefix,
kv_cache,
q_fp8,
q,
k,
weights,
quant_block_size,
scale_fmt,
topk_tokens,
head_dim,
max_model_len,
@@ -74,12 +124,118 @@ def sparse_attn_indexer(
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
ops.indexer_k_quant_and_cache(
ops.indexer_k_cache(
k,
kv_cache,
slot_mapping
)
# topk_indices_buffer[: hidden_states.shape[0]] = -1
if has_prefill:
prefill_metadata = attn_metadata.prefill
for chunk in prefill_metadata.chunks:
logits = ixfops.dsa_indexer_mqa_logits_with_blocks(
q[chunk.token_start:chunk.token_end],
chunk.cu_seqlens_q,
chunk.cu_seq_lens,
kv_cache,
chunk.block_table,
weights[chunk.token_start : chunk.token_end],
max_q_len=chunk.max_q_len,
max_kv_len=chunk.max_kv_len,
max_context_len=chunk.max_context_len
)
ixfops.dsa_update_topk_indices(
logits, chunk.cu_seqlen_ks, chunk.cu_seqlen_ke, topk_tokens,
topk_indices_buffer[chunk.token_start:chunk.token_end]
)
if has_decode:
decode_metadata = attn_metadata.decode
# TODO: support speculative decode
if decode_metadata.requires_padding:
raise NotImplementedError(
"Sparse attention indexer does not support requires_padding"
)
# Use dsa_indexer_mqa_logits_with_blocks similar to prefill
logits = ixfops.dsa_indexer_mqa_logits_with_blocks(
q[:num_decode_tokens],
decode_metadata.cu_seqlens_q,
decode_metadata.cu_seqlens_kv,
kv_cache,
decode_metadata.block_table,
weights[:num_decode_tokens],
max_q_len=decode_metadata.max_q_len,
max_kv_len=decode_metadata.max_kv_len,
max_context_len=decode_metadata.max_context_len,
)
ixfops.dsa_update_topk_indices(
logits,
decode_metadata.cu_seqlen_ks,
decode_metadata.cu_seqlen_ke,
topk_tokens,
topk_indices_buffer[:num_decode_tokens],
)
return topk_indices_buffer
def sparse_attn_indexer_original(
hidden_states: torch.Tensor,
k_cache_prefix: str,
kv_cache: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
weights: torch.Tensor,
topk_tokens: int,
head_dim: int,
max_model_len: int,
total_seq_lens: int,
topk_indices_buffer: torch.Tensor,
) -> torch.Tensor:
# careful! this will be None in dummy run
attn_metadata = get_forward_context().attn_metadata
# fp8_dtype = current_platform.fp8_dtype()
# assert isinstance(attn_metadata, dict)
if not isinstance(attn_metadata, dict):
# Reserve workspace for indexer during profiling run
current_workspace_manager().get_simultaneous(
((total_seq_lens, head_dim), torch.float8_e4m3fn),
((total_seq_lens, 4), torch.uint8),
)
return sparse_attn_indexer_fake(
hidden_states,
k_cache_prefix,
kv_cache,
q,
k,
weights,
topk_tokens,
head_dim,
max_model_len,
total_seq_lens,
topk_indices_buffer,
)
attn_metadata = attn_metadata[k_cache_prefix]
assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
slot_mapping = attn_metadata.slot_mapping
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
# During speculative decoding, k may be padded to the CUDA graph batch
# size while slot_mapping only covers actual tokens. Truncate k to avoid
# out-of-bounds reads in the kernel.
num_tokens = slot_mapping.shape[0]
k = k[:num_tokens]
ops.indexer_k_cache(
k,
kv_cache,
slot_mapping,
quant_block_size,
scale_fmt,
)
topk_indices_buffer[: hidden_states.shape[0]] = -1
@@ -88,44 +244,42 @@ def sparse_attn_indexer(
# Get the full shared workspace buffers once (will allocate on first use)
workspace_manager = current_workspace_manager()
k_fp8_full, k_scale_full = workspace_manager.get_simultaneous(
((total_seq_lens, head_dim), fp8_dtype),
((total_seq_lens, 4), torch.uint8),
)
k_full = workspace_manager.get_simultaneous(
((total_seq_lens, head_dim), torch.bfloat16),
)[0]
for chunk in prefill_metadata.chunks:
k_fp8 = k_fp8_full[: chunk.total_seq_lens]
k_scale = k_scale_full[: chunk.total_seq_lens]
ops.cp_gather_indexer_k_quant_cache(
k = k_full[: chunk.total_seq_lens]
# k_scale = k_scale_full[: chunk.total_seq_lens]
cp_gather_indexer_k_quant_cache(
kv_cache,
k_fp8,
k_scale,
k,
chunk.block_table,
chunk.cu_seq_lens,
chunk.num_reqs,
)
logits = fp8_mqa_logits(
q_fp8[chunk.token_start : chunk.token_end],
(k_fp8, k_scale.view(torch.float32).flatten()),
logits = ops.ref_mqa_logits(
q[chunk.token_start:chunk.token_end],
k,
weights[chunk.token_start : chunk.token_end],
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
clean_logits=False,
)
num_rows = logits.shape[0]
topk_indices = topk_indices_buffer[
chunk.token_start : chunk.token_end, :topk_tokens
]
torch.ops._C.top_k_per_row_prefill(
logits,
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
topk_indices,
num_rows,
logits.stride(0),
logits.stride(1),
topk_tokens,
)
topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]),
dim=-1)[1]
topk_indices -= chunk.cu_seqlen_ks[:, None]
mask_lo = topk_indices >= 0
mask_hi = topk_indices - (chunk.cu_seqlen_ke -
chunk.cu_seqlen_ks)[:, None] < 0
mask = torch.full_like(topk_indices,
False,
dtype=torch.bool,
device=topk_indices.device)
mask = mask_lo & mask_hi
topk_indices = topk_indices.masked_fill(~mask, -1)
topk_indices_buffer[
chunk.token_start:chunk.token_end, :topk_indices.
shape[-1]] = topk_indices.to(dtype=torch.int32)
# Compute lengths from row spans
# lengths = (chunk.cu_seqlen_ke - chunk.cu_seqlen_ks).to(torch.int32)
@@ -147,63 +301,50 @@ def sparse_attn_indexer(
# decode_threshold since we unstrictly split
# prefill and decode by decode_threshold
# (currently set to 1 + speculative tokens)
padded_q_fp8_decode_tokens = pack_seq_triton(
q_fp8[:num_decode_tokens], decode_lens
)
padded_q_decode_tokens = pack_seq_triton(
q[:num_decode_tokens], decode_lens)
else:
padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape(
decode_lens.shape[0], -1, *q_fp8.shape[1:]
)
padded_q_decode_tokens = q[:num_decode_tokens].reshape(
decode_lens.shape[0], -1, *q.shape[1:])
# TODO: move and optimize below logic with triton kernels
batch_size = padded_q_fp8_decode_tokens.shape[0]
next_n = padded_q_fp8_decode_tokens.shape[1]
batch_size = padded_q_decode_tokens.shape[0]
next_n = padded_q_decode_tokens.shape[1]
assert batch_size == decode_metadata.seq_lens.shape[0]
num_padded_tokens = batch_size * next_n
logits = fp8_paged_mqa_logits(
padded_q_fp8_decode_tokens,
logits = ops.ref_paged_mqa_logits(
padded_q_decode_tokens,
kv_cache,
weights[:num_padded_tokens],
decode_metadata.seq_lens,
decode_metadata.block_table,
decode_metadata.schedule_metadata,
max_model_len=max_model_len,
clean_logits=False,
)
num_rows = logits.shape[0]
topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens]
if decode_metadata.use_large_context_topk:
if next_n == 1:
lengths = decode_metadata.seq_lens
else:
# (bs,) -> (bs, 1) + (next_n,) -> (bs, next_n) -> (bs * next_n,)
lengths = (
decode_metadata.seq_lens.unsqueeze(1)
- next_n
+ 1
+ decode_metadata.offsets
).flatten()
torch.ops._C.large_context_topk(
logits,
topk_indices,
lengths,
None,
)
else:
torch.ops._C.top_k_per_row_decode(
logits,
next_n,
decode_metadata.seq_lens,
topk_indices,
num_rows,
logits.stride(0),
logits.stride(1),
topk_tokens,
)
# padded query len
current_device = padded_q_decode_tokens.device
padded_num_tokens = batch_size * next_n
positions = torch.arange(max_model_len,
device=current_device).unsqueeze(0).expand(
batch_size * next_n, -1)
row_indices = torch.arange(padded_num_tokens,
device=current_device) // next_n
next_n_offset = torch.arange(
padded_num_tokens,
device=padded_q_decode_tokens.device) % next_n
index_end_pos = (decode_metadata.seq_lens[row_indices] - next_n +
next_n_offset).unsqueeze(1)
# index_end_pos: [B * N, 1]
mask = positions <= index_end_pos
# mask: [B * N, L]
logits = logits.masked_fill(~mask, float('-inf'))
topk_indices = logits.topk(topk_tokens,
dim=-1)[1].to(torch.int32) # [B * N, K]
# ensure we don't set indices for the top k
# that is out of range(masked already)
# this will happen if context length is shorter than K
topk_indices[topk_indices > index_end_pos] = -1
if decode_metadata.requires_padding:
# if padded, we need to unpack
# the topk indices removing padded tokens
@@ -211,9 +352,8 @@ def sparse_attn_indexer(
topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]),
decode_lens,
)
topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = (
topk_indices
)
topk_indices_buffer[:num_decode_tokens, :topk_indices.
shape[-1]] = topk_indices.to(dtype=torch.int32)
return topk_indices_buffer
@@ -222,11 +362,9 @@ def sparse_attn_indexer_fake(
hidden_states: torch.Tensor,
k_cache_prefix: str,
kv_cache: torch.Tensor,
q_fp8: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
weights: torch.Tensor,
quant_block_size: int,
scale_fmt: str | None,
topk_tokens: int,
head_dim: int,
max_model_len: int,
@@ -278,9 +416,12 @@ class SparseAttnIndexer(CustomOp):
self.max_model_len = max_model_len
self.max_total_seq_len = max_total_seq_len
self.topk_indices_buffer = topk_indices_buffer
if current_platform.is_cuda() and not has_deep_gemm():
raise RuntimeError(
"Sparse Attention Indexer CUDA op requires DeepGEMM to be installed."
if current_platform.is_cuda() and not is_deep_gemm_supported():
logger.warning_once(
"DeepGEMM is not supported or available. SparseAttnIndexer will use a "
"less efficient PyTorch implementation. "
"Please make sure you have the required hardware and software setup "
"for DeepGEMM to achieve optimal performance."
)
def forward_native(
@@ -303,7 +444,7 @@ class SparseAttnIndexer(CustomOp):
def forward_cuda(
self,
hidden_states: torch.Tensor,
q_fp8: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
weights: torch.Tensor,
):
@@ -311,11 +452,9 @@ class SparseAttnIndexer(CustomOp):
hidden_states,
self.k_cache.prefix,
self.k_cache.kv_cache[0],
q_fp8,
q,
k,
weights,
self.quant_block_size,
self.scale_fmt,
self.topk_tokens,
self.head_dim,
self.max_model_len,