[1/N][Refactor][Qwen3-Next] remove redundant Qwen3NextSparseMoeBlock and Qwen3NextAttention (#3019)
### What this PR does / why we need it?
remove redundant Qwen3NextSparseMoeBlock and Qwen3NextAttention
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
```
def main():
prompts = [
"The future of AI is",
]
sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95)
# Create an LLM.
llm = LLM(
# model="/root/.cache/modelscope/hub/models/Qwen/Qwen3-30B-A3B",
model="Qwen/Qwen3-Next-80B-A3B-Instruct",
tensor_parallel_size=4,
enforce_eager=True,
trust_remote_code=True,
max_model_len=256,
gpu_memory_utilization=0.7,
block_size=64,
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
- vLLM version: v0.10.2
- vLLM main:
9d1c50a5ac
---------
Signed-off-by: Icey <1790571317@qq.com>
This commit is contained in:
@@ -11,11 +11,11 @@ from einops import rearrange
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from vllm import envs
|
||||
from vllm.attention import Attention, AttentionBackend, AttentionMetadata
|
||||
from vllm.attention import AttentionBackend, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig,
|
||||
VllmConfig, get_current_vllm_config)
|
||||
from vllm.distributed import (divide, get_ep_group, get_pp_group,
|
||||
from vllm.distributed import (divide, get_pp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
@@ -27,8 +27,6 @@ from vllm.model_executor.layers.layernorm import \
|
||||
# yapf: enable
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
@@ -37,10 +35,6 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import \
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import \
|
||||
GPTQMarlinConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
@@ -50,6 +44,8 @@ from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
|
||||
SupportsLoRA, SupportsPP)
|
||||
from vllm.model_executor.models.mamba_cache import MambaCacheParams
|
||||
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
|
||||
from vllm.model_executor.models.qwen3_next import (Qwen3NextAttention,
|
||||
Qwen3NextSparseMoeBlock)
|
||||
from vllm.model_executor.models.utils import (
|
||||
AutoWeightsLoader, PPMissingLayer, extract_layer_index,
|
||||
is_pp_missing_parameter, make_empty_intermediate_tensors_factory,
|
||||
@@ -68,112 +64,6 @@ from vllm_ascend.ops.fla import RMSNormGated, fused_gdn_gating
|
||||
from vllm_ascend.ops.sigmoid_gating import fused_recurrent_gated_delta_rule
|
||||
|
||||
|
||||
class Qwen3NextSparseMoeBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen3NextConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
enable_eplb: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
self.ep_group = get_ep_group().device_group
|
||||
self.ep_rank = self.ep_group.rank()
|
||||
self.ep_size = self.ep_group.size()
|
||||
self.n_routed_experts = config.num_experts
|
||||
|
||||
if self.tp_size > config.num_experts:
|
||||
raise ValueError(
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
f"the number of experts {config.num_experts}.")
|
||||
|
||||
# Load balancing settings.
|
||||
vllm_config = get_current_vllm_config()
|
||||
eplb_config = vllm_config.parallel_config.eplb_config
|
||||
self.enable_eplb = enable_eplb
|
||||
|
||||
self.n_logical_experts = self.n_routed_experts
|
||||
self.n_redundant_experts = eplb_config.num_redundant_experts
|
||||
self.n_physical_experts = (self.n_logical_experts +
|
||||
self.n_redundant_experts)
|
||||
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
|
||||
|
||||
self.physical_expert_start = (self.ep_rank *
|
||||
self.n_local_physical_experts)
|
||||
self.physical_expert_end = (self.physical_expert_start +
|
||||
self.n_local_physical_experts)
|
||||
|
||||
self.experts = FusedMoE(num_experts=self.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=False,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts",
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts)
|
||||
|
||||
self.gate = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
config.num_experts,
|
||||
bias=False,
|
||||
quant_config=self._maybe_ignore_quant_config(quant_config),
|
||||
prefix=f"{prefix}.gate")
|
||||
|
||||
if config.shared_expert_intermediate_size > 0:
|
||||
self.shared_expert = Qwen3NextMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.shared_expert_intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
reduce_results=self.experts.must_reduce_shared_expert_outputs(
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.shared_expert = None
|
||||
self.shared_expert_gate = torch.nn.Linear(config.hidden_size,
|
||||
1,
|
||||
bias=False)
|
||||
|
||||
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
|
||||
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
|
||||
# seems to avoid gate quantization.
|
||||
# See: https://huggingface.co/Qwen/Qwen3-30B-A3B-GPTQ-Int4
|
||||
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
|
||||
return None
|
||||
return quant_config
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# NOTE: hidden_states can have either 1D or 2D shape.
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_dim = hidden_states.shape[-1]
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
shared_output = None
|
||||
if self.shared_expert is not None:
|
||||
shared_output = self.shared_expert(hidden_states)
|
||||
if self.shared_expert_gate is not None:
|
||||
shared_output = F.sigmoid(
|
||||
self.shared_expert_gate(hidden_states)) * shared_output
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = self.experts(hidden_states=hidden_states,
|
||||
router_logits=router_logits)
|
||||
|
||||
if shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
|
||||
final_hidden_states)
|
||||
|
||||
return final_hidden_states.view(orig_shape)
|
||||
|
||||
|
||||
def torch_chunk_gated_delta_rule(
|
||||
query,
|
||||
key,
|
||||
@@ -473,7 +363,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
output: torch.Tensor,
|
||||
cache_params: Optional[MambaCacheParams] = None,
|
||||
):
|
||||
return torch.ops.vllm.gdn_attention(
|
||||
return torch.ops.vllm.npu_gdn_attention(
|
||||
hidden_states,
|
||||
output,
|
||||
self.prefix,
|
||||
@@ -737,123 +627,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
output[:num_actual_tokens], _ = self.out_proj(core_attn_out)
|
||||
|
||||
|
||||
class Qwen3NextAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen3NextConfig,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = config.num_attention_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = config.num_key_value_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = config.head_dim or (self.hidden_size // self.num_heads)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.dual_chunk_attention_config = getattr(
|
||||
config, "dual_chunk_attention_config", None)
|
||||
self.attn_output_gate = getattr(config, "attn_output_gate", True)
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
config.hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads * (1 + self.attn_output_gate),
|
||||
self.total_num_kv_heads,
|
||||
bias=getattr(config, "qkv_bias", False),
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
head_size=self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=config.max_position_embeddings,
|
||||
base=config.rope_theta,
|
||||
rope_scaling=config.rope_scaling,
|
||||
partial_rotary_factor=config.partial_rotary_factor,
|
||||
dual_chunk_attention_config=self.dual_chunk_attention_config,
|
||||
)
|
||||
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
**{
|
||||
"layer_idx": extract_layer_index(prefix),
|
||||
"dual_chunk_attention_config":
|
||||
self.dual_chunk_attention_config,
|
||||
} if self.dual_chunk_attention_config else {},
|
||||
)
|
||||
|
||||
self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||
self.k_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
):
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
|
||||
if self.attn_output_gate:
|
||||
q_gate, k, v = qkv.split(
|
||||
[self.q_size * 2, self.kv_size, self.kv_size], dim=-1)
|
||||
orig_shape = q_gate.shape[:-1]
|
||||
q_gate = q_gate.view(*orig_shape, self.num_heads, -1)
|
||||
q, gate = torch.chunk(q_gate, 2, dim=-1)
|
||||
q = q.reshape(*orig_shape, -1)
|
||||
gate = gate.reshape(*orig_shape, -1)
|
||||
else:
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
|
||||
dim=-1)
|
||||
|
||||
q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)).view(
|
||||
-1, self.num_heads * self.head_dim)
|
||||
k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)).view(
|
||||
-1, self.num_kv_heads * self.head_dim)
|
||||
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
|
||||
attn_output = self.attn(q, k, v)
|
||||
|
||||
if self.attn_output_gate:
|
||||
gate = torch.sigmoid(gate)
|
||||
attn_output = attn_output * gate
|
||||
|
||||
output[:], _ = self.o_proj(attn_output)
|
||||
|
||||
|
||||
class Qwen3NextDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@@ -1325,7 +1098,7 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
return self.model.get_expert_mapping()
|
||||
|
||||
|
||||
def gdn_attention(
|
||||
def npu_gdn_attention(
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
@@ -1335,7 +1108,7 @@ def gdn_attention(
|
||||
self._forward(hidden_states=hidden_states, output=output)
|
||||
|
||||
|
||||
def gdn_attention_fake(
|
||||
def npu_gdn_attention_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
@@ -1344,9 +1117,9 @@ def gdn_attention_fake(
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="gdn_attention",
|
||||
op_func=gdn_attention,
|
||||
op_name="npu_gdn_attention",
|
||||
op_func=npu_gdn_attention,
|
||||
mutates_args=["output"],
|
||||
fake_impl=gdn_attention_fake,
|
||||
fake_impl=npu_gdn_attention_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user