[model] Support POINTSV15Chat (#9651)
Co-authored-by: josephyou <josephyou@tencent.com> Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com> Co-authored-by: root <root@TENCENT64.site>
This commit is contained in:
@@ -917,6 +917,7 @@ multimodal_model_archs = [
|
||||
"Phi4MMForCausalLM",
|
||||
"VILAForConditionalGeneration",
|
||||
"Step3VLForConditionalGeneration",
|
||||
"POINTSV15ChatModel",
|
||||
"DotsVLMForCausalLM",
|
||||
"DotsOCRForCausalLM",
|
||||
"Sarashina2VisionForCausalLM",
|
||||
|
||||
29
python/sglang/srt/configs/points_v15_chat.py
Normal file
29
python/sglang/srt/configs/points_v15_chat.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
from transformers import PretrainedConfig, Qwen2Config
|
||||
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig
|
||||
|
||||
|
||||
class POINTSV15ChatConfig(PretrainedConfig):
|
||||
model_type = "pointsv1.5_chat"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config: Optional[Union[dict, Qwen2VLVisionConfig]] = None,
|
||||
llm_config: Optional[Union[dict, Qwen2Config]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
if vision_config is None:
|
||||
vision_config = Qwen2VLVisionConfig()
|
||||
elif isinstance(vision_config, dict):
|
||||
vision_config = Qwen2VLVisionConfig(**vision_config)
|
||||
self.vision_config = vision_config
|
||||
|
||||
if llm_config is None:
|
||||
llm_config = Qwen2Config()
|
||||
elif isinstance(llm_config, dict):
|
||||
llm_config = Qwen2Config(**llm_config)
|
||||
|
||||
self.llm_config = llm_config
|
||||
self.hidden_size = self.llm_config.hidden_size
|
||||
186
python/sglang/srt/models/points_v15_chat.py
Normal file
186
python/sglang/srt/models/points_v15_chat.py
Normal file
@@ -0,0 +1,186 @@
|
||||
import copy
|
||||
from typing import Iterable, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.configs.points_v15_chat import POINTSV15ChatConfig
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.managers.mm_utils import (
|
||||
MultiModalityDataPaddingPatternMultimodalTokens,
|
||||
general_mm_embed_routine,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
Modality,
|
||||
MultimodalDataItem,
|
||||
MultimodalInputs,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||
from sglang.srt.models.qwen2_vl import Qwen2VisionPatchMerger, Qwen2VisionTransformer
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
|
||||
class Qwen2VisionTransformerForNavitPOINTS(Qwen2VisionTransformer):
|
||||
def __init__(
|
||||
self,
|
||||
vision_config: POINTSV15ChatConfig,
|
||||
norm_eps: float = 1e-6,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(
|
||||
vision_config,
|
||||
norm_eps=norm_eps,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
grid_thw: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# patchify
|
||||
x = x.to(device=self.device, dtype=self.dtype)
|
||||
x = self.patch_embed(x)
|
||||
|
||||
# compute position embedding
|
||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||
position_embeddings = (emb.cos(), emb.sin())
|
||||
|
||||
# compute cu_seqlens
|
||||
cu_seqlens = torch.repeat_interleave(
|
||||
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||
).cumsum(dim=0, dtype=torch.int32)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
||||
|
||||
# transformers
|
||||
x = x.unsqueeze(1)
|
||||
for blk in self.blocks:
|
||||
x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class POINTSV15ChatModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: POINTSV15ChatConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
config.llm_config._attn_implementation = "flash_attention_2"
|
||||
config._attn_implementation_autoset = False
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
|
||||
llm_config = copy.deepcopy(config.llm_config)
|
||||
llm_config.architectures = ["Qwen2ForCausalLM"]
|
||||
self.llm = Qwen2ForCausalLM(
|
||||
config=llm_config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("llm", prefix),
|
||||
)
|
||||
|
||||
self.vision_encoder = Qwen2VisionTransformerForNavitPOINTS(
|
||||
config.vision_config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("vision_encoder", prefix),
|
||||
)
|
||||
|
||||
self.vision_projector = Qwen2VisionPatchMerger(
|
||||
d_model=config.llm_config.hidden_size,
|
||||
context_dim=1280,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("vision_projector", prefix),
|
||||
)
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
||||
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
||||
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||
|
||||
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
|
||||
self.vision_encoder.dtype
|
||||
)
|
||||
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
|
||||
|
||||
assert pixel_values.dim() == 2, pixel_values.dim()
|
||||
assert image_grid_thw.dim() == 2, image_grid_thw.dim()
|
||||
|
||||
image_features = self.vision_encoder(pixel_values, grid_thw=image_grid_thw)
|
||||
image_features = self.vision_projector(image_features)
|
||||
return image_features
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
get_embedding: bool = False,
|
||||
):
|
||||
hidden_states = general_mm_embed_routine(
|
||||
input_ids=input_ids,
|
||||
forward_batch=forward_batch,
|
||||
language_model=self.llm,
|
||||
data_embedding_funcs={
|
||||
Modality.IMAGE: self.get_image_feature,
|
||||
},
|
||||
positions=positions,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
if "vision_encoder" in name:
|
||||
# adapt to VisionAttention
|
||||
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
|
||||
|
||||
try:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
except KeyError:
|
||||
print(params_dict.keys())
|
||||
raise
|
||||
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
EntryClass = [POINTSV15ChatModel]
|
||||
52
python/sglang/srt/multimodal/processors/points_v15_chat.py
Normal file
52
python/sglang/srt/multimodal/processors/points_v15_chat.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# Copy from qwen_vl.py, adapted for points-v15-chat
|
||||
|
||||
import asyncio
|
||||
from typing import List, Union
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from sglang.srt.models.points_v15_chat import POINTSV15ChatModel
|
||||
from sglang.srt.multimodal.processors.qwen_vl import (
|
||||
Qwen2_5VLImageProcessor,
|
||||
resize_image_async,
|
||||
)
|
||||
|
||||
|
||||
class POINTSV15ChatProcessor(Qwen2_5VLImageProcessor):
|
||||
models = [POINTSV15ChatModel]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||
# Compatible with POINTSV15Chat
|
||||
hf_config.vision_start_token_id = None
|
||||
hf_config.vision_end_token_id = None
|
||||
hf_config.video_token_id = None
|
||||
|
||||
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_text,
|
||||
request_obj,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
)
|
||||
|
||||
if base_output.images and isinstance(base_output.images[0], Image.Image):
|
||||
resize_tasks = [resize_image_async(image) for image in base_output.images]
|
||||
base_output.images = await asyncio.gather(*resize_tasks)
|
||||
|
||||
mm_items, input_ids, _ = self.process_and_combine_mm_data(
|
||||
base_output, self.mm_tokens
|
||||
)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids.tolist(),
|
||||
"mm_items": mm_items,
|
||||
"im_token_id": self.mm_tokens.image_token_id,
|
||||
}
|
||||
@@ -960,6 +960,19 @@ register_conv_template(
|
||||
)
|
||||
)
|
||||
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="points-v15-chat",
|
||||
system_message="",
|
||||
system_template="",
|
||||
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
||||
sep="<|im_end|>\n",
|
||||
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
|
||||
stop_str=["<|im_end|>"],
|
||||
image_token="<|vision_start|><|image_pad|><|vision_end|>",
|
||||
video_token="<|vision_start|><|video_pad|><|vision_end|>",
|
||||
)
|
||||
)
|
||||
|
||||
MODEL_TYPE_TO_TEMPLATE = {
|
||||
"internvl_chat": "internvl-2-5",
|
||||
@@ -971,6 +984,12 @@ MODEL_TYPE_TO_TEMPLATE = {
|
||||
}
|
||||
|
||||
|
||||
@register_conv_template_matching_function
|
||||
def match_points_v15_chat(model_path: str):
|
||||
if re.search(r"points", model_path, re.IGNORECASE):
|
||||
return "points-v15-chat"
|
||||
|
||||
|
||||
def get_model_type(model_path: str) -> Optional[str]:
|
||||
config_path = os.path.join(model_path, "config.json")
|
||||
if not os.path.exists(config_path):
|
||||
|
||||
@@ -111,6 +111,12 @@ def get_hf_text_config(config: PretrainedConfig):
|
||||
# if transformers config doesn't align with this assumption.
|
||||
assert hasattr(config.text_config, "num_attention_heads")
|
||||
return config.text_config
|
||||
|
||||
if hasattr(config, "llm_config"):
|
||||
# PointsV1.5 Chat Model
|
||||
assert hasattr(config.llm_config, "num_attention_heads")
|
||||
return config.llm_config
|
||||
|
||||
if hasattr(config, "language_config"):
|
||||
return config.language_config
|
||||
if hasattr(config, "thinker_config"):
|
||||
|
||||
Reference in New Issue
Block a user