Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -8,7 +8,11 @@ import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, has_deep_gemm
from vllm.utils.deep_gemm import (
get_paged_mqa_logits_metadata,
is_deep_gemm_supported,
)
from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import num_compute_units
from vllm.v1.attention.backend import (
AttentionBackend,
@@ -21,6 +25,7 @@ from vllm.v1.attention.backends.utils import (
split_decodes_and_prefills,
split_prefill_chunks,
)
from vllm.v1.worker.cp_utils import get_total_cp_world_size
logger = init_logger(__name__)
@@ -68,11 +73,15 @@ class DeepseekV32IndexerPrefillChunkMetadata:
cu_seqlen_ks: torch.Tensor
cu_seqlen_ke: torch.Tensor
cu_seq_lens: torch.Tensor
cu_seqlens_q: torch.Tensor
token_to_seq: torch.Tensor
total_seq_lens: int
token_start: int
token_end: int
num_reqs: int
max_context_len: int
max_q_len: int # Maximum query length for dsa_indexer_mqa_logits_with_blocks
max_kv_len: int # Maximum key-value length for dsa_indexer_mqa_logits_with_blocks
@dataclass
@@ -86,9 +95,16 @@ class DeepSeekV32IndexerDecodeMetadata:
seq_lens: torch.Tensor
decode_lens: torch.Tensor
requires_padding: bool
schedule_metadata: torch.Tensor
# schedule_metadata: torch.Tensor
use_large_context_topk: bool
offsets: torch.Tensor | None # Precomputed offsets for speculative decoding
cu_seqlen_ks: torch.Tensor
cu_seqlen_ke: torch.Tensor
cu_seqlens_kv: torch.Tensor
cu_seqlens_q: torch.Tensor
max_context_len: int
max_q_len: int
max_kv_len: int
@dataclass
@@ -211,20 +227,39 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
if self.vllm_config.speculative_config
else 0
)
if self.num_speculative_tokens > 1:
raise ValueError(
"Sparse MLA only supports "
"num_speculative_tokens <= 1 because the DeepGEMM "
"fp8_paged_mqa_logits kernel does not support next_n > 2. "
f"Got num_speculative_tokens={self.num_speculative_tokens}."
)
self.reorder_batch_threshold += self.num_speculative_tokens
sm_count = num_compute_units(self.device.index)
self.num_sms = sm_count
self.decode_lens_buffer = torch.empty(
(scheduler_config.max_num_seqs,), dtype=torch.int32, device=self.device
(scheduler_config.max_num_batched_tokens,),
dtype=torch.int32,
device=self.device,
)
# Pre-allocated buffers for flattening (spec decode).
self.arange_buffer = torch.arange(
scheduler_config.max_num_seqs * (1 + self.num_speculative_tokens),
dtype=torch.int32,
device=self.device,
)
self.expanded_seq_lens_buffer = torch.zeros(
(scheduler_config.max_num_batched_tokens,),
dtype=torch.int32,
device=self.device,
)
max_num_blocks_per_req = cdiv(
self.vllm_config.model_config.max_model_len,
self.kv_cache_spec.block_size * get_total_cp_world_size(),
)
self.expanded_block_table_buffer = torch.zeros(
(
scheduler_config.max_num_batched_tokens,
max_num_blocks_per_req,
),
dtype=torch.int32,
device=self.device,
)
# See: DeepGMM/csrc/apis/attention.hpp
@@ -260,18 +295,88 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
.to(torch.int32)
.to(self.device)
)
cu_seqlens_q = prefill_query_start_loc.to(torch.int32).to(self.device)
max_context_len = seq_lens_cpu[reqs_start:reqs_end].max().item()
# max_q_len is the maximum query length among all batches in this chunk
# prefill_query_start_loc is cumsum of lengths with shape [batch+1]
max_q_len = (prefill_query_start_loc[1:] - prefill_query_start_loc[:-1]).max().item()
return DeepseekV32IndexerPrefillChunkMetadata(
cu_seqlen_ks=cu_seqlen_ks,
cu_seqlen_ke=cu_seqlen_ke,
cu_seq_lens=cu_seq_lens,
token_to_seq=token_to_seq,
total_seq_lens=total_seq_lens,
cu_seqlens_q=cu_seqlens_q,
block_table=block_table[reqs_start:reqs_end],
token_start=token_start,
token_end=token_end,
num_reqs=reqs_end - reqs_start,
max_context_len=max_context_len,
max_q_len=max_q_len,
max_kv_len=max_context_len
)
def build_decode_metadata(
self, common_attn_metadata, num_decodes, decode_lens, use_large_context_topk, offsets
):
decode_lens_cpu = torch.diff(
common_attn_metadata.query_start_loc_cpu[: num_decodes + 1]
)
assert (
decode_lens_cpu.max().item()
== decode_lens_cpu.min().item()
== 1
), "Only support single token decode in dsa_indexer backend"
# Calculate decode metadata parameters
seq_lens_decode = common_attn_metadata.seq_lens_cpu[:num_decodes]
max_context_len = seq_lens_decode.max().item()
max_kv_len = max_context_len
max_q_len = 1 # Single token decode
# Create cu_seqlens_q: cumulative sum of query lengths (all 1s)
cu_seqlens_q = torch.arange(
num_decodes + 1, dtype=torch.int32, device=self.device
)
# Create cu_seqlens_kv and related tensors using kv_spans_from_batches
decode_query_start_loc = torch.arange(
num_decodes + 1, dtype=torch.long
)
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
decode_query_start_loc, seq_lens_decode, self.device
)
cu_seqlens_kv = torch.cat(
[
torch.zeros(1, dtype=torch.int32, device=self.device),
torch.cumsum(seq_lens_decode.to(self.device), dim=0)
.to(torch.int32),
]
)
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
block_table=common_attn_metadata.block_table_tensor[
:num_decodes, ...
],
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
decode_lens=decode_lens,
requires_padding=(
decode_lens_cpu.max() > decode_lens_cpu.min()
).item(),
use_large_context_topk=use_large_context_topk,
offsets=offsets,
cu_seqlen_ks=cu_seqlen_ks,
cu_seqlen_ke=cu_seqlen_ke,
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q=cu_seqlens_q,
max_context_len=max_context_len,
max_q_len=max_q_len,
max_kv_len=max_kv_len,
# schedule_metadata=self.scheduler_metadata_buffer,
)
return decode_metadata
def build(
self,
common_prefix_len: int,
@@ -323,45 +428,103 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
common_attn_metadata.query_start_loc_cpu[: num_decodes + 1]
)
# Use CPU to avoid GPU sync; breaking async scheduling
requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item()
# Decide which top-k kernel to use based on batch size and sequence length
batch_size = num_decodes
_is_large_context = common_attn_metadata.max_seq_len > 8192
# Decision logic based on micro-benchmark results:
# - large_context_topk wins for batch <= 128 and seq_len > 8K
# - top_k_per_row_decode wins for batch > 128 or seq_len <= 8K
use_large_context_topk = batch_size <= 128 and _is_large_context
next_n = 1 + self.num_speculative_tokens
if next_n > 1:
offsets = torch.arange(next_n, device=self.device, dtype=torch.int32)
else:
offsets = None
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
# DeepGEMM is required for the paged MQA logits on CUDA devices
if current_platform.is_cuda() and has_deep_gemm():
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
seq_lens, self.kv_cache_spec.block_size, self.num_sms
)
block_table = common_attn_metadata.block_table_tensor[:num_decodes, ...]
# Padded CUDA graph requests have block_table entries of -1.
# Clamp to 0 to prevent OOB access in the DeepGEMM kernel.
# This is safe because padded requests have seq_lens=0, so the
# kernel produces no meaningful output for those rows.
block_table.clamp_(min=0)
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
block_table=block_table,
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
decode_lens=decode_lens,
requires_padding=requires_padding,
schedule_metadata=self.scheduler_metadata_buffer,
use_large_context_topk=use_large_context_topk,
offsets=offsets,
max_decode_len = int(decode_lens_cpu.max().item())
if max_decode_len > 1:
# Flatten multi-token decode requests into single-token
# batch entries, expanding seq_lens and block tables so
# the kernel always sees next_n=1.
# Assume 4 requests with seq_lens [10, 7, 12, 0] (the final req is
# padding) and decode_lens [3, 1, 4, 0] in the below example comments.
# The context lengths are therefore
# [10-3, 7-1, 12-4, 0-0] = [7, 6, 8, 0].
# 3 + 1 + 4 + 0 = 8
actual_expanded = int(decode_lens_cpu.sum().item())
# [7, 6, 8, 0] -> [7, 7, 7, 6, 8, 8, 8, 8]
expanded_base = torch.repeat_interleave(
seq_lens - decode_lens, decode_lens
)
# [0, 3, 4, 8] -> [0, 0, 0, 3, 4, 4, 4, 4]
expanded_starts = torch.repeat_interleave(
common_attn_metadata.query_start_loc[:num_decodes], decode_lens
)
# [0, 1, 2, 0, 0, 1, 2, 3]
positions_within = (
self.arange_buffer[:actual_expanded] - expanded_starts
)
# [8, 9, 10, 7, 9, 10, 11, 12, ...] where ... is unused buffer space
self.expanded_seq_lens_buffer[:actual_expanded] = (
expanded_base + positions_within + 1
)
self.expanded_seq_lens_buffer[actual_expanded:] = 0
seq_lens = self.expanded_seq_lens_buffer[:num_decode_tokens]
# Give each of the flattened entries the same block table row as the
# original request.
self.expanded_block_table_buffer[:actual_expanded] = (
torch.repeat_interleave(block_table, decode_lens, dim=0)
)
if actual_expanded < num_decode_tokens:
self.expanded_block_table_buffer[
actual_expanded:num_decode_tokens, 0
] = 0
block_table = self.expanded_block_table_buffer[:num_decode_tokens]
# All reqs now have decode_len=1
self.decode_lens_buffer[:num_decode_tokens] = 1
decode_lens = self.decode_lens_buffer[:num_decode_tokens]
offsets = None
batch_size = num_decode_tokens
else:
next_n = 1 + self.num_speculative_tokens
if next_n > 1:
offsets = torch.arange(
next_n, device=self.device, dtype=torch.int32
)
else:
offsets = None
batch_size = num_decodes
# DeepGEMM is required for the paged MQA logits on CUDA devices
if current_platform.is_cuda() and is_deep_gemm_supported():
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
seq_lens,
self.kv_cache_spec.block_size,
self.num_sms,
)
# Decide which top-k kernel to use based on batch size and sequence length
# Decision logic based on micro-benchmark results:
# - large_context_topk wins for batch <= 128 and seq_len > 8K
# - top_k_per_row_decode wins for batch > 128 or seq_len <= 8K
_is_large_context = common_attn_metadata.max_seq_len > 8192
use_large_context_topk = batch_size <= 128 and _is_large_context
# decode_metadata = DeepSeekV32IndexerDecodeMetadata(
# block_table=block_table,
# seq_lens=seq_lens,
# decode_lens=decode_lens,
# requires_padding=False,
# # schedule_metadata=self.scheduler_metadata_buffer,
# use_large_context_topk=use_large_context_topk,
# offsets=offsets,
# )
decode_metadata = self.build_decode_metadata(
common_attn_metadata, num_decodes, decode_lens, use_large_context_topk, offsets
)
attn_metadata = DeepseekV32IndexerMetadata(