diff --git a/benchmark/kernels/fused_moe_triton/benchmark_sglang_fused_moe_triton.py b/benchmark/kernels/fused_moe_triton/benchmark_sglang_fused_moe_triton.py index c392f8e77..dd8504fd9 100644 --- a/benchmark/kernels/fused_moe_triton/benchmark_sglang_fused_moe_triton.py +++ b/benchmark/kernels/fused_moe_triton/benchmark_sglang_fused_moe_triton.py @@ -33,7 +33,11 @@ def get_model_config(model_name: str, tp_size: int): topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size shard_intermediate_size = 2 * intermediate_size // tp_size - elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]: + elif config.architectures[0] in [ + "DeepseekV2ForCausalLM", + "DeepseekV3ForCausalLM", + "Glm4MoeForCausalLM", + ]: E = ( config.n_routed_experts + 1 if config.architectures[0] in ["DeepseekV3ForCausalLM"] diff --git a/benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py b/benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py index 390d33f56..6afd7f354 100644 --- a/benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py +++ b/benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py @@ -42,7 +42,11 @@ def get_model_config(model_name: str, tp_size: int): topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size shard_intermediate_size = 2 * intermediate_size // tp_size - elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]: + elif config.architectures[0] in [ + "DeepseekV2ForCausalLM", + "DeepseekV3ForCausalLM", + "Glm4MoeForCausalLM", + ]: E = ( config.n_routed_experts + 1 if config.architectures[0] in ["DeepseekV3ForCausalLM"] diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index c2d1d1415..f31970622 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -127,6 +127,9 @@ class ModelConfig: ): self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN" + if is_draft_model and self.hf_config.architectures[0] == "Glm4MoeForCausalLM": + self.hf_config.architectures[0] = "Glm4MoeForCausalLMNextN" + if is_draft_model and self.hf_config.architectures[0] == "MiMoForCausalLM": self.hf_config.architectures[0] = "MiMoMTP" # Check model type diff --git a/python/sglang/srt/function_call/ebnf_composer.py b/python/sglang/srt/function_call/ebnf_composer.py index 85d6039bb..1db7da6d8 100644 --- a/python/sglang/srt/function_call/ebnf_composer.py +++ b/python/sglang/srt/function_call/ebnf_composer.py @@ -165,6 +165,7 @@ class EBNFComposer: tool_call_separator: Optional[str] = None, call_rule_fmt: Optional[str] = None, key_value_rule_fmt: Optional[str] = None, + key_value_separator: str = ",", ): """ Generalized EBNF builder for all detectors. @@ -279,7 +280,11 @@ class EBNFComposer: # Add required properties joined by commas if required: - rule_parts.append(' "," '.join(prop_kv_pairs[k] for k in required)) + rule_parts.append( + f' "{key_value_separator}" '.join( + prop_kv_pairs[k] for k in required + ) + ) # Add optional properties with flexible ordering if optional: @@ -292,13 +297,15 @@ class EBNFComposer: if j == i: opt_parts.append(prop_kv_pairs[optional[j]]) else: - opt_parts.append(f' ( "," {prop_kv_pairs[optional[j]]} )?') + opt_parts.append( + f' ( "{key_value_separator}" {prop_kv_pairs[optional[j]]} )?' + ) opt_alternatives.append("".join(opt_parts)) # Wrap with appropriate comma handling based on whether we have required properties if required: # Required properties exist, so optional group needs outer comma - rule_parts.append(' ( "," ( ') + rule_parts.append(f' ( "{key_value_separator}" ( ') rule_parts.append(" | ".join(opt_alternatives)) rule_parts.append(" ) )?") else: diff --git a/python/sglang/srt/function_call/function_call_parser.py b/python/sglang/srt/function_call/function_call_parser.py index fde00f303..bf6a3d959 100644 --- a/python/sglang/srt/function_call/function_call_parser.py +++ b/python/sglang/srt/function_call/function_call_parser.py @@ -10,6 +10,7 @@ from sglang.srt.entrypoints.openai.protocol import ( from sglang.srt.function_call.base_format_detector import BaseFormatDetector from sglang.srt.function_call.core_types import ToolCallItem from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector +from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector from sglang.srt.function_call.kimik2_detector import KimiK2Detector from sglang.srt.function_call.llama32_detector import Llama32Detector from sglang.srt.function_call.mistral_detector import MistralDetector @@ -37,6 +38,7 @@ class FunctionCallParser: "pythonic": PythonicDetector, "kimi_k2": KimiK2Detector, "qwen3_coder": Qwen3CoderDetector, + "glm45": Glm4MoeDetector, } def __init__(self, tools: List[Tool], tool_call_parser: str): diff --git a/python/sglang/srt/function_call/glm4_moe_detector.py b/python/sglang/srt/function_call/glm4_moe_detector.py new file mode 100644 index 000000000..ace32d938 --- /dev/null +++ b/python/sglang/srt/function_call/glm4_moe_detector.py @@ -0,0 +1,165 @@ +import ast +import json +import logging +import re +from typing import List + +from sglang.srt.entrypoints.openai.protocol import Tool +from sglang.srt.function_call.base_format_detector import BaseFormatDetector +from sglang.srt.function_call.core_types import ( + StreamingParseResult, + StructureInfo, + _GetInfoFunc, +) +from sglang.srt.function_call.ebnf_composer import EBNFComposer + +logger = logging.getLogger(__name__) + + +def get_argument_type(func_name: str, arg_key: str, defined_tools: list): + name2tool = {tool.function.name: tool for tool in defined_tools} + if func_name not in name2tool: + return None + tool = name2tool[func_name] + if arg_key not in tool.function.parameters["properties"]: + return None + return tool.function.parameters["properties"][arg_key].get("type", None) + + +def parse_arguments(json_value): + try: + try: + parsed_value = json.loads(json_value) + except: + parsed_value = ast.literal_eval(json_value) + return parsed_value, True + except: + return json_value, False + + +class Glm4MoeDetector(BaseFormatDetector): + """ + Detector for GLM-4.5 models. + Assumes function call format: + get_weather\ncity\n北京\ndate\n2024-06-27\n\nget_weather\ncity\n上海\ndate\n2024-06-27\n + """ + + def __init__(self): + super().__init__() + self.bot_token = "" + self.eot_token = "" + self.func_call_regex = r".*?" + self.func_detail_regex = r"([^\n]*)\n(.*)" + self.func_arg_regex = r"(.*?)\s*(.*?)" + + def has_tool_call(self, text: str) -> bool: + """Check if the text contains a glm-4.5 format tool call.""" + return self.bot_token in text + + def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: + """ + One-time parsing: Detects and parses tool calls in the provided text. + + :param text: The complete text to parse. + :param tools: List of available tools. + :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls. + """ + idx = text.find(self.bot_token) + normal_text = text[:idx].strip() if idx != -1 else text + if self.bot_token not in text: + return StreamingParseResult(normal_text=normal_text, calls=[]) + match_result_list = re.findall(self.func_call_regex, text, re.DOTALL) + calls = [] + try: + for match_result in match_result_list: + # Get function name + func_detail = re.search(self.func_detail_regex, match_result, re.DOTALL) + func_name = func_detail.group(1) + func_args = func_detail.group(2) + pairs = re.findall( + r"(.*?)\s*(.*?)", + func_args, + re.DOTALL, + ) + arguments = {} + for arg_key, arg_value in pairs: + arg_key = arg_key.strip() + arg_value = arg_value.strip() + arg_type = get_argument_type(func_name, arg_key, tools) + if arg_type != "string": + arg_value, is_good_json = parse_arguments(arg_value) + arguments[arg_key] = arg_value + # construct match_result for parse_base_json + match_result = {"name": func_name, "parameters": arguments} + calls.extend(self.parse_base_json(match_result, tools)) + return StreamingParseResult(normal_text=normal_text, calls=calls) + except Exception as e: + logger.error(f"Error in detect_and_parse: {e}") + # return the normal text if parsing fails + return StreamingParseResult(normal_text=text) + + def parse_streaming_increment( + self, new_text: str, tools: List[Tool] + ) -> StreamingParseResult: + """ + Streaming incremental parsing tool calls for GLM-4.5 format. + """ + self._buffer += new_text + current_text = self._buffer + + start = current_text.find(self.bot_token) + if start == -1: + self._buffer = "" + if self.current_tool_id > 0: + current_text = "" + return StreamingParseResult(normal_text=current_text) + # find ensures we find the first self.eot_token so there will be at most one tool_call in current_text[:end+len(self.eot_token) + end = current_text.find(self.eot_token) + if end != -1: + # Initialize state if this is the first tool call + if self.current_tool_id == -1: + self.current_tool_id = 0 + self.prev_tool_call_arr = [] + self.streamed_args_for_tool = [""] + # Ensure we have enough entries in our tracking arrays + while len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + while len(self.streamed_args_for_tool) <= self.current_tool_id: + self.streamed_args_for_tool.append("") + result = self.detect_and_parse( + current_text[: end + len(self.eot_token)], tools=tools + ) + if result.calls: + self.prev_tool_call_arr[self.current_tool_id] = { + "name": result.calls[0].name, + "arguments": json.loads(result.calls[0].parameters), + } + self.streamed_args_for_tool[self.current_tool_id] = result.calls[ + 0 + ].parameters + result.calls[0].tool_index = self.current_tool_id + self.current_tool_id += 1 + self._buffer = current_text[end + len(self.eot_token) :] + return result + normal_text = current_text[:start] + self._buffer = current_text[start:] + return StreamingParseResult(normal_text=normal_text) + + def supports_structural_tag(self) -> bool: + return False + + def structure_info(self) -> _GetInfoFunc: + raise NotImplementedError() + + def build_ebnf(self, tools: List[Tool]): + return EBNFComposer.build_ebnf( + tools, + individual_call_start_token=self.bot_token, + individual_call_end_token=self.eot_token, + # GLM4Moe is not compatible with multiple tool_calls under tool_choice condition: it will output unlimited tool_calls... + # tool_call_separator="\\n", + function_format="xml", + call_rule_fmt='"{name}" "\\n" {arguments_rule} "\\n"', + key_value_rule_fmt='"{key}" "\\n" "" {valrule} ""', + key_value_separator="\\n", + ) diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py new file mode 100644 index 000000000..9716557f4 --- /dev/null +++ b/python/sglang/srt/models/glm4_moe.py @@ -0,0 +1,1034 @@ +# Copyright 2025-2026 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Inference-only GLM-4.5 model compatible with HuggingFace weights""" + +import logging +from typing import Any, Dict, Iterable, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import PretrainedConfig + +from sglang.srt.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + parallel_state, + tensor_model_parallel_all_reduce, +) +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.amx_utils import PackWeightMethod +from sglang.srt.layers.communicator import ( + LayerCommunicator, + LayerScatterModes, + enable_moe_dense_fully_dp, +) +from sglang.srt.layers.dp_attention import ( + get_attention_tp_rank, + get_attention_tp_size, + get_local_attention_dp_size, +) +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.ep_moe.layer import ( + DeepEPMoE, + get_moe_impl_class, + use_flashinfer_trtllm_moe, +) +from sglang.srt.layers.moe.topk import TopK +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.quantization.fp8_kernel import ( + is_fp8_fnuz, + per_tensor_quant_mla_fp8, + per_token_group_quant_mla_deep_gemm_masked_fp8, +) +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.deepseek_v2 import ( + DeepseekV2DecoderLayer, + DeepseekV2ForCausalLM, + DeepseekV2Model, + DeepseekV2MoE, +) +from sglang.srt.two_batch_overlap import ( + MaybeTboDeepEPDispatcher, + model_forward_maybe_tbo, +) +from sglang.srt.utils import ( + BumpAllocator, + DeepEPMode, + LazyValue, + add_prefix, + bind_or_assign, + cpu_has_amx_support, + get_bool_env_var, + get_device_sm, + get_int_env_var, + is_cpu, + is_cuda, + is_flashinfer_available, + is_hip, + is_non_idle_and_non_empty, + log_info_on_rank0, + use_intel_amx_backend, +) + +_is_hip = is_hip() +_is_cuda = is_cuda() +_is_fp8_fnuz = is_fp8_fnuz() +_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip +_is_cpu_amx_available = cpu_has_amx_support() +_is_cpu = is_cpu() +_device_sm = get_device_sm() + +if _is_cuda: + from sgl_kernel import dsv3_router_gemm +elif _is_cpu and _is_cpu_amx_available: + pass + +logger = logging.getLogger(__name__) + + +class Glm4MoeMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", + tp_rank: Optional[int] = None, + tp_size: Optional[int] = None, + ) -> None: + super().__init__() + self.tp_size = tp_size + + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=add_prefix("gate_up_proj", prefix), + tp_rank=tp_rank, + tp_size=tp_size, + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=add_prefix("down_proj", prefix), + tp_rank=tp_rank, + tp_size=tp_size, + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + + def forward(self, x, forward_batch=None, can_fuse_mlp_allreduce=False): + if (self.tp_size == 1) and x.shape[0] == 0: + return x + + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x, can_fuse_mlp_allreduce=can_fuse_mlp_allreduce) + return x + + +class Glm4MoeAttention(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + layer_id: int = 0, + rope_theta: float = 10000, + partial_rotary_factor: float = 0.5, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + head_dim: Optional[int] = None, + rms_norm_eps: float = 1e-05, + attention_bias: bool = True, + quant_config: Optional[QuantizationConfig] = None, + use_qk_norm: bool = False, + prefix: str = "", + alt_stream: Optional[torch.cuda.Stream] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + + attn_tp_rank = get_attention_tp_rank() + attn_tp_size = get_attention_tp_size() + + self.total_num_heads = num_heads + assert self.total_num_heads % attn_tp_size == 0 + self.num_heads = self.total_num_heads // attn_tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= attn_tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % attn_tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert attn_tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size) + self.head_dim = head_dim or hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.use_qk_norm = use_qk_norm + self.max_position_embeddings = max_position_embeddings + self.tp_rank = get_tensor_model_parallel_rank() + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=attention_bias, + quant_config=quant_config, + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, + prefix=add_prefix("qkv_proj", prefix), + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, + reduce_results=False, + prefix=add_prefix("o_proj", prefix), + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + partial_rotary_factor=partial_rotary_factor, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + prefix=add_prefix("attn", prefix), + ) + + if self.use_qk_norm: + self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + self.alt_stream = alt_stream + + def _apply_qk_norm( + self, q: torch.Tensor, k: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + # overlap qk norm + if self.alt_stream is not None and get_is_capture_mode(): + current_stream = torch.cuda.current_stream() + self.alt_stream.wait_stream(current_stream) + q_by_head = q.reshape(-1, self.head_dim) + q_by_head = self.q_norm(q_by_head) + with torch.cuda.stream(self.alt_stream): + k_by_head = k.reshape(-1, self.head_dim) + k_by_head = self.k_norm(k_by_head) + current_stream.wait_stream(self.alt_stream) + else: + q_by_head = q.reshape(-1, self.head_dim) + q_by_head = self.q_norm(q_by_head) + k_by_head = k.reshape(-1, self.head_dim) + k_by_head = self.k_norm(k_by_head) + q = q_by_head.view(q.shape) + k = k_by_head.view(k.shape) + return q, k + + def op_prepare(self, state): + state.attn_intermediate_state = self.forward_prepare( + positions=state.positions, + hidden_states=state.pop("hidden_states_after_comm_pre_attn"), + forward_batch=state.forward_batch, + ) + + def op_core(self, state): + state.hidden_states_after_attn = self.forward_core( + state.pop("attn_intermediate_state") + ) + + def forward_prepare( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ): + if hidden_states.shape[0] == 0: + return hidden_states, forward_batch, None + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + if self.use_qk_norm: + q, k = self._apply_qk_norm(q, k) + q, k = self.rotary_emb(positions, q, k) + inner_state = q, k, v, forward_batch + return None, forward_batch, inner_state + + def forward_core(self, intermediate_state): + hidden_states, forward_batch, inner_state = intermediate_state + if inner_state is None: + return hidden_states + attn_output = self.attn(*inner_state) + output, _ = self.o_proj(attn_output) + return output + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + s = self.forward_prepare( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + return self.forward_core(s) + + +class Glm4MoeGate(nn.Module): + def __init__( + self, + config, + prefix: str = "", + is_nextn: bool = False, + ): + super().__init__() + self.is_nextn = is_nextn + self.weight = nn.Parameter( + torch.empty((config.n_routed_experts, config.hidden_size)) + ) + self.e_score_correction_bias = nn.Parameter( + torch.empty((config.n_routed_experts)) + ) + if _is_cpu and _is_cpu_amx_available: + self.quant_method = PackWeightMethod(weight_names=["weight"]) + + def forward(self, hidden_states): + if use_intel_amx_backend(self): + return torch.ops.sgl_kernel.weight_packed_linear( + hidden_states, + self.weight, + None, # bias + True, # is_vnni + ) + + # NOTE: For some unknown reason, router_gemm seems degrade accept length. + if ( + _is_cuda + and not self.is_nextn + and hidden_states.shape[0] < 4 + and hidden_states.shape[1] == 7168 + and self.weight.shape[0] == 256 + and _device_sm >= 90 + ): + logits = dsv3_router_gemm(hidden_states, self.weight).to( + hidden_states.dtype + ) + else: + logits = F.linear(hidden_states, self.weight, None) + + return logits + + +class Glm4MoeSparseMoeBlock(DeepseekV2MoE): + def __init__( + self, + config: PretrainedConfig, + layer_id: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + alt_stream: Optional[torch.cuda.Stream] = None, + is_nextn: bool = False, + ): + nn.Module.__init__(self) + self.tp_size = get_tensor_model_parallel_world_size() + self.routed_scaling_factor = config.routed_scaling_factor + self.n_shared_experts = config.n_shared_experts + self.num_fused_shared_experts = ( + 0 + if global_server_args_dict["disable_shared_experts_fusion"] + else config.n_shared_experts + ) + self.config = config + self.layer_id = layer_id + self.alt_stream = alt_stream + + if self.tp_size > config.n_routed_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.n_routed_experts}." + ) + + if config.hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now." + ) + + self.gate = Glm4MoeGate( + config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn + ) + + self.topk = ( + TopK( + top_k=config.num_experts_per_tok + self.num_fused_shared_experts, + renormalize=config.norm_topk_prob, + use_grouped_topk=True, + num_expert_group=config.n_group, + num_fused_shared_experts=self.num_fused_shared_experts, + topk_group=config.topk_group, + correction_bias=self.gate.e_score_correction_bias, + routed_scaling_factor=self.routed_scaling_factor, + ) + if not use_flashinfer_trtllm_moe + else None + ) + + self.experts = get_moe_impl_class()( + num_experts=config.n_routed_experts + + self.num_fused_shared_experts + + global_server_args_dict["ep_num_redundant_experts"], + top_k=config.num_experts_per_tok + self.num_fused_shared_experts, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + layer_id=self.layer_id, + quant_config=quant_config, + routed_scaling_factor=self.routed_scaling_factor, + prefix=add_prefix("experts", prefix), + **( + dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]]) + if global_server_args_dict["enable_deepep_moe"] + else {} + ), + # Additional args for FusedMoE + **( + dict( + enable_flashinfer_cutlass_moe=True, + enable_ep_moe=global_server_args_dict["enable_ep_moe"], + ) + if global_server_args_dict["enable_flashinfer_cutlass_moe"] + else {} + ), + **( + dict( + renormalize=config.norm_topk_prob, + use_grouped_topk=True, + num_expert_group=config.n_group, + num_fused_shared_experts=self.num_fused_shared_experts, + topk_group=config.topk_group, + correction_bias=self.gate.e_score_correction_bias, + ) + if use_flashinfer_trtllm_moe + else {} + ), + ) + + self.shared_experts_is_int8 = False + self.shared_experts_is_fp8 = False + # self.shared_experts_weight_block_size = None + if config.n_shared_experts is not None and self.num_fused_shared_experts == 0: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = Glm4MoeMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + prefix=add_prefix("shared_experts", prefix), + **( + dict(tp_rank=0, tp_size=1) + if global_server_args_dict["enable_deepep_moe"] + else {} + ), + ) + is_packed_weight = hasattr( + self.shared_experts.gate_up_proj.quant_method, "quant_config" + ) + self.shared_experts_is_int8 = ( + not is_packed_weight + and self.shared_experts.gate_up_proj.weight.dtype == torch.int8 + ) + self.shared_experts_is_fp8 = ( + not is_packed_weight + and self.shared_experts.gate_up_proj.weight.dtype == torch.float8_e4m3fn + ) + + self.top_k = config.num_experts_per_tok + + if global_server_args_dict["enable_deepep_moe"]: + # TODO: we will support tp < ep in the future + self.ep_size = get_tensor_model_parallel_world_size() + self.num_experts = ( + config.n_routed_experts + + global_server_args_dict["ep_num_redundant_experts"] + ) + self.renormalize = config.norm_topk_prob + self.topk_group = config.topk_group + self.num_expert_group = config.n_group + self.correction_bias = ( + self.gate.e_score_correction_bias.data + if self.gate.e_score_correction_bias is not None + else None + ) + + self.deepep_dispatcher = MaybeTboDeepEPDispatcher( + group=parallel_state.get_tp_group().device_group, + router_topk=self.top_k, + permute_fusion=True, + num_experts=self.num_experts, + num_local_experts=config.n_routed_experts // self.tp_size, + hidden_size=config.hidden_size, + params_dtype=config.torch_dtype, + deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]], + async_finish=True, + return_recv_hook=True, + ) + + self._enable_deepep_moe = global_server_args_dict["enable_deepep_moe"] + + +class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer): + def __init__( + self, + config: PretrainedConfig, + layer_id: int, + quant_config: Optional[QuantizationConfig] = None, + is_nextn: bool = False, + prefix: str = "", + alt_stream: Optional[torch.cuda.Stream] = None, + ) -> None: + nn.Module.__init__(self) + self.hidden_size = config.hidden_size + self.config = config + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + rms_norm_eps = config.rms_norm_eps + attention_bias = config.attention_bias + self.enable_dp_attention = global_server_args_dict["enable_dp_attention"] + self.layer_id = layer_id + self.self_attn = Glm4MoeAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + layer_id=layer_id, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + partial_rotary_factor=partial_rotary_factor, + max_position_embeddings=max_position_embeddings, + head_dim=head_dim, + rms_norm_eps=rms_norm_eps, + attention_bias=attention_bias, + quant_config=quant_config, + prefix=add_prefix("self_attn", prefix), + use_qk_norm=config.use_qk_norm, + ) + + self.is_layer_sparse = self._is_layer_sparse(layer_id, is_nextn=is_nextn) + is_previous_layer_sparse = self._is_layer_sparse(layer_id - 1, is_nextn=False) + + num_layers = 1 if is_nextn else config.num_hidden_layers + self.layer_scatter_modes = LayerScatterModes.init_new( + layer_id=layer_id, + num_layers=num_layers, + is_layer_sparse=self.is_layer_sparse, + is_previous_layer_sparse=is_previous_layer_sparse, + ) + + if self.is_layer_sparse: + self.mlp = Glm4MoeSparseMoeBlock( + config=config, + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + layer_id=self.layer_id, + ) + else: + if enable_moe_dense_fully_dp(): + mlp_tp_rank, mlp_tp_size = 0, 1 + else: + mlp_tp_rank, mlp_tp_size = None, None + self.mlp = Glm4MoeMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + tp_rank=mlp_tp_rank, + tp_size=mlp_tp_size, + ) + + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + self.layer_communicator = LayerCommunicator( + layer_scatter_modes=self.layer_scatter_modes, + input_layernorm=self.input_layernorm, + post_attention_layernorm=self.post_attention_layernorm, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + residual: Optional[torch.Tensor], + zero_allocator: BumpAllocator, + ) -> torch.Tensor: + hidden_states, residual = self.layer_communicator.prepare_attn( + hidden_states, residual, forward_batch + ) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + + hidden_states, residual = self.layer_communicator.prepare_mlp( + hidden_states, residual, forward_batch + ) + + hidden_states = self.mlp(hidden_states, forward_batch) + + hidden_states, residual = self.layer_communicator.postprocess_layer( + hidden_states, residual, forward_batch + ) + + return hidden_states, residual + + +class Glm4MoeModel(DeepseekV2Model): + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + nn.Module.__init__(self) + self.padding_id = config.pad_token_id + self.vocab_size = config.vocab_size + self.first_k_dense_replace = config.first_k_dense_replace + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + enable_tp=not global_server_args_dict["enable_dp_attention"], + ) + self.alt_stream = torch.cuda.Stream() if _is_cuda else None + self.layers = nn.ModuleList( + [ + Glm4MoeDecoderLayer( + config, + layer_id, + quant_config=quant_config, + prefix=add_prefix(f"layers.{layer_id}", prefix), + alt_stream=self.alt_stream, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.dp_size = get_local_attention_dp_size() + + +class Glm4MoeForCausalLM(DeepseekV2ForCausalLM): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + nn.Module.__init__(self) + config.moe_layer_freq = 1 + self.config = config + self.tp_size = get_tensor_model_parallel_world_size() + self.quant_config = quant_config + self.determine_num_fused_shared_experts("Glm4MoeForCausalLM") + self.model = Glm4MoeModel( + config, quant_config, prefix=add_prefix("model", prefix) + ) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), + use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], + ) + self.logits_processor = LogitsProcessor(config) + self.dp_size = get_local_attention_dp_size() + + self._routed_experts_weights_of_layer = LazyValue( + lambda: { + layer_id: layer.mlp.get_moe_weights() + for layer_id, layer in enumerate(self.model.layers) + if isinstance(layer.mlp, DeepseekV2MoE) + } + ) + + def determine_num_fused_shared_experts( + self, architecture: str = "DeepseekV3ForCausalLM" + ): + self.num_fused_shared_experts = 0 + if global_server_args_dict["disable_shared_experts_fusion"]: + return + + # Only Deepseek V3/R1 can use shared experts fusion optimization now. + disable_reason = None + if ( + not _is_cuda + or torch.cuda.get_device_capability("cuda") < (8, 0) + or self.config.architectures[0] != architecture + or self.config.n_routed_experts != 128 + or self.config.n_shared_experts != 1 + ): + disable_reason = "Only GLM-4.5 on NV-platform with capability >= 80 can use shared experts fusion optimization." + elif ( + global_server_args_dict["enable_deepep_moe"] + or global_server_args_dict["enable_ep_moe"] + ): + disable_reason = "Deepseek GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode." + + if disable_reason is not None: + global_server_args_dict["disable_shared_experts_fusion"] = True + log_info_on_rank0( + logger, + f"{disable_reason} Shared experts fusion optimization is disabled.", + ) + return + + self.num_fused_shared_experts = self.config.n_shared_experts + + def get_input_embeddings(self) -> nn.Embedding: + return self.model.embed_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False): + + if is_nextn: + if hasattr(self.config, "num_nextn_predict_layers"): + num_nextn_layers = self.config.num_nextn_predict_layers + assert num_nextn_layers == 1, "Only 1 nextn layer is supported" + # compatible with old design + nextn_layer_id = ( + 0 + if self.config.num_hidden_layers == 1 + else self.config.num_hidden_layers + ) + else: + raise ValueError("num_nextn_predict_layers is not in the config") + + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + if self.num_fused_shared_experts > 0: + assert self.num_fused_shared_experts == 1 + weights_list = list(weights) + weights_dict = dict(weights_list) + if self.quant_config is not None: + if self.quant_config.get_name() == "w8a8_int8": + suffix_list = [ + "down_proj.weight", + "down_proj.weight_scale", + "gate_proj.weight", + "gate_proj.weight_scale", + "up_proj.weight", + "up_proj.weight_scale", + ] + elif ( + self.quant_config.get_name() == "fp8" + or self.quant_config.get_name() == "blockwise_int8" + ): + suffix_list = [ + "down_proj.weight", + "down_proj.weight_scale", + "gate_proj.weight", + "gate_proj.weight_scale", + "up_proj.weight", + "up_proj.weight_scale", + ] + elif self.quant_config.get_name() == "awq": + suffix_list = [ + "down_proj.qweight", + "down_proj.qzeros", + "down_proj.scales", + "gate_proj.qweight", + "gate_proj.qzeros", + "gate_proj.scales", + "up_proj.qweight", + "up_proj.qzeros", + "up_proj.scales", + ] + elif self.quant_config.get_name() == "modelopt_fp4": + suffix_list = [ + "down_proj.weight", + "down_proj.weight_scale", + "down_proj.weight_scale_2", + "down_proj.input_scale", + "gate_proj.weight", + "gate_proj.weight_scale", + "gate_proj.weight_scale_2", + "gate_proj.input_scale", + "up_proj.weight", + "up_proj.weight_scale", + "up_proj.weight_scale_2", + "up_proj.input_scale", + ] + else: + raise ValueError( + f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}." + ) + else: + suffix_list = [ + "down_proj.weight", + "gate_proj.weight", + "up_proj.weight", + ] + names_to_remove = [] + + moe_layers = ( + range( + self.config.first_k_dense_replace, + self.config.num_hidden_layers, + self.config.moe_layer_freq, + ) + if not is_nextn + else [nextn_layer_id] + ) + + for moe_layer in moe_layers: + for suffix in suffix_list: + shared_expert_weight_name = ( + f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}" + ) + # online fp8 quantization does not load weight_scale + if shared_expert_weight_name not in weights_dict: + continue + weights_list.append( + ( + f"model.layers.{moe_layer}." + f"mlp.experts." + f"{self.config.n_routed_experts + 0}" + f".{suffix}", + weights_dict[shared_expert_weight_name], + ) + ) + names_to_remove += [shared_expert_weight_name] + weights = [w for w in weights_list if w[0] not in names_to_remove] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = get_moe_impl_class().make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts + self.num_fused_shared_experts, + ) + + # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None + fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and ( + self.config.q_lora_rank is not None + ) + cached_a_proj = {} if fuse_qkv_a_proj else None + + if is_nextn: + nextn_layer_prefix = f"model.layers.{nextn_layer_id}" + nextn_spec_weight_names = [ + "shared_head.norm", + "eh_proj", + "enorm", + "hnorm", + ] + + params_dict = dict(self.named_parameters()) + weight_names = [] + for name, loaded_weight in weights: + weight_names.append(name) + + if not is_nextn: + if hasattr(self.config, "num_nextn_predict_layers"): + num_nextn_layers = self.config.num_nextn_predict_layers + if num_nextn_layers > 0 and name.startswith("model.layers"): + name_list = name.split(".") + if ( + len(name_list) >= 3 + and int(name_list[2]) >= self.config.num_hidden_layers + ): + continue + else: + if not name.startswith(nextn_layer_prefix): + continue + + # Use shared head and embed weights from target model + if "shared_head.head" in name or "embed_tokens" in name: + continue + + is_decoder = True + # For nextn specific weights + for weight_name in nextn_spec_weight_names: + if weight_name in name: + name = name.replace(nextn_layer_prefix, "model") + is_decoder = False + break + # For decoder layer weights + if is_decoder: + name = name.replace(nextn_layer_prefix, "model.decoder") + + if "rotary_emb.inv_freq" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if ("mlp.experts." in name) and name not in params_dict: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if fuse_qkv_a_proj and ( + "q_a_proj" in name or "kv_a_proj_with_mqa" in name + ): + cached_a_proj[name] = loaded_weight + q_a_proj_name = ( + name + if "q_a_proj" in name + else name.replace("kv_a_proj_with_mqa", "q_a_proj") + ) + kv_a_proj_name = ( + name + if "kv_a_proj_with_mqa" in name + else name.replace("q_a_proj", "kv_a_proj_with_mqa") + ) + + # When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter + if ( + q_a_proj_name in cached_a_proj + and kv_a_proj_name in cached_a_proj + ): + q_a_proj_weight = cached_a_proj[q_a_proj_name] + kv_a_proj_weight = cached_a_proj[kv_a_proj_name] + fused_weight = torch.cat( + [q_a_proj_weight, kv_a_proj_weight], dim=0 + ) + param_name = ( + name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa") + if "q_a_proj" in name + else name.replace( + "kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa" + ) + ) + param = params_dict[param_name] + + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, fused_weight) + cached_a_proj.pop(q_a_proj_name) + cached_a_proj.pop(kv_a_proj_name) + else: + if ( + "k_scale" in name or "v_scale" in name + ) and name not in params_dict: + # modelopt attn kv scale is named differently + if any(scale in name for scale in ["k_scale", "v_scale"]): + name = name.replace("_proj", "attn_mqa") + else: + logger.warning( + f"Unknown scale found in checkpoint: {name}" + ) + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + + +EntryClass = [Glm4MoeForCausalLM] diff --git a/python/sglang/srt/models/glm4_moe_nextn.py b/python/sglang/srt/models/glm4_moe_nextn.py new file mode 100644 index 000000000..1a0793d8a --- /dev/null +++ b/python/sglang/srt/models/glm4_moe_nextn.py @@ -0,0 +1,167 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Inference-only GLM-4.5 NextN Speculative Decoding.""" +import logging +from typing import Iterable, Optional, Tuple + +import torch +from torch import nn +from transformers import PretrainedConfig + +from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.models.glm4_moe import Glm4MoeDecoderLayer, Glm4MoeForCausalLM +from sglang.srt.utils import BumpAllocator, add_prefix + +logger = logging.getLogger(__name__) + + +class Glm4MoeModelNextN(nn.Module): + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + if quant_config is not None and quant_config.get_name() == "modelopt_fp4": + logger.warning( + "Overriding Glm4MoeForCausalLMNextN quant config for modelopt_fp4 GLM-4.5 model." + ) + quant_config = None + + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + enable_tp=not global_server_args_dict["enable_dp_attention"], + prefix=add_prefix("embed_tokens", prefix), + ) + + self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False) + + self.decoder = Glm4MoeDecoderLayer( + config, + 0, + quant_config=quant_config, + is_nextn=True, + prefix=add_prefix("decoder", prefix), + ) + + self.shared_head = nn.Module() + self.shared_head.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + zero_allocator = BumpAllocator( + buffer_size=2, + dtype=torch.float32, + device=( + input_embeds.device if input_embeds is not None else input_ids.device + ), + ) + + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + + if hidden_states.shape[0] > 0: + hidden_states = self.eh_proj( + torch.cat( + ( + self.enorm(hidden_states), + self.hnorm(forward_batch.spec_info.hidden_states), + ), + dim=-1, + ) + ) + + residual = None + with get_global_expert_distribution_recorder().disable_this_region(): + hidden_states, residual = self.decoder( + positions, hidden_states, forward_batch, residual, zero_allocator + ) + + if not forward_batch.forward_mode.is_idle(): + if residual is not None: + hidden_states, _ = self.shared_head.norm(hidden_states, residual) + else: + hidden_states = self.shared_head.norm(hidden_states) + + return hidden_states + + +class Glm4MoeForCausalLMNextN(Glm4MoeForCausalLM): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + nn.Module.__init__(self) + self.config = config + self.tp_size = get_tensor_model_parallel_world_size() + self.quant_config = quant_config + self.determine_num_fused_shared_experts("Glm4MoeForCausalLMNextN") + + self.model = Glm4MoeModelNextN( + config, quant_config, prefix=add_prefix("model", prefix) + ) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("model.shared_head.head", prefix), + use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], + ) + self.logits_processor = LogitsProcessor(config) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, forward_batch) + return self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + super().load_weights(weights, is_nextn=True) + + +EntryClass = [Glm4MoeForCausalLMNextN] diff --git a/python/sglang/srt/reasoning_parser.py b/python/sglang/srt/reasoning_parser.py index e51ca5b61..b5b737856 100644 --- a/python/sglang/srt/reasoning_parser.py +++ b/python/sglang/srt/reasoning_parser.py @@ -231,6 +231,7 @@ class ReasoningParser: "deepseek-r1": DeepSeekR1Detector, "qwen3": Qwen3Detector, "qwen3-thinking": Qwen3ThinkingDetector, + "glm45": Qwen3Detector, "kimi": KimiDetector, } diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index b0e6fbab3..54dc76ed7 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -513,7 +513,7 @@ class ServerArgs: ) model_arch = self.get_hf_config().architectures[0] - if model_arch == "DeepseekV3ForCausalLM": + if model_arch in ["DeepseekV3ForCausalLM", "Glm4MoeForCausalLM"]: # Auto set draft_model_path DeepSeek-V3/R1 if self.speculative_draft_model_path is None: self.speculative_draft_model_path = self.model_path @@ -1108,6 +1108,7 @@ class ServerArgs: "pythonic", "kimi_k2", "qwen3_coder", + "glm45", ], default=ServerArgs.tool_call_parser, help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', and 'qwen3_coder'.", diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 29bb18b08..f824a006a 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -2343,6 +2343,7 @@ def is_fa3_default_architecture(hf_config): "Gemma3ForConditionalGeneration", "Qwen3ForCausalLM", "Qwen3MoeForCausalLM", + "Glm4MoeForCausalLM", } return architectures[0] in default_archs diff --git a/test/srt/openai_server/features/test_enable_thinking.py b/test/srt/openai_server/features/test_enable_thinking.py index 37fb6ca7c..78354673c 100644 --- a/test/srt/openai_server/features/test_enable_thinking.py +++ b/test/srt/openai_server/features/test_enable_thinking.py @@ -43,6 +43,7 @@ class TestEnableThinking(CustomTestCase): "qwen3", ], ) + cls.additional_chat_kwargs = {} @classmethod def tearDownClass(cls): @@ -59,6 +60,7 @@ class TestEnableThinking(CustomTestCase): "temperature": 0, "separate_reasoning": True, "chat_template_kwargs": {"enable_thinking": True}, + **self.additional_chat_kwargs, }, ) @@ -82,6 +84,7 @@ class TestEnableThinking(CustomTestCase): "temperature": 0, "separate_reasoning": True, "chat_template_kwargs": {"enable_thinking": False}, + **self.additional_chat_kwargs, }, ) @@ -107,6 +110,7 @@ class TestEnableThinking(CustomTestCase): "separate_reasoning": True, "stream": True, "chat_template_kwargs": {"enable_thinking": True}, + **self.additional_chat_kwargs, }, stream=True, ) @@ -151,6 +155,7 @@ class TestEnableThinking(CustomTestCase): "separate_reasoning": True, "stream": True, "chat_template_kwargs": {"enable_thinking": False}, + **self.additional_chat_kwargs, }, stream=True, ) @@ -184,5 +189,55 @@ class TestEnableThinking(CustomTestCase): ) +## Skip for ci test +# class TestGLM45EnableThinking(TestEnableThinking): +# @classmethod +# def setUpClass(cls): +# # Replace with the model name needed for testing; if not required, reuse DEFAULT_SMALL_MODEL_NAME_FOR_TEST +# cls.model = "THUDM/GLM-4.5" +# cls.base_url = DEFAULT_URL_FOR_TEST +# cls.api_key = "sk-1234" +# cls.process = popen_launch_server( +# cls.model, +# cls.base_url, +# timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, +# api_key=cls.api_key, +# other_args=[ +# "--tool-call-parser", +# "glm45", +# "--reasoning-parser", +# "glm45", +# "--tp-size", +# "8" +# ], +# ) + +# # Validate whether enable-thinking conflict with tool_calls +# cls.additional_chat_kwargs = { +# "tools": [ +# { +# "type": "function", +# "function": { +# "name": "add", +# "description": "Compute the sum of two numbers", +# "parameters": { +# "type": "object", +# "properties": { +# "a": { +# "type": "int", +# "description": "A number", +# }, +# "b": { +# "type": "int", +# "description": "A number", +# }, +# }, +# "required": ["a", "b"], +# }, +# }, +# } +# ] +# } + if __name__ == "__main__": unittest.main() diff --git a/test/srt/openai_server/function_call/test_openai_function_calling.py b/test/srt/openai_server/function_call/test_openai_function_calling.py index 2486cc050..1d687eb7f 100644 --- a/test/srt/openai_server/function_call/test_openai_function_calling.py +++ b/test/srt/openai_server/function_call/test_openai_function_calling.py @@ -223,7 +223,10 @@ class TestOpenAIServerFunctionCalling(CustomTestCase): messages = [ {"role": "system", "content": self.SYSTEM_MESSAGE}, - {"role": "user", "content": "What is the temperature in Paris?"}, + { + "role": "user", + "content": "What is the temperature in Paris in celsius??", + }, ] response_stream = client.chat.completions.create( @@ -910,5 +913,40 @@ class TestOpenAIPythonicFunctionCalling(CustomTestCase): ) +## Skip for ci test +# class TestGLM45ServerFunctionCalling(TestOpenAIServerFunctionCalling): +# @classmethod +# def setUpClass(cls): +# # Replace with the model name needed for testing; if not required, reuse DEFAULT_SMALL_MODEL_NAME_FOR_TEST +# cls.model = "THUDM/GLM-4.5" +# cls.base_url = DEFAULT_URL_FOR_TEST +# cls.api_key = "sk-123456" + +# # Start the local OpenAI Server. If necessary, you can add other parameters such as --enable-tools. +# cls.process = popen_launch_server( +# cls.model, +# cls.base_url, +# timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, +# api_key=cls.api_key, +# other_args=[ +# # If your server needs extra parameters to test function calling, please add them here. +# "--tool-call-parser", +# "glm45", +# "--reasoning-parser", +# "glm45", +# "--tp-size", +# "8" +# ], +# ) +# cls.base_url += "/v1" +# cls.tokenizer = get_tokenizer(cls.model) + +# # This test is too difficult for GLM4-moe. Skip it from the UT +# def test_function_call_required(self): +# pass + +# def test_function_calling_multiturn(self): +# self._test_function_calling_multiturn() + if __name__ == "__main__": unittest.main() diff --git a/test/srt/test_function_call_parser.py b/test/srt/test_function_call_parser.py index 511020651..32b7e4a5b 100644 --- a/test/srt/test_function_call_parser.py +++ b/test/srt/test_function_call_parser.py @@ -6,6 +6,7 @@ from xgrammar import GrammarCompiler, TokenizerInfo from sglang.srt.entrypoints.openai.protocol import Function, Tool from sglang.srt.function_call.base_format_detector import BaseFormatDetector from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector +from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector from sglang.srt.function_call.kimik2_detector import KimiK2Detector from sglang.srt.function_call.llama32_detector import Llama32Detector from sglang.srt.function_call.mistral_detector import MistralDetector @@ -510,6 +511,7 @@ class TestEBNFGeneration(unittest.TestCase): self.qwen25_detector = Qwen25Detector() self.qwen3_coder_detector = Qwen3CoderDetector() self.kimik2_detector = KimiK2Detector() + self.glm45_detector = Glm4MoeDetector() def test_pythonic_detector_ebnf(self): """Test that the PythonicDetector generates valid EBNF.""" @@ -622,6 +624,29 @@ class TestEBNFGeneration(unittest.TestCase): except RuntimeError as e: self.fail(f"Failed to compile EBNF: {e}") + def test_glm45_detector_ebnf(self): + """Test that the Glm4MoeDetector generates valid EBNF.""" + ebnf = self.glm45_detector.build_ebnf(self.tools) + self.assertIsNotNone(ebnf) + # Check that the EBNF contains expected patterns for XML format + self.assertIn('"" function_call ""', ebnf) + self.assertIn('"get_weather" "\\n" arguments_get_weather', ebnf) + self.assertIn( + '"location" "\\n" "" xml_text "" ( "\\n" ( "unit" "\\n" "" ("celsius" | "fahrenheit") "" ) )?', + ebnf, + ) + self.assertIn('"search" "\\n" arguments_search', ebnf) + self.assertIn( + '"query" "\\n" "" xml_text ""', + ebnf, + ) + # Validate that the EBNF can be compiled by GrammarCompiler + try: + ctx = self.grammar_compiler.compile_grammar(ebnf) + self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully") + except RuntimeError as e: + self.fail(f"Failed to compile EBNF: {e}") + def test_qwen3_coder_detector_ebnf(self): """Test that the Qwen3CoderDetector generates valid EBNF.""" ebnf = self.qwen3_coder_detector.build_ebnf(self.tools) @@ -1919,5 +1944,164 @@ circle self.assertEqual(params2["dimensions"], {"radius": 5}) +class TestGlm4MoeDetector(unittest.TestCase): + def setUp(self): + self.tools = [ + Tool( + type="function", + function=Function( + name="get_weather", + description="Get weather information", + parameters={ + "type": "object", + "properties": { + "city": {"type": "string", "description": "City name"}, + "date": {"type": "string", "description": "Date"}, + }, + "required": ["city", "date"], + }, + ), + ), + ] + self.detector = Glm4MoeDetector() + + def test_single_tool_call(self): + text = ( + "get_weather\n" + "city\nBeijing\n" + "date\n2024-06-27\n" + "" + ) + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual( + result.calls[0].parameters, '{"city": "Beijing", "date": "2024-06-27"}' + ) + self.assertEqual(result.normal_text, "") + + def test_multiple_tool_calls(self): + text = ( + "get_weather\n" + "city\nBeijing\n" + "date\n2024-06-27\n" + "" + "get_weather\n" + "city\nShanghai\n" + "date\n2024-06-28\n" + "" + ) + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 2) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual( + result.calls[0].parameters, '{"city": "Beijing", "date": "2024-06-27"}' + ) + self.assertEqual(result.calls[1].name, "get_weather") + self.assertEqual( + result.calls[1].parameters, '{"city": "Shanghai", "date": "2024-06-28"}' + ) + self.assertEqual(result.normal_text, "") + + def test_streaming_tool_call(self): + """Test streaming incremental parsing of a tool call.""" + chunks = [ + "get_weather\n", + "city\nBeijing\n", + "date\n2024-06-27\n", + "", + ] + tool_calls = [] + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + for tool_call_chunk in result.calls: + if ( + hasattr(tool_call_chunk, "tool_index") + and tool_call_chunk.tool_index is not None + ): + while len(tool_calls) <= tool_call_chunk.tool_index: + tool_calls.append({"name": "", "parameters": {}}) + tc = tool_calls[tool_call_chunk.tool_index] + if tool_call_chunk.name: + tc["name"] = tool_call_chunk.name + if tool_call_chunk.parameters: + tc["parameters"] = tool_call_chunk.parameters + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0]["name"], "get_weather") + self.assertEqual( + tool_calls[0]["parameters"], '{"city": "Beijing", "date": "2024-06-27"}' + ) + + def test_streaming_multiple_tool_calls(self): + """Test streaming incremental parsing of multiple tool calls.""" + chunks = [ + "get_weather\n", + "city\nBeijing\n", + "date\n2024-06-27\n", + "get_weather\n", + "city\nShanghai\n", + "date\n2024-06-28\n", + "", + ] + tool_calls = [] + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + for tool_call_chunk in result.calls: + if ( + hasattr(tool_call_chunk, "tool_index") + and tool_call_chunk.tool_index is not None + ): + while len(tool_calls) <= tool_call_chunk.tool_index: + tool_calls.append({"name": "", "parameters": {}}) + tc = tool_calls[tool_call_chunk.tool_index] + if tool_call_chunk.name: + tc["name"] = tool_call_chunk.name + if tool_call_chunk.parameters: + tc["parameters"] = tool_call_chunk.parameters + self.assertEqual(len(tool_calls), 2) + self.assertEqual(tool_calls[0]["name"], "get_weather") + self.assertEqual( + tool_calls[0]["parameters"], '{"city": "Beijing", "date": "2024-06-27"}' + ) + self.assertEqual(tool_calls[1]["name"], "get_weather") + self.assertEqual( + tool_calls[1]["parameters"], '{"city": "Shanghai", "date": "2024-06-28"}' + ) + + def test_tool_call_completion(self): + """Test that the buffer and state are reset after a tool call is completed.""" + chunks = [ + "get_weather\n", + "city\nBeijing\n", + "date\n2024-06-27\n", + "", + ] + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + self.assertEqual(self.detector.current_tool_id, 1) + + def test_invalid_tool_call(self): + """Test that invalid tool calls are handled correctly.""" + text = "invalid_func\ncity\nBeijing\n" + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 0) + + def test_partial_tool_call(self): + """Test parsing a partial tool call that spans multiple chunks.""" + text1 = "get_weather\ncity\n" + result1 = self.detector.parse_streaming_increment(text1, self.tools) + self.assertEqual(result1.normal_text, "") + self.assertEqual(result1.calls, []) + self.assertEqual(self.detector._buffer, text1) + text2 = "Beijing\ndate\n2024-06-27\n" + result2 = self.detector.parse_streaming_increment(text2, self.tools) + self.assertEqual(len(result2.calls), 1) + self.assertEqual(result2.calls[0].name, "get_weather") + self.assertEqual( + result2.calls[0].parameters, '{"city": "Beijing", "date": "2024-06-27"}' + ) + self.assertEqual(self.detector._buffer, "") + + if __name__ == "__main__": unittest.main()