Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -10,6 +10,7 @@ from einops import rearrange
from torch import nn
from transformers.activations import ACT2FN
from vllm import envs
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (
CacheConfig,
@@ -34,7 +35,8 @@ from vllm.model_executor.layers.fla.ops import (
chunk_gated_delta_rule as fla_chunk_gated_delta_rule,
)
from vllm.model_executor.layers.fla.ops import (
fused_recurrent_gated_delta_rule,
fused_recurrent_gated_delta_rule_packed_decode,
fused_sigmoid_gating_delta_rule_update,
)
from vllm.model_executor.layers.fla.ops.chunk import l2norm_fwd
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
@@ -114,7 +116,7 @@ def fi_chunk_gated_delta_rule(
beta: torch.Tensor,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = True,
):
from flashinfer.gdn_prefill import (
@@ -153,21 +155,13 @@ def fi_chunk_gated_delta_rule(
class ChunkGatedDeltaRule(CustomOp):
def __init__(self) -> None:
super().__init__()
# if current_platform.is_cuda() and current_platform.is_device_capability(90):
# logger.info_once(
# "Using FlashInfer GDN prefill kernel on CUDA compute capability 90"
# )
# self._forward_method = self.forward_cuda
# else:
# logger.info_once(
# "Using FlashAttn GDN prefill kernel on CUDA compute capability 90"
# )
# self._forward_method = self.forward_native
logger.info_once(
"Using FlashAttn GDN prefill kernel on CUDA compute capability 90"
)
self._forward_method = self.forward_native
if current_platform.is_cuda() and current_platform.is_device_capability(90):
logger.info_once(
"Using FlashInfer GDN prefill kernel on CUDA compute capability 90"
)
self._forward_method = self.forward_cuda
else:
self._forward_method = self.forward_native
def forward_cuda(
self,
@@ -178,10 +172,10 @@ class ChunkGatedDeltaRule(CustomOp):
beta: torch.Tensor,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = True,
):
return fi_chunk_gated_delta_rule(
return self.forward_native(
q=q,
k=k,
v=v,
@@ -202,7 +196,7 @@ class ChunkGatedDeltaRule(CustomOp):
beta: torch.Tensor,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = True,
):
return fla_chunk_gated_delta_rule(
@@ -420,6 +414,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
prefix=f"{prefix}.in_proj_qkvz",
)
# ba_proj doesn't support blockwise fp8 quantization.
# # in_proj_ba is defined as MergedColumnParallelLinear for
# compatibility with Qwen3_5.
self.in_proj_ba = MergedColumnParallelLinear(
input_size=self.hidden_size,
output_sizes=[self.num_v_heads] * 2,
@@ -469,7 +465,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
group_size=None,
norm_before_gate=True,
device=current_platform.current_device(),
dtype=config.dtype,
)
self.out_proj = RowParallelLinear(
@@ -482,6 +477,9 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
)
self.chunk_gated_delta_rule = ChunkGatedDeltaRule()
self.enable_packed_recurrent_decode = (
envs.VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE
)
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
@@ -631,6 +629,106 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
output[:num_tokens], _ = self.out_proj(core_attn_out)
def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None:
"""Warm up GDN prefill kernels during V1 profiling.
During V1 profile runs, ``_forward_core`` returns early because
``attn_metadata`` is ``None``, so the autotuned kernels used by
``chunk_gated_delta_rule`` (e.g. ``solve_tril``,
``chunk_scaled_dot_kkt``) are never invoked. After profiling,
vLLM allocates KV cache using most of the remaining GPU memory.
When the first real inference triggers the autotuner it OOMs
because there is not enough memory left for benchmarking.
This method runs minimal forward passes through
``chunk_gated_delta_rule`` with small dummy tensors to force
autotuning while GPU memory is still plentiful. The autotuner
results are cached globally, so only the first layer incurs
actual benchmarking cost.
Most kernels use a fixed ``BT = chunk_size`` (64), but
``chunk_fwd_kernel_o`` recomputes ``BT`` from the sequence
length: ``min(64, max(16, next_power_of_2(T)))``. Since ``BT``
is part of its autotune key, we run warmup passes with T = 16,
32, and 64 to cover all possible ``BT`` values.
The decode path uses ``fused_sigmoid_gating_delta_rule_update``
which has fixed kernel parameters (no autotuning), so only the
prefill (chunked) path needs warming up.
"""
if hasattr(self, "_prefill_kernels_warmed_up"):
return
self._prefill_kernels_warmed_up = True
device = mixed_qkv.device
dtype = mixed_qkv.dtype
num_k_heads = self.num_k_heads // self.tp_size
num_v_heads = self.num_v_heads // self.tp_size
_, state_dtype = self.get_state_dtype()
# Run warmup for each possible BT value of chunk_fwd_kernel_o:
# T=16 → BT=16, T=32 → BT=32, T=64 → BT=64.
# Other kernels always use BT=chunk_size(64), so their autotune
# cache is populated on the first pass and reused thereafter.
for T in (16, 32, 64):
q = torch.randn(
1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype
)
k = torch.randn(
1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype
)
v = torch.randn(
1, T, num_v_heads, self.head_v_dim, device=device, dtype=dtype
)
# NOTE: g and beta must have the same dtypes as during
# inference, so we construct them with the same function
# (fused_gdn_gating). dummy_a and dummy_b are throwaway
# inputs required by that function.
dummy_a = torch.randn(T, num_v_heads, device=device, dtype=dtype)
dummy_b = torch.randn(T, num_v_heads, device=device, dtype=dtype)
g, beta = fused_gdn_gating(self.A_log, dummy_a, dummy_b, self.dt_bias)
state = torch.zeros(
1,
num_v_heads,
self.head_v_dim,
self.head_k_dim,
device=device,
dtype=state_dtype,
)
cu_seqlens = torch.tensor([0, T], device=device, dtype=torch.int32)
try:
self.chunk_gated_delta_rule(
q=q,
k=k,
v=v,
g=g,
beta=beta,
initial_state=state,
output_final_state=True,
cu_seqlens=cu_seqlens,
use_qk_l2norm_in_kernel=True,
)
except Exception:
logger.warning(
"GDN prefill kernel warmup (T=%d) failed for "
"layer %s. First inference may OOM due to "
"autotuner.",
T,
self.prefix,
exc_info=True,
)
else:
logger.debug(
"GDN prefill kernel warmup (T=%d) completed for layer %s",
T,
self.prefix,
)
finally:
del q, k, v, dummy_a, dummy_b, g, beta, state, cu_seqlens
torch.cuda.empty_cache()
def _forward_core(
self,
mixed_qkv: torch.Tensor,
@@ -638,19 +736,34 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
a: torch.Tensor,
core_attn_out: torch.Tensor,
):
"""
Core attention computation (called by custom op).
"""
forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata
if attn_metadata is None:
# V1 profile run
# V1 profile run — warm up prefill kernels so that
# autotuning completes before KV cache allocation.
self._warmup_prefill_kernels(mixed_qkv)
return
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, GDNAttentionMetadata)
if (
self.enable_packed_recurrent_decode
and attn_metadata.spec_sequence_masks is None
and attn_metadata.num_prefills == 0
and attn_metadata.num_decodes > 0
):
return self._forward_core_decode_non_spec(
mixed_qkv=mixed_qkv,
b=b,
a=a,
core_attn_out=core_attn_out,
attn_metadata=attn_metadata,
virtual_engine=forward_context.virtual_engine,
)
has_initial_state = attn_metadata.has_initial_state
spec_query_start_loc = attn_metadata.spec_query_start_loc
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
@@ -738,41 +851,40 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
mixed_qkv_non_spec
)
g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias)
if spec_sequence_masks is not None:
if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
g_spec = g
beta_spec = beta
g_non_spec = None
beta_non_spec = None
else:
g_spec = g.index_select(1, spec_token_indx)
beta_spec = beta.index_select(1, spec_token_indx)
if attn_metadata.num_prefills > 0:
g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias)
if spec_sequence_masks is not None:
g_non_spec = g.index_select(1, non_spec_token_indx)
beta_non_spec = beta.index_select(1, non_spec_token_indx)
else:
g_non_spec = g
beta_non_spec = beta
else:
g_spec = None
beta_spec = None
g_non_spec = g
beta_non_spec = beta
g_non_spec = None
beta_non_spec = None
# 2. Recurrent attention
# 2.1: Process the multi-query part
if spec_sequence_masks is not None:
core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule(
q=query_spec,
k=key_spec,
v=value_spec,
g=g_spec,
beta=beta_spec,
initial_state=ssm_state,
inplace_final_state=True,
cu_seqlens=spec_query_start_loc[: attn_metadata.num_spec_decodes + 1],
ssm_state_indices=spec_state_indices_tensor,
num_accepted_tokens=num_accepted_tokens,
use_qk_l2norm_in_kernel=True,
core_attn_out_spec, last_recurrent_state = (
fused_sigmoid_gating_delta_rule_update(
A_log=self.A_log,
a=a,
b=b,
dt_bias=self.dt_bias,
q=query_spec,
k=key_spec,
v=value_spec,
initial_state=ssm_state,
inplace_final_state=True,
cu_seqlens=spec_query_start_loc[
: attn_metadata.num_spec_decodes + 1
],
ssm_state_indices=spec_state_indices_tensor,
num_accepted_tokens=num_accepted_tokens,
use_qk_l2norm_in_kernel=True,
)
)
else:
core_attn_out_spec, last_recurrent_state = None, None
@@ -801,12 +913,14 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
)
elif attn_metadata.num_decodes > 0:
core_attn_out_non_spec, last_recurrent_state = (
fused_recurrent_gated_delta_rule(
fused_sigmoid_gating_delta_rule_update(
A_log=self.A_log,
a=a,
b=b,
dt_bias=self.dt_bias,
q=query_non_spec,
k=key_non_spec,
v=value_non_spec,
g=g_non_spec,
beta=beta_non_spec,
initial_state=ssm_state,
inplace_final_state=True,
cu_seqlens=non_spec_query_start_loc[
@@ -834,6 +948,55 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
else:
core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)
def _forward_core_decode_non_spec(
self,
mixed_qkv: torch.Tensor,
b: torch.Tensor,
a: torch.Tensor,
core_attn_out: torch.Tensor,
attn_metadata: GDNAttentionMetadata,
virtual_engine: int,
):
"""
Core attention computation with a packed non-spec decode fast path.
"""
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
self_kv_cache = self.kv_cache[virtual_engine]
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
num_actual_tokens = attn_metadata.num_actual_tokens
mixed_qkv = mixed_qkv[:num_actual_tokens]
b = b[:num_actual_tokens]
a = a[:num_actual_tokens]
conv_weights = self.conv1d.weight.view(
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
)
mixed_qkv_non_spec = causal_conv1d_update(
mixed_qkv,
conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=non_spec_state_indices_tensor[:num_actual_tokens],
validate_data=False,
)
out_buf = core_attn_out[:num_actual_tokens].unsqueeze(1)
fused_recurrent_gated_delta_rule_packed_decode(
mixed_qkv=mixed_qkv_non_spec,
a=a,
b=b,
A_log=self.A_log,
dt_bias=self.dt_bias,
scale=self.head_k_dim**-0.5,
initial_state=ssm_state,
out=out_buf,
ssm_state_indices=non_spec_state_indices_tensor[:num_actual_tokens],
use_qk_l2norm_in_kernel=True,
)
return
class Qwen3NextAttention(nn.Module):
def __init__(
@@ -961,7 +1124,7 @@ class Qwen3NextDecoderLayer(nn.Module):
) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
config = vllm_config.model_config.hf_text_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
@@ -1024,7 +1187,6 @@ class Qwen3NextDecoderLayer(nn.Module):
1,
1,
config.hidden_size,
dtype=config.dtype,
),
)
self.ffn_layer_scale = torch.nn.Parameter(
@@ -1032,7 +1194,6 @@ class Qwen3NextDecoderLayer(nn.Module):
1,
1,
config.hidden_size,
dtype=config.dtype,
),
)
@@ -1299,6 +1460,8 @@ class QwenNextMixtureOfExperts(MixtureOfExperts):
self.moe_layers = []
example_moe = None
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue
if isinstance(layer, Qwen3NextDecoderLayer) and isinstance(
layer.mlp, Qwen3NextSparseMoeBlock
):
@@ -1334,6 +1497,8 @@ class Qwen3NextForCausalLM(
"v_proj",
],
"gate_up_proj": ["gate_proj", "up_proj"],
"in_proj_qkvz": ["in_proj_qkvz"],
"in_proj_ba": ["in_proj_ba"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):