Files
xc-llm-kunlun/vllm_kunlun/models/qwen3_moe.py
2025-12-10 12:05:39 +08:00

837 lines
31 KiB
Python

#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Adapted from vllm/model_executor/models/qwen3_moe.py
# Copyright 2023 The vLLM team.
#
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
import os
from collections.abc import Iterable
from typing import Any, Optional, Union, Tuple, Set
import torch
import os
from torch import nn
from transformers import PretrainedConfig
from vllm_kunlun.ops.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm_kunlun.ops.activation import SiluAndMul
from vllm_kunlun.ops.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm_kunlun.ops.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm_kunlun.ops.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.model_executor.models.interfaces import SupportsPP
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
make_layers,
maybe_prefix,
)
from vllm_kunlun.ops.rotary_embedding import Split_Norm_Rope
logger = init_logger(__name__)
class Qwen3MoeMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now."
)
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class Qwen3MoeSparseMoeBlock(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
if self.tp_size > config.num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_experts}."
)
self.experts = FusedMoE(
num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
)
self.quant_config = quant_config
self.gate = ReplicatedLinear(
config.hidden_size,
config.num_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.gate",
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_dim = hidden_states.shape[-1]
hidden_states = hidden_states.view(-1, hidden_dim)
if self.quant_config is None:
kunlun_linear_weights = self.gate.get_weights()
final_hidden_states = self.experts(
hidden_states=hidden_states, linear_weights=kunlun_linear_weights
)
else:
kunlun_linear_weights = self.gate.get_weights()
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
linear_weights=kunlun_linear_weights,
)
if self.tp_size > 1:
final_hidden_states = (
self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
final_hidden_states
)
)
return final_hidden_states.view(orig_shape)
class Qwen3MoeAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192,
head_dim: Optional[int] = None,
rms_norm_eps: float = 1e-06,
qkv_bias: bool = False,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = head_dim or (hidden_size // self.total_num_heads)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
if rope_scaling is not None:
scaling_factor = rope_scaling["factor"]
self.max_position_embeddings = int(
self.max_position_embeddings * scaling_factor
)
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
if os.getenv("FUSED_QK_ROPE_OP") == "1":
# Rope fusion operators
q, k, v = Split_Norm_Rope(
qkv,
self.rotary_emb.cos_sin_cache,
self.q_norm.weight,
self.k_norm.weight,
positions,
self.max_position_embeddings,
self.num_heads,
self.num_kv_heads,
self.head_dim,
)
else:
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Add qk-norm
q_by_head = q.view(
*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim
)
q_by_head = self.q_norm(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(
*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim
)
k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class Qwen3MoeDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.self_attn = Qwen3MoeAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
rms_norm_eps=config.rms_norm_eps,
qkv_bias=getattr(config, "attention_bias", False),
head_dim=getattr(config, "head_dim", None),
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
# `mlp_only_layers` in the config.
layer_idx = extract_layer_index(prefix)
mlp_only_layers = (
[] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
)
if (layer_idx not in mlp_only_layers) and (
config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
):
self.mlp = Qwen3MoeSparseMoeBlock(
config=config, quant_config=quant_config, prefix=f"{prefix}.mlp"
)
else:
self.mlp = Qwen3MoeMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile
class Qwen3MoeModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.config = config
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.hidden_size, prefix=f"{prefix}.embed_tokens"
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Qwen3MoeDecoderLayer(
config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual}
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts,
)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
weights_to_quantize = {}
for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if (
name.endswith(".bias") or name.endswith("_bias")
) and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(name)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
# Map to the parameter name in the model
name_mapped = name.replace(weight_name, param_name)
# Layer/PP skip judgment
if is_pp_missing_parameter(name_mapped, self):
continue
if (
name_mapped.endswith(".bias") or name_mapped.endswith("_bias")
) and name_mapped not in params_dict:
continue
# Get the param and target module
param = params_dict.get(name_mapped, None)
if param is None:
continue
# === Only when the target MoE layer has int8 weights and scales, and the name matches, the "streaming quantization" is performed ===
if self._should_stream_quantize(name_mapped):
# Note: Pass the mapped name_mapped instead of the original name
self._stream_quantize_moe_weight(
name_mapped,
param,
loaded_weight,
expert_id=expert_id,
shard_id=shard_id,
)
loaded_params.add(name_mapped)
else:
# Fallback: Normal weight loading (non-quantized)
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(
param,
loaded_weight,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
)
loaded_params.add(name_mapped)
break
else:
# Skip loading extra bias for GPTQ models.
if (
name.endswith(".bias") or name.endswith("_bias")
) and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"):
remapped_kv_scale_name = name.replace(
".kv_scale", ".attn.kv_scale"
)
if remapped_kv_scale_name not in params_dict:
logger.warning_once(
"Found kv scale in the checkpoint "
f"(e.g. {name}), but not found the expected "
f"name in the model "
f"(e.g. {remapped_kv_scale_name}). "
"kv-scale is not loaded."
)
continue
else:
name = remapped_kv_scale_name
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
# loaded_params.add(name)
return loaded_params
def _is_moe_weight(self, name: str) -> bool:
"""Check if the weight is MoE weight"""
return name.endswith("w13_weight") or name.endswith("w2_weight")
def _is_expert_complete(self, cache_key):
cache = self._moe_weight_cache.get(cache_key)
if cache is None:
return False
w13_ok = (0 in cache["w13_shards"]) and (1 in cache["w13_shards"])
w2_ok = cache["w2_weight"] is not None
return w13_ok and w2_ok
@torch.no_grad()
def _stream_quantize_moe_weight(
self,
param_name: str,
param: nn.Parameter,
loaded_weight: torch.Tensor,
*,
expert_id,
shard_id,
):
rank = os.environ.get("RANK", "0")
# Ensure expert_id is an integer
try:
expert_id = int(expert_id)
except (ValueError, TypeError):
if isinstance(expert_id, str):
expert_id = int(expert_id)
# Process shard_id
if isinstance(shard_id, str):
if shard_id in ("gate", "w1"):
shard_id = 0
elif shard_id in ("up", "w3"):
shard_id = 1
elif shard_id == "w2":
shard_id = 0
else:
try:
shard_id = int(shard_id)
except ValueError:
shard_id = 0
else:
shard_id = int(shard_id)
# Initialize cache
if not hasattr(self, "_moe_weight_cache"):
self._moe_weight_cache = {}
self._expert_batch_count = 0 # Batch counter
module_path = ".".join(param_name.split(".")[:-1])
cache_key = (module_path, expert_id)
cache = self._moe_weight_cache.get(cache_key)
if cache is None:
cache = {
"w13_shards": {},
"w2_weight": None,
"target_module": self.get_submodule(module_path),
"done": False,
}
self._moe_weight_cache[cache_key] = cache
if cache.get("done", False):
return
# Cache weights (keep original precision)
if "w13_weight" in param_name:
cache["w13_shards"][shard_id] = loaded_weight.clone()
elif "w2_weight" in param_name:
cache["w2_weight"] = loaded_weight.clone()
# Check if complete
if self._is_expert_complete(cache_key):
# Quantize this expert
self._quantize_expert_weights(cache_key)
cache["done"] = True
self._moe_weight_cache.pop(cache_key, None)
# Force synchronization every 4 experts
self._expert_batch_count += 1
if self._expert_batch_count % 4 == 0:
torch.cuda.synchronize() # Force synchronization
# print(f"[Rank {rank}] Completed batch of {self._expert_batch_count} experts")
def _quantize_expert_weights(self, cache_key):
"""Quantize the complete weights of an expert (supports TP sharding)"""
module_path, expert_id = cache_key
cache = self._moe_weight_cache[cache_key]
target_module = cache["target_module"]
# Get TP config
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
# Get actual shapes
E, twoN, H = target_module.w13_weight.shape
_, H2, N = target_module.w2_weight.shape
qmax = 127.0
# Process w13_weight: concatenate gate and up
gate_weight = cache["w13_shards"][0] # [768, 2048]
up_weight = cache["w13_shards"][1] # [768, 2048]
# TP sharding
if tp_size > 1:
# Calculate shard for each TP rank
gate_per_rank = gate_weight.shape[0] // tp_size
up_per_rank = up_weight.shape[0] // tp_size
gate_start = tp_rank * gate_per_rank
gate_end = (tp_rank + 1) * gate_per_rank
up_start = tp_rank * up_per_rank
up_end = (tp_rank + 1) * up_per_rank
gate_weight = gate_weight[gate_start:gate_end, :] # [192, 2048]
up_weight = up_weight[up_start:up_end, :] # [192, 2048]
w13_complete = torch.cat([gate_weight, up_weight], dim=0) # [384, 2048]
# Quantize w13_weight
w13_f = w13_complete.float()
w13_abs_max = torch.amax(torch.abs(w13_f), dim=-1) # [384]
w13_scale_2d = torch.clamp(w13_abs_max, min=1e-6) / qmax # [384]
w13_scale_3d = w13_scale_2d.unsqueeze(-1) # [384, 1]
w13_q = torch.round(w13_f / w13_scale_3d).clamp_(-128, 127).to(torch.int8)
# Write w13_weight
target_module.w13_weight.data[expert_id, :, :].copy_(
w13_q.to(target_module.w13_weight.device)
)
# Update w13_scale - pre-multiply 127
s = getattr(target_module, "w13_weight_scale")
s.data[expert_id, :].copy_((w13_scale_2d * 127.0).to(s.device))
# Process w2_weight
w2_weight = cache["w2_weight"] # [2048, 768]
# TP sharding for w2 weight
if tp_size > 1:
w2_per_rank = w2_weight.shape[1] // tp_size
w2_start = tp_rank * w2_per_rank
w2_end = (tp_rank + 1) * w2_per_rank
w2_weight = w2_weight[:, w2_start:w2_end] # [2048, 192]
w2_f = w2_weight.float() # [2048, 192]
w2_abs_max = torch.amax(torch.abs(w2_f), dim=-1) # [2048]
w2_scale_2d = torch.clamp(w2_abs_max, min=1e-6) / qmax # [2048]
w2_scale_3d = w2_scale_2d.unsqueeze(-1) # [2048, 1]
w2_q = torch.round(w2_f / w2_scale_3d).clamp_(-128, 127).to(torch.int8)
# Write w2_weight
w2_param = getattr(target_module, "w2_weight")
w2_param.data[expert_id, :, :].copy_(w2_q.to(w2_param.device))
# Update w2_scale - pre-multiply 127
w2_s = getattr(target_module, "w2_weight_scale")
w2_s.data[expert_id, :].copy_((w2_scale_2d * 127.0).to(w2_s.device))
# Clear cache
cache["w13_shards"].clear()
cache["w2_weight"] = None
def _is_int8_moe_target_module(self, module_path: str) -> bool:
"""Check if a module_path is a FusedMoE target using INT8(W8A8).
Determine by the actual existing parameters and dtype, not relying on quant_config names.
"""
try:
mod = self.get_submodule(module_path)
except Exception:
return False
# Need to have both int8 weights and float32 scales, and dimensions come from CompressedTensorsW8A8 path
if not (
hasattr(mod, "w13_weight")
and hasattr(mod, "w2_weight")
and hasattr(mod, "w13_weight_scale")
and hasattr(mod, "w2_weight_scale")
):
return False
try:
return (
mod.w13_weight.dtype == torch.int8
and mod.w2_weight.dtype == torch.int8
and mod.w13_weight_scale.dtype == torch.float32
and mod.w2_weight_scale.dtype == torch.float32
)
except Exception:
return False
def _should_stream_quantize(self, param_name: str) -> bool:
"""Only when (1) the parameter name corresponds to the MoE weights we defined; and
(2) the MoE layer is indeed the INT8 path (exists int8 weights + scales)
Stream quantization is enabled; otherwise, it falls back to the default loading.
"""
# First, determine if it is the MoE weight name we want to process (w13_weight / w2_weight)
if not self._is_moe_weight(param_name):
return False
# Then, check if the module containing this param is the INT8 path
module_path = ".".join(param_name.split(".")[:-1])
return self._is_int8_moe_target_module(module_path)
class Qwen3MoeForCausalLM(nn.Module, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
fall_back_to_pt_during_load = False
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = Qwen3MoeModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
kv_caches: list[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)