[1/N][Refactor][Qwen3-Next] remove redundant Qwen3NextSparseMoeBlock and Qwen3NextAttention (#3019)
### What this PR does / why we need it?
remove redundant Qwen3NextSparseMoeBlock and Qwen3NextAttention
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
```
def main():
prompts = [
"The future of AI is",
]
sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95)
# Create an LLM.
llm = LLM(
# model="/root/.cache/modelscope/hub/models/Qwen/Qwen3-30B-A3B",
model="Qwen/Qwen3-Next-80B-A3B-Instruct",
tensor_parallel_size=4,
enforce_eager=True,
trust_remote_code=True,
max_model_len=256,
gpu_memory_utilization=0.7,
block_size=64,
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
- vLLM version: v0.10.2
- vLLM main:
9d1c50a5ac
---------
Signed-off-by: Icey <1790571317@qq.com>
This commit is contained in:
@@ -11,11 +11,11 @@ from einops import rearrange
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.attention import Attention, AttentionBackend, AttentionMetadata
|
from vllm.attention import AttentionBackend, AttentionMetadata
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig,
|
from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig,
|
||||||
VllmConfig, get_current_vllm_config)
|
VllmConfig, get_current_vllm_config)
|
||||||
from vllm.distributed import (divide, get_ep_group, get_pp_group,
|
from vllm.distributed import (divide, get_pp_group,
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
from vllm.forward_context import ForwardContext, get_forward_context
|
from vllm.forward_context import ForwardContext, get_forward_context
|
||||||
@@ -27,8 +27,6 @@ from vllm.model_executor.layers.layernorm import \
|
|||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
|
||||||
ReplicatedLinear,
|
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||||
@@ -37,10 +35,6 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import \
|
|||||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
|
||||||
from vllm.model_executor.layers.quantization.gptq_marlin import \
|
|
||||||
GPTQMarlinConfig
|
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import (
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
@@ -50,6 +44,8 @@ from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
|
|||||||
SupportsLoRA, SupportsPP)
|
SupportsLoRA, SupportsPP)
|
||||||
from vllm.model_executor.models.mamba_cache import MambaCacheParams
|
from vllm.model_executor.models.mamba_cache import MambaCacheParams
|
||||||
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
|
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
|
||||||
|
from vllm.model_executor.models.qwen3_next import (Qwen3NextAttention,
|
||||||
|
Qwen3NextSparseMoeBlock)
|
||||||
from vllm.model_executor.models.utils import (
|
from vllm.model_executor.models.utils import (
|
||||||
AutoWeightsLoader, PPMissingLayer, extract_layer_index,
|
AutoWeightsLoader, PPMissingLayer, extract_layer_index,
|
||||||
is_pp_missing_parameter, make_empty_intermediate_tensors_factory,
|
is_pp_missing_parameter, make_empty_intermediate_tensors_factory,
|
||||||
@@ -68,112 +64,6 @@ from vllm_ascend.ops.fla import RMSNormGated, fused_gdn_gating
|
|||||||
from vllm_ascend.ops.sigmoid_gating import fused_recurrent_gated_delta_rule
|
from vllm_ascend.ops.sigmoid_gating import fused_recurrent_gated_delta_rule
|
||||||
|
|
||||||
|
|
||||||
class Qwen3NextSparseMoeBlock(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: Qwen3NextConfig,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
|
||||||
prefix: str = "",
|
|
||||||
enable_eplb: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
|
|
||||||
self.ep_group = get_ep_group().device_group
|
|
||||||
self.ep_rank = self.ep_group.rank()
|
|
||||||
self.ep_size = self.ep_group.size()
|
|
||||||
self.n_routed_experts = config.num_experts
|
|
||||||
|
|
||||||
if self.tp_size > config.num_experts:
|
|
||||||
raise ValueError(
|
|
||||||
f"Tensor parallel size {self.tp_size} is greater than "
|
|
||||||
f"the number of experts {config.num_experts}.")
|
|
||||||
|
|
||||||
# Load balancing settings.
|
|
||||||
vllm_config = get_current_vllm_config()
|
|
||||||
eplb_config = vllm_config.parallel_config.eplb_config
|
|
||||||
self.enable_eplb = enable_eplb
|
|
||||||
|
|
||||||
self.n_logical_experts = self.n_routed_experts
|
|
||||||
self.n_redundant_experts = eplb_config.num_redundant_experts
|
|
||||||
self.n_physical_experts = (self.n_logical_experts +
|
|
||||||
self.n_redundant_experts)
|
|
||||||
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
|
|
||||||
|
|
||||||
self.physical_expert_start = (self.ep_rank *
|
|
||||||
self.n_local_physical_experts)
|
|
||||||
self.physical_expert_end = (self.physical_expert_start +
|
|
||||||
self.n_local_physical_experts)
|
|
||||||
|
|
||||||
self.experts = FusedMoE(num_experts=self.n_routed_experts,
|
|
||||||
top_k=config.num_experts_per_tok,
|
|
||||||
hidden_size=config.hidden_size,
|
|
||||||
intermediate_size=config.moe_intermediate_size,
|
|
||||||
reduce_results=False,
|
|
||||||
renormalize=config.norm_topk_prob,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=f"{prefix}.experts",
|
|
||||||
enable_eplb=self.enable_eplb,
|
|
||||||
num_redundant_experts=self.n_redundant_experts)
|
|
||||||
|
|
||||||
self.gate = ReplicatedLinear(
|
|
||||||
config.hidden_size,
|
|
||||||
config.num_experts,
|
|
||||||
bias=False,
|
|
||||||
quant_config=self._maybe_ignore_quant_config(quant_config),
|
|
||||||
prefix=f"{prefix}.gate")
|
|
||||||
|
|
||||||
if config.shared_expert_intermediate_size > 0:
|
|
||||||
self.shared_expert = Qwen3NextMLP(
|
|
||||||
hidden_size=config.hidden_size,
|
|
||||||
intermediate_size=config.shared_expert_intermediate_size,
|
|
||||||
hidden_act=config.hidden_act,
|
|
||||||
quant_config=quant_config,
|
|
||||||
reduce_results=self.experts.must_reduce_shared_expert_outputs(
|
|
||||||
),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.shared_expert = None
|
|
||||||
self.shared_expert_gate = torch.nn.Linear(config.hidden_size,
|
|
||||||
1,
|
|
||||||
bias=False)
|
|
||||||
|
|
||||||
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
|
|
||||||
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
|
|
||||||
# seems to avoid gate quantization.
|
|
||||||
# See: https://huggingface.co/Qwen/Qwen3-30B-A3B-GPTQ-Int4
|
|
||||||
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
|
|
||||||
return None
|
|
||||||
return quant_config
|
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
||||||
# NOTE: hidden_states can have either 1D or 2D shape.
|
|
||||||
orig_shape = hidden_states.shape
|
|
||||||
hidden_dim = hidden_states.shape[-1]
|
|
||||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
|
||||||
|
|
||||||
shared_output = None
|
|
||||||
if self.shared_expert is not None:
|
|
||||||
shared_output = self.shared_expert(hidden_states)
|
|
||||||
if self.shared_expert_gate is not None:
|
|
||||||
shared_output = F.sigmoid(
|
|
||||||
self.shared_expert_gate(hidden_states)) * shared_output
|
|
||||||
|
|
||||||
# router_logits: (num_tokens, n_experts)
|
|
||||||
router_logits, _ = self.gate(hidden_states)
|
|
||||||
final_hidden_states = self.experts(hidden_states=hidden_states,
|
|
||||||
router_logits=router_logits)
|
|
||||||
|
|
||||||
if shared_output is not None:
|
|
||||||
final_hidden_states = final_hidden_states + shared_output
|
|
||||||
if self.tp_size > 1:
|
|
||||||
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
|
|
||||||
final_hidden_states)
|
|
||||||
|
|
||||||
return final_hidden_states.view(orig_shape)
|
|
||||||
|
|
||||||
|
|
||||||
def torch_chunk_gated_delta_rule(
|
def torch_chunk_gated_delta_rule(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
@@ -473,7 +363,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
|||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
cache_params: Optional[MambaCacheParams] = None,
|
cache_params: Optional[MambaCacheParams] = None,
|
||||||
):
|
):
|
||||||
return torch.ops.vllm.gdn_attention(
|
return torch.ops.vllm.npu_gdn_attention(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
output,
|
output,
|
||||||
self.prefix,
|
self.prefix,
|
||||||
@@ -737,123 +627,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
|||||||
output[:num_actual_tokens], _ = self.out_proj(core_attn_out)
|
output[:num_actual_tokens], _ = self.out_proj(core_attn_out)
|
||||||
|
|
||||||
|
|
||||||
class Qwen3NextAttention(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: Qwen3NextConfig,
|
|
||||||
model_config: Optional[ModelConfig] = None,
|
|
||||||
cache_config: Optional[CacheConfig] = None,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
|
||||||
prefix: str = "",
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.hidden_size = config.hidden_size
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
self.total_num_heads = config.num_attention_heads
|
|
||||||
assert self.total_num_heads % tp_size == 0
|
|
||||||
self.num_heads = self.total_num_heads // tp_size
|
|
||||||
self.total_num_kv_heads = config.num_key_value_heads
|
|
||||||
if self.total_num_kv_heads >= tp_size:
|
|
||||||
# Number of KV heads is greater than TP size, so we partition
|
|
||||||
# the KV heads across multiple tensor parallel GPUs.
|
|
||||||
assert self.total_num_kv_heads % tp_size == 0
|
|
||||||
else:
|
|
||||||
# Number of KV heads is less than TP size, so we replicate
|
|
||||||
# the KV heads across multiple tensor parallel GPUs.
|
|
||||||
assert tp_size % self.total_num_kv_heads == 0
|
|
||||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
|
||||||
self.head_dim = config.head_dim or (self.hidden_size // self.num_heads)
|
|
||||||
self.q_size = self.num_heads * self.head_dim
|
|
||||||
self.kv_size = self.num_kv_heads * self.head_dim
|
|
||||||
self.scaling = self.head_dim**-0.5
|
|
||||||
self.dual_chunk_attention_config = getattr(
|
|
||||||
config, "dual_chunk_attention_config", None)
|
|
||||||
self.attn_output_gate = getattr(config, "attn_output_gate", True)
|
|
||||||
|
|
||||||
self.qkv_proj = QKVParallelLinear(
|
|
||||||
config.hidden_size,
|
|
||||||
self.head_dim,
|
|
||||||
self.total_num_heads * (1 + self.attn_output_gate),
|
|
||||||
self.total_num_kv_heads,
|
|
||||||
bias=getattr(config, "qkv_bias", False),
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=f"{prefix}.qkv_proj",
|
|
||||||
)
|
|
||||||
|
|
||||||
self.o_proj = RowParallelLinear(
|
|
||||||
self.total_num_heads * self.head_dim,
|
|
||||||
config.hidden_size,
|
|
||||||
bias=False,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=f"{prefix}.o_proj",
|
|
||||||
)
|
|
||||||
|
|
||||||
self.rotary_emb = get_rope(
|
|
||||||
head_size=self.head_dim,
|
|
||||||
rotary_dim=self.head_dim,
|
|
||||||
max_position=config.max_position_embeddings,
|
|
||||||
base=config.rope_theta,
|
|
||||||
rope_scaling=config.rope_scaling,
|
|
||||||
partial_rotary_factor=config.partial_rotary_factor,
|
|
||||||
dual_chunk_attention_config=self.dual_chunk_attention_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.attn = Attention(
|
|
||||||
self.num_heads,
|
|
||||||
self.head_dim,
|
|
||||||
self.scaling,
|
|
||||||
num_kv_heads=self.num_kv_heads,
|
|
||||||
cache_config=cache_config,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=f"{prefix}.attn",
|
|
||||||
**{
|
|
||||||
"layer_idx": extract_layer_index(prefix),
|
|
||||||
"dual_chunk_attention_config":
|
|
||||||
self.dual_chunk_attention_config,
|
|
||||||
} if self.dual_chunk_attention_config else {},
|
|
||||||
)
|
|
||||||
|
|
||||||
self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
|
||||||
self.k_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
output: torch.Tensor,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
):
|
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
|
||||||
|
|
||||||
if self.attn_output_gate:
|
|
||||||
q_gate, k, v = qkv.split(
|
|
||||||
[self.q_size * 2, self.kv_size, self.kv_size], dim=-1)
|
|
||||||
orig_shape = q_gate.shape[:-1]
|
|
||||||
q_gate = q_gate.view(*orig_shape, self.num_heads, -1)
|
|
||||||
q, gate = torch.chunk(q_gate, 2, dim=-1)
|
|
||||||
q = q.reshape(*orig_shape, -1)
|
|
||||||
gate = gate.reshape(*orig_shape, -1)
|
|
||||||
else:
|
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
|
|
||||||
dim=-1)
|
|
||||||
|
|
||||||
q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)).view(
|
|
||||||
-1, self.num_heads * self.head_dim)
|
|
||||||
k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)).view(
|
|
||||||
-1, self.num_kv_heads * self.head_dim)
|
|
||||||
|
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
|
||||||
|
|
||||||
attn_output = self.attn(q, k, v)
|
|
||||||
|
|
||||||
if self.attn_output_gate:
|
|
||||||
gate = torch.sigmoid(gate)
|
|
||||||
attn_output = attn_output * gate
|
|
||||||
|
|
||||||
output[:], _ = self.o_proj(attn_output)
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen3NextDecoderLayer(nn.Module):
|
class Qwen3NextDecoderLayer(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -1325,7 +1098,7 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
return self.model.get_expert_mapping()
|
return self.model.get_expert_mapping()
|
||||||
|
|
||||||
|
|
||||||
def gdn_attention(
|
def npu_gdn_attention(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
@@ -1335,7 +1108,7 @@ def gdn_attention(
|
|||||||
self._forward(hidden_states=hidden_states, output=output)
|
self._forward(hidden_states=hidden_states, output=output)
|
||||||
|
|
||||||
|
|
||||||
def gdn_attention_fake(
|
def npu_gdn_attention_fake(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
@@ -1344,9 +1117,9 @@ def gdn_attention_fake(
|
|||||||
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="gdn_attention",
|
op_name="npu_gdn_attention",
|
||||||
op_func=gdn_attention,
|
op_func=npu_gdn_attention,
|
||||||
mutates_args=["output"],
|
mutates_args=["output"],
|
||||||
fake_impl=gdn_attention_fake,
|
fake_impl=npu_gdn_attention_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
dispatch_key=current_platform.dispatch_key,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user