[model] support MiniCPM-V 4.0 (#8747)

Signed-off-by: tc-mb <caitianchi@modelbest.cn>
Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
This commit is contained in:
tc-mb
2025-09-03 06:33:03 +08:00
committed by GitHub
parent 11dcabc545
commit 03dbf1aa8e
4 changed files with 246 additions and 6 deletions

View File

@@ -54,6 +54,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.utils import set_default_torch_dtype
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.idefics2 import Idefics2VisionTransformer
from sglang.srt.models.llama import LlamaConfig, LlamaForCausalLM
from sglang.srt.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
from sglang.srt.utils import add_prefix, flatten_nested_list
@@ -581,7 +582,7 @@ class MiniCPMBaseModel(nn.Module):
def init_llm(
self,
config: Qwen2Config,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> nn.Module:
@@ -774,7 +775,168 @@ class MiniCPMV2_6(MiniCPMBaseModel):
return pattern.pad_input_tokens(input_ids, image_inputs)
_SUPPORT_VERSION = {(2, 6): MiniCPMV2_6}
class MiniCPMV4_0(MiniCPMBaseModel):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
# vision encoder
"fc1",
"fc2",
"out_proj",
# language model
"qkv_proj", # same name with vision encoder
"o_proj",
"gate_up_proj",
"down_proj",
# resampler
"kv_proj",
]
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
embedding_modules = {}
embedding_padding_modules = []
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__(config=config, quant_config=quant_config, prefix=prefix)
assert self.version == (4, 0)
def init_llm(
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> nn.Module:
return LlamaForCausalLM(config=config, quant_config=quant_config, prefix=prefix)
def init_vision_module(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> nn.Module:
model = Idefics2VisionTransformer(
config=config.vision_config, quant_config=quant_config, prefix=prefix
)
if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1]
setattr(model, "embed_dim", model.embeddings.embed_dim)
setattr(model, "patch_size", model.embeddings.patch_size)
return model
def init_resampler(
self,
embed_dim: int,
vision_dim: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> nn.Module:
with set_default_torch_dtype(torch.float16):
# The resampler in 2.6 remains consistent with the one in 2.5.
resampler = Resampler2_5(
num_queries=self.config.query_num,
embed_dim=embed_dim,
num_heads=embed_dim // 128,
kv_dim=vision_dim,
quant_config=quant_config,
prefix=prefix,
)
return resampler.to(device="cuda", dtype=torch.get_default_dtype())
def get_vision_embedding(
self,
pixel_values: List[torch.Tensor],
patch_attn_mask: Optional[torch.Tensor] = None,
tgt_sizes: Optional[torch.Tensor] = None,
) -> torch.Tensor:
vision_embedding = self.vpm(
pixel_values,
patch_attention_mask=patch_attn_mask,
tgt_sizes=tgt_sizes,
)
return vision_embedding
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
# list of tensors
pixel_values = flatten_nested_list([item.feature for item in items])
tgt_sizes = torch.stack(
flatten_nested_list([item.tgt_size for item in items]), dim=0
)
assert len(pixel_values) == tgt_sizes.shape[0]
device = self.vpm.embeddings.position_embedding.weight.device
dtype = self.vpm.embeddings.position_embedding.weight.dtype
all_pixel_values_lst = [
i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
]
max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
assert isinstance(max_patches, int)
all_pixel_values = torch.nn.utils.rnn.pad_sequence(
all_pixel_values_lst, batch_first=True, padding_value=0.0
)
B, L, _ = all_pixel_values.shape
all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
patch_attn_mask = torch.zeros(
(B, 1, max_patches), dtype=torch.bool, device=device
)
tgt_sizes_tensor = tgt_sizes.clone().to(device=patch_attn_mask.device)
mask_shapes = tgt_sizes_tensor[:, 0] * tgt_sizes_tensor[:, 1]
patch_attn_mask[:, 0, :] = torch.arange(
patch_attn_mask.size(2), device=patch_attn_mask.device
).unsqueeze(0) < mask_shapes.unsqueeze(1)
vision_embedding = self.vpm(
all_pixel_values.type(dtype),
patch_attention_mask=patch_attn_mask,
tgt_sizes=tgt_sizes,
)
return self.resampler(vision_embedding, tgt_sizes)
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
# Get all special token IDs
im_start_id: int = image_inputs.im_start_id
im_end_id: int = image_inputs.im_end_id
slice_start_id: int = image_inputs.slice_start_id
slice_end_id: int = image_inputs.slice_end_id
media_token_pairs = [(im_start_id, im_end_id), (slice_start_id, slice_end_id)]
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
return pattern.pad_input_tokens(input_ids, image_inputs)
_SUPPORT_VERSION = {
(2, 6): MiniCPMV2_6,
(4, 0): MiniCPMV4_0,
}
class MiniCPMV:
@@ -809,7 +971,7 @@ class MiniCPMV:
# Dispatch class based on version
instance_class = _SUPPORT_VERSION.get(version)
if instance_class is None:
raise ValueError("Currently, MiniCPMV only supports versions 2.6")
raise ValueError("Currently, MiniCPMV only supports versions 2.6 and 4.0")
try:
minicpmv = instance_class(