[3/N][Refactor][Qwen3-Next] Refacotr model structure and fix bug by vllm #25400 (#3142)

### What this PR does / why we need it?
Refactor model structure in qwen3_next.py to reduce code line.

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
```
def main():
    prompts = [
        "The future of AI is",
    ]

    # Create a sampling params object.
    sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95)
    # Create an LLM.
    llm = LLM(
        model="Qwen/Qwen3-Next-80B-A3B-Instruct",
              tensor_parallel_size=4,
              enforce_eager=True,
              trust_remote_code=True,
              max_model_len=256,
              gpu_memory_utilization=0.7,
              block_size=64,
              )
    # Generate texts from the prompts.
    outputs = llm.generate(prompts, sampling_params)
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```


- vLLM version: v0.10.2
- vLLM main:
https://github.com/vllm-project/vllm/commit/releases/v0.11.0

---------

Signed-off-by: Icey <1790571317@qq.com>
This commit is contained in:
Icey
2025-09-28 21:14:36 +08:00
committed by GitHub
parent 4ff422c730
commit dd56e9306b
2 changed files with 27 additions and 363 deletions

View File

@@ -53,4 +53,4 @@ def register_model():
)
ModelRegistry.register_model(
"Qwen3NextForCausalLM",
"vllm_ascend.models.qwen3_next:Qwen3NextForCausalLM")
"vllm_ascend.models.qwen3_next:CustomQwen3NextForCausalLM")

View File

@@ -14,10 +14,9 @@ from vllm.attention import AttentionBackend, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig,
VllmConfig, get_current_vllm_config)
from vllm.distributed import (divide, get_pp_group,
get_tensor_model_parallel_rank,
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fla.ops import RMSNormGated
from vllm.model_executor.layers.fla.ops.chunk import chunk_gated_delta_rule
from vllm.model_executor.layers.fla.ops.fused_recurrent import \
@@ -44,27 +43,24 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
MixtureOfExperts,
SupportsLoRA, SupportsPP)
from vllm.model_executor.models.mamba_cache import MambaCacheParams
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
from vllm.model_executor.models.qwen3_next import (Qwen3NextAttention,
Qwen3NextSparseMoeBlock,
fused_gdn_gating)
from vllm.model_executor.models.utils import (
AutoWeightsLoader, PPMissingLayer, extract_layer_index,
is_pp_missing_parameter, make_empty_intermediate_tensors_factory,
make_layers, maybe_prefix)
PPMissingLayer, extract_layer_index, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import Qwen3NextConfig
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
from vllm.model_executor.models.qwen3_next import Qwen3NextAttention # isort: skip
from vllm.model_executor.models.qwen3_next import Qwen3NextDecoderLayer # isort: skip
from vllm.model_executor.models.qwen3_next import Qwen3NextForCausalLM # isort: skip
from vllm.model_executor.models.qwen3_next import Qwen3NextGatedDeltaNet # isort: skip
from vllm.model_executor.models.qwen3_next import Qwen3NextModel # isort: skip
from vllm.model_executor.models.qwen3_next import Qwen3NextSparseMoeBlock # isort: skip
from vllm.model_executor.models.qwen3_next import fused_gdn_gating # isort: skip
class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase):
@property
def mamba_type(self) -> str:
@@ -80,14 +76,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
return MambaStateShapeCalculator.gated_delta_net_state_shape(
self.tp_size,
self.num_k_heads,
self.num_v_heads,
self.head_k_dim,
self.head_v_dim,
self.conv_kernel_size,
self.num_spec,
use_v1=True)
self.tp_size, self.num_k_heads, self.num_v_heads, self.head_k_dim,
self.head_v_dim, self.conv_kernel_size, self.num_spec)
def __init__(
self,
@@ -98,7 +88,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
speculative_config: Optional[SpeculativeConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
nn.Module.__init__(self)
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.hidden_size = config.hidden_size
@@ -195,85 +185,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
def fix_query_key_value_ordering(
self,
mixed_qkvz,
mixed_ba,
):
"""
Derives `query`, `key` and `value` tensors from `mixed_qkvzba`.
"""
new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + (
self.num_k_heads // self.tp_size,
(self.head_k_dim + self.head_k_dim +
(self.head_v_dim + self.head_v_dim) * self.num_v_heads //
self.num_k_heads),
)
new_tensor_shape_ba = mixed_qkvz.size()[:-1] + (
self.num_k_heads // self.tp_size,
2 * self.num_v_heads // self.num_k_heads,
)
mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz)
mixed_ba = mixed_ba.view(*new_tensor_shape_ba)
split_arg_list_qkvz = [
self.head_k_dim,
self.head_k_dim,
(self.num_v_heads // self.num_k_heads * self.head_v_dim),
(self.num_v_heads // self.num_k_heads * self.head_v_dim),
]
split_arg_list_ba = [
self.num_v_heads // self.num_k_heads,
self.num_v_heads // self.num_k_heads
]
# [b, sq, ng, (hn + hn + np/ng * hn + np/ng + np/ng)]
# --> [b, sq, ng, hn], [b, sq, ng, hn], [b, sq, ng, np/ng * hn],
# [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng], [b, sq, ng, np/ng]
(query, key, value, z) = torch.split(mixed_qkvz,
split_arg_list_qkvz,
dim=2)
(b, a) = torch.split(mixed_ba, split_arg_list_ba, dim=2)
# [b, sq, ng, np/ng * hn] -> [b, sq, np, hn]
value = value.reshape(value.size(0), -1, self.head_v_dim)
z = z.reshape(z.size(0), -1, self.head_v_dim)
b = b.reshape(b.size(0), self.num_v_heads // self.tp_size)
a = a.reshape(a.size(0), self.num_v_heads // self.tp_size)
return query, key, value, z, b, a
def rearrange_mixed_qkv(self, mixed_qkv):
if mixed_qkv is None:
return None, None, None
query, key, value = torch.split(
mixed_qkv,
[
self.key_dim // self.tp_size,
self.key_dim // self.tp_size,
self.value_dim // self.tp_size,
],
dim=-1,
)
query, key = map(
lambda x: rearrange(x, 'l (h d) -> 1 l h d', d=self.head_k_dim),
(query, key))
value = rearrange(value, 'l (h d) -> 1 l h d', d=self.head_v_dim)
return query, key, value
def forward(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
cache_params: Optional[MambaCacheParams] = None,
):
return torch.ops.vllm.npu_gdn_attention(
hidden_states,
output,
self.prefix,
)
def _forward(
self,
hidden_states: torch.Tensor,
@@ -340,24 +251,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
mixed_qkv_spec = None
mixed_qkv_non_spec = mixed_qkv
# 2.1: process the mutli-query part
# if spec_sequence_masks is not None:
# mixed_qkv_spec = mixed_qkv_spec.view(
# attn_metadata.num_spec_decodes, -1, mixed_qkv_spec.size(-1))
# mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b l d -> b d l')
# mixed_qkv_spec = causal_conv1d_update(
# mixed_qkv_spec,
# conv_state,
# conv_weights,
# self.conv1d.bias,
# self.activation,
# conv_state_indices=spec_state_indices_tensor[:, 0]
# [:attn_metadata.num_spec_decodes],
# num_accepted_tokens=num_accepted_tokens,
# validate_data=False,
# )
# mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b d l -> (b l) d')
# 2.2: process the remaining part
if attn_metadata.num_prefills > 0:
# - "cache_indices" updates the conv_state cache in positions
@@ -532,7 +425,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
output[:num_actual_tokens], _ = self.out_proj(core_attn_out)
class Qwen3NextDecoderLayer(nn.Module):
class CustomQwen3NextDecoderLayer(Qwen3NextDecoderLayer):
def __init__(
self,
@@ -545,14 +438,14 @@ class Qwen3NextDecoderLayer(nn.Module):
prefix: str = "",
enable_eplb: bool = False,
) -> None:
super().__init__()
nn.Module.__init__(self)
self.config = config
self.layer_type = layer_type
self.layer_idx = extract_layer_index(prefix)
if self.layer_type == "linear_attention":
self.linear_attn = Qwen3NextGatedDeltaNet(
self.linear_attn = CustomQwen3NextGatedDeltaNet(
config,
model_config=model_config,
cache_config=cache_config,
@@ -611,69 +504,12 @@ class Qwen3NextDecoderLayer(nn.Module):
dtype=config.torch_dtype,
), )
def forward(
self,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
positions: torch.Tensor = None,
**kwargs: object,
):
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
self_attention_output = torch.empty_like(hidden_states)
if self.layer_type == "linear_attention":
self.linear_attn(
hidden_states=hidden_states,
output=self_attention_output,
)
elif self.layer_type == "full_attention":
self.self_attn(
hidden_states=hidden_states,
output=self_attention_output,
positions=positions,
)
else:
raise ValueError("Invalid layer_type")
hidden_states = self_attention_output
if self.layer_scale:
if len(hidden_states.shape) == 2:
hidden_states = hidden_states * (
self.attn_layer_scale.to(hidden_states.dtype)[0] + 1)
else:
hidden_states = hidden_states * (
self.attn_layer_scale.to(hidden_states.dtype) + 1)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
if self.layer_scale:
if len(hidden_states.shape) == 2:
hidden_states = hidden_states * (
self.ffn_layer_scale.to(hidden_states.dtype)[0] + 1)
else:
assert len(hidden_states.shape) == len(
self.ffn_layer_scale.shape
), f'shape must be the same {len(hidden_states.shape)}, {len(self.ffn_layer_scale.shape)}' # noqa: E501
hidden_states = hidden_states * (
self.ffn_layer_scale.to(hidden_states.dtype) + 1)
return hidden_states, residual
@support_torch_compile
class Qwen3NextModel(nn.Module):
class CustomQwen3NextModel(Qwen3NextModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
nn.Module.__init__(self)
config: Qwen3NextConfig = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
@@ -697,7 +533,7 @@ class Qwen3NextModel(nn.Module):
)
def get_layer(prefix: str):
return Qwen3NextDecoderLayer(
return CustomQwen3NextDecoderLayer(
config,
layer_type=config.layer_types[extract_layer_index(prefix)],
model_config=model_config,
@@ -717,52 +553,6 @@ class Qwen3NextModel(nn.Module):
self.norm = Qwen3NextRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers:
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,
residual=residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts,
num_redundant_experts=self.num_redundant_experts)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
@@ -842,10 +632,10 @@ class Qwen3NextModel(nn.Module):
return loaded_params
class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
MixtureOfExperts, IsHybrid):
class CustomQwen3NextForCausalLM(Qwen3NextForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
@@ -856,12 +646,10 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
"Qwen3Next currently does not support prefix caching"
assert envs.VLLM_USE_V1, "Qwen3Next requires VLLM_USE_V1"
self.quant_config = vllm_config.quant_config
super().__init__()
self.config = config
self.scheduler_config = scheduler_config
self.model = Qwen3NextModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.model = CustomQwen3NextModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
@@ -904,127 +692,3 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
self.num_local_physical_experts = example_layer.n_local_physical_experts
self.num_routed_experts = example_layer.n_routed_experts
self.num_redundant_experts = example_layer.n_redundant_experts
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata(
self,
num_physical_experts: int,
num_local_physical_experts: int,
) -> None:
assert self.num_local_physical_experts == num_local_physical_experts
self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = (num_physical_experts -
self.num_logical_experts)
for layer in self.model.layers:
if isinstance(layer.mlp, Qwen3NextSparseMoeBlock):
moe = layer.mlp
moe.n_local_physical_experts = num_local_physical_experts
moe.n_physical_experts = num_physical_experts
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
):
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
@classmethod
def get_mamba_state_dtype_from_config(
cls,
vllm_config: "VllmConfig",
) -> tuple[torch.dtype, torch.dtype]:
return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
vllm_config.model_config.dtype,
vllm_config.cache_config.mamba_cache_dtype)
@classmethod
def get_mamba_state_shape_from_config(
cls, vllm_config: "VllmConfig"
) -> tuple[tuple[int, int], tuple[int, int]]:
parallel_config = vllm_config.parallel_config
hf_config = vllm_config.model_config.hf_config
tp_size = parallel_config.tensor_parallel_size
num_spec = (vllm_config.speculative_config.num_speculative_tokens
if vllm_config.speculative_config else 0)
return MambaStateShapeCalculator.gated_delta_net_state_shape(
tp_size,
hf_config.linear_num_key_heads,
hf_config.linear_num_value_heads,
hf_config.linear_key_head_dim,
hf_config.linear_value_head_dim,
hf_config.linear_conv_kernel_dim,
num_spec,
use_v1=True)
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata=None, # type: ignore
) -> Optional[torch.Tensor]:
return self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=["mtp."],
)
return loader.load_weights(weights)
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping()
def npu_gdn_attention(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self._forward(hidden_states=hidden_states, output=output)
def npu_gdn_attention_fake(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="npu_gdn_attention",
op_func=npu_gdn_attention,
mutates_args=["output"],
fake_impl=npu_gdn_attention_fake,
dispatch_key=current_platform.dispatch_key,
)