Files
2026-03-10 13:31:25 +08:00

476 lines
19 KiB
Python

################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 The ZhipuAI Team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GLM-4.5 model compatible with HuggingFace weights."""
import typing
from collections.abc import Callable, Iterable
from typing import Optional, Union
import torch
import torch_br
from torch import nn
from transformers.models.glm4_moe import Glm4MoeConfig
import vllm
import vllm.model_executor.models.glm4_moe
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.distributed import (get_ep_group, get_pp_group,
get_tensor_model_parallel_world_size)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.glm4_moe import (
Glm4MoeAttention, Glm4MoeDecoderLayer, get_spec_layer_idx_from_weight_name)
from vllm.model_executor.models.utils import is_pp_missing_parameter
from vllm.sequence import IntermediateTensors
from vllm_br.v1.attention.backends.attention_v1 import (
SUPAFlashAttentionMetadata)
from .supa_module import MergedGateUpMLPSiluL2
class Glm4MoE(nn.Module):
def __init__(
self,
config: Glm4MoeConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.routed_scaling_factor = config.routed_scaling_factor
self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank()
self.ep_size = self.ep_group.size()
self.n_routed_experts: int = config.n_routed_experts
self.n_shared_experts: int = config.n_shared_experts
if config.hidden_act != "silu":
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
"Only silu is supported for now.")
self.gate = ReplicatedLinear(config.hidden_size,
config.n_routed_experts,
bias=False,
quant_config=None,
params_dtype=torch.float32,
prefix=f"{prefix}.gate")
self.gate.e_score_correction_bias = nn.Parameter(
torch.empty(config.n_routed_experts, dtype=torch.float32))
# Load balancing settings.
vllm_config = get_current_vllm_config()
eplb_config = vllm_config.parallel_config.eplb_config
self.enable_eplb = enable_eplb
self.n_redundant_experts = eplb_config.num_redundant_experts
self.n_logical_experts = self.n_routed_experts
self.n_physical_experts = (self.n_logical_experts +
self.n_redundant_experts)
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
self.physical_expert_start = (self.ep_rank *
self.n_local_physical_experts)
self.physical_expert_end = (self.physical_expert_start +
self.n_local_physical_experts)
self.experts = FusedMoE(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func="sigmoid",
# we do scaling outside, set factor to 1.0 to avoid double mul
routed_scaling_factor=1.0,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts)
if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts)
self.shared_experts = MergedGateUpMLPSiluL2(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
# reduce_results=self.experts.must_reduce_shared_expert_outputs(
# ),
prefix=f"{prefix}.shared_experts",
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
orig_shape = hidden_states.shape
assert self.n_shared_experts is not None, 'n_shared_experts must be set'
# NOTE: gate has been fused with shared_experts, no more single gate call
# and we packed router weights, shared_experts weights and down weights in a tuple
tuple_router_shared_expert_weight = (
self.gate.weight, self.shared_experts.gate_up_proj.weight,
self.shared_experts.down_proj.weight)
hidden_states = hidden_states.view(-1, orig_shape[-1])
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=tuple_router_shared_expert_weight)
if hasattr(final_hidden_states, 'all_reduced'):
# NOTE: this flag indicates that the final_hidden_states has been reduced in fused_moe
delattr(final_hidden_states, 'all_reduced')
elif self.tp_size > 1:
final_hidden_states = (
self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states))
return final_hidden_states.view(orig_shape)
vllm.model_executor.models.glm4_moe.Glm4MoE = Glm4MoE
def Glm4MoeAttention_forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
attn_metadata: SUPAFlashAttentionMetadata = forward_context.attn_metadata
if attn_metadata is None:
## for dummy run
return hidden_states
seq_len = hidden_states.shape[-2]
decode_seql = 512
if seq_len <= decode_seql:
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.attn.layer_name]
kv_cache = self.attn.kv_cache[forward_context.virtual_engine]
if kv_cache is not None:
if hasattr(self.qkv_proj, "qweight"):
qkv_weight = self.qkv_proj.qweight.data
qkv_scales = self.qkv_proj.scales.data
elif hasattr(self.qkv_proj, "weight_packed"):
qkv_weight = self.qkv_proj.weight_packed.data
qkv_scales = self.qkv_proj.weight_scale.data
else:
qkv_weight = self.qkv_proj.weight
qkv_scales = None
q, k, v = torch_br.br_qwen3_prefix_attn_infer(
hidden_states,
qkv_weight, [self.q_size, self.kv_size, self.kv_size],
self.head_dim,
self.q_norm.variance_epsilon,
self.q_norm.weight,
self.k_norm.weight,
self.rotary_emb.sin_cache,
self.rotary_emb.cos_cache,
kv_cache,
positions,
attn_metadata.slot_mapping,
rotary_dim=self.rotary_emb.rotary_dim,
bias=self.qkv_proj.bias,
scales=qkv_scales)
if hasattr(attn_metadata, 'do_cache'):
attn_metadata.do_cache = False
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
else:
return hidden_states
else:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = torch_br.br_fused_split_rms_rope_infer(
qkv, [self.q_size, self.kv_size, self.kv_size],
self.head_dim,
self.q_norm.variance_epsilon,
self.q_norm.weight,
self.k_norm.weight,
self.rotary_emb.sin_cache,
self.rotary_emb.cos_cache,
positions,
rotary_dim=self.rotary_emb.rotary_dim)
if hasattr(attn_metadata, 'do_cache'):
attn_metadata.do_cache = True
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
vllm.model_executor.models.glm4_moe.Glm4MoeAttention.forward = Glm4MoeAttention_forward
def Glm4MoeDecoderLayer__init__(
self,
config: Glm4MoeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
) -> None:
super(Glm4MoeDecoderLayer, self).__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
131072)
# DecoderLayers are created with `make_layers` which passes the prefix
# with the layer's index.
layer_idx = int(prefix.split(sep='.')[-1])
self.layer_idx = layer_idx
self.self_attn = Glm4MoeAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
head_dim=config.head_dim,
rms_norm_eps=config.rms_norm_eps,
qkv_bias=config.attention_bias,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
use_qk_norm=config.use_qk_norm,
)
if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace):
self.mlp = Glm4MoE(
config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
enable_eplb=enable_eplb,
)
else:
self.mlp = MergedGateUpMLPSiluL2(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.routed_scaling_factor = config.routed_scaling_factor
vllm.model_executor.models.glm4_moe.Glm4MoeDecoderLayer.__init__ = Glm4MoeDecoderLayer__init__
def Glm4MoeModel_forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
residual = residual.unsqueeze(0) # NOTE: SUPA wants 3D input
hidden_states = hidden_states.unsqueeze(0)
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states":
hidden_states.squeeze(0)
if hidden_states is not None else hidden_states,
"residual":
residual.squeeze(0) if residual is not None else residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states.squeeze(0)
vllm.model_executor.models.glm4_moe.Glm4MoeModel.forward = Glm4MoeModel_forward
def Glm4MoeModel_load_weights(
self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights:
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is not None:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if (("mlp.experts." in name) and name not in params_dict):
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
# weight layout infer
if name.find("norm.weight") != -1 or name.find(
"e_score_correction_bias") != -1:
param.data = param.data.to(torch.float32)
torch.supa.empty_cache()
break
else:
is_expert_weight = False
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
# Anyway, this is an expert weight and should not be
# attempted to load as other weights later
is_expert_weight = True
# Do not modify `name` since the loop may continue here
# Instead, create a new variable
name_mapped = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name_mapped, self):
continue
if name_mapped not in params_dict:
continue
param = params_dict[name_mapped]
# We should ask the weight loader to return success or not
# here since otherwise we may skip experts with other
# available replicas.
weight_loader = typing.cast(Callable[..., bool],
param.weight_loader)
success = weight_loader(param,
loaded_weight,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True)
# weight layout infer
if name.find("norm.weight") != -1 or name.find(
"e_score_correction_bias") != -1:
param.data = param.data.to(torch.float32)
torch.supa.empty_cache()
if success:
name = name_mapped
break
else:
if is_expert_weight:
# We've checked that this is an expert weight
# However it's not mapped locally to this rank
# So we simply skip it
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# weight layout infer
if name.find("norm.weight") != -1 or name.find(
"e_score_correction_bias") != -1:
param.data = param.data.to(torch.float32)
if name.find("gate.weight") != -1:
param.data = param.data.to(torch.bfloat16)
torch.supa.empty_cache()
loaded_params.add(name)
return loaded_params
vllm.model_executor.models.glm4_moe.Glm4MoeModel.load_weights = Glm4MoeModel_load_weights