Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -10,6 +10,7 @@ from einops import rearrange
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
|
||||
from vllm import envs
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
@@ -34,7 +35,8 @@ from vllm.model_executor.layers.fla.ops import (
|
||||
chunk_gated_delta_rule as fla_chunk_gated_delta_rule,
|
||||
)
|
||||
from vllm.model_executor.layers.fla.ops import (
|
||||
fused_recurrent_gated_delta_rule,
|
||||
fused_recurrent_gated_delta_rule_packed_decode,
|
||||
fused_sigmoid_gating_delta_rule_update,
|
||||
)
|
||||
from vllm.model_executor.layers.fla.ops.chunk import l2norm_fwd
|
||||
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
|
||||
@@ -114,7 +116,7 @@ def fi_chunk_gated_delta_rule(
|
||||
beta: torch.Tensor,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = True,
|
||||
):
|
||||
from flashinfer.gdn_prefill import (
|
||||
@@ -153,21 +155,13 @@ def fi_chunk_gated_delta_rule(
|
||||
class ChunkGatedDeltaRule(CustomOp):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# if current_platform.is_cuda() and current_platform.is_device_capability(90):
|
||||
# logger.info_once(
|
||||
# "Using FlashInfer GDN prefill kernel on CUDA compute capability 90"
|
||||
# )
|
||||
# self._forward_method = self.forward_cuda
|
||||
# else:
|
||||
# logger.info_once(
|
||||
# "Using FlashAttn GDN prefill kernel on CUDA compute capability 90"
|
||||
# )
|
||||
# self._forward_method = self.forward_native
|
||||
|
||||
logger.info_once(
|
||||
"Using FlashAttn GDN prefill kernel on CUDA compute capability 90"
|
||||
)
|
||||
self._forward_method = self.forward_native
|
||||
if current_platform.is_cuda() and current_platform.is_device_capability(90):
|
||||
logger.info_once(
|
||||
"Using FlashInfer GDN prefill kernel on CUDA compute capability 90"
|
||||
)
|
||||
self._forward_method = self.forward_cuda
|
||||
else:
|
||||
self._forward_method = self.forward_native
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
@@ -178,10 +172,10 @@ class ChunkGatedDeltaRule(CustomOp):
|
||||
beta: torch.Tensor,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = True,
|
||||
):
|
||||
return fi_chunk_gated_delta_rule(
|
||||
return self.forward_native(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
@@ -202,7 +196,7 @@ class ChunkGatedDeltaRule(CustomOp):
|
||||
beta: torch.Tensor,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = True,
|
||||
):
|
||||
return fla_chunk_gated_delta_rule(
|
||||
@@ -420,6 +414,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
prefix=f"{prefix}.in_proj_qkvz",
|
||||
)
|
||||
# ba_proj doesn't support blockwise fp8 quantization.
|
||||
# # in_proj_ba is defined as MergedColumnParallelLinear for
|
||||
# compatibility with Qwen3_5.
|
||||
self.in_proj_ba = MergedColumnParallelLinear(
|
||||
input_size=self.hidden_size,
|
||||
output_sizes=[self.num_v_heads] * 2,
|
||||
@@ -469,7 +465,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
device=current_platform.current_device(),
|
||||
dtype=config.dtype,
|
||||
)
|
||||
|
||||
self.out_proj = RowParallelLinear(
|
||||
@@ -482,6 +477,9 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
)
|
||||
|
||||
self.chunk_gated_delta_rule = ChunkGatedDeltaRule()
|
||||
self.enable_packed_recurrent_decode = (
|
||||
envs.VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE
|
||||
)
|
||||
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
@@ -631,6 +629,106 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
|
||||
output[:num_tokens], _ = self.out_proj(core_attn_out)
|
||||
|
||||
def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None:
|
||||
"""Warm up GDN prefill kernels during V1 profiling.
|
||||
|
||||
During V1 profile runs, ``_forward_core`` returns early because
|
||||
``attn_metadata`` is ``None``, so the autotuned kernels used by
|
||||
``chunk_gated_delta_rule`` (e.g. ``solve_tril``,
|
||||
``chunk_scaled_dot_kkt``) are never invoked. After profiling,
|
||||
vLLM allocates KV cache using most of the remaining GPU memory.
|
||||
When the first real inference triggers the autotuner it OOMs
|
||||
because there is not enough memory left for benchmarking.
|
||||
|
||||
This method runs minimal forward passes through
|
||||
``chunk_gated_delta_rule`` with small dummy tensors to force
|
||||
autotuning while GPU memory is still plentiful. The autotuner
|
||||
results are cached globally, so only the first layer incurs
|
||||
actual benchmarking cost.
|
||||
|
||||
Most kernels use a fixed ``BT = chunk_size`` (64), but
|
||||
``chunk_fwd_kernel_o`` recomputes ``BT`` from the sequence
|
||||
length: ``min(64, max(16, next_power_of_2(T)))``. Since ``BT``
|
||||
is part of its autotune key, we run warmup passes with T = 16,
|
||||
32, and 64 to cover all possible ``BT`` values.
|
||||
|
||||
The decode path uses ``fused_sigmoid_gating_delta_rule_update``
|
||||
which has fixed kernel parameters (no autotuning), so only the
|
||||
prefill (chunked) path needs warming up.
|
||||
"""
|
||||
if hasattr(self, "_prefill_kernels_warmed_up"):
|
||||
return
|
||||
self._prefill_kernels_warmed_up = True
|
||||
|
||||
device = mixed_qkv.device
|
||||
dtype = mixed_qkv.dtype
|
||||
num_k_heads = self.num_k_heads // self.tp_size
|
||||
num_v_heads = self.num_v_heads // self.tp_size
|
||||
_, state_dtype = self.get_state_dtype()
|
||||
|
||||
# Run warmup for each possible BT value of chunk_fwd_kernel_o:
|
||||
# T=16 → BT=16, T=32 → BT=32, T=64 → BT=64.
|
||||
# Other kernels always use BT=chunk_size(64), so their autotune
|
||||
# cache is populated on the first pass and reused thereafter.
|
||||
for T in (16, 32, 64):
|
||||
q = torch.randn(
|
||||
1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype
|
||||
)
|
||||
k = torch.randn(
|
||||
1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype
|
||||
)
|
||||
v = torch.randn(
|
||||
1, T, num_v_heads, self.head_v_dim, device=device, dtype=dtype
|
||||
)
|
||||
# NOTE: g and beta must have the same dtypes as during
|
||||
# inference, so we construct them with the same function
|
||||
# (fused_gdn_gating). dummy_a and dummy_b are throwaway
|
||||
# inputs required by that function.
|
||||
dummy_a = torch.randn(T, num_v_heads, device=device, dtype=dtype)
|
||||
dummy_b = torch.randn(T, num_v_heads, device=device, dtype=dtype)
|
||||
g, beta = fused_gdn_gating(self.A_log, dummy_a, dummy_b, self.dt_bias)
|
||||
state = torch.zeros(
|
||||
1,
|
||||
num_v_heads,
|
||||
self.head_v_dim,
|
||||
self.head_k_dim,
|
||||
device=device,
|
||||
dtype=state_dtype,
|
||||
)
|
||||
cu_seqlens = torch.tensor([0, T], device=device, dtype=torch.int32)
|
||||
|
||||
try:
|
||||
self.chunk_gated_delta_rule(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
g=g,
|
||||
beta=beta,
|
||||
initial_state=state,
|
||||
output_final_state=True,
|
||||
cu_seqlens=cu_seqlens,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"GDN prefill kernel warmup (T=%d) failed for "
|
||||
"layer %s. First inference may OOM due to "
|
||||
"autotuner.",
|
||||
T,
|
||||
self.prefix,
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"GDN prefill kernel warmup (T=%d) completed for layer %s",
|
||||
T,
|
||||
self.prefix,
|
||||
)
|
||||
finally:
|
||||
del q, k, v, dummy_a, dummy_b, g, beta, state, cu_seqlens
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def _forward_core(
|
||||
self,
|
||||
mixed_qkv: torch.Tensor,
|
||||
@@ -638,19 +736,34 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
a: torch.Tensor,
|
||||
core_attn_out: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Core attention computation (called by custom op).
|
||||
"""
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||
|
||||
if attn_metadata is None:
|
||||
# V1 profile run
|
||||
# V1 profile run — warm up prefill kernels so that
|
||||
# autotuning completes before KV cache allocation.
|
||||
self._warmup_prefill_kernels(mixed_qkv)
|
||||
return
|
||||
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
assert isinstance(attn_metadata, GDNAttentionMetadata)
|
||||
|
||||
if (
|
||||
self.enable_packed_recurrent_decode
|
||||
and attn_metadata.spec_sequence_masks is None
|
||||
and attn_metadata.num_prefills == 0
|
||||
and attn_metadata.num_decodes > 0
|
||||
):
|
||||
return self._forward_core_decode_non_spec(
|
||||
mixed_qkv=mixed_qkv,
|
||||
b=b,
|
||||
a=a,
|
||||
core_attn_out=core_attn_out,
|
||||
attn_metadata=attn_metadata,
|
||||
virtual_engine=forward_context.virtual_engine,
|
||||
)
|
||||
|
||||
has_initial_state = attn_metadata.has_initial_state
|
||||
spec_query_start_loc = attn_metadata.spec_query_start_loc
|
||||
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
|
||||
@@ -738,41 +851,40 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
mixed_qkv_non_spec
|
||||
)
|
||||
|
||||
g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias)
|
||||
|
||||
if spec_sequence_masks is not None:
|
||||
if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
|
||||
g_spec = g
|
||||
beta_spec = beta
|
||||
g_non_spec = None
|
||||
beta_non_spec = None
|
||||
else:
|
||||
g_spec = g.index_select(1, spec_token_indx)
|
||||
beta_spec = beta.index_select(1, spec_token_indx)
|
||||
if attn_metadata.num_prefills > 0:
|
||||
g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias)
|
||||
if spec_sequence_masks is not None:
|
||||
g_non_spec = g.index_select(1, non_spec_token_indx)
|
||||
beta_non_spec = beta.index_select(1, non_spec_token_indx)
|
||||
else:
|
||||
g_non_spec = g
|
||||
beta_non_spec = beta
|
||||
else:
|
||||
g_spec = None
|
||||
beta_spec = None
|
||||
g_non_spec = g
|
||||
beta_non_spec = beta
|
||||
g_non_spec = None
|
||||
beta_non_spec = None
|
||||
|
||||
# 2. Recurrent attention
|
||||
|
||||
# 2.1: Process the multi-query part
|
||||
if spec_sequence_masks is not None:
|
||||
core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule(
|
||||
q=query_spec,
|
||||
k=key_spec,
|
||||
v=value_spec,
|
||||
g=g_spec,
|
||||
beta=beta_spec,
|
||||
initial_state=ssm_state,
|
||||
inplace_final_state=True,
|
||||
cu_seqlens=spec_query_start_loc[: attn_metadata.num_spec_decodes + 1],
|
||||
ssm_state_indices=spec_state_indices_tensor,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
core_attn_out_spec, last_recurrent_state = (
|
||||
fused_sigmoid_gating_delta_rule_update(
|
||||
A_log=self.A_log,
|
||||
a=a,
|
||||
b=b,
|
||||
dt_bias=self.dt_bias,
|
||||
q=query_spec,
|
||||
k=key_spec,
|
||||
v=value_spec,
|
||||
initial_state=ssm_state,
|
||||
inplace_final_state=True,
|
||||
cu_seqlens=spec_query_start_loc[
|
||||
: attn_metadata.num_spec_decodes + 1
|
||||
],
|
||||
ssm_state_indices=spec_state_indices_tensor,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
core_attn_out_spec, last_recurrent_state = None, None
|
||||
@@ -801,12 +913,14 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
)
|
||||
elif attn_metadata.num_decodes > 0:
|
||||
core_attn_out_non_spec, last_recurrent_state = (
|
||||
fused_recurrent_gated_delta_rule(
|
||||
fused_sigmoid_gating_delta_rule_update(
|
||||
A_log=self.A_log,
|
||||
a=a,
|
||||
b=b,
|
||||
dt_bias=self.dt_bias,
|
||||
q=query_non_spec,
|
||||
k=key_non_spec,
|
||||
v=value_non_spec,
|
||||
g=g_non_spec,
|
||||
beta=beta_non_spec,
|
||||
initial_state=ssm_state,
|
||||
inplace_final_state=True,
|
||||
cu_seqlens=non_spec_query_start_loc[
|
||||
@@ -834,6 +948,55 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
else:
|
||||
core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)
|
||||
|
||||
def _forward_core_decode_non_spec(
|
||||
self,
|
||||
mixed_qkv: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
a: torch.Tensor,
|
||||
core_attn_out: torch.Tensor,
|
||||
attn_metadata: GDNAttentionMetadata,
|
||||
virtual_engine: int,
|
||||
):
|
||||
"""
|
||||
Core attention computation with a packed non-spec decode fast path.
|
||||
"""
|
||||
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
|
||||
self_kv_cache = self.kv_cache[virtual_engine]
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
mixed_qkv = mixed_qkv[:num_actual_tokens]
|
||||
b = b[:num_actual_tokens]
|
||||
a = a[:num_actual_tokens]
|
||||
|
||||
conv_weights = self.conv1d.weight.view(
|
||||
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
|
||||
)
|
||||
mixed_qkv_non_spec = causal_conv1d_update(
|
||||
mixed_qkv,
|
||||
conv_state,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
conv_state_indices=non_spec_state_indices_tensor[:num_actual_tokens],
|
||||
validate_data=False,
|
||||
)
|
||||
out_buf = core_attn_out[:num_actual_tokens].unsqueeze(1)
|
||||
fused_recurrent_gated_delta_rule_packed_decode(
|
||||
mixed_qkv=mixed_qkv_non_spec,
|
||||
a=a,
|
||||
b=b,
|
||||
A_log=self.A_log,
|
||||
dt_bias=self.dt_bias,
|
||||
scale=self.head_k_dim**-0.5,
|
||||
initial_state=ssm_state,
|
||||
out=out_buf,
|
||||
ssm_state_indices=non_spec_state_indices_tensor[:num_actual_tokens],
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
class Qwen3NextAttention(nn.Module):
|
||||
def __init__(
|
||||
@@ -961,7 +1124,7 @@ class Qwen3NextDecoderLayer(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
config = vllm_config.model_config.hf_text_config
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
@@ -1024,7 +1187,6 @@ class Qwen3NextDecoderLayer(nn.Module):
|
||||
1,
|
||||
1,
|
||||
config.hidden_size,
|
||||
dtype=config.dtype,
|
||||
),
|
||||
)
|
||||
self.ffn_layer_scale = torch.nn.Parameter(
|
||||
@@ -1032,7 +1194,6 @@ class Qwen3NextDecoderLayer(nn.Module):
|
||||
1,
|
||||
1,
|
||||
config.hidden_size,
|
||||
dtype=config.dtype,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1299,6 +1460,8 @@ class QwenNextMixtureOfExperts(MixtureOfExperts):
|
||||
self.moe_layers = []
|
||||
example_moe = None
|
||||
for layer in self.model.layers:
|
||||
if isinstance(layer, PPMissingLayer):
|
||||
continue
|
||||
if isinstance(layer, Qwen3NextDecoderLayer) and isinstance(
|
||||
layer.mlp, Qwen3NextSparseMoeBlock
|
||||
):
|
||||
@@ -1334,6 +1497,8 @@ class Qwen3NextForCausalLM(
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"in_proj_qkvz": ["in_proj_qkvz"],
|
||||
"in_proj_ba": ["in_proj_ba"],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
|
||||
Reference in New Issue
Block a user