diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index d5ec94f3d..a674178bc 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -917,6 +917,7 @@ multimodal_model_archs = [ "Phi4MMForCausalLM", "VILAForConditionalGeneration", "Step3VLForConditionalGeneration", + "POINTSV15ChatModel", "DotsVLMForCausalLM", "DotsOCRForCausalLM", "Sarashina2VisionForCausalLM", diff --git a/python/sglang/srt/configs/points_v15_chat.py b/python/sglang/srt/configs/points_v15_chat.py new file mode 100644 index 000000000..758939b27 --- /dev/null +++ b/python/sglang/srt/configs/points_v15_chat.py @@ -0,0 +1,29 @@ +from typing import Optional, Union + +from transformers import PretrainedConfig, Qwen2Config +from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig + + +class POINTSV15ChatConfig(PretrainedConfig): + model_type = "pointsv1.5_chat" + + def __init__( + self, + vision_config: Optional[Union[dict, Qwen2VLVisionConfig]] = None, + llm_config: Optional[Union[dict, Qwen2Config]] = None, + **kwargs, + ): + super().__init__(**kwargs) + if vision_config is None: + vision_config = Qwen2VLVisionConfig() + elif isinstance(vision_config, dict): + vision_config = Qwen2VLVisionConfig(**vision_config) + self.vision_config = vision_config + + if llm_config is None: + llm_config = Qwen2Config() + elif isinstance(llm_config, dict): + llm_config = Qwen2Config(**llm_config) + + self.llm_config = llm_config + self.hidden_size = self.llm_config.hidden_size diff --git a/python/sglang/srt/models/points_v15_chat.py b/python/sglang/srt/models/points_v15_chat.py new file mode 100644 index 000000000..79a74ca2c --- /dev/null +++ b/python/sglang/srt/models/points_v15_chat.py @@ -0,0 +1,186 @@ +import copy +from typing import Iterable, List, Optional, Set, Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from sglang.srt.configs.points_v15_chat import POINTSV15ChatConfig +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.managers.mm_utils import ( + MultiModalityDataPaddingPatternMultimodalTokens, + general_mm_embed_routine, +) +from sglang.srt.managers.schedule_batch import ( + Modality, + MultimodalDataItem, + MultimodalInputs, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.qwen2 import Qwen2ForCausalLM +from sglang.srt.models.qwen2_vl import Qwen2VisionPatchMerger, Qwen2VisionTransformer +from sglang.srt.utils import add_prefix + + +class Qwen2VisionTransformerForNavitPOINTS(Qwen2VisionTransformer): + def __init__( + self, + vision_config: POINTSV15ChatConfig, + norm_eps: float = 1e-6, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__( + vision_config, + norm_eps=norm_eps, + quant_config=quant_config, + prefix=prefix, + ) + + def forward( + self, + x: torch.Tensor, + grid_thw: torch.Tensor, + ) -> torch.Tensor: + # patchify + x = x.to(device=self.device, dtype=self.dtype) + x = self.patch_embed(x) + + # compute position embedding + rotary_pos_emb = self.rot_pos_emb(grid_thw) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + # compute cu_seqlens + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum(dim=0, dtype=torch.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) + + # transformers + x = x.unsqueeze(1) + for blk in self.blocks: + x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings) + + return x + + +class POINTSV15ChatModel(nn.Module): + def __init__( + self, + config: POINTSV15ChatConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + **kwargs, + ) -> None: + super().__init__() + config.llm_config._attn_implementation = "flash_attention_2" + config._attn_implementation_autoset = False + self.config = config + self.quant_config = quant_config + + llm_config = copy.deepcopy(config.llm_config) + llm_config.architectures = ["Qwen2ForCausalLM"] + self.llm = Qwen2ForCausalLM( + config=llm_config, + quant_config=quant_config, + prefix=add_prefix("llm", prefix), + ) + + self.vision_encoder = Qwen2VisionTransformerForNavitPOINTS( + config.vision_config, + quant_config=quant_config, + prefix=add_prefix("vision_encoder", prefix), + ) + + self.vision_projector = Qwen2VisionPatchMerger( + d_model=config.llm_config.hidden_size, + context_dim=1280, + quant_config=quant_config, + prefix=add_prefix("vision_projector", prefix), + ) + + def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): + pattern = MultiModalityDataPaddingPatternMultimodalTokens() + return pattern.pad_input_tokens(input_ids, mm_inputs) + + def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: + pixel_values = torch.cat([item.feature for item in items], dim=0).type( + self.vision_encoder.dtype + ) + image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0) + + assert pixel_values.dim() == 2, pixel_values.dim() + assert image_grid_thw.dim() == 2, image_grid_thw.dim() + + image_features = self.vision_encoder(pixel_values, grid_thw=image_grid_thw) + image_features = self.vision_projector(image_features) + return image_features + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + get_embedding: bool = False, + ): + hidden_states = general_mm_embed_routine( + input_ids=input_ids, + forward_batch=forward_batch, + language_model=self.llm, + data_embedding_funcs={ + Modality.IMAGE: self.get_image_feature, + }, + positions=positions, + ) + + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if "vision_encoder" in name: + # adapt to VisionAttention + name = name.replace(r"attn.qkv.", r"attn.qkv_proj.") + + try: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + except KeyError: + print(params_dict.keys()) + raise + + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +EntryClass = [POINTSV15ChatModel] diff --git a/python/sglang/srt/multimodal/processors/points_v15_chat.py b/python/sglang/srt/multimodal/processors/points_v15_chat.py new file mode 100644 index 000000000..c5c674bda --- /dev/null +++ b/python/sglang/srt/multimodal/processors/points_v15_chat.py @@ -0,0 +1,52 @@ +# Copy from qwen_vl.py, adapted for points-v15-chat + +import asyncio +from typing import List, Union + +from PIL import Image + +from sglang.srt.models.points_v15_chat import POINTSV15ChatModel +from sglang.srt.multimodal.processors.qwen_vl import ( + Qwen2_5VLImageProcessor, + resize_image_async, +) + + +class POINTSV15ChatProcessor(Qwen2_5VLImageProcessor): + models = [POINTSV15ChatModel] + + def __init__(self, hf_config, server_args, _processor, *args, **kwargs): + # Compatible with POINTSV15Chat + hf_config.vision_start_token_id = None + hf_config.vision_end_token_id = None + hf_config.video_token_id = None + + super().__init__(hf_config, server_args, _processor, *args, **kwargs) + + async def process_mm_data_async( + self, + image_data: List[Union[str, bytes]], + input_text, + request_obj, + *args, + **kwargs, + ): + base_output = self.load_mm_data( + prompt=input_text, + image_data=image_data, + multimodal_tokens=self.mm_tokens, + ) + + if base_output.images and isinstance(base_output.images[0], Image.Image): + resize_tasks = [resize_image_async(image) for image in base_output.images] + base_output.images = await asyncio.gather(*resize_tasks) + + mm_items, input_ids, _ = self.process_and_combine_mm_data( + base_output, self.mm_tokens + ) + + return { + "input_ids": input_ids.tolist(), + "mm_items": mm_items, + "im_token_id": self.mm_tokens.image_token_id, + } diff --git a/python/sglang/srt/parser/conversation.py b/python/sglang/srt/parser/conversation.py index 8a2fe4e7f..2d03e1bfa 100644 --- a/python/sglang/srt/parser/conversation.py +++ b/python/sglang/srt/parser/conversation.py @@ -960,6 +960,19 @@ register_conv_template( ) ) +register_conv_template( + Conversation( + name="points-v15-chat", + system_message="", + system_template="", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep="<|im_end|>\n", + sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE, + stop_str=["<|im_end|>"], + image_token="<|vision_start|><|image_pad|><|vision_end|>", + video_token="<|vision_start|><|video_pad|><|vision_end|>", + ) +) MODEL_TYPE_TO_TEMPLATE = { "internvl_chat": "internvl-2-5", @@ -971,6 +984,12 @@ MODEL_TYPE_TO_TEMPLATE = { } +@register_conv_template_matching_function +def match_points_v15_chat(model_path: str): + if re.search(r"points", model_path, re.IGNORECASE): + return "points-v15-chat" + + def get_model_type(model_path: str) -> Optional[str]: config_path = os.path.join(model_path, "config.json") if not os.path.exists(config_path): diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index b20fcd605..d4e8b8562 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -111,6 +111,12 @@ def get_hf_text_config(config: PretrainedConfig): # if transformers config doesn't align with this assumption. assert hasattr(config.text_config, "num_attention_heads") return config.text_config + + if hasattr(config, "llm_config"): + # PointsV1.5 Chat Model + assert hasattr(config.llm_config, "num_attention_heads") + return config.llm_config + if hasattr(config, "language_config"): return config.language_config if hasattr(config, "thinker_config"):