GLM-4.5 Model Support (#8224)

Co-authored-by: Lifu Huang <lifu.hlf@gmail.com>
Co-authored-by: Binyao Jiang <byjiang1996@gmail.com>
Co-authored-by: Stefan He <hebiaobuaa@gmail.com>
This commit is contained in:
Yuxuan Zhang
2025-07-28 13:54:07 +08:00
committed by GitHub
parent 2fd5c7049f
commit 6d6a8bc278
14 changed files with 1673 additions and 7 deletions

View File

@@ -127,6 +127,9 @@ class ModelConfig:
):
self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN"
if is_draft_model and self.hf_config.architectures[0] == "Glm4MoeForCausalLM":
self.hf_config.architectures[0] = "Glm4MoeForCausalLMNextN"
if is_draft_model and self.hf_config.architectures[0] == "MiMoForCausalLM":
self.hf_config.architectures[0] = "MiMoMTP"
# Check model type

View File

@@ -165,6 +165,7 @@ class EBNFComposer:
tool_call_separator: Optional[str] = None,
call_rule_fmt: Optional[str] = None,
key_value_rule_fmt: Optional[str] = None,
key_value_separator: str = ",",
):
"""
Generalized EBNF builder for all detectors.
@@ -279,7 +280,11 @@ class EBNFComposer:
# Add required properties joined by commas
if required:
rule_parts.append(' "," '.join(prop_kv_pairs[k] for k in required))
rule_parts.append(
f' "{key_value_separator}" '.join(
prop_kv_pairs[k] for k in required
)
)
# Add optional properties with flexible ordering
if optional:
@@ -292,13 +297,15 @@ class EBNFComposer:
if j == i:
opt_parts.append(prop_kv_pairs[optional[j]])
else:
opt_parts.append(f' ( "," {prop_kv_pairs[optional[j]]} )?')
opt_parts.append(
f' ( "{key_value_separator}" {prop_kv_pairs[optional[j]]} )?'
)
opt_alternatives.append("".join(opt_parts))
# Wrap with appropriate comma handling based on whether we have required properties
if required:
# Required properties exist, so optional group needs outer comma
rule_parts.append(' ( "," ( ')
rule_parts.append(f' ( "{key_value_separator}" ( ')
rule_parts.append(" | ".join(opt_alternatives))
rule_parts.append(" ) )?")
else:

View File

@@ -10,6 +10,7 @@ from sglang.srt.entrypoints.openai.protocol import (
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import ToolCallItem
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector
from sglang.srt.function_call.kimik2_detector import KimiK2Detector
from sglang.srt.function_call.llama32_detector import Llama32Detector
from sglang.srt.function_call.mistral_detector import MistralDetector
@@ -37,6 +38,7 @@ class FunctionCallParser:
"pythonic": PythonicDetector,
"kimi_k2": KimiK2Detector,
"qwen3_coder": Qwen3CoderDetector,
"glm45": Glm4MoeDetector,
}
def __init__(self, tools: List[Tool], tool_call_parser: str):

View File

@@ -0,0 +1,165 @@
import ast
import json
import logging
import re
from typing import List
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
StructureInfo,
_GetInfoFunc,
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
logger = logging.getLogger(__name__)
def get_argument_type(func_name: str, arg_key: str, defined_tools: list):
name2tool = {tool.function.name: tool for tool in defined_tools}
if func_name not in name2tool:
return None
tool = name2tool[func_name]
if arg_key not in tool.function.parameters["properties"]:
return None
return tool.function.parameters["properties"][arg_key].get("type", None)
def parse_arguments(json_value):
try:
try:
parsed_value = json.loads(json_value)
except:
parsed_value = ast.literal_eval(json_value)
return parsed_value, True
except:
return json_value, False
class Glm4MoeDetector(BaseFormatDetector):
"""
Detector for GLM-4.5 models.
Assumes function call format:
<tool_call>get_weather\n<arg_key>city</arg_key>\n<arg_value>北京</arg_value>\n<arg_key>date</arg_key>\n<arg_value>2024-06-27</arg_value>\n</tool_call>\n<tool_call>get_weather\n<arg_key>city</arg_key>\n<arg_value>上海</arg_value>\n<arg_key>date</arg_key>\n<arg_value>2024-06-27</arg_value>\n</tool_call>
"""
def __init__(self):
super().__init__()
self.bot_token = "<tool_call>"
self.eot_token = "</tool_call>"
self.func_call_regex = r"<tool_call>.*?</tool_call>"
self.func_detail_regex = r"<tool_call>([^\n]*)\n(.*)</tool_call>"
self.func_arg_regex = r"<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>"
def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a glm-4.5 format tool call."""
return self.bot_token in text
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""
One-time parsing: Detects and parses tool calls in the provided text.
:param text: The complete text to parse.
:param tools: List of available tools.
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
"""
idx = text.find(self.bot_token)
normal_text = text[:idx].strip() if idx != -1 else text
if self.bot_token not in text:
return StreamingParseResult(normal_text=normal_text, calls=[])
match_result_list = re.findall(self.func_call_regex, text, re.DOTALL)
calls = []
try:
for match_result in match_result_list:
# Get function name
func_detail = re.search(self.func_detail_regex, match_result, re.DOTALL)
func_name = func_detail.group(1)
func_args = func_detail.group(2)
pairs = re.findall(
r"<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>",
func_args,
re.DOTALL,
)
arguments = {}
for arg_key, arg_value in pairs:
arg_key = arg_key.strip()
arg_value = arg_value.strip()
arg_type = get_argument_type(func_name, arg_key, tools)
if arg_type != "string":
arg_value, is_good_json = parse_arguments(arg_value)
arguments[arg_key] = arg_value
# construct match_result for parse_base_json
match_result = {"name": func_name, "parameters": arguments}
calls.extend(self.parse_base_json(match_result, tools))
return StreamingParseResult(normal_text=normal_text, calls=calls)
except Exception as e:
logger.error(f"Error in detect_and_parse: {e}")
# return the normal text if parsing fails
return StreamingParseResult(normal_text=text)
def parse_streaming_increment(
self, new_text: str, tools: List[Tool]
) -> StreamingParseResult:
"""
Streaming incremental parsing tool calls for GLM-4.5 format.
"""
self._buffer += new_text
current_text = self._buffer
start = current_text.find(self.bot_token)
if start == -1:
self._buffer = ""
if self.current_tool_id > 0:
current_text = ""
return StreamingParseResult(normal_text=current_text)
# find ensures we find the first self.eot_token so there will be at most one tool_call in current_text[:end+len(self.eot_token)
end = current_text.find(self.eot_token)
if end != -1:
# Initialize state if this is the first tool call
if self.current_tool_id == -1:
self.current_tool_id = 0
self.prev_tool_call_arr = []
self.streamed_args_for_tool = [""]
# Ensure we have enough entries in our tracking arrays
while len(self.prev_tool_call_arr) <= self.current_tool_id:
self.prev_tool_call_arr.append({})
while len(self.streamed_args_for_tool) <= self.current_tool_id:
self.streamed_args_for_tool.append("")
result = self.detect_and_parse(
current_text[: end + len(self.eot_token)], tools=tools
)
if result.calls:
self.prev_tool_call_arr[self.current_tool_id] = {
"name": result.calls[0].name,
"arguments": json.loads(result.calls[0].parameters),
}
self.streamed_args_for_tool[self.current_tool_id] = result.calls[
0
].parameters
result.calls[0].tool_index = self.current_tool_id
self.current_tool_id += 1
self._buffer = current_text[end + len(self.eot_token) :]
return result
normal_text = current_text[:start]
self._buffer = current_text[start:]
return StreamingParseResult(normal_text=normal_text)
def supports_structural_tag(self) -> bool:
return False
def structure_info(self) -> _GetInfoFunc:
raise NotImplementedError()
def build_ebnf(self, tools: List[Tool]):
return EBNFComposer.build_ebnf(
tools,
individual_call_start_token=self.bot_token,
individual_call_end_token=self.eot_token,
# GLM4Moe is not compatible with multiple tool_calls under tool_choice condition: it will output unlimited tool_calls...
# tool_call_separator="\\n",
function_format="xml",
call_rule_fmt='"{name}" "\\n" {arguments_rule} "\\n"',
key_value_rule_fmt='"<arg_key>{key}</arg_key>" "\\n" "<arg_value>" {valrule} "</arg_value>"',
key_value_separator="\\n",
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,167 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Inference-only GLM-4.5 NextN Speculative Decoding."""
import logging
from typing import Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.glm4_moe import Glm4MoeDecoderLayer, Glm4MoeForCausalLM
from sglang.srt.utils import BumpAllocator, add_prefix
logger = logging.getLogger(__name__)
class Glm4MoeModelNextN(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
logger.warning(
"Overriding Glm4MoeForCausalLMNextN quant config for modelopt_fp4 GLM-4.5 model."
)
quant_config = None
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
enable_tp=not global_server_args_dict["enable_dp_attention"],
prefix=add_prefix("embed_tokens", prefix),
)
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
self.decoder = Glm4MoeDecoderLayer(
config,
0,
quant_config=quant_config,
is_nextn=True,
prefix=add_prefix("decoder", prefix),
)
self.shared_head = nn.Module()
self.shared_head.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
zero_allocator = BumpAllocator(
buffer_size=2,
dtype=torch.float32,
device=(
input_embeds.device if input_embeds is not None else input_ids.device
),
)
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_embeds
if hidden_states.shape[0] > 0:
hidden_states = self.eh_proj(
torch.cat(
(
self.enorm(hidden_states),
self.hnorm(forward_batch.spec_info.hidden_states),
),
dim=-1,
)
)
residual = None
with get_global_expert_distribution_recorder().disable_this_region():
hidden_states, residual = self.decoder(
positions, hidden_states, forward_batch, residual, zero_allocator
)
if not forward_batch.forward_mode.is_idle():
if residual is not None:
hidden_states, _ = self.shared_head.norm(hidden_states, residual)
else:
hidden_states = self.shared_head.norm(hidden_states)
return hidden_states
class Glm4MoeForCausalLMNextN(Glm4MoeForCausalLM):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config
self.determine_num_fused_shared_experts("Glm4MoeForCausalLMNextN")
self.model = Glm4MoeModelNextN(
config, quant_config, prefix=add_prefix("model", prefix)
)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("model.shared_head.head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
super().load_weights(weights, is_nextn=True)
EntryClass = [Glm4MoeForCausalLMNextN]

View File

@@ -231,6 +231,7 @@ class ReasoningParser:
"deepseek-r1": DeepSeekR1Detector,
"qwen3": Qwen3Detector,
"qwen3-thinking": Qwen3ThinkingDetector,
"glm45": Qwen3Detector,
"kimi": KimiDetector,
}

View File

@@ -513,7 +513,7 @@ class ServerArgs:
)
model_arch = self.get_hf_config().architectures[0]
if model_arch == "DeepseekV3ForCausalLM":
if model_arch in ["DeepseekV3ForCausalLM", "Glm4MoeForCausalLM"]:
# Auto set draft_model_path DeepSeek-V3/R1
if self.speculative_draft_model_path is None:
self.speculative_draft_model_path = self.model_path
@@ -1108,6 +1108,7 @@ class ServerArgs:
"pythonic",
"kimi_k2",
"qwen3_coder",
"glm45",
],
default=ServerArgs.tool_call_parser,
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', and 'qwen3_coder'.",

View File

@@ -2343,6 +2343,7 @@ def is_fa3_default_architecture(hf_config):
"Gemma3ForConditionalGeneration",
"Qwen3ForCausalLM",
"Qwen3MoeForCausalLM",
"Glm4MoeForCausalLM",
}
return architectures[0] in default_archs