Support glm4.1v and glm4.5v (#8798)

Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: Xinyuan Tong <justinning0323@outlook.com>
Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
Co-authored-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: zRzRzRzRzRzRzR <2448370773@qq.com>
Co-authored-by: Minglei Zhu <mingleizhu1122@gmail.com>
Co-authored-by: Chang Su <csu272@usc.edu>
This commit is contained in:
Binyao Jiang
2025-08-09 00:59:13 -07:00
committed by GitHub
parent faa25df1ae
commit f29aba8c6e
21 changed files with 1584 additions and 19 deletions

View File

@@ -27,6 +27,18 @@ python -m sglang.launch_server --model-path microsoft/Phi-4-multimodal-instruct
python -m benchmark/mmmu/bench_sglang.py --concurrency 8 --lora-path vision
```
You can use `--response-answer-regex` to specify how to extract the answer from the response string. E.g.,
```
python3 -m sglang.launch_server --model-path zai-org/GLM-4.1V-9B-Thinking --reasoning-parser glm45
python3 bench_sglang.py --response-answer-regex "<\|begin_of_box\|>(.*)<\|end_of_box\|>" --concurrency 64
```
You can use `--extra-request-body` to specify additional OpenAI request parameters. E.g.,
```
python3 bench_sglang.py --extra-request-body '{"max_new_tokens": 128, "temperature": 0.01}'
```
### Evaluate hf
```

View File

@@ -11,6 +11,7 @@ The eval output will be logged
import argparse
import asyncio
import re
import sys
import time
import traceback
@@ -145,7 +146,17 @@ async def eval_mmmu(args) -> None:
_, response = await process_sample(
client, sample, sampling_params, lora_path
)
process_result(response, sample, answer_dict, out_samples)
answer = (
re.search(args.response_answer_regex, response)
if response is not None
else None
)
process_result(
answer.group(1) if answer else response,
sample,
answer_dict,
out_samples,
)
else:
semaphore = asyncio.Semaphore(args.concurrency)
tasks = [
@@ -157,7 +168,17 @@ async def eval_mmmu(args) -> None:
for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks)):
sample, response = await coro
process_result(response, sample, answer_dict, out_samples)
answer = (
re.search(args.response_answer_regex, response)
if response is not None
else None
)
process_result(
answer.group(1) if answer else response,
sample,
answer_dict,
out_samples,
)
if args.profile:
print("Stopping profiler...")

View File

@@ -35,6 +35,7 @@ class EvalArgs:
profile: bool = False
profile_number: int = 5
concurrency: int = 1
response_answer_regex: str = "(.*)"
lora_path: Optional[str] = None
@staticmethod
@@ -92,6 +93,12 @@ class EvalArgs:
default=EvalArgs.concurrency,
help="Number of concurrent requests to make during evaluation. Default is 1, which means no concurrency.",
)
parser.add_argument(
"--response-answer-regex",
type=str,
default=EvalArgs.response_answer_regex,
help="Specific regex to capture the answer from the response, string",
)
parser.add_argument(
"--lora-path",
type=str,

View File

@@ -39,3 +39,4 @@ in the GitHub search bar.
| **Mistral-Small-3.1-24B** | `mistralai/Mistral-Small-3.1-24B-Instruct-2503` | `mistral` | Mistral 3.1 is a multimodal model that can generate text from text or images input. It also supports tool calling and structured output. |
| **Phi-4-multimodal-instruct** | `microsoft/Phi-4-multimodal-instruct` | `phi-4-mm` | Phi-4-multimodal-instruct is the multimodal variant of the Phi-4-mini model, enhanced with LoRA for improved multimodal capabilities. It supports text, vision and audio modalities in SGLang. |
| **MiMo-VL** (7B) | `XiaomiMiMo/MiMo-VL-7B-RL` | `mimo-vl` | Xiaomi's compact yet powerful vision-language model featuring a native resolution ViT encoder for fine-grained visual details, an MLP projector for cross-modal alignment, and the MiMo-7B language model optimized for complex reasoning tasks. |
| **GLM-4.5V** (106B) / **GLM-4.1V**(9B) | `zai-org/GLM-4.5V` | `glm-4v` | GLM-4.5V and GLM-4.1V-Thinking: Towards Versatile Multimodal Reasoning with Scalable Reinforcement Learning |

View File

@@ -505,6 +505,22 @@ register_chat_template(
)
)
# Reference: https://huggingface.co/docs/transformers/main/model_doc/glm4_v#usage-example
register_chat_template(
ChatTemplate(
name="glm-4v",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("<|system|>\n", "\n"),
"user": ("<|user|>\n", "\n"),
"assistant": ("<|assistant|>\n", "\n"),
},
style=ChatTemplateStyle.PLAIN,
stop_str=["<|user|>", "<|endoftext|>", "<|observation|>"],
image_token="<|image|>",
)
)
@register_chat_template_matching_function
def match_deepseek(model_path: str):
@@ -562,6 +578,8 @@ def match_chat_ml(model_path: str):
return "chatml"
if re.search(r"qwen.*vl", model_path, re.IGNORECASE):
return "qwen2-vl"
if re.search(r"glm[-_]?4(\.\d+)?v", model_path, re.IGNORECASE):
return "glm-4v"
if re.search(r"qwen.*(chat|instruct)", model_path, re.IGNORECASE) and not re.search(
r"llava", model_path, re.IGNORECASE
):

View File

@@ -659,6 +659,8 @@ multimodal_model_archs = [
"DeepseekVL2ForCausalLM",
"Gemma3ForConditionalGeneration",
"Gemma3nForConditionalGeneration",
"Glm4vForConditionalGeneration",
"Glm4vMoeForConditionalGeneration",
"Grok1VForCausalLM",
"Grok1AForCausalLM",
"LlavaLlamaForCausalLM",

View File

@@ -316,6 +316,7 @@ class EBNFComposer:
combined_args = "".join(rule_parts)
arguments_rule = args_template.format(arg_rules=combined_args)
arguments_rule = arguments_rule or '""'
# Add the function call rule and its arguments rule
ebnf_lines.append(

View File

@@ -158,7 +158,7 @@ class Glm4MoeDetector(BaseFormatDetector):
individual_call_end_token=self.eot_token,
tool_call_separator="\\n",
function_format="xml",
call_rule_fmt='"{name}" "\\n" {arguments_rule} "\\n"',
call_rule_fmt='"{name}" "\\n" ( {arguments_rule} "\\n" )?',
key_value_rule_fmt='"<arg_key>{key}</arg_key>" "\\n" "<arg_value>" {valrule} "</arg_value>"',
key_value_separator="\\n",
)

View File

@@ -102,6 +102,12 @@ def detect_jinja_template_content_format(chat_template: str) -> str:
if _is_var_or_elems_access(loop_iter, "message", "content"):
return "openai" # Found content iteration → openai format
# Also check for patterns like: {%- for item in msg.content -%} or {%- for item in m.content -%}
if _is_var_or_elems_access(
loop_iter, "msg", "content"
) or _is_var_or_elems_access(loop_iter, "m", "content"):
return "openai" # Found content iteration → openai format (glm4v)
return "string" # No content loops found → string format
except Exception as e:
logger.debug(f"Error when parsing AST of Jinja template: {e}")

View File

@@ -1,6 +1,7 @@
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/refs/tags/v0.6.6.post1/vllm/model_executor/layers/rotary_embedding.py
"""Rotary Positional Embeddings."""
import itertools
import math
from typing import Any, Dict, List, Optional, Tuple, Union
@@ -946,7 +947,37 @@ class MRotaryEmbedding(RotaryEmbedding):
self.mrope_section = mrope_section
if self.mrope_section:
assert sum(self.mrope_section) == rotary_dim // 2
expected_sum = rotary_dim // 2
actual_sum = sum(self.mrope_section)
if actual_sum != expected_sum:
print(
f"MRoPE section sum mismatch: expected {expected_sum}, got {actual_sum}. "
f"Adjusting mrope_section to match rotary_dim // 2 = {expected_sum}"
)
# Auto-correct by scaling the mrope_section proportionally
if actual_sum > 0:
scale_factor = expected_sum / actual_sum
self.mrope_section = [
max(1, int(section * scale_factor))
for section in self.mrope_section
]
# Ensure the sum exactly matches by adjusting the last element
current_sum = sum(self.mrope_section)
if current_sum != expected_sum:
self.mrope_section[-1] += expected_sum - current_sum
else:
# If all sections are 0, create a default distribution
self.mrope_section = [
expected_sum // len(self.mrope_section)
] * len(self.mrope_section)
# Handle remainder
remainder = expected_sum % len(self.mrope_section)
for i in range(remainder):
self.mrope_section[i] += 1
print(
f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})"
)
def forward(
self,
@@ -1153,6 +1184,204 @@ class MRotaryEmbedding(RotaryEmbedding):
mrope_position_deltas = max_position_ids + 1 - s
return position_ids, mrope_position_deltas
# Adapted from https://github.com/vllm-project/vllm/blob/3779eb8c81449b924a23457fc77e45a0e6171178/vllm/model_executor/layers/rotary_embedding.py#L1120
@staticmethod
def get_rope_index_glm4v(
input_ids: torch.Tensor,
hf_config: Any,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
attention_mask: torch.Tensor,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Get mrope input positions and delta value for GLM4V."""
image_token_id = hf_config.image_token_id
video_start_token_id = hf_config.video_start_token_id
video_end_token_id = hf_config.video_end_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
mrope_position_deltas = []
if input_ids is not None and (
image_grid_thw is not None or video_grid_thw is not None
):
total_input_ids = input_ids
if attention_mask is None:
attention_mask = torch.ones_like(total_input_ids)
position_ids = torch.ones(
3,
input_ids.shape[0],
input_ids.shape[1],
dtype=input_ids.dtype,
device=input_ids.device,
)
image_index, video_index = 0, 0
video_group_index = 0
attention_mask = attention_mask.to(total_input_ids.device)
for i, input_ids in enumerate(total_input_ids):
input_ids = input_ids[attention_mask[i] == 1]
input_tokens = input_ids.tolist()
input_token_type = []
video_check_flg = False
for token in input_tokens:
if token == video_start_token_id:
video_check_flg = True
elif token == video_end_token_id:
video_check_flg = False
if token == image_token_id and not video_check_flg:
input_token_type.append("image")
elif token == image_token_id and video_check_flg:
input_token_type.append("video")
else:
input_token_type.append("text")
input_type_group = []
for key, group in itertools.groupby(
enumerate(input_token_type), lambda x: x[1]
):
group = list(group)
start_index = group[0][0]
end_index = group[-1][0] + 1
input_type_group.append((key, start_index, end_index))
llm_pos_ids_list = []
video_frame_num = 1
for modality_type, start_idx, end_idx in input_type_group:
st_idx = (
llm_pos_ids_list[-1].max() + 1
if len(llm_pos_ids_list) > 0
else 0
)
if modality_type == "image":
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
llm_grid_t, llm_grid_h, llm_grid_w = (
t.item(),
h.item() // spatial_merge_size,
w.item() // spatial_merge_size,
)
t_index = (
torch.arange(llm_grid_t)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
.flatten()
)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx
)
image_index += 1
video_frame_num = 1
elif modality_type == "video":
t, h, w = (
video_frame_num,
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h.item() // spatial_merge_size,
w.item() // spatial_merge_size,
)
for t_idx in range(llm_grid_t):
t_index = (
torch.tensor(t_idx)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
.flatten()
)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(1, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(1, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx
)
video_group_index += 1
if video_group_index >= video_grid_thw[video_index][0]:
video_index += 1
video_group_index = 0
video_frame_num += 1
else:
text_len = end_idx - start_idx
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
video_frame_num = 1
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(
position_ids.device
)
mrope_position_deltas.append(
llm_positions.max() + 1 - len(total_input_ids[i])
)
mrope_position_deltas = torch.tensor(
mrope_position_deltas, device=input_ids.device
).unsqueeze(1)
return position_ids, mrope_position_deltas
else:
if attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = (
position_ids.unsqueeze(0)
.expand(3, -1, -1)
.to(attention_mask.device)
)
max_position_ids = position_ids.max(0, keepdim=False)[0].max(
-1, keepdim=True
)[0]
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
else:
position_ids = (
torch.arange(input_ids.shape[1], device=input_ids.device)
.view(1, 1, -1)
.expand(3, input_ids.shape[0], -1)
)
mrope_position_deltas = torch.zeros(
[input_ids.shape[0], 1],
device=input_ids.device,
dtype=input_ids.dtype,
)
return position_ids, mrope_position_deltas
@staticmethod
def get_next_input_positions(
mrope_position_delta: int,

View File

@@ -218,6 +218,12 @@ class Glm4Model(nn.Module):
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self) -> nn.Embedding:
return self.embed_tokens
def dtype(self) -> torch.dtype:
return next(self.parameters()).dtype
@torch.no_grad()
def forward(
self,

View File

@@ -0,0 +1,589 @@
import logging
from functools import lru_cache, partial
from typing import Iterable, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.glm4v.configuration_glm4v import Glm4vConfig, Glm4vVisionConfig
from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.schedule_batch import MultimodalDataItem
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.glm4 import Glm4Model
from sglang.srt.models.qwen2_5_vl import (
Qwen2_5_VisionBlock,
Qwen2_5_VLForConditionalGeneration,
)
from sglang.srt.utils import add_prefix
logger = logging.getLogger(__name__)
cached_get_processor = lru_cache(get_processor)
class Glm4vRMSNorm(RMSNorm):
def forward(self, x: torch.Tensor) -> torch.Tensor:
original_shape = x.shape
x_2d = x.contiguous().reshape(-1, original_shape[-1])
x_2d = super().forward(x_2d)
x = x_2d.reshape(original_shape)
return x
class Glm4vVisionMLP(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: int,
bias: bool = False,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
input_size=in_features,
output_sizes=[hidden_features] * 2,
bias=bias,
quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
)
self.down_proj = RowParallelLinear(
hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
prefix=add_prefix("down_proj", prefix),
)
self.act_fn = SiluAndMul()
def forward(self, x: torch.Tensor):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class Glm4vVisionBlock(Qwen2_5_VisionBlock):
def __init__(
self,
config: Glm4vVisionConfig,
norm_layer: Optional[nn.Module] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(
dim=config.hidden_size,
intermediate_dim=config.out_hidden_size,
num_heads=config.num_heads,
hidden_act=config.hidden_act,
norm_layer=norm_layer,
quant_config=quant_config,
prefix=prefix,
)
self.norm1 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm2 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp = Glm4vVisionMLP(
config.hidden_size,
config.out_hidden_size,
bias=False,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
class Glm4vVisionPatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 14,
temporal_patch_size: int = 2,
in_channels: int = 3,
hidden_size: int = 1536,
) -> None:
super().__init__()
self.patch_size = patch_size
self.temporal_patch_size = temporal_patch_size
self.hidden_size = hidden_size
self.in_channels = in_channels
kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = nn.Conv3d(
in_channels,
hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=True,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.view(
-1,
self.in_channels,
self.temporal_patch_size,
self.patch_size,
self.patch_size,
)
x = self.proj(x).view(-1, self.hidden_size)
return x
class Glm4vPatchMerger(nn.Module):
def __init__(
self,
d_model: int,
context_dim: int,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = d_model
self.proj = ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
bias=bias,
quant_config=quant_config,
prefix=add_prefix("proj", prefix),
gather_output=True,
)
self.post_projection_norm = nn.LayerNorm(self.hidden_size)
self.gate_up_proj = MergedColumnParallelLinear(
input_size=self.hidden_size,
output_sizes=[context_dim] * 2,
bias=bias,
quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
)
self.down_proj = RowParallelLinear(
context_dim,
self.hidden_size,
bias=bias,
quant_config=quant_config,
prefix=add_prefix("down_proj", prefix),
)
self.extra_activation_func = nn.GELU()
def forward(self, x: torch.Tensor):
x, _ = self.proj(x)
x = self.extra_activation_func(self.post_projection_norm(x))
gate_up, _ = self.gate_up_proj(x)
gate, up = gate_up.chunk(2, dim=-1)
x = F.silu(gate) * up
x, _ = self.down_proj(x)
return x
class Glm4vVisionEmbeddings(nn.Module):
def __init__(self, config: Glm4vVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer(
"position_ids",
torch.arange(self.num_positions).expand((1, -1)),
persistent=False,
)
def forward(
self, embeddings, lengths, image_shapes, h_coords, w_coords
) -> torch.Tensor:
pos_embed_weight = self.position_embedding.weight
hidden_size = pos_embed_weight.shape[1]
total_seq = h_coords.shape[0]
device = pos_embed_weight.device
# Move coordinates to correct device
h_coords, w_coords = h_coords.to(device), w_coords.to(device)
# Handle empty sequence case
if total_seq == 0:
adapted_pos_embed = torch.empty(
0, hidden_size, device=device, dtype=pos_embed_weight.dtype
)
else:
# Convert inputs to tensors if needed
if isinstance(lengths, list):
lengths = torch.tensor(lengths, device=device, dtype=torch.long)
if not isinstance(image_shapes, torch.Tensor):
image_shapes = torch.tensor(
image_shapes, device=device, dtype=torch.long
)
# Prepare 2D position embedding
orig_size_sq = pos_embed_weight.shape[0]
orig_size = int(orig_size_sq**0.5)
pos_embed_2d = (
pos_embed_weight.view(orig_size, orig_size, hidden_size)
.permute(2, 0, 1)
.unsqueeze(0)
.to(device=device, dtype=torch.float32)
)
# Calculate target dimensions for each patch
target_h = torch.cat(
[image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]
).to(device=device, dtype=torch.float32)
target_w = torch.cat(
[image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]
).to(device=device, dtype=torch.float32)
# Normalize coordinates to [-1, 1] range for grid_sample
h_coords = h_coords.to(device=device, dtype=torch.float32)
w_coords = w_coords.to(device=device, dtype=torch.float32)
norm_w = ((w_coords + 0.5) / target_w) * 2 - 1
norm_h = ((h_coords + 0.5) / target_h) * 2 - 1
# Create sampling grid
grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2)
# Perform bicubic interpolation
interpolated_embed_fp32 = F.grid_sample(
pos_embed_2d,
grid,
mode="bicubic",
align_corners=False,
padding_mode="border",
)
# Reshape and convert back to original dtype
adapted_pos_embed_fp32 = (
interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0)
)
adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to(
embeddings.device
)
# Add adapted position encoding to embeddings
embeddings = embeddings + adapted_pos_embed
return embeddings
class Glm4vVisionRotaryEmbedding(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
self.dim = dim
self.theta = theta
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._seq_len_cached = 0
self._freqs_cached = None
def update_freqs_cache(self, seqlen: int) -> None:
if seqlen > self._seq_len_cached:
seqlen *= 2
self._seq_len_cached = seqlen
self.inv_freq = 1.0 / (
self.theta
** (
torch.arange(
0,
self.dim,
2,
dtype=torch.float,
device=self.inv_freq.device,
)
/ self.dim
)
)
seq = torch.arange(
seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
)
freqs = torch.outer(seq, self.inv_freq)
self._freqs_cached = freqs
def forward(self, seqlen: int) -> torch.Tensor:
self.update_freqs_cache(seqlen)
return self._freqs_cached[:seqlen]
class Glm4vVisionModel(nn.Module):
def __init__(
self,
vision_config: Glm4vVisionConfig,
norm_eps: float = 1e-6,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
patch_size = vision_config.patch_size
temporal_patch_size = vision_config.temporal_patch_size
in_channels = vision_config.in_channels
depth = vision_config.depth
self.hidden_size = vision_config.hidden_size
self.num_heads = vision_config.num_heads
self.patch_size = vision_config.patch_size
self.spatial_merge_size = vision_config.spatial_merge_size
self.out_hidden_size = vision_config.out_hidden_size
self.patch_embed = Glm4vVisionPatchEmbed(
patch_size=patch_size,
temporal_patch_size=temporal_patch_size,
in_channels=in_channels,
hidden_size=self.hidden_size,
)
norm_layer = partial(Glm4vRMSNorm, eps=norm_eps)
head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2)
self.blocks = nn.ModuleList(
[
Glm4vVisionBlock(
config=vision_config,
norm_layer=norm_layer,
quant_config=quant_config,
prefix=add_prefix(f"blocks.{layer_idx}", prefix),
)
for layer_idx in range(depth)
]
)
self.merger = Glm4vPatchMerger(
d_model=vision_config.out_hidden_size,
context_dim=vision_config.intermediate_size,
quant_config=quant_config,
bias=False,
prefix=add_prefix("merger", prefix),
)
self.embeddings = Glm4vVisionEmbeddings(vision_config)
self.post_conv_layernorm = Glm4vRMSNorm(
vision_config.hidden_size, eps=vision_config.rms_norm_eps
)
self.downsample = nn.Conv2d(
in_channels=vision_config.hidden_size,
out_channels=vision_config.out_hidden_size,
kernel_size=vision_config.spatial_merge_size,
stride=vision_config.spatial_merge_size,
)
self.post_layernorm = Glm4vRMSNorm(
vision_config.hidden_size, eps=vision_config.rms_norm_eps
)
@property
def dtype(self) -> torch.dtype:
return self.patch_embed.proj.weight.dtype
@property
def device(self) -> torch.device:
return self.patch_embed.proj.weight.device
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
pos_ids = []
for t, h, w in grid_thw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
hpos_ids = (
hpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
.permute(0, 2, 1, 3)
.flatten()
)
wpos_ids = (
wpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
.permute(0, 2, 1, 3)
.flatten()
)
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb, pos_ids
def forward(self, x: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
# patchify
x = x.to(device=self.device, dtype=self.dtype)
x = self.patch_embed(x)
x = self.post_conv_layernorm(x)
# compute position embedding
rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw)
# compute cu_seqlens
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
x = self.embeddings(
x, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1]
)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
rotary_pos_emb_tuple = (emb.cos(), emb.sin())
# x.shape: (s, b, d) where b=1 for vision processing
# transformers
x = x.unsqueeze(1)
for blk in self.blocks:
x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=rotary_pos_emb_tuple)
# adapter
x = self.post_layernorm(x)
x = x.view(-1, self.spatial_merge_size, self.spatial_merge_size, x.shape[-1])
x = x.permute(0, 3, 1, 2)
x = self.downsample(x).view(-1, self.out_hidden_size)
x = self.merger(x)
return x
class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
def __init__(
self,
config: Glm4vConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
self.config = config
self.model = Glm4Model(
config,
quant_config,
prefix=add_prefix("model", prefix),
)
self.visual = Glm4vVisionModel(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
quant_config=quant_config,
prefix=add_prefix("visual", prefix),
)
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
pixel_values = torch.cat(
[item.feature.squeeze(0) for item in items], dim=0
).type(self.visual.dtype)
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
# For multi-image, pixel_values is [num_of_images, L, C] shape
# assert pixel_values.dim() == 2, pixel_values.dim()
assert image_grid_thw.dim() == 2, image_grid_thw.dim()
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
split_sizes = (
image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2
).tolist()
image_embeds = torch.split(image_embeds, split_sizes)
return torch.cat(image_embeds)
def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
pixel_values_videos = torch.cat(
[item.feature.squeeze(0) for item in items], dim=0
).type(self.visual.dtype)
video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
# For multi-video, pixel_values_videos is [num_of_videos, L, C] shape
# assert pixel_values_videos.dim() == 2, pixel_values_videos.dim()
assert video_grid_thw.dim() == 2, video_grid_thw.dim()
# reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames
temp_frames_hw = []
for t, h, w in video_grid_thw:
repeated_row = (
torch.tensor([1, h.item(), w.item()]).unsqueeze(0).repeat(t, 1)
)
temp_frames_hw.append(repeated_row)
flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0)
video_embeds = self.visual(
pixel_values_videos, grid_thw=flattened_video_grid_thw
)
split_sizes = (
video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2
).tolist()
video_embeds = torch.split(video_embeds, split_sizes)
return torch.cat(video_embeds)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".up_proj", 1),
(".gate_up_proj", ".gate_proj", 0),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "language_model." in name:
name = name.replace("language_model.", "")
if "model.visual." in name:
name = name.replace("model.visual.", "visual.")
if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if "visual" in name:
# adapt to VisionAttention
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
try:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
except KeyError:
print(params_dict.keys())
raise
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
EntryClass = [Glm4vForConditionalGeneration]

View File

@@ -0,0 +1,400 @@
import logging
from functools import lru_cache
from typing import Iterable, Optional, Tuple
import torch
import torch.nn as nn
from transformers.models.glm4v_moe.configuration_glm4v_moe import Glm4vMoeConfig
from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
parallel_state,
tensor_model_parallel_all_reduce,
)
from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.glm4_moe import Glm4MoeModel
from sglang.srt.models.glm4v import Glm4vForConditionalGeneration, Glm4vVisionModel
from sglang.srt.utils import add_prefix, is_cuda, log_info_on_rank0
_is_cuda = is_cuda()
logger = logging.getLogger(__name__)
cached_get_processor = lru_cache(get_processor)
class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
def __init__(
self,
config: Glm4vMoeConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
config.moe_layer_freq = 1
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.dp_size = get_local_attention_dp_size()
self.quant_config = quant_config
self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
self.num_fused_shared_experts = (
0
if global_server_args_dict["disable_shared_experts_fusion"]
else config.n_shared_experts
)
self.model = Glm4MoeModel(
config,
quant_config,
prefix=add_prefix("language_model", prefix),
)
self.visual = Glm4vVisionModel(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
quant_config=quant_config,
prefix=add_prefix("visual", prefix),
)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
def determine_num_fused_shared_experts(
self, architecture: str = "Glm4MoeForCausalLM"
):
self.num_fused_shared_experts = 0
if global_server_args_dict["disable_shared_experts_fusion"]:
return
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
disable_reason = None
if (
not _is_cuda
or torch.cuda.get_device_capability("cuda") < (8, 0)
or self.config.architectures[0] != architecture
or self.config.n_shared_experts != 1
):
disable_reason = "Only GLM-4.5 on NV-platform with capability >= 80 can use shared experts fusion optimization."
elif get_moe_expert_parallel_world_size() > 1:
disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism."
if disable_reason is not None:
global_server_args_dict["disable_shared_experts_fusion"] = True
self.num_fused_shared_experts = 0
log_info_on_rank0(
logger,
f"{disable_reason} Shared experts fusion optimization is disabled.",
)
return
self.num_fused_shared_experts = self.config.n_shared_experts
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
if is_nextn:
if hasattr(self.config, "num_nextn_predict_layers"):
num_nextn_layers = self.config.num_nextn_predict_layers
assert num_nextn_layers == 1, "Only 1 nextn layer is supported"
# compatible with old design
nextn_layer_id = (
0
if self.config.num_hidden_layers == 1
else self.config.num_hidden_layers
)
else:
raise ValueError("num_nextn_predict_layers is not in the config")
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
if self.num_fused_shared_experts > 0:
assert self.num_fused_shared_experts == 1
weights_list = list(weights)
weights_dict = dict(weights_list)
if self.quant_config is not None:
if self.quant_config.get_name() == "w8a8_int8":
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale",
"gate_proj.weight",
"gate_proj.weight_scale",
"up_proj.weight",
"up_proj.weight_scale",
]
elif (
self.quant_config.get_name() == "fp8"
or self.quant_config.get_name() == "blockwise_int8"
or self.quant_config.get_name() == "compressed_tensors"
):
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale",
"gate_proj.weight",
"gate_proj.weight_scale",
"up_proj.weight",
"up_proj.weight_scale",
]
elif self.quant_config.get_name() == "awq":
suffix_list = [
"down_proj.qweight",
"down_proj.qzeros",
"down_proj.scales",
"gate_proj.qweight",
"gate_proj.qzeros",
"gate_proj.scales",
"up_proj.qweight",
"up_proj.qzeros",
"up_proj.scales",
]
elif self.quant_config.get_name() == "modelopt_fp4":
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale",
"down_proj.weight_scale_2",
"down_proj.input_scale",
"gate_proj.weight",
"gate_proj.weight_scale",
"gate_proj.weight_scale_2",
"gate_proj.input_scale",
"up_proj.weight",
"up_proj.weight_scale",
"up_proj.weight_scale_2",
"up_proj.input_scale",
]
else:
raise ValueError(
f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
)
else:
suffix_list = [
"down_proj.weight",
"gate_proj.weight",
"up_proj.weight",
]
names_to_remove = []
moe_layers = (
range(
self.config.first_k_dense_replace,
self.config.num_hidden_layers,
self.config.moe_layer_freq,
)
if not is_nextn
else [nextn_layer_id]
)
for moe_layer in moe_layers:
for suffix in suffix_list:
shared_expert_weight_name = (
f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
)
# online fp8 quantization does not load weight_scale
if shared_expert_weight_name not in weights_dict:
continue
weights_list.append(
(
f"model.layers.{moe_layer}."
f"mlp.experts."
f"{self.config.n_routed_experts + 0}"
f".{suffix}",
weights_dict[shared_expert_weight_name],
)
)
names_to_remove += [shared_expert_weight_name]
weights = [w for w in weights_list if w[0] not in names_to_remove]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
)
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
self.config.q_lora_rank is not None
)
cached_a_proj = {} if fuse_qkv_a_proj else None
if is_nextn:
nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
nextn_spec_weight_names = [
"shared_head.norm",
"eh_proj",
"enorm",
"hnorm",
]
params_dict = dict(self.named_parameters())
weight_names = []
for name, loaded_weight in weights:
weight_names.append(name)
if not is_nextn:
if hasattr(self.config, "num_nextn_predict_layers"):
num_nextn_layers = self.config.num_nextn_predict_layers
if num_nextn_layers > 0 and name.startswith("model.layers"):
name_list = name.split(".")
if (
len(name_list) >= 3
and int(name_list[2]) >= self.config.num_hidden_layers
):
continue
else:
if not name.startswith(nextn_layer_prefix):
continue
# Use shared head and embed weights from target model
if "shared_head.head" in name or "embed_tokens" in name:
continue
is_decoder = True
# For nextn specific weights
for weight_name in nextn_spec_weight_names:
if weight_name in name:
name = name.replace(nextn_layer_prefix, "model")
is_decoder = False
break
# For decoder layer weights
if is_decoder:
name = name.replace(nextn_layer_prefix, "model.decoder")
if "language_model." in name:
name = name.replace("language_model.", "")
if "model.visual." in name:
name = name.replace("model.visual.", "visual.")
if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
)
break
else:
if "visual" in name:
# adapt to VisionAttention
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if fuse_qkv_a_proj and (
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
):
cached_a_proj[name] = loaded_weight
q_a_proj_name = (
name
if "q_a_proj" in name
else name.replace("kv_a_proj_with_mqa", "q_a_proj")
)
kv_a_proj_name = (
name
if "kv_a_proj_with_mqa" in name
else name.replace("q_a_proj", "kv_a_proj_with_mqa")
)
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
if (
q_a_proj_name in cached_a_proj
and kv_a_proj_name in cached_a_proj
):
q_a_proj_weight = cached_a_proj[q_a_proj_name]
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
fused_weight = torch.cat(
[q_a_proj_weight, kv_a_proj_weight], dim=0
)
param_name = (
name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
if "q_a_proj" in name
else name.replace(
"kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
)
)
param = params_dict[param_name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, fused_weight)
cached_a_proj.pop(q_a_proj_name)
cached_a_proj.pop(kv_a_proj_name)
else:
if (
"k_scale" in name or "v_scale" in name
) and name not in params_dict:
# modelopt attn kv scale is named differently
if any(scale in name for scale in ["k_scale", "v_scale"]):
name = name.replace("_proj", "attn_mqa")
else:
logger.warning(
f"Unknown scale found in checkpoint: {name}"
)
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
EntryClass = [Glm4vMoeForConditionalGeneration]

View File

@@ -22,13 +22,19 @@ class BaseMultiModalProcessorOutput:
input_text: str
# frames loaded from image, in given order
images: Optional[list[Union[Image.Image, dict]]] = None
images: Optional[list[Union[Image.Image, dict]]] = dataclasses.field(
default_factory=list
)
# videos
videos: Optional[list[Union[torch.Tensor, dict]]] = None
videos: Optional[list[Union[torch.Tensor, dict]]] = dataclasses.field(
default_factory=list
)
# audios
audios: Optional[list[Union[np.ndarray, dict]]] = None
audios: Optional[list[Union[np.ndarray, dict]]] = dataclasses.field(
default_factory=list
)
def organize_results(self) -> List[Tuple[Modality, Any]]:
"""

View File

@@ -0,0 +1,132 @@
import re
from typing import List, Union
from decord import VideoReader
from transformers.video_utils import VideoMetadata
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.models.glm4v import Glm4vForConditionalGeneration
from sglang.srt.models.glm4v_moe import Glm4vMoeForConditionalGeneration
from sglang.srt.multimodal.processors.base_processor import (
BaseMultimodalProcessor as SGLangBaseProcessor,
)
from sglang.srt.multimodal.processors.base_processor import (
BaseMultiModalProcessorOutput,
MultimodalSpecialTokens,
)
class Glm4vImageProcessor(SGLangBaseProcessor):
models = [Glm4vForConditionalGeneration, Glm4vMoeForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
# GLM-4.1V and GLM-4.5V specific tokens
self.IMAGE_TOKEN = "<|image|>"
self.VIDEO_TOKEN = "<|video|>"
self.IMAGE_START_TOKEN = "<|begin_of_image|>"
self.IMAGE_END_TOKEN = "<|end_of_image|>"
self.VIDEO_START_TOKEN = "<|begin_of_video|>"
self.VIDEO_END_TOKEN = "<|end_of_video|>"
# Token IDs
self.IM_TOKEN_ID = hf_config.image_token_id
self.VIDEO_TOKEN_ID = hf_config.video_token_id
self.IMAGE_START_TOKEN_ID = hf_config.image_start_token_id
self.IMAGE_END_TOKEN_ID = hf_config.image_end_token_id
self.VIDEO_START_TOKEN_ID = hf_config.video_start_token_id
self.VIDEO_END_TOKEN_ID = hf_config.video_end_token_id
# Vision config
self.IMAGE_FACTOR = 28
self.MIN_PIXELS = 112 * 112
self.MAX_PIXELS = 30000 * 28 * 28 * 2
self.mm_tokens = MultimodalSpecialTokens(
image_token=self.IMAGE_TOKEN,
image_token_id=self.IM_TOKEN_ID,
video_token=self.VIDEO_TOKEN,
# Note: For GLM4v videos, it uses the video token before tokenization but uses image token after tokenization
video_token_id=self.IM_TOKEN_ID,
).build(_processor)
# adapted from https://github.com/huggingface/transformers/blob/369c99d0cea403b77bd0aef818527106453fd9fc/src/transformers/video_utils.py#L312
async def preprocess_video(self, vr: VideoReader):
"""
Preprocess video using VideoReader from Decord backend.
Args:
vr (VideoReader): VideoReader object from decord
Returns:
tuple: A tuple containing processed frames and metadata
"""
video_fps = vr.get_avg_fps()
total_num_frames = len(vr)
duration = total_num_frames / video_fps if video_fps else 0
metadata = VideoMetadata(
total_num_frames=int(total_num_frames),
fps=float(video_fps),
duration=float(duration),
video_backend="decord",
)
# Extract all frames
indices = list(range(total_num_frames))
frames = vr.get_batch(indices).asnumpy()
metadata.frames_indices = indices
return frames, metadata
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_text,
request_obj,
*args,
**kwargs,
):
base_output = self.load_mm_data(
prompt=input_text,
image_data=image_data,
video_data=request_obj.video_data,
multimodal_tokens=self.mm_tokens,
)
video_metadata = None
if base_output.videos:
videos_processed = [
await self.preprocess_video(video) for video in base_output.videos
]
base_output.videos, video_metadata = map(list, zip(*videos_processed))
# transformer requires the video inputs to be under this format
base_output.videos = [base_output.videos]
video_metadata = [video_metadata]
mm_items, input_ids, ret = self.process_and_combine_mm_data(
base_output, self.mm_tokens, video_metadata=video_metadata
)
input_ids = input_ids.flatten()
mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index_glm4v(
input_ids=input_ids.unsqueeze(0),
hf_config=self.hf_config,
image_grid_thw=getattr(ret, "image_grid_thw", None),
video_grid_thw=getattr(ret, "video_grid_thw", None),
attention_mask=getattr(ret, "attention_mask", None),
)
mrope_positions = mrope_positions.squeeze(1)
mm_inputs = {
"input_ids": input_ids.tolist(),
"mm_items": mm_items,
"im_token_id": self.mm_tokens.image_token_id,
"video_token_id": self.mm_tokens.video_token_id,
"mrope_positions": mrope_positions,
"mrope_position_delta": mrope_position_delta,
}
return mm_inputs

View File

@@ -815,7 +815,7 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
vr = VideoReader(tmp_file.name, ctx=ctx)
elif video_file.startswith("data:"):
_, encoded = video_file.split(",", 1)
video_bytes = base64.b64decode(encoded)
video_bytes = pybase64.b64decode(encoded)
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
tmp_file.write(video_bytes)
tmp_file.close()
@@ -823,7 +823,7 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
elif os.path.isfile(video_file):
vr = VideoReader(video_file, ctx=ctx)
else:
video_bytes = base64.b64decode(video_file)
video_bytes = pybase64.b64decode(video_file)
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
tmp_file.write(video_bytes)
tmp_file.close()

View File

@@ -948,5 +948,6 @@ class TestOpenAIPythonicFunctionCalling(CustomTestCase):
# def test_function_calling_multiturn(self):
# self._test_function_calling_multiturn()
if __name__ == "__main__":
unittest.main()

View File

@@ -497,6 +497,17 @@ class TestEBNFGeneration(unittest.TestCase):
},
),
),
Tool(
type="function",
function=Function(
name="empty_param_func",
description="Function with empty parameters",
parameters={
"properties": {},
"required": [],
},
),
),
]
self.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
@@ -630,16 +641,21 @@ class TestEBNFGeneration(unittest.TestCase):
self.assertIsNotNone(ebnf)
# Check that the EBNF contains expected patterns for XML format
self.assertIn('"<tool_call>" function_call "</tool_call>"', ebnf)
self.assertIn('"get_weather" "\\n" arguments_get_weather', ebnf)
self.assertIn('"get_weather" "\\n" ( arguments_get_weather "\\n" )?', ebnf)
self.assertIn(
'"<arg_key>location</arg_key>" "\\n" "<arg_value>" xml_text "</arg_value>" ( "\\n" ( "<arg_key>unit</arg_key>" "\\n" "<arg_value>" ("celsius" | "fahrenheit") "</arg_value>" ) )?',
ebnf,
)
self.assertIn('"search" "\\n" arguments_search', ebnf)
self.assertIn('"search" "\\n" ( arguments_search "\\n" )?', ebnf)
self.assertIn(
'"<arg_key>query</arg_key>" "\\n" "<arg_value>" xml_text "</arg_value>"',
ebnf,
)
self.assertIn(
'"empty_param_func" "\\n" ( arguments_empty_param_func "\\n" )?', ebnf
)
self.assertIn('arguments_empty_param_func ::= ""', ebnf)
# Validate that the EBNF can be compiled by GrammarCompiler
try:
ctx = self.grammar_compiler.compile_grammar(ebnf)

View File

@@ -60,6 +60,86 @@ class TestTemplateContentFormatDetection(CustomTestCase):
result = detect_jinja_template_content_format("")
self.assertEqual(result, "string")
def test_detect_msg_content_pattern(self):
"""Test detection of template with msg.content pattern (should be 'openai' format)."""
msg_content_pattern = """
[gMASK]<sop>
{%- for msg in messages %}
{%- if msg.role == 'system' %}
<|system|>
{{ msg.content }}
{%- elif msg.role == 'user' %}
<|user|>{{ '\n' }}
{%- if msg.content is string %}
{{ msg.content }}
{%- else %}
{%- for item in msg.content %}
{%- if item.type == 'video' or 'video' in item %}
<|begin_of_video|><|video|><|end_of_video|>
{%- elif item.type == 'image' or 'image' in item %}
<|begin_of_image|><|image|><|end_of_image|>
{%- elif item.type == 'text' %}
{{ item.text }}
{%- endif %}
{%- endfor %}
{%- endif %}
{%- elif msg.role == 'assistant' %}
{%- if msg.metadata %}
<|assistant|>{{ msg.metadata }}
{{ msg.content }}
{%- else %}
<|assistant|>
{{ msg.content }}
{%- endif %}
{%- endif %}
{%- endfor %}
{% if add_generation_prompt %}<|assistant|>
{% endif %}
"""
result = detect_jinja_template_content_format(msg_content_pattern)
self.assertEqual(result, "openai")
def test_detect_m_content_pattern(self):
"""Test detection of template with m.content pattern (should be 'openai' format)."""
msg_content_pattern = """
[gMASK]<sop>
{%- for m in messages %}
{%- if m.role == 'system' %}
<|system|>
{{ m.content }}
{%- elif m.role == 'user' %}
<|user|>{{ '\n' }}
{%- if m.content is string %}
{{ m.content }}
{%- else %}
{%- for item in m.content %}
{%- if item.type == 'video' or 'video' in item %}
<|begin_of_video|><|video|><|end_of_video|>
{%- elif item.type == 'image' or 'image' in item %}
<|begin_of_image|><|image|><|end_of_image|>
{%- elif item.type == 'text' %}
{{ item.text }}
{%- endif %}
{%- endfor %}
{%- endif %}
{%- elif m.role == 'assistant' %}
{%- if m.metadata %}
<|assistant|>{{ m.metadata }}
{{ m.content }}
{%- else %}
<|assistant|>
{{ m.content }}
{%- endif %}
{%- endif %}
{%- endfor %}
{% if add_generation_prompt %}<|assistant|>
{% endif %}
"""
result = detect_jinja_template_content_format(msg_content_pattern)
self.assertEqual(result, "openai")
def test_process_content_openai_format(self):
"""Test content processing for openai format."""
msg_dict = {

View File

@@ -348,6 +348,33 @@ class TestVILAServer(TestOpenAIVisionServer):
cls.base_url += "/v1"
# Skip for ci test
# class TestGLM41VServer(TestOpenAIVisionServer):
# @classmethod
# def setUpClass(cls):
# cls.model = "zai-org/GLM-4.1V-9B-Thinking"
# cls.base_url = DEFAULT_URL_FOR_TEST
# cls.api_key = "sk-123456"
# cls.process = popen_launch_server(
# cls.model,
# cls.base_url,
# timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
# other_args=[
# "--trust-remote-code",
# "--mem-fraction-static",
# "0.68",
# "--cuda-graph-max-bs",
# "4",
# "--reasoning-parser",
# "glm45",
# ],
# )
# cls.base_url += "/v1"
# def test_video_chat_completion(self):
# self._test_video_chat_completion()
if __name__ == "__main__":
del TestOpenAIVisionServer
unittest.main()

View File

@@ -96,8 +96,13 @@ class TestOpenAIVisionServer(CustomTestCase):
), f"text: {text}, should contain cab, taxi, SUV, vehicle or car"
# MiniCPMO fails to recognize `iron`, but `hanging`
assert (
"iron" in text or "hang" in text or "cloth" in text or "holding" in text
), f"text: {text}, should contain iron, hang, cloth or holding"
"iron" in text
or "hang" in text
or "cloth" in text
or "coat" in text
or "holding" in text
or "outfit" in text
), f"text: {text}, should contain iron, hang, cloth, coat or holding or outfit"
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
@@ -193,11 +198,15 @@ class TestOpenAIVisionServer(CustomTestCase):
print(f"Multi images response:\n{text}")
print("-" * 30)
assert (
"man" in text or "cab" in text or "SUV" in text or "taxi" in text
), f"text: {text}, should contain man, cab, SUV or taxi"
"man" in text
or "cab" in text
or "SUV" in text
or "taxi" in text
or "car" in text
), f"text: {text}, should contain man, cab, SUV, taxi or car"
assert (
"logo" in text or '"S"' in text or "SG" in text
), f"text: {text}, should contain logo, S or SG"
"logo" in text or '"S"' in text or "SG" in text or "graphic" in text
), f"text: {text}, should contain logo, S or SG or graphic"
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
@@ -320,11 +329,12 @@ class TestOpenAIVisionServer(CustomTestCase):
or "individual" in video_response
or "speaker" in video_response
or "Steve" in video_response
or "hand" in video_response
), f"""
====================== video_response =====================
{video_response}
===========================================================
should contain 'man' or 'person' or 'individual' or 'speaker'
should contain 'man' or 'person' or 'individual' or 'speaker' or 'hand'
"""
assert (
"present" in video_response
@@ -375,7 +385,8 @@ class TestOpenAIVisionServer(CustomTestCase):
or "person" in video_response
or "individual" in video_response
or "speaker" in video_response
), f"video_response: {video_response}, should either have 'man' in video_response, or 'person' in video_response, or 'individual' in video_response or 'speaker' in video_response"
or "hand" in video_response
), f"video_response: {video_response}, should either have 'man' in video_response, or 'person' in video_response, or 'individual' in video_response, or 'speaker' in video_response or 'hand' in video_response"
assert (
"present" in video_response
or "examine" in video_response