[model] Support POINTSV15Chat (#9651)
Co-authored-by: josephyou <josephyou@tencent.com> Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com> Co-authored-by: root <root@TENCENT64.site>
This commit is contained in:
@@ -917,6 +917,7 @@ multimodal_model_archs = [
|
|||||||
"Phi4MMForCausalLM",
|
"Phi4MMForCausalLM",
|
||||||
"VILAForConditionalGeneration",
|
"VILAForConditionalGeneration",
|
||||||
"Step3VLForConditionalGeneration",
|
"Step3VLForConditionalGeneration",
|
||||||
|
"POINTSV15ChatModel",
|
||||||
"DotsVLMForCausalLM",
|
"DotsVLMForCausalLM",
|
||||||
"DotsOCRForCausalLM",
|
"DotsOCRForCausalLM",
|
||||||
"Sarashina2VisionForCausalLM",
|
"Sarashina2VisionForCausalLM",
|
||||||
|
|||||||
29
python/sglang/srt/configs/points_v15_chat.py
Normal file
29
python/sglang/srt/configs/points_v15_chat.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from transformers import PretrainedConfig, Qwen2Config
|
||||||
|
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig
|
||||||
|
|
||||||
|
|
||||||
|
class POINTSV15ChatConfig(PretrainedConfig):
|
||||||
|
model_type = "pointsv1.5_chat"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vision_config: Optional[Union[dict, Qwen2VLVisionConfig]] = None,
|
||||||
|
llm_config: Optional[Union[dict, Qwen2Config]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
if vision_config is None:
|
||||||
|
vision_config = Qwen2VLVisionConfig()
|
||||||
|
elif isinstance(vision_config, dict):
|
||||||
|
vision_config = Qwen2VLVisionConfig(**vision_config)
|
||||||
|
self.vision_config = vision_config
|
||||||
|
|
||||||
|
if llm_config is None:
|
||||||
|
llm_config = Qwen2Config()
|
||||||
|
elif isinstance(llm_config, dict):
|
||||||
|
llm_config = Qwen2Config(**llm_config)
|
||||||
|
|
||||||
|
self.llm_config = llm_config
|
||||||
|
self.hidden_size = self.llm_config.hidden_size
|
||||||
186
python/sglang/srt/models/points_v15_chat.py
Normal file
186
python/sglang/srt/models/points_v15_chat.py
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
import copy
|
||||||
|
from typing import Iterable, List, Optional, Set, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from sglang.srt.configs.points_v15_chat import POINTSV15ChatConfig
|
||||||
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
|
from sglang.srt.managers.mm_utils import (
|
||||||
|
MultiModalityDataPaddingPatternMultimodalTokens,
|
||||||
|
general_mm_embed_routine,
|
||||||
|
)
|
||||||
|
from sglang.srt.managers.schedule_batch import (
|
||||||
|
Modality,
|
||||||
|
MultimodalDataItem,
|
||||||
|
MultimodalInputs,
|
||||||
|
)
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||||
|
from sglang.srt.models.qwen2_vl import Qwen2VisionPatchMerger, Qwen2VisionTransformer
|
||||||
|
from sglang.srt.utils import add_prefix
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2VisionTransformerForNavitPOINTS(Qwen2VisionTransformer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vision_config: POINTSV15ChatConfig,
|
||||||
|
norm_eps: float = 1e-6,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
vision_config,
|
||||||
|
norm_eps=norm_eps,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=prefix,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
grid_thw: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# patchify
|
||||||
|
x = x.to(device=self.device, dtype=self.dtype)
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
|
||||||
|
# compute position embedding
|
||||||
|
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||||
|
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||||
|
position_embeddings = (emb.cos(), emb.sin())
|
||||||
|
|
||||||
|
# compute cu_seqlens
|
||||||
|
cu_seqlens = torch.repeat_interleave(
|
||||||
|
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||||
|
).cumsum(dim=0, dtype=torch.int32)
|
||||||
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
||||||
|
|
||||||
|
# transformers
|
||||||
|
x = x.unsqueeze(1)
|
||||||
|
for blk in self.blocks:
|
||||||
|
x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class POINTSV15ChatModel(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: POINTSV15ChatConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
config.llm_config._attn_implementation = "flash_attention_2"
|
||||||
|
config._attn_implementation_autoset = False
|
||||||
|
self.config = config
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
llm_config = copy.deepcopy(config.llm_config)
|
||||||
|
llm_config.architectures = ["Qwen2ForCausalLM"]
|
||||||
|
self.llm = Qwen2ForCausalLM(
|
||||||
|
config=llm_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("llm", prefix),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.vision_encoder = Qwen2VisionTransformerForNavitPOINTS(
|
||||||
|
config.vision_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("vision_encoder", prefix),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.vision_projector = Qwen2VisionPatchMerger(
|
||||||
|
d_model=config.llm_config.hidden_size,
|
||||||
|
context_dim=1280,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("vision_projector", prefix),
|
||||||
|
)
|
||||||
|
|
||||||
|
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
||||||
|
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
||||||
|
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||||
|
|
||||||
|
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||||
|
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
|
||||||
|
self.vision_encoder.dtype
|
||||||
|
)
|
||||||
|
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
|
||||||
|
|
||||||
|
assert pixel_values.dim() == 2, pixel_values.dim()
|
||||||
|
assert image_grid_thw.dim() == 2, image_grid_thw.dim()
|
||||||
|
|
||||||
|
image_features = self.vision_encoder(pixel_values, grid_thw=image_grid_thw)
|
||||||
|
image_features = self.vision_projector(image_features)
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
get_embedding: bool = False,
|
||||||
|
):
|
||||||
|
hidden_states = general_mm_embed_routine(
|
||||||
|
input_ids=input_ids,
|
||||||
|
forward_batch=forward_batch,
|
||||||
|
language_model=self.llm,
|
||||||
|
data_embedding_funcs={
|
||||||
|
Modality.IMAGE: self.get_image_feature,
|
||||||
|
},
|
||||||
|
positions=positions,
|
||||||
|
)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
("gate_up_proj", "gate_proj", 0),
|
||||||
|
("gate_up_proj", "up_proj", 1),
|
||||||
|
]
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params: Set[str] = set()
|
||||||
|
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if "vision_encoder" in name:
|
||||||
|
# adapt to VisionAttention
|
||||||
|
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
except KeyError:
|
||||||
|
print(params_dict.keys())
|
||||||
|
raise
|
||||||
|
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
EntryClass = [POINTSV15ChatModel]
|
||||||
52
python/sglang/srt/multimodal/processors/points_v15_chat.py
Normal file
52
python/sglang/srt/multimodal/processors/points_v15_chat.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
# Copy from qwen_vl.py, adapted for points-v15-chat
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from sglang.srt.models.points_v15_chat import POINTSV15ChatModel
|
||||||
|
from sglang.srt.multimodal.processors.qwen_vl import (
|
||||||
|
Qwen2_5VLImageProcessor,
|
||||||
|
resize_image_async,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class POINTSV15ChatProcessor(Qwen2_5VLImageProcessor):
|
||||||
|
models = [POINTSV15ChatModel]
|
||||||
|
|
||||||
|
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||||
|
# Compatible with POINTSV15Chat
|
||||||
|
hf_config.vision_start_token_id = None
|
||||||
|
hf_config.vision_end_token_id = None
|
||||||
|
hf_config.video_token_id = None
|
||||||
|
|
||||||
|
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||||
|
|
||||||
|
async def process_mm_data_async(
|
||||||
|
self,
|
||||||
|
image_data: List[Union[str, bytes]],
|
||||||
|
input_text,
|
||||||
|
request_obj,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
base_output = self.load_mm_data(
|
||||||
|
prompt=input_text,
|
||||||
|
image_data=image_data,
|
||||||
|
multimodal_tokens=self.mm_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
if base_output.images and isinstance(base_output.images[0], Image.Image):
|
||||||
|
resize_tasks = [resize_image_async(image) for image in base_output.images]
|
||||||
|
base_output.images = await asyncio.gather(*resize_tasks)
|
||||||
|
|
||||||
|
mm_items, input_ids, _ = self.process_and_combine_mm_data(
|
||||||
|
base_output, self.mm_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids.tolist(),
|
||||||
|
"mm_items": mm_items,
|
||||||
|
"im_token_id": self.mm_tokens.image_token_id,
|
||||||
|
}
|
||||||
@@ -960,6 +960,19 @@ register_conv_template(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
register_conv_template(
|
||||||
|
Conversation(
|
||||||
|
name="points-v15-chat",
|
||||||
|
system_message="",
|
||||||
|
system_template="",
|
||||||
|
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
||||||
|
sep="<|im_end|>\n",
|
||||||
|
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
|
||||||
|
stop_str=["<|im_end|>"],
|
||||||
|
image_token="<|vision_start|><|image_pad|><|vision_end|>",
|
||||||
|
video_token="<|vision_start|><|video_pad|><|vision_end|>",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
MODEL_TYPE_TO_TEMPLATE = {
|
MODEL_TYPE_TO_TEMPLATE = {
|
||||||
"internvl_chat": "internvl-2-5",
|
"internvl_chat": "internvl-2-5",
|
||||||
@@ -971,6 +984,12 @@ MODEL_TYPE_TO_TEMPLATE = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@register_conv_template_matching_function
|
||||||
|
def match_points_v15_chat(model_path: str):
|
||||||
|
if re.search(r"points", model_path, re.IGNORECASE):
|
||||||
|
return "points-v15-chat"
|
||||||
|
|
||||||
|
|
||||||
def get_model_type(model_path: str) -> Optional[str]:
|
def get_model_type(model_path: str) -> Optional[str]:
|
||||||
config_path = os.path.join(model_path, "config.json")
|
config_path = os.path.join(model_path, "config.json")
|
||||||
if not os.path.exists(config_path):
|
if not os.path.exists(config_path):
|
||||||
|
|||||||
@@ -111,6 +111,12 @@ def get_hf_text_config(config: PretrainedConfig):
|
|||||||
# if transformers config doesn't align with this assumption.
|
# if transformers config doesn't align with this assumption.
|
||||||
assert hasattr(config.text_config, "num_attention_heads")
|
assert hasattr(config.text_config, "num_attention_heads")
|
||||||
return config.text_config
|
return config.text_config
|
||||||
|
|
||||||
|
if hasattr(config, "llm_config"):
|
||||||
|
# PointsV1.5 Chat Model
|
||||||
|
assert hasattr(config.llm_config, "num_attention_heads")
|
||||||
|
return config.llm_config
|
||||||
|
|
||||||
if hasattr(config, "language_config"):
|
if hasattr(config, "language_config"):
|
||||||
return config.language_config
|
return config.language_config
|
||||||
if hasattr(config, "thinker_config"):
|
if hasattr(config, "thinker_config"):
|
||||||
|
|||||||
Reference in New Issue
Block a user