Support glm4.1v and glm4.5v (#8798)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com> Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com> Co-authored-by: Xinyuan Tong <justinning0323@outlook.com> Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com> Co-authored-by: Xinyuan Tong <xinyuantong.cs@gmail.com> Co-authored-by: zRzRzRzRzRzRzR <2448370773@qq.com> Co-authored-by: Minglei Zhu <mingleizhu1122@gmail.com> Co-authored-by: Chang Su <csu272@usc.edu>
This commit is contained in:
@@ -27,6 +27,18 @@ python -m sglang.launch_server --model-path microsoft/Phi-4-multimodal-instruct
|
|||||||
python -m benchmark/mmmu/bench_sglang.py --concurrency 8 --lora-path vision
|
python -m benchmark/mmmu/bench_sglang.py --concurrency 8 --lora-path vision
|
||||||
```
|
```
|
||||||
|
|
||||||
|
You can use `--response-answer-regex` to specify how to extract the answer from the response string. E.g.,
|
||||||
|
```
|
||||||
|
python3 -m sglang.launch_server --model-path zai-org/GLM-4.1V-9B-Thinking --reasoning-parser glm45
|
||||||
|
|
||||||
|
python3 bench_sglang.py --response-answer-regex "<\|begin_of_box\|>(.*)<\|end_of_box\|>" --concurrency 64
|
||||||
|
```
|
||||||
|
|
||||||
|
You can use `--extra-request-body` to specify additional OpenAI request parameters. E.g.,
|
||||||
|
```
|
||||||
|
python3 bench_sglang.py --extra-request-body '{"max_new_tokens": 128, "temperature": 0.01}'
|
||||||
|
```
|
||||||
|
|
||||||
### Evaluate hf
|
### Evaluate hf
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ The eval output will be logged
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import re
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
@@ -145,7 +146,17 @@ async def eval_mmmu(args) -> None:
|
|||||||
_, response = await process_sample(
|
_, response = await process_sample(
|
||||||
client, sample, sampling_params, lora_path
|
client, sample, sampling_params, lora_path
|
||||||
)
|
)
|
||||||
process_result(response, sample, answer_dict, out_samples)
|
answer = (
|
||||||
|
re.search(args.response_answer_regex, response)
|
||||||
|
if response is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
process_result(
|
||||||
|
answer.group(1) if answer else response,
|
||||||
|
sample,
|
||||||
|
answer_dict,
|
||||||
|
out_samples,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
semaphore = asyncio.Semaphore(args.concurrency)
|
semaphore = asyncio.Semaphore(args.concurrency)
|
||||||
tasks = [
|
tasks = [
|
||||||
@@ -157,7 +168,17 @@ async def eval_mmmu(args) -> None:
|
|||||||
|
|
||||||
for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks)):
|
for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks)):
|
||||||
sample, response = await coro
|
sample, response = await coro
|
||||||
process_result(response, sample, answer_dict, out_samples)
|
answer = (
|
||||||
|
re.search(args.response_answer_regex, response)
|
||||||
|
if response is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
process_result(
|
||||||
|
answer.group(1) if answer else response,
|
||||||
|
sample,
|
||||||
|
answer_dict,
|
||||||
|
out_samples,
|
||||||
|
)
|
||||||
|
|
||||||
if args.profile:
|
if args.profile:
|
||||||
print("Stopping profiler...")
|
print("Stopping profiler...")
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ class EvalArgs:
|
|||||||
profile: bool = False
|
profile: bool = False
|
||||||
profile_number: int = 5
|
profile_number: int = 5
|
||||||
concurrency: int = 1
|
concurrency: int = 1
|
||||||
|
response_answer_regex: str = "(.*)"
|
||||||
lora_path: Optional[str] = None
|
lora_path: Optional[str] = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -92,6 +93,12 @@ class EvalArgs:
|
|||||||
default=EvalArgs.concurrency,
|
default=EvalArgs.concurrency,
|
||||||
help="Number of concurrent requests to make during evaluation. Default is 1, which means no concurrency.",
|
help="Number of concurrent requests to make during evaluation. Default is 1, which means no concurrency.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--response-answer-regex",
|
||||||
|
type=str,
|
||||||
|
default=EvalArgs.response_answer_regex,
|
||||||
|
help="Specific regex to capture the answer from the response, string",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lora-path",
|
"--lora-path",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
@@ -39,3 +39,4 @@ in the GitHub search bar.
|
|||||||
| **Mistral-Small-3.1-24B** | `mistralai/Mistral-Small-3.1-24B-Instruct-2503` | `mistral` | Mistral 3.1 is a multimodal model that can generate text from text or images input. It also supports tool calling and structured output. |
|
| **Mistral-Small-3.1-24B** | `mistralai/Mistral-Small-3.1-24B-Instruct-2503` | `mistral` | Mistral 3.1 is a multimodal model that can generate text from text or images input. It also supports tool calling and structured output. |
|
||||||
| **Phi-4-multimodal-instruct** | `microsoft/Phi-4-multimodal-instruct` | `phi-4-mm` | Phi-4-multimodal-instruct is the multimodal variant of the Phi-4-mini model, enhanced with LoRA for improved multimodal capabilities. It supports text, vision and audio modalities in SGLang. |
|
| **Phi-4-multimodal-instruct** | `microsoft/Phi-4-multimodal-instruct` | `phi-4-mm` | Phi-4-multimodal-instruct is the multimodal variant of the Phi-4-mini model, enhanced with LoRA for improved multimodal capabilities. It supports text, vision and audio modalities in SGLang. |
|
||||||
| **MiMo-VL** (7B) | `XiaomiMiMo/MiMo-VL-7B-RL` | `mimo-vl` | Xiaomi's compact yet powerful vision-language model featuring a native resolution ViT encoder for fine-grained visual details, an MLP projector for cross-modal alignment, and the MiMo-7B language model optimized for complex reasoning tasks. |
|
| **MiMo-VL** (7B) | `XiaomiMiMo/MiMo-VL-7B-RL` | `mimo-vl` | Xiaomi's compact yet powerful vision-language model featuring a native resolution ViT encoder for fine-grained visual details, an MLP projector for cross-modal alignment, and the MiMo-7B language model optimized for complex reasoning tasks. |
|
||||||
|
| **GLM-4.5V** (106B) / **GLM-4.1V**(9B) | `zai-org/GLM-4.5V` | `glm-4v` | GLM-4.5V and GLM-4.1V-Thinking: Towards Versatile Multimodal Reasoning with Scalable Reinforcement Learning |
|
||||||
|
|||||||
@@ -505,6 +505,22 @@ register_chat_template(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Reference: https://huggingface.co/docs/transformers/main/model_doc/glm4_v#usage-example
|
||||||
|
register_chat_template(
|
||||||
|
ChatTemplate(
|
||||||
|
name="glm-4v",
|
||||||
|
default_system_prompt=None,
|
||||||
|
role_prefix_and_suffix={
|
||||||
|
"system": ("<|system|>\n", "\n"),
|
||||||
|
"user": ("<|user|>\n", "\n"),
|
||||||
|
"assistant": ("<|assistant|>\n", "\n"),
|
||||||
|
},
|
||||||
|
style=ChatTemplateStyle.PLAIN,
|
||||||
|
stop_str=["<|user|>", "<|endoftext|>", "<|observation|>"],
|
||||||
|
image_token="<|image|>",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_chat_template_matching_function
|
@register_chat_template_matching_function
|
||||||
def match_deepseek(model_path: str):
|
def match_deepseek(model_path: str):
|
||||||
@@ -562,6 +578,8 @@ def match_chat_ml(model_path: str):
|
|||||||
return "chatml"
|
return "chatml"
|
||||||
if re.search(r"qwen.*vl", model_path, re.IGNORECASE):
|
if re.search(r"qwen.*vl", model_path, re.IGNORECASE):
|
||||||
return "qwen2-vl"
|
return "qwen2-vl"
|
||||||
|
if re.search(r"glm[-_]?4(\.\d+)?v", model_path, re.IGNORECASE):
|
||||||
|
return "glm-4v"
|
||||||
if re.search(r"qwen.*(chat|instruct)", model_path, re.IGNORECASE) and not re.search(
|
if re.search(r"qwen.*(chat|instruct)", model_path, re.IGNORECASE) and not re.search(
|
||||||
r"llava", model_path, re.IGNORECASE
|
r"llava", model_path, re.IGNORECASE
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -659,6 +659,8 @@ multimodal_model_archs = [
|
|||||||
"DeepseekVL2ForCausalLM",
|
"DeepseekVL2ForCausalLM",
|
||||||
"Gemma3ForConditionalGeneration",
|
"Gemma3ForConditionalGeneration",
|
||||||
"Gemma3nForConditionalGeneration",
|
"Gemma3nForConditionalGeneration",
|
||||||
|
"Glm4vForConditionalGeneration",
|
||||||
|
"Glm4vMoeForConditionalGeneration",
|
||||||
"Grok1VForCausalLM",
|
"Grok1VForCausalLM",
|
||||||
"Grok1AForCausalLM",
|
"Grok1AForCausalLM",
|
||||||
"LlavaLlamaForCausalLM",
|
"LlavaLlamaForCausalLM",
|
||||||
|
|||||||
@@ -316,6 +316,7 @@ class EBNFComposer:
|
|||||||
|
|
||||||
combined_args = "".join(rule_parts)
|
combined_args = "".join(rule_parts)
|
||||||
arguments_rule = args_template.format(arg_rules=combined_args)
|
arguments_rule = args_template.format(arg_rules=combined_args)
|
||||||
|
arguments_rule = arguments_rule or '""'
|
||||||
|
|
||||||
# Add the function call rule and its arguments rule
|
# Add the function call rule and its arguments rule
|
||||||
ebnf_lines.append(
|
ebnf_lines.append(
|
||||||
|
|||||||
@@ -158,7 +158,7 @@ class Glm4MoeDetector(BaseFormatDetector):
|
|||||||
individual_call_end_token=self.eot_token,
|
individual_call_end_token=self.eot_token,
|
||||||
tool_call_separator="\\n",
|
tool_call_separator="\\n",
|
||||||
function_format="xml",
|
function_format="xml",
|
||||||
call_rule_fmt='"{name}" "\\n" {arguments_rule} "\\n"',
|
call_rule_fmt='"{name}" "\\n" ( {arguments_rule} "\\n" )?',
|
||||||
key_value_rule_fmt='"<arg_key>{key}</arg_key>" "\\n" "<arg_value>" {valrule} "</arg_value>"',
|
key_value_rule_fmt='"<arg_key>{key}</arg_key>" "\\n" "<arg_value>" {valrule} "</arg_value>"',
|
||||||
key_value_separator="\\n",
|
key_value_separator="\\n",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -102,6 +102,12 @@ def detect_jinja_template_content_format(chat_template: str) -> str:
|
|||||||
if _is_var_or_elems_access(loop_iter, "message", "content"):
|
if _is_var_or_elems_access(loop_iter, "message", "content"):
|
||||||
return "openai" # Found content iteration → openai format
|
return "openai" # Found content iteration → openai format
|
||||||
|
|
||||||
|
# Also check for patterns like: {%- for item in msg.content -%} or {%- for item in m.content -%}
|
||||||
|
if _is_var_or_elems_access(
|
||||||
|
loop_iter, "msg", "content"
|
||||||
|
) or _is_var_or_elems_access(loop_iter, "m", "content"):
|
||||||
|
return "openai" # Found content iteration → openai format (glm4v)
|
||||||
|
|
||||||
return "string" # No content loops found → string format
|
return "string" # No content loops found → string format
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Error when parsing AST of Jinja template: {e}")
|
logger.debug(f"Error when parsing AST of Jinja template: {e}")
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/refs/tags/v0.6.6.post1/vllm/model_executor/layers/rotary_embedding.py
|
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/refs/tags/v0.6.6.post1/vllm/model_executor/layers/rotary_embedding.py
|
||||||
|
|
||||||
"""Rotary Positional Embeddings."""
|
"""Rotary Positional Embeddings."""
|
||||||
|
import itertools
|
||||||
import math
|
import math
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
@@ -946,7 +947,37 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
|
|
||||||
self.mrope_section = mrope_section
|
self.mrope_section = mrope_section
|
||||||
if self.mrope_section:
|
if self.mrope_section:
|
||||||
assert sum(self.mrope_section) == rotary_dim // 2
|
expected_sum = rotary_dim // 2
|
||||||
|
actual_sum = sum(self.mrope_section)
|
||||||
|
if actual_sum != expected_sum:
|
||||||
|
print(
|
||||||
|
f"MRoPE section sum mismatch: expected {expected_sum}, got {actual_sum}. "
|
||||||
|
f"Adjusting mrope_section to match rotary_dim // 2 = {expected_sum}"
|
||||||
|
)
|
||||||
|
# Auto-correct by scaling the mrope_section proportionally
|
||||||
|
if actual_sum > 0:
|
||||||
|
scale_factor = expected_sum / actual_sum
|
||||||
|
self.mrope_section = [
|
||||||
|
max(1, int(section * scale_factor))
|
||||||
|
for section in self.mrope_section
|
||||||
|
]
|
||||||
|
# Ensure the sum exactly matches by adjusting the last element
|
||||||
|
current_sum = sum(self.mrope_section)
|
||||||
|
if current_sum != expected_sum:
|
||||||
|
self.mrope_section[-1] += expected_sum - current_sum
|
||||||
|
else:
|
||||||
|
# If all sections are 0, create a default distribution
|
||||||
|
self.mrope_section = [
|
||||||
|
expected_sum // len(self.mrope_section)
|
||||||
|
] * len(self.mrope_section)
|
||||||
|
# Handle remainder
|
||||||
|
remainder = expected_sum % len(self.mrope_section)
|
||||||
|
for i in range(remainder):
|
||||||
|
self.mrope_section[i] += 1
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})"
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -1153,6 +1184,204 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
mrope_position_deltas = max_position_ids + 1 - s
|
mrope_position_deltas = max_position_ids + 1 - s
|
||||||
return position_ids, mrope_position_deltas
|
return position_ids, mrope_position_deltas
|
||||||
|
|
||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/3779eb8c81449b924a23457fc77e45a0e6171178/vllm/model_executor/layers/rotary_embedding.py#L1120
|
||||||
|
@staticmethod
|
||||||
|
def get_rope_index_glm4v(
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
hf_config: Any,
|
||||||
|
image_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||||
|
video_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Get mrope input positions and delta value for GLM4V."""
|
||||||
|
image_token_id = hf_config.image_token_id
|
||||||
|
video_start_token_id = hf_config.video_start_token_id
|
||||||
|
video_end_token_id = hf_config.video_end_token_id
|
||||||
|
spatial_merge_size = hf_config.vision_config.spatial_merge_size
|
||||||
|
|
||||||
|
mrope_position_deltas = []
|
||||||
|
if input_ids is not None and (
|
||||||
|
image_grid_thw is not None or video_grid_thw is not None
|
||||||
|
):
|
||||||
|
total_input_ids = input_ids
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = torch.ones_like(total_input_ids)
|
||||||
|
position_ids = torch.ones(
|
||||||
|
3,
|
||||||
|
input_ids.shape[0],
|
||||||
|
input_ids.shape[1],
|
||||||
|
dtype=input_ids.dtype,
|
||||||
|
device=input_ids.device,
|
||||||
|
)
|
||||||
|
image_index, video_index = 0, 0
|
||||||
|
video_group_index = 0
|
||||||
|
attention_mask = attention_mask.to(total_input_ids.device)
|
||||||
|
for i, input_ids in enumerate(total_input_ids):
|
||||||
|
input_ids = input_ids[attention_mask[i] == 1]
|
||||||
|
input_tokens = input_ids.tolist()
|
||||||
|
|
||||||
|
input_token_type = []
|
||||||
|
video_check_flg = False
|
||||||
|
for token in input_tokens:
|
||||||
|
if token == video_start_token_id:
|
||||||
|
video_check_flg = True
|
||||||
|
elif token == video_end_token_id:
|
||||||
|
video_check_flg = False
|
||||||
|
|
||||||
|
if token == image_token_id and not video_check_flg:
|
||||||
|
input_token_type.append("image")
|
||||||
|
elif token == image_token_id and video_check_flg:
|
||||||
|
input_token_type.append("video")
|
||||||
|
else:
|
||||||
|
input_token_type.append("text")
|
||||||
|
|
||||||
|
input_type_group = []
|
||||||
|
for key, group in itertools.groupby(
|
||||||
|
enumerate(input_token_type), lambda x: x[1]
|
||||||
|
):
|
||||||
|
group = list(group)
|
||||||
|
start_index = group[0][0]
|
||||||
|
end_index = group[-1][0] + 1
|
||||||
|
input_type_group.append((key, start_index, end_index))
|
||||||
|
|
||||||
|
llm_pos_ids_list = []
|
||||||
|
video_frame_num = 1
|
||||||
|
for modality_type, start_idx, end_idx in input_type_group:
|
||||||
|
st_idx = (
|
||||||
|
llm_pos_ids_list[-1].max() + 1
|
||||||
|
if len(llm_pos_ids_list) > 0
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
if modality_type == "image":
|
||||||
|
t, h, w = (
|
||||||
|
image_grid_thw[image_index][0],
|
||||||
|
image_grid_thw[image_index][1],
|
||||||
|
image_grid_thw[image_index][2],
|
||||||
|
)
|
||||||
|
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||||
|
t.item(),
|
||||||
|
h.item() // spatial_merge_size,
|
||||||
|
w.item() // spatial_merge_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
t_index = (
|
||||||
|
torch.arange(llm_grid_t)
|
||||||
|
.view(-1, 1)
|
||||||
|
.expand(-1, llm_grid_h * llm_grid_w)
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
h_index = (
|
||||||
|
torch.arange(llm_grid_h)
|
||||||
|
.view(1, -1, 1)
|
||||||
|
.expand(llm_grid_t, -1, llm_grid_w)
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
w_index = (
|
||||||
|
torch.arange(llm_grid_w)
|
||||||
|
.view(1, 1, -1)
|
||||||
|
.expand(llm_grid_t, llm_grid_h, -1)
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.stack([t_index, h_index, w_index]) + st_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
image_index += 1
|
||||||
|
video_frame_num = 1
|
||||||
|
|
||||||
|
elif modality_type == "video":
|
||||||
|
t, h, w = (
|
||||||
|
video_frame_num,
|
||||||
|
video_grid_thw[video_index][1],
|
||||||
|
video_grid_thw[video_index][2],
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||||
|
t,
|
||||||
|
h.item() // spatial_merge_size,
|
||||||
|
w.item() // spatial_merge_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
for t_idx in range(llm_grid_t):
|
||||||
|
t_index = (
|
||||||
|
torch.tensor(t_idx)
|
||||||
|
.view(-1, 1)
|
||||||
|
.expand(-1, llm_grid_h * llm_grid_w)
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
|
||||||
|
h_index = (
|
||||||
|
torch.arange(llm_grid_h)
|
||||||
|
.view(1, -1, 1)
|
||||||
|
.expand(1, -1, llm_grid_w)
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
w_index = (
|
||||||
|
torch.arange(llm_grid_w)
|
||||||
|
.view(1, 1, -1)
|
||||||
|
.expand(1, llm_grid_h, -1)
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.stack([t_index, h_index, w_index]) + st_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
video_group_index += 1
|
||||||
|
|
||||||
|
if video_group_index >= video_grid_thw[video_index][0]:
|
||||||
|
video_index += 1
|
||||||
|
video_group_index = 0
|
||||||
|
|
||||||
|
video_frame_num += 1
|
||||||
|
|
||||||
|
else:
|
||||||
|
text_len = end_idx - start_idx
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
video_frame_num = 1
|
||||||
|
|
||||||
|
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||||
|
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(
|
||||||
|
position_ids.device
|
||||||
|
)
|
||||||
|
mrope_position_deltas.append(
|
||||||
|
llm_positions.max() + 1 - len(total_input_ids[i])
|
||||||
|
)
|
||||||
|
mrope_position_deltas = torch.tensor(
|
||||||
|
mrope_position_deltas, device=input_ids.device
|
||||||
|
).unsqueeze(1)
|
||||||
|
return position_ids, mrope_position_deltas
|
||||||
|
else:
|
||||||
|
if attention_mask is not None:
|
||||||
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||||
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||||
|
position_ids = (
|
||||||
|
position_ids.unsqueeze(0)
|
||||||
|
.expand(3, -1, -1)
|
||||||
|
.to(attention_mask.device)
|
||||||
|
)
|
||||||
|
max_position_ids = position_ids.max(0, keepdim=False)[0].max(
|
||||||
|
-1, keepdim=True
|
||||||
|
)[0]
|
||||||
|
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
|
||||||
|
else:
|
||||||
|
position_ids = (
|
||||||
|
torch.arange(input_ids.shape[1], device=input_ids.device)
|
||||||
|
.view(1, 1, -1)
|
||||||
|
.expand(3, input_ids.shape[0], -1)
|
||||||
|
)
|
||||||
|
mrope_position_deltas = torch.zeros(
|
||||||
|
[input_ids.shape[0], 1],
|
||||||
|
device=input_ids.device,
|
||||||
|
dtype=input_ids.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
return position_ids, mrope_position_deltas
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_next_input_positions(
|
def get_next_input_positions(
|
||||||
mrope_position_delta: int,
|
mrope_position_delta: int,
|
||||||
|
|||||||
@@ -218,6 +218,12 @@ class Glm4Model(nn.Module):
|
|||||||
|
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def get_input_embeddings(self) -> nn.Embedding:
|
||||||
|
return self.embed_tokens
|
||||||
|
|
||||||
|
def dtype(self) -> torch.dtype:
|
||||||
|
return next(self.parameters()).dtype
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
589
python/sglang/srt/models/glm4v.py
Normal file
589
python/sglang/srt/models/glm4v.py
Normal file
@@ -0,0 +1,589 @@
|
|||||||
|
import logging
|
||||||
|
from functools import lru_cache, partial
|
||||||
|
from typing import Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers.models.glm4v.configuration_glm4v import Glm4vConfig, Glm4vVisionConfig
|
||||||
|
|
||||||
|
from sglang.srt.hf_transformers_utils import get_processor
|
||||||
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
|
from sglang.srt.layers.linear import (
|
||||||
|
ColumnParallelLinear,
|
||||||
|
MergedColumnParallelLinear,
|
||||||
|
RowParallelLinear,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
|
from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||||
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
|
from sglang.srt.managers.schedule_batch import MultimodalDataItem
|
||||||
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
|
from sglang.srt.models.glm4 import Glm4Model
|
||||||
|
from sglang.srt.models.qwen2_5_vl import (
|
||||||
|
Qwen2_5_VisionBlock,
|
||||||
|
Qwen2_5_VLForConditionalGeneration,
|
||||||
|
)
|
||||||
|
from sglang.srt.utils import add_prefix
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
cached_get_processor = lru_cache(get_processor)
|
||||||
|
|
||||||
|
|
||||||
|
class Glm4vRMSNorm(RMSNorm):
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
original_shape = x.shape
|
||||||
|
x_2d = x.contiguous().reshape(-1, original_shape[-1])
|
||||||
|
x_2d = super().forward(x_2d)
|
||||||
|
x = x_2d.reshape(original_shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Glm4vVisionMLP(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
hidden_features: int,
|
||||||
|
bias: bool = False,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
|
input_size=in_features,
|
||||||
|
output_sizes=[hidden_features] * 2,
|
||||||
|
bias=bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("gate_up_proj", prefix),
|
||||||
|
)
|
||||||
|
self.down_proj = RowParallelLinear(
|
||||||
|
hidden_features,
|
||||||
|
in_features,
|
||||||
|
bias=bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("down_proj", prefix),
|
||||||
|
)
|
||||||
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
|
x = self.act_fn(gate_up)
|
||||||
|
x, _ = self.down_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Glm4vVisionBlock(Qwen2_5_VisionBlock):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Glm4vVisionConfig,
|
||||||
|
norm_layer: Optional[nn.Module] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
dim=config.hidden_size,
|
||||||
|
intermediate_dim=config.out_hidden_size,
|
||||||
|
num_heads=config.num_heads,
|
||||||
|
hidden_act=config.hidden_act,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=prefix,
|
||||||
|
)
|
||||||
|
self.norm1 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.norm2 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
self.mlp = Glm4vVisionMLP(
|
||||||
|
config.hidden_size,
|
||||||
|
config.out_hidden_size,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("mlp", prefix),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Glm4vVisionPatchEmbed(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
patch_size: int = 14,
|
||||||
|
temporal_patch_size: int = 2,
|
||||||
|
in_channels: int = 3,
|
||||||
|
hidden_size: int = 1536,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.temporal_patch_size = temporal_patch_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
kernel_size = (temporal_patch_size, patch_size, patch_size)
|
||||||
|
self.proj = nn.Conv3d(
|
||||||
|
in_channels,
|
||||||
|
hidden_size,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=kernel_size,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = x.view(
|
||||||
|
-1,
|
||||||
|
self.in_channels,
|
||||||
|
self.temporal_patch_size,
|
||||||
|
self.patch_size,
|
||||||
|
self.patch_size,
|
||||||
|
)
|
||||||
|
x = self.proj(x).view(-1, self.hidden_size)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Glm4vPatchMerger(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int,
|
||||||
|
context_dim: int,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
bias: bool = False,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = d_model
|
||||||
|
self.proj = ColumnParallelLinear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.hidden_size,
|
||||||
|
bias=bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("proj", prefix),
|
||||||
|
gather_output=True,
|
||||||
|
)
|
||||||
|
self.post_projection_norm = nn.LayerNorm(self.hidden_size)
|
||||||
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
|
input_size=self.hidden_size,
|
||||||
|
output_sizes=[context_dim] * 2,
|
||||||
|
bias=bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("gate_up_proj", prefix),
|
||||||
|
)
|
||||||
|
self.down_proj = RowParallelLinear(
|
||||||
|
context_dim,
|
||||||
|
self.hidden_size,
|
||||||
|
bias=bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("down_proj", prefix),
|
||||||
|
)
|
||||||
|
self.extra_activation_func = nn.GELU()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
x, _ = self.proj(x)
|
||||||
|
x = self.extra_activation_func(self.post_projection_norm(x))
|
||||||
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
|
gate, up = gate_up.chunk(2, dim=-1)
|
||||||
|
x = F.silu(gate) * up
|
||||||
|
x, _ = self.down_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Glm4vVisionEmbeddings(nn.Module):
|
||||||
|
def __init__(self, config: Glm4vVisionConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.image_size = config.image_size
|
||||||
|
self.patch_size = config.patch_size
|
||||||
|
|
||||||
|
self.num_patches = (self.image_size // self.patch_size) ** 2
|
||||||
|
self.num_positions = self.num_patches
|
||||||
|
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
||||||
|
self.register_buffer(
|
||||||
|
"position_ids",
|
||||||
|
torch.arange(self.num_positions).expand((1, -1)),
|
||||||
|
persistent=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, embeddings, lengths, image_shapes, h_coords, w_coords
|
||||||
|
) -> torch.Tensor:
|
||||||
|
pos_embed_weight = self.position_embedding.weight
|
||||||
|
hidden_size = pos_embed_weight.shape[1]
|
||||||
|
total_seq = h_coords.shape[0]
|
||||||
|
device = pos_embed_weight.device
|
||||||
|
|
||||||
|
# Move coordinates to correct device
|
||||||
|
h_coords, w_coords = h_coords.to(device), w_coords.to(device)
|
||||||
|
|
||||||
|
# Handle empty sequence case
|
||||||
|
if total_seq == 0:
|
||||||
|
adapted_pos_embed = torch.empty(
|
||||||
|
0, hidden_size, device=device, dtype=pos_embed_weight.dtype
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Convert inputs to tensors if needed
|
||||||
|
if isinstance(lengths, list):
|
||||||
|
lengths = torch.tensor(lengths, device=device, dtype=torch.long)
|
||||||
|
if not isinstance(image_shapes, torch.Tensor):
|
||||||
|
image_shapes = torch.tensor(
|
||||||
|
image_shapes, device=device, dtype=torch.long
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare 2D position embedding
|
||||||
|
orig_size_sq = pos_embed_weight.shape[0]
|
||||||
|
orig_size = int(orig_size_sq**0.5)
|
||||||
|
pos_embed_2d = (
|
||||||
|
pos_embed_weight.view(orig_size, orig_size, hidden_size)
|
||||||
|
.permute(2, 0, 1)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.to(device=device, dtype=torch.float32)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate target dimensions for each patch
|
||||||
|
target_h = torch.cat(
|
||||||
|
[image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]
|
||||||
|
).to(device=device, dtype=torch.float32)
|
||||||
|
target_w = torch.cat(
|
||||||
|
[image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]
|
||||||
|
).to(device=device, dtype=torch.float32)
|
||||||
|
|
||||||
|
# Normalize coordinates to [-1, 1] range for grid_sample
|
||||||
|
h_coords = h_coords.to(device=device, dtype=torch.float32)
|
||||||
|
w_coords = w_coords.to(device=device, dtype=torch.float32)
|
||||||
|
norm_w = ((w_coords + 0.5) / target_w) * 2 - 1
|
||||||
|
norm_h = ((h_coords + 0.5) / target_h) * 2 - 1
|
||||||
|
|
||||||
|
# Create sampling grid
|
||||||
|
grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2)
|
||||||
|
|
||||||
|
# Perform bicubic interpolation
|
||||||
|
interpolated_embed_fp32 = F.grid_sample(
|
||||||
|
pos_embed_2d,
|
||||||
|
grid,
|
||||||
|
mode="bicubic",
|
||||||
|
align_corners=False,
|
||||||
|
padding_mode="border",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reshape and convert back to original dtype
|
||||||
|
adapted_pos_embed_fp32 = (
|
||||||
|
interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0)
|
||||||
|
)
|
||||||
|
adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to(
|
||||||
|
embeddings.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add adapted position encoding to embeddings
|
||||||
|
embeddings = embeddings + adapted_pos_embed
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class Glm4vVisionRotaryEmbedding(nn.Module):
|
||||||
|
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.theta = theta
|
||||||
|
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
self._seq_len_cached = 0
|
||||||
|
self._freqs_cached = None
|
||||||
|
|
||||||
|
def update_freqs_cache(self, seqlen: int) -> None:
|
||||||
|
if seqlen > self._seq_len_cached:
|
||||||
|
seqlen *= 2
|
||||||
|
self._seq_len_cached = seqlen
|
||||||
|
self.inv_freq = 1.0 / (
|
||||||
|
self.theta
|
||||||
|
** (
|
||||||
|
torch.arange(
|
||||||
|
0,
|
||||||
|
self.dim,
|
||||||
|
2,
|
||||||
|
dtype=torch.float,
|
||||||
|
device=self.inv_freq.device,
|
||||||
|
)
|
||||||
|
/ self.dim
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seq = torch.arange(
|
||||||
|
seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
|
||||||
|
)
|
||||||
|
freqs = torch.outer(seq, self.inv_freq)
|
||||||
|
self._freqs_cached = freqs
|
||||||
|
|
||||||
|
def forward(self, seqlen: int) -> torch.Tensor:
|
||||||
|
self.update_freqs_cache(seqlen)
|
||||||
|
return self._freqs_cached[:seqlen]
|
||||||
|
|
||||||
|
|
||||||
|
class Glm4vVisionModel(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vision_config: Glm4vVisionConfig,
|
||||||
|
norm_eps: float = 1e-6,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
patch_size = vision_config.patch_size
|
||||||
|
temporal_patch_size = vision_config.temporal_patch_size
|
||||||
|
in_channels = vision_config.in_channels
|
||||||
|
depth = vision_config.depth
|
||||||
|
self.hidden_size = vision_config.hidden_size
|
||||||
|
self.num_heads = vision_config.num_heads
|
||||||
|
|
||||||
|
self.patch_size = vision_config.patch_size
|
||||||
|
self.spatial_merge_size = vision_config.spatial_merge_size
|
||||||
|
self.out_hidden_size = vision_config.out_hidden_size
|
||||||
|
|
||||||
|
self.patch_embed = Glm4vVisionPatchEmbed(
|
||||||
|
patch_size=patch_size,
|
||||||
|
temporal_patch_size=temporal_patch_size,
|
||||||
|
in_channels=in_channels,
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
norm_layer = partial(Glm4vRMSNorm, eps=norm_eps)
|
||||||
|
head_dim = self.hidden_size // self.num_heads
|
||||||
|
self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2)
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
Glm4vVisionBlock(
|
||||||
|
config=vision_config,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix(f"blocks.{layer_idx}", prefix),
|
||||||
|
)
|
||||||
|
for layer_idx in range(depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.merger = Glm4vPatchMerger(
|
||||||
|
d_model=vision_config.out_hidden_size,
|
||||||
|
context_dim=vision_config.intermediate_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
bias=False,
|
||||||
|
prefix=add_prefix("merger", prefix),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.embeddings = Glm4vVisionEmbeddings(vision_config)
|
||||||
|
|
||||||
|
self.post_conv_layernorm = Glm4vRMSNorm(
|
||||||
|
vision_config.hidden_size, eps=vision_config.rms_norm_eps
|
||||||
|
)
|
||||||
|
self.downsample = nn.Conv2d(
|
||||||
|
in_channels=vision_config.hidden_size,
|
||||||
|
out_channels=vision_config.out_hidden_size,
|
||||||
|
kernel_size=vision_config.spatial_merge_size,
|
||||||
|
stride=vision_config.spatial_merge_size,
|
||||||
|
)
|
||||||
|
self.post_layernorm = Glm4vRMSNorm(
|
||||||
|
vision_config.hidden_size, eps=vision_config.rms_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self) -> torch.dtype:
|
||||||
|
return self.patch_embed.proj.weight.dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
return self.patch_embed.proj.weight.device
|
||||||
|
|
||||||
|
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
|
||||||
|
pos_ids = []
|
||||||
|
for t, h, w in grid_thw:
|
||||||
|
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
||||||
|
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
||||||
|
hpos_ids = (
|
||||||
|
hpos_ids.reshape(
|
||||||
|
h // self.spatial_merge_size,
|
||||||
|
self.spatial_merge_size,
|
||||||
|
w // self.spatial_merge_size,
|
||||||
|
self.spatial_merge_size,
|
||||||
|
)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
wpos_ids = (
|
||||||
|
wpos_ids.reshape(
|
||||||
|
h // self.spatial_merge_size,
|
||||||
|
self.spatial_merge_size,
|
||||||
|
w // self.spatial_merge_size,
|
||||||
|
self.spatial_merge_size,
|
||||||
|
)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
||||||
|
pos_ids = torch.cat(pos_ids, dim=0)
|
||||||
|
max_grid_size = grid_thw[:, 1:].max()
|
||||||
|
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
||||||
|
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||||
|
return rotary_pos_emb, pos_ids
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
|
||||||
|
# patchify
|
||||||
|
x = x.to(device=self.device, dtype=self.dtype)
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
x = self.post_conv_layernorm(x)
|
||||||
|
|
||||||
|
# compute position embedding
|
||||||
|
rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw)
|
||||||
|
# compute cu_seqlens
|
||||||
|
cu_seqlens = torch.repeat_interleave(
|
||||||
|
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||||
|
).cumsum(dim=0, dtype=torch.int32)
|
||||||
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
||||||
|
|
||||||
|
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||||
|
x = self.embeddings(
|
||||||
|
x, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1]
|
||||||
|
)
|
||||||
|
|
||||||
|
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||||
|
rotary_pos_emb_tuple = (emb.cos(), emb.sin())
|
||||||
|
|
||||||
|
# x.shape: (s, b, d) where b=1 for vision processing
|
||||||
|
# transformers
|
||||||
|
x = x.unsqueeze(1)
|
||||||
|
for blk in self.blocks:
|
||||||
|
x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=rotary_pos_emb_tuple)
|
||||||
|
|
||||||
|
# adapter
|
||||||
|
x = self.post_layernorm(x)
|
||||||
|
x = x.view(-1, self.spatial_merge_size, self.spatial_merge_size, x.shape[-1])
|
||||||
|
x = x.permute(0, 3, 1, 2)
|
||||||
|
x = self.downsample(x).view(-1, self.out_hidden_size)
|
||||||
|
x = self.merger(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Glm4vConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.model = Glm4Model(
|
||||||
|
config,
|
||||||
|
quant_config,
|
||||||
|
prefix=add_prefix("model", prefix),
|
||||||
|
)
|
||||||
|
self.visual = Glm4vVisionModel(
|
||||||
|
config.vision_config,
|
||||||
|
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("visual", prefix),
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.tie_word_embeddings:
|
||||||
|
self.lm_head = self.model.embed_tokens
|
||||||
|
else:
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("lm_head", prefix),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||||
|
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
|
||||||
|
|
||||||
|
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||||
|
pixel_values = torch.cat(
|
||||||
|
[item.feature.squeeze(0) for item in items], dim=0
|
||||||
|
).type(self.visual.dtype)
|
||||||
|
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
|
||||||
|
# For multi-image, pixel_values is [num_of_images, L, C] shape
|
||||||
|
# assert pixel_values.dim() == 2, pixel_values.dim()
|
||||||
|
assert image_grid_thw.dim() == 2, image_grid_thw.dim()
|
||||||
|
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||||
|
split_sizes = (
|
||||||
|
image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2
|
||||||
|
).tolist()
|
||||||
|
image_embeds = torch.split(image_embeds, split_sizes)
|
||||||
|
return torch.cat(image_embeds)
|
||||||
|
|
||||||
|
def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||||
|
pixel_values_videos = torch.cat(
|
||||||
|
[item.feature.squeeze(0) for item in items], dim=0
|
||||||
|
).type(self.visual.dtype)
|
||||||
|
video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
|
||||||
|
# For multi-video, pixel_values_videos is [num_of_videos, L, C] shape
|
||||||
|
# assert pixel_values_videos.dim() == 2, pixel_values_videos.dim()
|
||||||
|
assert video_grid_thw.dim() == 2, video_grid_thw.dim()
|
||||||
|
|
||||||
|
# reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames
|
||||||
|
temp_frames_hw = []
|
||||||
|
for t, h, w in video_grid_thw:
|
||||||
|
repeated_row = (
|
||||||
|
torch.tensor([1, h.item(), w.item()]).unsqueeze(0).repeat(t, 1)
|
||||||
|
)
|
||||||
|
temp_frames_hw.append(repeated_row)
|
||||||
|
flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0)
|
||||||
|
video_embeds = self.visual(
|
||||||
|
pixel_values_videos, grid_thw=flattened_video_grid_thw
|
||||||
|
)
|
||||||
|
split_sizes = (
|
||||||
|
video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2
|
||||||
|
).tolist()
|
||||||
|
video_embeds = torch.split(video_embeds, split_sizes)
|
||||||
|
return torch.cat(video_embeds)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
(".qkv_proj", ".q_proj", "q"),
|
||||||
|
(".qkv_proj", ".k_proj", "k"),
|
||||||
|
(".qkv_proj", ".v_proj", "v"),
|
||||||
|
(".gate_up_proj", ".up_proj", 1),
|
||||||
|
(".gate_up_proj", ".gate_proj", 0),
|
||||||
|
]
|
||||||
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if "language_model." in name:
|
||||||
|
name = name.replace("language_model.", "")
|
||||||
|
if "model.visual." in name:
|
||||||
|
name = name.replace("model.visual.", "visual.")
|
||||||
|
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if "visual" in name:
|
||||||
|
# adapt to VisionAttention
|
||||||
|
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
except KeyError:
|
||||||
|
print(params_dict.keys())
|
||||||
|
raise
|
||||||
|
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
EntryClass = [Glm4vForConditionalGeneration]
|
||||||
400
python/sglang/srt/models/glm4v_moe.py
Normal file
400
python/sglang/srt/models/glm4v_moe.py
Normal file
@@ -0,0 +1,400 @@
|
|||||||
|
import logging
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Iterable, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers.models.glm4v_moe.configuration_glm4v_moe import Glm4vMoeConfig
|
||||||
|
|
||||||
|
from sglang.srt.distributed import (
|
||||||
|
get_moe_expert_parallel_world_size,
|
||||||
|
get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
parallel_state,
|
||||||
|
tensor_model_parallel_all_reduce,
|
||||||
|
)
|
||||||
|
from sglang.srt.hf_transformers_utils import get_processor
|
||||||
|
from sglang.srt.layers.dp_attention import (
|
||||||
|
get_attention_tp_rank,
|
||||||
|
get_attention_tp_size,
|
||||||
|
get_local_attention_dp_size,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
|
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||||
|
from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||||
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
|
from sglang.srt.models.glm4_moe import Glm4MoeModel
|
||||||
|
from sglang.srt.models.glm4v import Glm4vForConditionalGeneration, Glm4vVisionModel
|
||||||
|
from sglang.srt.utils import add_prefix, is_cuda, log_info_on_rank0
|
||||||
|
|
||||||
|
_is_cuda = is_cuda()
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
cached_get_processor = lru_cache(get_processor)
|
||||||
|
|
||||||
|
|
||||||
|
class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Glm4vMoeConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
|
||||||
|
config.moe_layer_freq = 1
|
||||||
|
self.config = config
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.dp_size = get_local_attention_dp_size()
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
|
||||||
|
self.num_fused_shared_experts = (
|
||||||
|
0
|
||||||
|
if global_server_args_dict["disable_shared_experts_fusion"]
|
||||||
|
else config.n_shared_experts
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model = Glm4MoeModel(
|
||||||
|
config,
|
||||||
|
quant_config,
|
||||||
|
prefix=add_prefix("language_model", prefix),
|
||||||
|
)
|
||||||
|
self.visual = Glm4vVisionModel(
|
||||||
|
config.vision_config,
|
||||||
|
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("visual", prefix),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("lm_head", prefix),
|
||||||
|
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||||
|
)
|
||||||
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||||
|
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
|
||||||
|
|
||||||
|
def determine_num_fused_shared_experts(
|
||||||
|
self, architecture: str = "Glm4MoeForCausalLM"
|
||||||
|
):
|
||||||
|
self.num_fused_shared_experts = 0
|
||||||
|
if global_server_args_dict["disable_shared_experts_fusion"]:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
|
||||||
|
disable_reason = None
|
||||||
|
if (
|
||||||
|
not _is_cuda
|
||||||
|
or torch.cuda.get_device_capability("cuda") < (8, 0)
|
||||||
|
or self.config.architectures[0] != architecture
|
||||||
|
or self.config.n_shared_experts != 1
|
||||||
|
):
|
||||||
|
disable_reason = "Only GLM-4.5 on NV-platform with capability >= 80 can use shared experts fusion optimization."
|
||||||
|
elif get_moe_expert_parallel_world_size() > 1:
|
||||||
|
disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism."
|
||||||
|
|
||||||
|
if disable_reason is not None:
|
||||||
|
global_server_args_dict["disable_shared_experts_fusion"] = True
|
||||||
|
self.num_fused_shared_experts = 0
|
||||||
|
log_info_on_rank0(
|
||||||
|
logger,
|
||||||
|
f"{disable_reason} Shared experts fusion optimization is disabled.",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
self.num_fused_shared_experts = self.config.n_shared_experts
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
|
||||||
|
|
||||||
|
if is_nextn:
|
||||||
|
if hasattr(self.config, "num_nextn_predict_layers"):
|
||||||
|
num_nextn_layers = self.config.num_nextn_predict_layers
|
||||||
|
assert num_nextn_layers == 1, "Only 1 nextn layer is supported"
|
||||||
|
# compatible with old design
|
||||||
|
nextn_layer_id = (
|
||||||
|
0
|
||||||
|
if self.config.num_hidden_layers == 1
|
||||||
|
else self.config.num_hidden_layers
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("num_nextn_predict_layers is not in the config")
|
||||||
|
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
("gate_up_proj", "gate_proj", 0),
|
||||||
|
("gate_up_proj", "up_proj", 1),
|
||||||
|
]
|
||||||
|
if self.num_fused_shared_experts > 0:
|
||||||
|
assert self.num_fused_shared_experts == 1
|
||||||
|
weights_list = list(weights)
|
||||||
|
weights_dict = dict(weights_list)
|
||||||
|
if self.quant_config is not None:
|
||||||
|
if self.quant_config.get_name() == "w8a8_int8":
|
||||||
|
suffix_list = [
|
||||||
|
"down_proj.weight",
|
||||||
|
"down_proj.weight_scale",
|
||||||
|
"gate_proj.weight",
|
||||||
|
"gate_proj.weight_scale",
|
||||||
|
"up_proj.weight",
|
||||||
|
"up_proj.weight_scale",
|
||||||
|
]
|
||||||
|
elif (
|
||||||
|
self.quant_config.get_name() == "fp8"
|
||||||
|
or self.quant_config.get_name() == "blockwise_int8"
|
||||||
|
or self.quant_config.get_name() == "compressed_tensors"
|
||||||
|
):
|
||||||
|
suffix_list = [
|
||||||
|
"down_proj.weight",
|
||||||
|
"down_proj.weight_scale",
|
||||||
|
"gate_proj.weight",
|
||||||
|
"gate_proj.weight_scale",
|
||||||
|
"up_proj.weight",
|
||||||
|
"up_proj.weight_scale",
|
||||||
|
]
|
||||||
|
elif self.quant_config.get_name() == "awq":
|
||||||
|
suffix_list = [
|
||||||
|
"down_proj.qweight",
|
||||||
|
"down_proj.qzeros",
|
||||||
|
"down_proj.scales",
|
||||||
|
"gate_proj.qweight",
|
||||||
|
"gate_proj.qzeros",
|
||||||
|
"gate_proj.scales",
|
||||||
|
"up_proj.qweight",
|
||||||
|
"up_proj.qzeros",
|
||||||
|
"up_proj.scales",
|
||||||
|
]
|
||||||
|
elif self.quant_config.get_name() == "modelopt_fp4":
|
||||||
|
suffix_list = [
|
||||||
|
"down_proj.weight",
|
||||||
|
"down_proj.weight_scale",
|
||||||
|
"down_proj.weight_scale_2",
|
||||||
|
"down_proj.input_scale",
|
||||||
|
"gate_proj.weight",
|
||||||
|
"gate_proj.weight_scale",
|
||||||
|
"gate_proj.weight_scale_2",
|
||||||
|
"gate_proj.input_scale",
|
||||||
|
"up_proj.weight",
|
||||||
|
"up_proj.weight_scale",
|
||||||
|
"up_proj.weight_scale_2",
|
||||||
|
"up_proj.input_scale",
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
suffix_list = [
|
||||||
|
"down_proj.weight",
|
||||||
|
"gate_proj.weight",
|
||||||
|
"up_proj.weight",
|
||||||
|
]
|
||||||
|
names_to_remove = []
|
||||||
|
|
||||||
|
moe_layers = (
|
||||||
|
range(
|
||||||
|
self.config.first_k_dense_replace,
|
||||||
|
self.config.num_hidden_layers,
|
||||||
|
self.config.moe_layer_freq,
|
||||||
|
)
|
||||||
|
if not is_nextn
|
||||||
|
else [nextn_layer_id]
|
||||||
|
)
|
||||||
|
|
||||||
|
for moe_layer in moe_layers:
|
||||||
|
for suffix in suffix_list:
|
||||||
|
shared_expert_weight_name = (
|
||||||
|
f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
|
||||||
|
)
|
||||||
|
# online fp8 quantization does not load weight_scale
|
||||||
|
if shared_expert_weight_name not in weights_dict:
|
||||||
|
continue
|
||||||
|
weights_list.append(
|
||||||
|
(
|
||||||
|
f"model.layers.{moe_layer}."
|
||||||
|
f"mlp.experts."
|
||||||
|
f"{self.config.n_routed_experts + 0}"
|
||||||
|
f".{suffix}",
|
||||||
|
weights_dict[shared_expert_weight_name],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
names_to_remove += [shared_expert_weight_name]
|
||||||
|
weights = [w for w in weights_list if w[0] not in names_to_remove]
|
||||||
|
|
||||||
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||||
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
|
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
|
||||||
|
ckpt_gate_proj_name="gate_proj",
|
||||||
|
ckpt_down_proj_name="down_proj",
|
||||||
|
ckpt_up_proj_name="up_proj",
|
||||||
|
num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
|
||||||
|
fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
|
||||||
|
self.config.q_lora_rank is not None
|
||||||
|
)
|
||||||
|
cached_a_proj = {} if fuse_qkv_a_proj else None
|
||||||
|
|
||||||
|
if is_nextn:
|
||||||
|
nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
|
||||||
|
nextn_spec_weight_names = [
|
||||||
|
"shared_head.norm",
|
||||||
|
"eh_proj",
|
||||||
|
"enorm",
|
||||||
|
"hnorm",
|
||||||
|
]
|
||||||
|
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
weight_names = []
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
weight_names.append(name)
|
||||||
|
|
||||||
|
if not is_nextn:
|
||||||
|
if hasattr(self.config, "num_nextn_predict_layers"):
|
||||||
|
num_nextn_layers = self.config.num_nextn_predict_layers
|
||||||
|
if num_nextn_layers > 0 and name.startswith("model.layers"):
|
||||||
|
name_list = name.split(".")
|
||||||
|
if (
|
||||||
|
len(name_list) >= 3
|
||||||
|
and int(name_list[2]) >= self.config.num_hidden_layers
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
if not name.startswith(nextn_layer_prefix):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Use shared head and embed weights from target model
|
||||||
|
if "shared_head.head" in name or "embed_tokens" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_decoder = True
|
||||||
|
# For nextn specific weights
|
||||||
|
for weight_name in nextn_spec_weight_names:
|
||||||
|
if weight_name in name:
|
||||||
|
name = name.replace(nextn_layer_prefix, "model")
|
||||||
|
is_decoder = False
|
||||||
|
break
|
||||||
|
# For decoder layer weights
|
||||||
|
if is_decoder:
|
||||||
|
name = name.replace(nextn_layer_prefix, "model.decoder")
|
||||||
|
|
||||||
|
if "language_model." in name:
|
||||||
|
name = name.replace("language_model.", "")
|
||||||
|
if "model.visual." in name:
|
||||||
|
name = name.replace("model.visual.", "visual.")
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
|
# Skip non-stacked layers and experts (experts handled below).
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||||
|
# Since we handle the experts below in expert_params_mapping,
|
||||||
|
# we need to skip here BEFORE we update the name, otherwise
|
||||||
|
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||||
|
# will then be updated below in expert_params_mapping
|
||||||
|
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||||
|
if ("mlp.experts." in name) and name not in params_dict:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
for mapping in expert_params_mapping:
|
||||||
|
param_name, weight_name, expert_id, shard_id = mapping
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(
|
||||||
|
param,
|
||||||
|
loaded_weight,
|
||||||
|
name,
|
||||||
|
shard_id=shard_id,
|
||||||
|
expert_id=expert_id,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if "visual" in name:
|
||||||
|
# adapt to VisionAttention
|
||||||
|
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
|
||||||
|
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
if fuse_qkv_a_proj and (
|
||||||
|
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
|
||||||
|
):
|
||||||
|
cached_a_proj[name] = loaded_weight
|
||||||
|
q_a_proj_name = (
|
||||||
|
name
|
||||||
|
if "q_a_proj" in name
|
||||||
|
else name.replace("kv_a_proj_with_mqa", "q_a_proj")
|
||||||
|
)
|
||||||
|
kv_a_proj_name = (
|
||||||
|
name
|
||||||
|
if "kv_a_proj_with_mqa" in name
|
||||||
|
else name.replace("q_a_proj", "kv_a_proj_with_mqa")
|
||||||
|
)
|
||||||
|
|
||||||
|
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
|
||||||
|
if (
|
||||||
|
q_a_proj_name in cached_a_proj
|
||||||
|
and kv_a_proj_name in cached_a_proj
|
||||||
|
):
|
||||||
|
q_a_proj_weight = cached_a_proj[q_a_proj_name]
|
||||||
|
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
|
||||||
|
fused_weight = torch.cat(
|
||||||
|
[q_a_proj_weight, kv_a_proj_weight], dim=0
|
||||||
|
)
|
||||||
|
param_name = (
|
||||||
|
name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
|
||||||
|
if "q_a_proj" in name
|
||||||
|
else name.replace(
|
||||||
|
"kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
param = params_dict[param_name]
|
||||||
|
|
||||||
|
weight_loader = getattr(
|
||||||
|
param, "weight_loader", default_weight_loader
|
||||||
|
)
|
||||||
|
weight_loader(param, fused_weight)
|
||||||
|
cached_a_proj.pop(q_a_proj_name)
|
||||||
|
cached_a_proj.pop(kv_a_proj_name)
|
||||||
|
else:
|
||||||
|
if (
|
||||||
|
"k_scale" in name or "v_scale" in name
|
||||||
|
) and name not in params_dict:
|
||||||
|
# modelopt attn kv scale is named differently
|
||||||
|
if any(scale in name for scale in ["k_scale", "v_scale"]):
|
||||||
|
name = name.replace("_proj", "attn_mqa")
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Unknown scale found in checkpoint: {name}"
|
||||||
|
)
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(
|
||||||
|
param, "weight_loader", default_weight_loader
|
||||||
|
)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
EntryClass = [Glm4vMoeForConditionalGeneration]
|
||||||
@@ -22,13 +22,19 @@ class BaseMultiModalProcessorOutput:
|
|||||||
input_text: str
|
input_text: str
|
||||||
|
|
||||||
# frames loaded from image, in given order
|
# frames loaded from image, in given order
|
||||||
images: Optional[list[Union[Image.Image, dict]]] = None
|
images: Optional[list[Union[Image.Image, dict]]] = dataclasses.field(
|
||||||
|
default_factory=list
|
||||||
|
)
|
||||||
|
|
||||||
# videos
|
# videos
|
||||||
videos: Optional[list[Union[torch.Tensor, dict]]] = None
|
videos: Optional[list[Union[torch.Tensor, dict]]] = dataclasses.field(
|
||||||
|
default_factory=list
|
||||||
|
)
|
||||||
|
|
||||||
# audios
|
# audios
|
||||||
audios: Optional[list[Union[np.ndarray, dict]]] = None
|
audios: Optional[list[Union[np.ndarray, dict]]] = dataclasses.field(
|
||||||
|
default_factory=list
|
||||||
|
)
|
||||||
|
|
||||||
def organize_results(self) -> List[Tuple[Modality, Any]]:
|
def organize_results(self) -> List[Tuple[Modality, Any]]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
132
python/sglang/srt/multimodal/processors/glm4v.py
Normal file
132
python/sglang/srt/multimodal/processors/glm4v.py
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
import re
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
from decord import VideoReader
|
||||||
|
from transformers.video_utils import VideoMetadata
|
||||||
|
|
||||||
|
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
||||||
|
from sglang.srt.models.glm4v import Glm4vForConditionalGeneration
|
||||||
|
from sglang.srt.models.glm4v_moe import Glm4vMoeForConditionalGeneration
|
||||||
|
from sglang.srt.multimodal.processors.base_processor import (
|
||||||
|
BaseMultimodalProcessor as SGLangBaseProcessor,
|
||||||
|
)
|
||||||
|
from sglang.srt.multimodal.processors.base_processor import (
|
||||||
|
BaseMultiModalProcessorOutput,
|
||||||
|
MultimodalSpecialTokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Glm4vImageProcessor(SGLangBaseProcessor):
|
||||||
|
models = [Glm4vForConditionalGeneration, Glm4vMoeForConditionalGeneration]
|
||||||
|
|
||||||
|
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||||
|
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||||
|
|
||||||
|
# GLM-4.1V and GLM-4.5V specific tokens
|
||||||
|
self.IMAGE_TOKEN = "<|image|>"
|
||||||
|
self.VIDEO_TOKEN = "<|video|>"
|
||||||
|
self.IMAGE_START_TOKEN = "<|begin_of_image|>"
|
||||||
|
self.IMAGE_END_TOKEN = "<|end_of_image|>"
|
||||||
|
self.VIDEO_START_TOKEN = "<|begin_of_video|>"
|
||||||
|
self.VIDEO_END_TOKEN = "<|end_of_video|>"
|
||||||
|
|
||||||
|
# Token IDs
|
||||||
|
self.IM_TOKEN_ID = hf_config.image_token_id
|
||||||
|
self.VIDEO_TOKEN_ID = hf_config.video_token_id
|
||||||
|
self.IMAGE_START_TOKEN_ID = hf_config.image_start_token_id
|
||||||
|
self.IMAGE_END_TOKEN_ID = hf_config.image_end_token_id
|
||||||
|
self.VIDEO_START_TOKEN_ID = hf_config.video_start_token_id
|
||||||
|
self.VIDEO_END_TOKEN_ID = hf_config.video_end_token_id
|
||||||
|
|
||||||
|
# Vision config
|
||||||
|
self.IMAGE_FACTOR = 28
|
||||||
|
self.MIN_PIXELS = 112 * 112
|
||||||
|
self.MAX_PIXELS = 30000 * 28 * 28 * 2
|
||||||
|
|
||||||
|
self.mm_tokens = MultimodalSpecialTokens(
|
||||||
|
image_token=self.IMAGE_TOKEN,
|
||||||
|
image_token_id=self.IM_TOKEN_ID,
|
||||||
|
video_token=self.VIDEO_TOKEN,
|
||||||
|
# Note: For GLM4v videos, it uses the video token before tokenization but uses image token after tokenization
|
||||||
|
video_token_id=self.IM_TOKEN_ID,
|
||||||
|
).build(_processor)
|
||||||
|
|
||||||
|
# adapted from https://github.com/huggingface/transformers/blob/369c99d0cea403b77bd0aef818527106453fd9fc/src/transformers/video_utils.py#L312
|
||||||
|
async def preprocess_video(self, vr: VideoReader):
|
||||||
|
"""
|
||||||
|
Preprocess video using VideoReader from Decord backend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vr (VideoReader): VideoReader object from decord
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: A tuple containing processed frames and metadata
|
||||||
|
"""
|
||||||
|
video_fps = vr.get_avg_fps()
|
||||||
|
total_num_frames = len(vr)
|
||||||
|
duration = total_num_frames / video_fps if video_fps else 0
|
||||||
|
|
||||||
|
metadata = VideoMetadata(
|
||||||
|
total_num_frames=int(total_num_frames),
|
||||||
|
fps=float(video_fps),
|
||||||
|
duration=float(duration),
|
||||||
|
video_backend="decord",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract all frames
|
||||||
|
indices = list(range(total_num_frames))
|
||||||
|
frames = vr.get_batch(indices).asnumpy()
|
||||||
|
metadata.frames_indices = indices
|
||||||
|
|
||||||
|
return frames, metadata
|
||||||
|
|
||||||
|
async def process_mm_data_async(
|
||||||
|
self,
|
||||||
|
image_data: List[Union[str, bytes]],
|
||||||
|
input_text,
|
||||||
|
request_obj,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
base_output = self.load_mm_data(
|
||||||
|
prompt=input_text,
|
||||||
|
image_data=image_data,
|
||||||
|
video_data=request_obj.video_data,
|
||||||
|
multimodal_tokens=self.mm_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
video_metadata = None
|
||||||
|
|
||||||
|
if base_output.videos:
|
||||||
|
videos_processed = [
|
||||||
|
await self.preprocess_video(video) for video in base_output.videos
|
||||||
|
]
|
||||||
|
base_output.videos, video_metadata = map(list, zip(*videos_processed))
|
||||||
|
# transformer requires the video inputs to be under this format
|
||||||
|
base_output.videos = [base_output.videos]
|
||||||
|
video_metadata = [video_metadata]
|
||||||
|
|
||||||
|
mm_items, input_ids, ret = self.process_and_combine_mm_data(
|
||||||
|
base_output, self.mm_tokens, video_metadata=video_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
input_ids = input_ids.flatten()
|
||||||
|
mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index_glm4v(
|
||||||
|
input_ids=input_ids.unsqueeze(0),
|
||||||
|
hf_config=self.hf_config,
|
||||||
|
image_grid_thw=getattr(ret, "image_grid_thw", None),
|
||||||
|
video_grid_thw=getattr(ret, "video_grid_thw", None),
|
||||||
|
attention_mask=getattr(ret, "attention_mask", None),
|
||||||
|
)
|
||||||
|
mrope_positions = mrope_positions.squeeze(1)
|
||||||
|
|
||||||
|
mm_inputs = {
|
||||||
|
"input_ids": input_ids.tolist(),
|
||||||
|
"mm_items": mm_items,
|
||||||
|
"im_token_id": self.mm_tokens.image_token_id,
|
||||||
|
"video_token_id": self.mm_tokens.video_token_id,
|
||||||
|
"mrope_positions": mrope_positions,
|
||||||
|
"mrope_position_delta": mrope_position_delta,
|
||||||
|
}
|
||||||
|
|
||||||
|
return mm_inputs
|
||||||
@@ -815,7 +815,7 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
|
|||||||
vr = VideoReader(tmp_file.name, ctx=ctx)
|
vr = VideoReader(tmp_file.name, ctx=ctx)
|
||||||
elif video_file.startswith("data:"):
|
elif video_file.startswith("data:"):
|
||||||
_, encoded = video_file.split(",", 1)
|
_, encoded = video_file.split(",", 1)
|
||||||
video_bytes = base64.b64decode(encoded)
|
video_bytes = pybase64.b64decode(encoded)
|
||||||
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
|
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
|
||||||
tmp_file.write(video_bytes)
|
tmp_file.write(video_bytes)
|
||||||
tmp_file.close()
|
tmp_file.close()
|
||||||
@@ -823,7 +823,7 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
|
|||||||
elif os.path.isfile(video_file):
|
elif os.path.isfile(video_file):
|
||||||
vr = VideoReader(video_file, ctx=ctx)
|
vr = VideoReader(video_file, ctx=ctx)
|
||||||
else:
|
else:
|
||||||
video_bytes = base64.b64decode(video_file)
|
video_bytes = pybase64.b64decode(video_file)
|
||||||
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
|
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
|
||||||
tmp_file.write(video_bytes)
|
tmp_file.write(video_bytes)
|
||||||
tmp_file.close()
|
tmp_file.close()
|
||||||
|
|||||||
@@ -948,5 +948,6 @@ class TestOpenAIPythonicFunctionCalling(CustomTestCase):
|
|||||||
# def test_function_calling_multiturn(self):
|
# def test_function_calling_multiturn(self):
|
||||||
# self._test_function_calling_multiturn()
|
# self._test_function_calling_multiturn()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -497,6 +497,17 @@ class TestEBNFGeneration(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
Tool(
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="empty_param_func",
|
||||||
|
description="Function with empty parameters",
|
||||||
|
parameters={
|
||||||
|
"properties": {},
|
||||||
|
"required": [],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
self.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
self.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||||
@@ -630,16 +641,21 @@ class TestEBNFGeneration(unittest.TestCase):
|
|||||||
self.assertIsNotNone(ebnf)
|
self.assertIsNotNone(ebnf)
|
||||||
# Check that the EBNF contains expected patterns for XML format
|
# Check that the EBNF contains expected patterns for XML format
|
||||||
self.assertIn('"<tool_call>" function_call "</tool_call>"', ebnf)
|
self.assertIn('"<tool_call>" function_call "</tool_call>"', ebnf)
|
||||||
self.assertIn('"get_weather" "\\n" arguments_get_weather', ebnf)
|
self.assertIn('"get_weather" "\\n" ( arguments_get_weather "\\n" )?', ebnf)
|
||||||
self.assertIn(
|
self.assertIn(
|
||||||
'"<arg_key>location</arg_key>" "\\n" "<arg_value>" xml_text "</arg_value>" ( "\\n" ( "<arg_key>unit</arg_key>" "\\n" "<arg_value>" ("celsius" | "fahrenheit") "</arg_value>" ) )?',
|
'"<arg_key>location</arg_key>" "\\n" "<arg_value>" xml_text "</arg_value>" ( "\\n" ( "<arg_key>unit</arg_key>" "\\n" "<arg_value>" ("celsius" | "fahrenheit") "</arg_value>" ) )?',
|
||||||
ebnf,
|
ebnf,
|
||||||
)
|
)
|
||||||
self.assertIn('"search" "\\n" arguments_search', ebnf)
|
self.assertIn('"search" "\\n" ( arguments_search "\\n" )?', ebnf)
|
||||||
self.assertIn(
|
self.assertIn(
|
||||||
'"<arg_key>query</arg_key>" "\\n" "<arg_value>" xml_text "</arg_value>"',
|
'"<arg_key>query</arg_key>" "\\n" "<arg_value>" xml_text "</arg_value>"',
|
||||||
ebnf,
|
ebnf,
|
||||||
)
|
)
|
||||||
|
self.assertIn(
|
||||||
|
'"empty_param_func" "\\n" ( arguments_empty_param_func "\\n" )?', ebnf
|
||||||
|
)
|
||||||
|
self.assertIn('arguments_empty_param_func ::= ""', ebnf)
|
||||||
|
|
||||||
# Validate that the EBNF can be compiled by GrammarCompiler
|
# Validate that the EBNF can be compiled by GrammarCompiler
|
||||||
try:
|
try:
|
||||||
ctx = self.grammar_compiler.compile_grammar(ebnf)
|
ctx = self.grammar_compiler.compile_grammar(ebnf)
|
||||||
|
|||||||
@@ -60,6 +60,86 @@ class TestTemplateContentFormatDetection(CustomTestCase):
|
|||||||
result = detect_jinja_template_content_format("")
|
result = detect_jinja_template_content_format("")
|
||||||
self.assertEqual(result, "string")
|
self.assertEqual(result, "string")
|
||||||
|
|
||||||
|
def test_detect_msg_content_pattern(self):
|
||||||
|
"""Test detection of template with msg.content pattern (should be 'openai' format)."""
|
||||||
|
msg_content_pattern = """
|
||||||
|
[gMASK]<sop>
|
||||||
|
{%- for msg in messages %}
|
||||||
|
{%- if msg.role == 'system' %}
|
||||||
|
<|system|>
|
||||||
|
{{ msg.content }}
|
||||||
|
{%- elif msg.role == 'user' %}
|
||||||
|
<|user|>{{ '\n' }}
|
||||||
|
{%- if msg.content is string %}
|
||||||
|
{{ msg.content }}
|
||||||
|
{%- else %}
|
||||||
|
{%- for item in msg.content %}
|
||||||
|
{%- if item.type == 'video' or 'video' in item %}
|
||||||
|
<|begin_of_video|><|video|><|end_of_video|>
|
||||||
|
{%- elif item.type == 'image' or 'image' in item %}
|
||||||
|
<|begin_of_image|><|image|><|end_of_image|>
|
||||||
|
{%- elif item.type == 'text' %}
|
||||||
|
{{ item.text }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- elif msg.role == 'assistant' %}
|
||||||
|
{%- if msg.metadata %}
|
||||||
|
<|assistant|>{{ msg.metadata }}
|
||||||
|
{{ msg.content }}
|
||||||
|
{%- else %}
|
||||||
|
<|assistant|>
|
||||||
|
{{ msg.content }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{% if add_generation_prompt %}<|assistant|>
|
||||||
|
{% endif %}
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = detect_jinja_template_content_format(msg_content_pattern)
|
||||||
|
self.assertEqual(result, "openai")
|
||||||
|
|
||||||
|
def test_detect_m_content_pattern(self):
|
||||||
|
"""Test detection of template with m.content pattern (should be 'openai' format)."""
|
||||||
|
msg_content_pattern = """
|
||||||
|
[gMASK]<sop>
|
||||||
|
{%- for m in messages %}
|
||||||
|
{%- if m.role == 'system' %}
|
||||||
|
<|system|>
|
||||||
|
{{ m.content }}
|
||||||
|
{%- elif m.role == 'user' %}
|
||||||
|
<|user|>{{ '\n' }}
|
||||||
|
{%- if m.content is string %}
|
||||||
|
{{ m.content }}
|
||||||
|
{%- else %}
|
||||||
|
{%- for item in m.content %}
|
||||||
|
{%- if item.type == 'video' or 'video' in item %}
|
||||||
|
<|begin_of_video|><|video|><|end_of_video|>
|
||||||
|
{%- elif item.type == 'image' or 'image' in item %}
|
||||||
|
<|begin_of_image|><|image|><|end_of_image|>
|
||||||
|
{%- elif item.type == 'text' %}
|
||||||
|
{{ item.text }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- elif m.role == 'assistant' %}
|
||||||
|
{%- if m.metadata %}
|
||||||
|
<|assistant|>{{ m.metadata }}
|
||||||
|
{{ m.content }}
|
||||||
|
{%- else %}
|
||||||
|
<|assistant|>
|
||||||
|
{{ m.content }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{% if add_generation_prompt %}<|assistant|>
|
||||||
|
{% endif %}
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = detect_jinja_template_content_format(msg_content_pattern)
|
||||||
|
self.assertEqual(result, "openai")
|
||||||
|
|
||||||
def test_process_content_openai_format(self):
|
def test_process_content_openai_format(self):
|
||||||
"""Test content processing for openai format."""
|
"""Test content processing for openai format."""
|
||||||
msg_dict = {
|
msg_dict = {
|
||||||
|
|||||||
@@ -348,6 +348,33 @@ class TestVILAServer(TestOpenAIVisionServer):
|
|||||||
cls.base_url += "/v1"
|
cls.base_url += "/v1"
|
||||||
|
|
||||||
|
|
||||||
|
# Skip for ci test
|
||||||
|
# class TestGLM41VServer(TestOpenAIVisionServer):
|
||||||
|
# @classmethod
|
||||||
|
# def setUpClass(cls):
|
||||||
|
# cls.model = "zai-org/GLM-4.1V-9B-Thinking"
|
||||||
|
# cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
# cls.api_key = "sk-123456"
|
||||||
|
# cls.process = popen_launch_server(
|
||||||
|
# cls.model,
|
||||||
|
# cls.base_url,
|
||||||
|
# timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
# other_args=[
|
||||||
|
# "--trust-remote-code",
|
||||||
|
# "--mem-fraction-static",
|
||||||
|
# "0.68",
|
||||||
|
# "--cuda-graph-max-bs",
|
||||||
|
# "4",
|
||||||
|
# "--reasoning-parser",
|
||||||
|
# "glm45",
|
||||||
|
# ],
|
||||||
|
# )
|
||||||
|
# cls.base_url += "/v1"
|
||||||
|
|
||||||
|
# def test_video_chat_completion(self):
|
||||||
|
# self._test_video_chat_completion()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
del TestOpenAIVisionServer
|
del TestOpenAIVisionServer
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -96,8 +96,13 @@ class TestOpenAIVisionServer(CustomTestCase):
|
|||||||
), f"text: {text}, should contain cab, taxi, SUV, vehicle or car"
|
), f"text: {text}, should contain cab, taxi, SUV, vehicle or car"
|
||||||
# MiniCPMO fails to recognize `iron`, but `hanging`
|
# MiniCPMO fails to recognize `iron`, but `hanging`
|
||||||
assert (
|
assert (
|
||||||
"iron" in text or "hang" in text or "cloth" in text or "holding" in text
|
"iron" in text
|
||||||
), f"text: {text}, should contain iron, hang, cloth or holding"
|
or "hang" in text
|
||||||
|
or "cloth" in text
|
||||||
|
or "coat" in text
|
||||||
|
or "holding" in text
|
||||||
|
or "outfit" in text
|
||||||
|
), f"text: {text}, should contain iron, hang, cloth, coat or holding or outfit"
|
||||||
assert response.id
|
assert response.id
|
||||||
assert response.created
|
assert response.created
|
||||||
assert response.usage.prompt_tokens > 0
|
assert response.usage.prompt_tokens > 0
|
||||||
@@ -193,11 +198,15 @@ class TestOpenAIVisionServer(CustomTestCase):
|
|||||||
print(f"Multi images response:\n{text}")
|
print(f"Multi images response:\n{text}")
|
||||||
print("-" * 30)
|
print("-" * 30)
|
||||||
assert (
|
assert (
|
||||||
"man" in text or "cab" in text or "SUV" in text or "taxi" in text
|
"man" in text
|
||||||
), f"text: {text}, should contain man, cab, SUV or taxi"
|
or "cab" in text
|
||||||
|
or "SUV" in text
|
||||||
|
or "taxi" in text
|
||||||
|
or "car" in text
|
||||||
|
), f"text: {text}, should contain man, cab, SUV, taxi or car"
|
||||||
assert (
|
assert (
|
||||||
"logo" in text or '"S"' in text or "SG" in text
|
"logo" in text or '"S"' in text or "SG" in text or "graphic" in text
|
||||||
), f"text: {text}, should contain logo, S or SG"
|
), f"text: {text}, should contain logo, S or SG or graphic"
|
||||||
assert response.id
|
assert response.id
|
||||||
assert response.created
|
assert response.created
|
||||||
assert response.usage.prompt_tokens > 0
|
assert response.usage.prompt_tokens > 0
|
||||||
@@ -320,11 +329,12 @@ class TestOpenAIVisionServer(CustomTestCase):
|
|||||||
or "individual" in video_response
|
or "individual" in video_response
|
||||||
or "speaker" in video_response
|
or "speaker" in video_response
|
||||||
or "Steve" in video_response
|
or "Steve" in video_response
|
||||||
|
or "hand" in video_response
|
||||||
), f"""
|
), f"""
|
||||||
====================== video_response =====================
|
====================== video_response =====================
|
||||||
{video_response}
|
{video_response}
|
||||||
===========================================================
|
===========================================================
|
||||||
should contain 'man' or 'person' or 'individual' or 'speaker'
|
should contain 'man' or 'person' or 'individual' or 'speaker' or 'hand'
|
||||||
"""
|
"""
|
||||||
assert (
|
assert (
|
||||||
"present" in video_response
|
"present" in video_response
|
||||||
@@ -375,7 +385,8 @@ class TestOpenAIVisionServer(CustomTestCase):
|
|||||||
or "person" in video_response
|
or "person" in video_response
|
||||||
or "individual" in video_response
|
or "individual" in video_response
|
||||||
or "speaker" in video_response
|
or "speaker" in video_response
|
||||||
), f"video_response: {video_response}, should either have 'man' in video_response, or 'person' in video_response, or 'individual' in video_response or 'speaker' in video_response"
|
or "hand" in video_response
|
||||||
|
), f"video_response: {video_response}, should either have 'man' in video_response, or 'person' in video_response, or 'individual' in video_response, or 'speaker' in video_response or 'hand' in video_response"
|
||||||
assert (
|
assert (
|
||||||
"present" in video_response
|
"present" in video_response
|
||||||
or "examine" in video_response
|
or "examine" in video_response
|
||||||
|
|||||||
Reference in New Issue
Block a user