diff --git a/docs/references/supported_models.md b/docs/references/supported_models.md
index 1cc7b8747..860841816 100644
--- a/docs/references/supported_models.md
+++ b/docs/references/supported_models.md
@@ -24,7 +24,7 @@
- InternLM 2
- Exaone 3
- BaiChuan2
-- MiniCPM / MiniCPM 3
+- MiniCPM / MiniCPM 3 / MiniCPMV
- XVERSE / XVERSE MoE
- SmolLM
- GLM-4
diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py
index 4a774c4fb..845e1e52d 100644
--- a/python/sglang/lang/chat_template.py
+++ b/python/sglang/lang/chat_template.py
@@ -88,7 +88,6 @@ register_chat_template(
)
)
-
register_chat_template(
ChatTemplate(
name="claude",
@@ -101,7 +100,6 @@ register_chat_template(
)
)
-
register_chat_template(
ChatTemplate(
name="chatml",
@@ -116,7 +114,6 @@ register_chat_template(
)
)
-
register_chat_template(
ChatTemplate(
name="chatml-llava",
@@ -132,7 +129,6 @@ register_chat_template(
)
)
-
# There is default system prompt for qwen
# reference: https://modelscope.cn/models/qwen/Qwen2-72B-Instruct/file/view/master?fileName=tokenizer_config.json&status=1
# The chat template is: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
@@ -219,6 +215,21 @@ register_chat_template(
)
)
+# https://huggingface.co/openbmb/MiniCPM-V-2_6
+register_chat_template(
+ ChatTemplate(
+ name="minicpmv",
+ default_system_prompt=None,
+ role_prefix_and_suffix={
+ "system": ("", " "),
+ "user": ("user:", " "),
+ "assistant": ("assistant:", ""),
+ },
+ stop_str=("<|im_end|>", "<|endoftext|>"),
+ image_token="(./)",
+ )
+)
+
# The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token.
register_chat_template(
ChatTemplate(
diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py
index d087a2f23..6d144f844 100644
--- a/python/sglang/srt/configs/model_config.py
+++ b/python/sglang/srt/configs/model_config.py
@@ -402,6 +402,7 @@ def is_multimodal_model(model_architectures: List[str]):
or "LlavaVidForCausalLM" in model_architectures
or "MllamaForConditionalGeneration" in model_architectures
or "Qwen2VLForConditionalGeneration" in model_architectures
+ or "MiniCPMV" in model_architectures
):
return True
else:
diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py
index 60dba87cb..3a775aa1e 100644
--- a/python/sglang/srt/conversation.py
+++ b/python/sglang/srt/conversation.py
@@ -452,7 +452,6 @@ def generate_chat_conv(
# Add a blank message for the assistant.
conv.append_message(conv.roles[1], None)
-
return conv
@@ -555,3 +554,17 @@ register_conv_template(
image_token="<|vision_start|><|image_pad|><|vision_end|>",
)
)
+
+# Reference: https://huggingface.co/openbmb/MiniCPM-V-2_6#usage
+register_conv_template(
+ Conversation(
+ name="minicpmv",
+ system_message="You are a helpful assistant",
+ system_template="<|im_start|>system\n{system_message}.",
+ roles=("<|im_start|>user", "<|im_start|>assistant"),
+ sep="<|im_end|>\n",
+ sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
+ stop_str=("<|im_end|>", "<|endoftext|>"),
+ image_token="(./)",
+ )
+)
diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py
new file mode 100644
index 000000000..f66456b04
--- /dev/null
+++ b/python/sglang/srt/layers/attention/vision.py
@@ -0,0 +1,204 @@
+from __future__ import annotations
+
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from einops import rearrange, repeat
+from vllm.distributed import parallel_state
+from vllm.distributed import utils as dist_utils
+
+from sglang.srt.layers.attention.triton_ops.prefill_attention import (
+ context_attention_fwd,
+)
+from sglang.srt.layers.linear import (
+ ColumnParallelLinear,
+ QKVParallelLinear,
+ RowParallelLinear,
+)
+from sglang.srt.layers.quantization import QuantizationConfig
+
+
+def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
+ if not interleaved:
+ x1, x2 = x.chunk(2, dim=-1)
+ return torch.cat((-x2, x1), dim=-1)
+ else:
+ x1, x2 = x[..., ::2], x[..., 1::2]
+ return rearrange(
+ torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
+ )
+
+
+def apply_rotary_emb_torch(
+ x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
+) -> torch.Tensor:
+ """
+ x: (batch_size, seqlen, nheads, headdim)
+ cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
+ """
+ ro_dim = cos.shape[-1] * 2
+ assert ro_dim <= x.shape[-1]
+ cos = repeat(
+ cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
+ )
+ sin = repeat(
+ sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
+ )
+ return torch.cat(
+ [
+ x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
+ x[..., ro_dim:],
+ ],
+ dim=-1,
+ )
+
+
+def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
+ t_ = t.float()
+ cos = freqs.cos()
+ sin = freqs.sin()
+ output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
+ return output
+
+
+class VisionAttention(nn.Module):
+ """Multi-headed attention without any cache, mostly used for ViT."""
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ projection_size: int,
+ use_qkv_parallel: bool,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ):
+ super().__init__()
+ world_size = parallel_state.get_tensor_model_parallel_world_size()
+
+ self.hidden_size_per_attention_head = dist_utils.divide(
+ projection_size, num_heads
+ )
+ self.num_attention_heads_per_partition = dist_utils.divide(
+ num_heads, world_size
+ )
+ # self.tp_size = get_tensor_model_parallel_world_size()
+ # num_heads = self.num_heads_per_partition
+ self.use_qkv_parallel = use_qkv_parallel
+ if use_qkv_parallel:
+ self.head_dim = embed_dim // num_heads
+ self.qkv_proj = QKVParallelLinear(
+ hidden_size=embed_dim,
+ head_size=self.head_dim,
+ total_num_heads=num_heads,
+ quant_config=quant_config,
+ prefix=f"{prefix}.qkv_proj",
+ )
+ else:
+ self.qkv_proj = ColumnParallelLinear(
+ input_size=embed_dim,
+ output_size=3 * projection_size,
+ quant_config=quant_config,
+ prefix=f"{prefix}.qkv_proj",
+ )
+ self.proj = RowParallelLinear(
+ input_size=embed_dim,
+ output_size=embed_dim,
+ quant_config=quant_config,
+ prefix=f"{prefix}.out_proj",
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ cu_seqlens: Optional[torch.Tensor] = None,
+ rotary_pos_emb: torch.Tensor = None,
+ ) -> torch.Tensor:
+ """
+ Input shape: [b, s, embed_dim]
+ Output shape: [s, b, num_heads * head_size]
+ """
+
+ bsz, s, _ = x.shape
+ if self.use_qkv_parallel:
+ # [b, s, embed_dim] --> [b, s, embed_dim]
+ qkv, _ = self.qkv_proj(x)
+ q, k, v = qkv.chunk(3, dim=-1)
+
+ # [b, s, embed_dim] --> [b * s, num_heads, head_size]
+ q, k, v = [
+ x.reshape(
+ bsz * s, self.num_attention_heads_per_partition, -1
+ ).contiguous()
+ for x in (q, k, v)
+ ]
+ else:
+ # [b, s, embed_dim] --> [s, b, embed_dim]
+ x = rearrange(x, "b s ... -> s b ...")
+ # [s, b, embed_dim] --> [s, b, head * 3 * head_dim]
+ qkv, _ = self.qkv_proj(x)
+ # [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
+ new_x_shape = qkv.size()[:-1] + (
+ self.num_attention_heads_per_partition,
+ 3 * self.hidden_size_per_attention_head,
+ )
+ qkv = qkv.view(*new_x_shape)
+
+ # [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
+ q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3)
+
+ # [s, b, head, head_dim] --> [b, s, head, head_dim]
+ q, k, v = [
+ rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
+ ]
+
+ if rotary_pos_emb is not None:
+ q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
+ k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
+
+ if self.use_qkv_parallel:
+ pass
+ else:
+ # [b, s, head, head_dim] --> [b * s, head, head_dim]
+ q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
+
+ # [b * s, num_heads, head_size]
+ output = torch.empty_like(q)
+
+ seq_lens = (cu_seqlens[1:] - cu_seqlens[:-1]).cuda()
+ max_seqlen = seq_lens.max().item()
+
+ context_attention_fwd(
+ q,
+ k,
+ v,
+ output,
+ cu_seqlens.cuda(),
+ seq_lens,
+ max_seqlen,
+ is_causal=False,
+ )
+
+ if self.use_qkv_parallel:
+
+ # [b * s, head, head_dim] --> [b, s, head * head_dim]
+ output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz)
+
+ # [b, s, head, head_dim] --> [b, s, head, head_dim]
+ output, _ = self.proj(output)
+ else:
+ # [b * s, head, head_dim] --> [b, s, head, head_dim]
+ context_layer = rearrange(output, "(b s) ... -> b s ...", b=bsz)
+
+ # [s, b, num_heads * head_size]
+ context_layer = rearrange(
+ context_layer, "b s h d -> s b (h d)"
+ ).contiguous()
+
+ # [s, b, num_heads * head_size] --> [s, b, num_heads * head_size]
+ output, _ = self.proj(context_layer)
+
+ output = output.view(bsz, s, -1)
+
+ return output
diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py
index a4fe49051..10f264677 100644
--- a/python/sglang/srt/layers/logits_processor.py
+++ b/python/sglang/srt/layers/logits_processor.py
@@ -127,7 +127,7 @@ class LogitsProcessor(nn.Module):
hidden_states,
lm_head: VocabParallelEmbedding,
logits_metadata: Union[LogitsMetadata, ForwardBatch],
- ):
+ ) -> LogitsProcessorOutput:
if isinstance(logits_metadata, ForwardBatch):
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py
index 4f57ac5b2..3b959b1ba 100644
--- a/python/sglang/srt/managers/data_parallel_controller.py
+++ b/python/sglang/srt/managers/data_parallel_controller.py
@@ -56,6 +56,7 @@ class DataParallelController:
def __init__(self, server_args, port_args) -> None:
# Parse args
+ self.max_total_num_tokens = None
self.server_args = server_args
self.port_args = port_args
self.load_balance_method = LoadBalanceMethod.from_str(
@@ -96,6 +97,8 @@ class DataParallelController:
True,
)
+ self.max_req_input_len = None
+
def launch_dp_schedulers(self, server_args, port_args):
base_gpu_id = 0
@@ -189,6 +192,7 @@ class DataParallelController:
scheduler_info.append(scheduler_pipe_readers[i].recv())
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
+ self.max_req_input_len = scheduler_info[0]["max_req_input_len"]
def round_robin_scheduler(self, req):
self.workers[self.round_robin_counter].send_pyobj(req)
@@ -231,7 +235,11 @@ def run_data_parallel_controller_process(
try:
controller = DataParallelController(server_args, port_args)
pipe_writer.send(
- {"status": "ready", "max_total_num_tokens": controller.max_total_num_tokens}
+ {
+ "status": "ready",
+ "max_total_num_tokens": controller.max_total_num_tokens,
+ "max_req_input_len": controller.max_req_input_len,
+ }
)
if server_args.node_rank == 0:
controller.event_loop()
diff --git a/python/sglang/srt/managers/image_processor.py b/python/sglang/srt/managers/image_processor.py
index 7120fa48d..c8ebbed78 100644
--- a/python/sglang/srt/managers/image_processor.py
+++ b/python/sglang/srt/managers/image_processor.py
@@ -9,6 +9,8 @@ from typing import List, Optional, Union
import numpy as np
import transformers
+from decord import VideoReader, cpu
+from PIL import Image
from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.mm_utils import expand2square, process_anyres_image
@@ -36,6 +38,7 @@ class BaseImageProcessor(ABC):
def __init__(self, hf_config, server_args, _processor):
self.hf_config = hf_config
self._processor = _processor
+ self.server_args = server_args
self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor,
@@ -126,7 +129,12 @@ class LlavaImageProcessor(BaseImageProcessor):
)
async def process_images_async(
- self, image_data: List[Union[str, bytes]], input_text, request_obj
+ self,
+ image_data: List[Union[str, bytes]],
+ input_text,
+ request_obj,
+ *args,
+ **kwargs,
):
if not image_data:
return None
@@ -229,6 +237,147 @@ class MllamaImageProcessor(BaseImageProcessor):
return image_inputs
+class MiniCPMVImageProcessor(BaseImageProcessor):
+ def __init__(self, hf_config, server_args, _processor):
+ super().__init__(hf_config, server_args, _processor)
+
+ @staticmethod
+ def _process_images_task(images, input_text):
+ result = global_processor.__call__(
+ text=input_text, images=images, return_tensors="pt"
+ )
+ return {
+ "input_ids": result["input_ids"],
+ "pixel_values": result["pixel_values"],
+ "tgt_sizes": result["tgt_sizes"],
+ }
+
+ async def _process_images(self, images, input_text):
+ if self.executor is not None:
+ loop = asyncio.get_event_loop()
+ image_inputs = await loop.run_in_executor(
+ self.executor,
+ MiniCPMVImageProcessor._process_images_task,
+ images,
+ input_text,
+ )
+ else:
+ image_inputs = self._processor(
+ images=images, text=input_text, return_tensors="pt"
+ )
+
+ return image_inputs
+
+ async def process_images_async(
+ self,
+ image_data: List[Union[str, bytes]],
+ input_text,
+ request_obj,
+ max_req_input_len,
+ ):
+ if not image_data:
+ return None
+
+ if not isinstance(image_data, list):
+ image_data = [image_data]
+
+ image_hashes, image_sizes = [], []
+ raw_images = []
+ IMAGE_TOKEN = "(./)"
+
+ # roughly calculate the max number of frames
+ # TODO: the process should be applied to all the visual inputs
+ def calculate_max_num_frames() -> int:
+ # Model-specific
+ NUM_TOKEN_PER_FRAME = 330
+
+ ret = (max_req_input_len - len(input_text)) // NUM_TOKEN_PER_FRAME
+ return min(ret, 100)
+
+ # if cuda OOM set a smaller number
+ MAX_NUM_FRAMES = calculate_max_num_frames()
+ print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}")
+
+ def encode_video(video_path):
+ if not os.path.exists(video_path):
+ logger.error(f"Video {video_path} does not exist")
+ return []
+
+ if MAX_NUM_FRAMES == 0:
+ return []
+
+ def uniform_sample(l, n):
+ gap = len(l) / n
+ idxs = [int(i * gap + gap / 2) for i in range(n)]
+ return [l[i] for i in idxs]
+
+ vr = VideoReader(video_path, ctx=cpu(0))
+ sample_fps = round(vr.get_avg_fps() / 1) # FPS
+ frame_idx = [i for i in range(0, len(vr), sample_fps)]
+ if len(frame_idx) > MAX_NUM_FRAMES:
+ frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
+ frames = vr.get_batch(frame_idx).asnumpy()
+ frames = [Image.fromarray(v.astype("uint8")) for v in frames]
+ return frames
+
+ if isinstance(input_text, list):
+ assert len(input_text) and isinstance(input_text[0], int)
+ input_text = self._processor.tokenizer.decode(input_text)
+
+ # MiniCPMV requires each frame of video as a single image token
+ text_parts = input_text.split(IMAGE_TOKEN)
+ new_text_parts = []
+
+ for image_index, image in enumerate(image_data):
+ try:
+ if isinstance(image, str) and image.startswith("video:"):
+ path = image[len("video:") :]
+ frames = encode_video(path)
+ else:
+ raw_image, size = load_image(image)
+ frames = [raw_image]
+ if len(frames) == 0:
+ continue
+ except FileNotFoundError as e:
+ print(e)
+ return None
+
+ image_sizes += frames[0].size * len(frames)
+ image_hashes += [hash(image)] * len(frames)
+ raw_images += frames
+ new_text_parts.append(text_parts[image_index])
+ new_text_parts.append(IMAGE_TOKEN * len(frames))
+
+ new_text_parts.append(text_parts[-1])
+ input_text = "".join(new_text_parts)
+ if len(raw_images) == 0:
+ return None
+ res = await self._process_images(images=raw_images, input_text=input_text)
+ pixel_values = res["pixel_values"]
+ tgt_sizes = res["tgt_sizes"]
+ input_ids = res["input_ids"]
+
+ # Collect special token ids
+ tokenizer = self._processor.tokenizer
+ im_start_id = [tokenizer.im_start_id]
+ im_end_id = [tokenizer.im_end_id]
+ if tokenizer.slice_start_id:
+ slice_start_id = [tokenizer.slice_start_id]
+ slice_end_id = [tokenizer.slice_end_id]
+
+ return {
+ "input_ids": input_ids.flatten().tolist(),
+ "pixel_values": pixel_values,
+ "tgt_sizes": tgt_sizes,
+ "image_hashes": image_hashes,
+ "modalities": request_obj.modalities or ["image"],
+ "im_start_id": im_start_id,
+ "im_end_id": im_end_id,
+ "slice_start_id": slice_start_id,
+ "slice_end_id": slice_end_id,
+ }
+
+
class Qwen2VLImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _image_processor):
self.hf_config = hf_config
@@ -289,7 +438,12 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
return self._process_single_image_task(image_data)
async def process_images_async(
- self, image_data: List[Union[str, bytes]], input_text, request_obj
+ self,
+ image_data: List[Union[str, bytes]],
+ input_text,
+ request_obj,
+ *args,
+ **kwargs,
):
if not image_data:
return None
@@ -350,6 +504,8 @@ def get_image_processor(
return MllamaImageProcessor(hf_config, server_args, processor)
elif "Qwen2VLForConditionalGeneration" in hf_config.architectures:
return Qwen2VLImageProcessor(hf_config, server_args, processor.image_processor)
+ elif "MiniCPMV" in hf_config.architectures:
+ return MiniCPMVImageProcessor(hf_config, server_args, processor)
else:
return LlavaImageProcessor(hf_config, server_args, processor.image_processor)
diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py
index 6c3800785..faf05a7ff 100644
--- a/python/sglang/srt/managers/schedule_batch.py
+++ b/python/sglang/srt/managers/schedule_batch.py
@@ -52,7 +52,6 @@ from sglang.srt.server_args import ServerArgs
if TYPE_CHECKING:
from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
-
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
# Put some global args for easy access
@@ -68,7 +67,6 @@ global_server_args_dict = {
"device": ServerArgs.device,
}
-
logger = logging.getLogger(__name__)
@@ -149,6 +147,16 @@ class ImageInputs:
image_grid_thws: List[Tuple[int, int, int]] = None
mrope_position_delta: Optional[torch.Tensor] = None
+ # MiniCPMV related
+ # All the images in the batch should share the same special image
+ # bound token ids.
+ im_start_id: Optional[torch.Tensor] = None
+ im_end_id: Optional[torch.Tensor] = None
+ slice_start_id: Optional[torch.Tensor] = None
+ slice_end_id: Optional[torch.Tensor] = None
+
+ tgt_sizes: Optional[list] = None
+
@staticmethod
def from_dict(obj: dict):
ret = ImageInputs(
@@ -168,6 +176,11 @@ class ImageInputs:
"aspect_ratio_ids",
"aspect_ratio_mask",
"image_grid_thws",
+ "im_start_id",
+ "im_end_id",
+ "slice_start_id",
+ "slice_end_id",
+ "tgt_sizes",
]
for arg in optional_args:
if arg in obj:
@@ -1140,7 +1153,6 @@ class ScheduleBatch:
global bid
bid += 1
-
return ModelWorkerBatch(
bid=bid,
forward_mode=self.forward_mode,
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
index bc963e008..5ed0fde34 100644
--- a/python/sglang/srt/managers/scheduler.py
+++ b/python/sglang/srt/managers/scheduler.py
@@ -274,7 +274,6 @@ class Scheduler:
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
global_server_args_dict.update(worker_global_server_args_dict)
set_random_seed(self.random_seed)
-
# Print debug info
logger.info(
f"max_total_num_tokens={self.max_total_num_tokens}, "
@@ -1729,7 +1728,11 @@ def run_scheduler_process(
try:
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
pipe_writer.send(
- {"status": "ready", "max_total_num_tokens": scheduler.max_total_num_tokens}
+ {
+ "status": "ready",
+ "max_total_num_tokens": scheduler.max_total_num_tokens,
+ "max_req_input_len": scheduler.max_req_input_len,
+ }
)
if scheduler.enable_overlap:
scheduler.event_loop_overlap()
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
index 18ac7503c..9dcc986d9 100644
--- a/python/sglang/srt/managers/tokenizer_manager.py
+++ b/python/sglang/srt/managers/tokenizer_manager.py
@@ -112,6 +112,7 @@ class TokenizerManager:
port_args: PortArgs,
):
# Parse args
+
self.server_args = server_args
self.enable_metrics = server_args.enable_metrics
self.log_requests = server_args.log_requests
@@ -207,6 +208,8 @@ class TokenizerManager:
self.resume_memory_occupation_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
+ # Set after scheduler is initialized
+ self.max_req_input_len = None
# Metrics
if self.enable_metrics:
@@ -281,7 +284,7 @@ class TokenizerManager:
if self.is_generation:
# TODO: also support getting embeddings for multimodal models
image_inputs: Dict = await self.image_processor.process_images_async(
- obj.image_data, input_text or input_ids, obj
+ obj.image_data, input_text or input_ids, obj, self.max_req_input_len
)
if image_inputs and "input_ids" in image_inputs:
input_ids = image_inputs["input_ids"]
diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py
index 261914694..bca4711eb 100644
--- a/python/sglang/srt/model_executor/model_runner.py
+++ b/python/sglang/srt/model_executor/model_runner.py
@@ -237,7 +237,7 @@ class ModelRunner:
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
if not self.is_draft_worker:
- # Only initilzie the distributed environment on the target model worker.
+ # Only initialize the distributed environment on the target model worker.
init_distributed_environment(
backend=backend,
world_size=self.tp_size,
diff --git a/python/sglang/srt/models/minicpmv.py b/python/sglang/srt/models/minicpmv.py
new file mode 100644
index 000000000..5ff941b6c
--- /dev/null
+++ b/python/sglang/srt/models/minicpmv.py
@@ -0,0 +1,1238 @@
+# Adapted from
+# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
+# Copyright 2023 The vLLM team.
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
+from functools import cached_property, partial
+from typing import (
+ Any,
+ Callable,
+ Iterable,
+ List,
+ Literal,
+ Optional,
+ Tuple,
+ TypedDict,
+ Union,
+)
+
+import torch
+import torch.types
+from PIL import Image
+from torch import nn
+from torch.nn.init import trunc_normal_
+from transformers import PretrainedConfig
+from vllm.distributed import divide, get_tensor_model_parallel_world_size
+from vllm.model_executor.layers.resampler import get_2d_sincos_pos_embed
+from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
+from vllm.model_executor.models.module_mapping import MultiModelKeys
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+
+from sglang.srt.layers.activation import get_act_fn
+from sglang.srt.layers.attention.vision import VisionAttention
+from sglang.srt.layers.linear import (
+ ColumnParallelLinear,
+ ReplicatedLinear,
+ RowParallelLinear,
+)
+from sglang.srt.layers.logits_processor import LogitsProcessor
+from sglang.srt.layers.quantization.base_config import QuantizationConfig
+from sglang.srt.managers.schedule_batch import ImageInputs
+from sglang.srt.model_executor.forward_batch_info import ForwardBatch
+from sglang.srt.model_loader.utils import set_default_torch_dtype
+from sglang.srt.model_loader.weight_utils import default_weight_loader
+from sglang.srt.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
+
+RawImageType = Union[Image.Image, torch.Tensor]
+
+
+class Idefics2VisionMLP(nn.Module):
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.config = config
+ self.activation_fn = get_act_fn(config.hidden_act)
+ self.fc1 = ColumnParallelLinear(
+ config.hidden_size,
+ config.intermediate_size,
+ bias=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.fc1",
+ )
+ self.fc2 = RowParallelLinear(
+ config.intermediate_size,
+ config.hidden_size,
+ bias=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.fc2",
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states, _ = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states, _ = self.fc2(hidden_states)
+ return hidden_states
+
+
+class Idefics2EncoderLayer(nn.Module):
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.embed_dim = config.hidden_size
+
+ self.num_heads = config.num_attention_heads
+ tp_size = get_tensor_model_parallel_world_size()
+ num_heads_per_partition = divide(self.num_heads, tp_size)
+ self.self_attn = VisionAttention(
+ embed_dim=config.hidden_size,
+ num_heads=num_heads_per_partition,
+ projection_size=config.intermediate_size,
+ use_qkv_parallel=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.self_attn",
+ )
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+ self.mlp = Idefics2VisionMLP(config, quant_config=quant_config)
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ forward_batch: ForwardBatch,
+ ) -> torch.Tensor:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`):
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
+
+ """
+ residual = hidden_states
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states = self.self_attn(
+ hidden_states,
+ cu_seqlens=cu_seqlens,
+ # , forward_batch=forward_batch
+ )
+ hidden_states = residual + hidden_states
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+class Idefics2Encoder(nn.Module):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers` self attention
+ layers. Each layer is a
+ [`Idefics2EncoderLayer`].
+
+ Args:
+ config: Idefics2Config
+ """
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ ) -> None:
+ super().__init__()
+
+ self.config = config
+ self.layers = nn.ModuleList(
+ [
+ Idefics2EncoderLayer(
+ config,
+ quant_config=quant_config,
+ )
+ for _ in range(config.num_hidden_layers)
+ ]
+ )
+
+ def forward(
+ self,
+ inputs_embeds: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ forward_batch: ForwardBatch,
+ ) -> torch.Tensor:
+ r"""
+ Args:
+ inputs_embeds (torch.Tensor):
+ Optionally, instead of passing `input_ids` you can choose to
+ directly pass an embedded representation.
+ This is useful if you want more control over how to convert
+ `input_ids` indices into associated vectorsthan the model's
+ internal embedding lookup matrix.
+ """
+ hidden_states = inputs_embeds
+ for encoder_layer in self.layers:
+ layer_outputs = encoder_layer(
+ hidden_states, cu_seqlens=cu_seqlens, forward_batch=forward_batch
+ )
+ hidden_states = layer_outputs
+ return hidden_states
+
+
+class Idefics2VisionEmbeddings(nn.Module):
+ """
+ This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings
+ ` to enable images of variable
+ resolution.
+
+ The modifications are adapted from [Patch n' Pack: NaViT, a Vision
+ Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)
+ which allows treating images in their native aspect ratio and without the
+ need to resize them to the same fixed size. In particular, we start from the
+ original pre-trained SigLIP model(which uses images of fixed-size square
+ images) and adapt it by training on images of variable resolutions.
+ """
+
+ def __init__(self, config: PretrainedConfig):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ padding="valid",
+ )
+ self.num_patches_per_side = self.image_size // self.patch_size
+ self.num_patches = self.num_patches_per_side**2
+ self.num_positions = self.num_patches
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ patch_attention_mask: torch.BoolTensor,
+ tgt_sizes: Optional[torch.IntTensor] = None,
+ ) -> torch.Tensor:
+ batch_size, _, max_im_h, max_im_w = pixel_values.shape
+ target_dtype = self.patch_embedding.weight.dtype
+ pixel_values = pixel_values.to(
+ device=self.patch_embedding.weight.device, dtype=target_dtype
+ )
+ patch_embeds = self.patch_embedding(pixel_values)
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
+ max_nb_patches_h, max_nb_patches_w = (
+ max_im_h // self.patch_size,
+ max_im_w // self.patch_size,
+ )
+ boundaries = torch.arange(
+ 1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side
+ )
+ position_ids = torch.full(
+ size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0
+ )
+
+ for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
+
+ if tgt_sizes is not None:
+ nb_patches_h = tgt_sizes[batch_idx][0]
+ nb_patches_w = tgt_sizes[batch_idx][1]
+ else:
+ nb_patches_h = p_attn_mask[:, 0].sum()
+ nb_patches_w = p_attn_mask[0].sum()
+ fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
+ fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
+ bucket_coords_h = torch.bucketize(
+ fractional_coords_h, boundaries, right=True
+ )
+ bucket_coords_w = torch.bucketize(
+ fractional_coords_w, boundaries, right=True
+ )
+ pos_ids = (
+ bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w
+ ).flatten()
+ position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
+ position_ids = position_ids.to(self.position_embedding.weight.device)
+ embeddings = embeddings + self.position_embedding(position_ids)
+ return embeddings
+
+
+class Idefics2VisionTransformer(nn.Module):
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+
+ embed_dim = config.hidden_size
+ self.config = config
+ self.embeddings = Idefics2VisionEmbeddings(config)
+ self.encoder = Idefics2Encoder(config=config, quant_config=quant_config)
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+
+ def get_input_embeddings(self):
+ return self.embeddings
+
+ def compute_cu_seqlens(self, tgt_sizes: torch.Tensor) -> torch.Tensor:
+ patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] # shape: (batch_size,)
+
+ # 做 prefix sum 来得到 cu_seqlens,注意在最前面插一个 0 作为 offset
+ cu_seqlens = torch.cat(
+ [
+ torch.tensor([0], device=patch_len.device, dtype=torch.int32),
+ torch.cumsum(patch_len, dim=0, dtype=torch.int32),
+ ],
+ dim=0,
+ ).to(tgt_sizes.device)
+ return cu_seqlens
+
+ def forward(
+ self,
+ pixel_values,
+ forward_batch: ForwardBatch,
+ patch_attention_mask: Optional[torch.BoolTensor] = None,
+ tgt_sizes: Optional[torch.IntTensor] = None,
+ ) -> torch.Tensor:
+ hidden_states = self.embeddings(
+ pixel_values=pixel_values,
+ patch_attention_mask=patch_attention_mask,
+ # forward_batch=forward_batch,
+ tgt_sizes=tgt_sizes,
+ )
+ cu_seqlens = self.compute_cu_seqlens(tgt_sizes)
+ encoder_outputs = self.encoder(
+ hidden_states, cu_seqlens=cu_seqlens, forward_batch=forward_batch
+ )
+ last_hidden_state = self.post_layernorm(encoder_outputs)
+ return last_hidden_state
+
+
+class MiniCPMVImagePixelInputs(TypedDict):
+ type: Literal["pixel_values"]
+ data: List[torch.Tensor]
+ """
+ Shape: `(batch_size * num_images, num_channels, height, width)`
+
+ Note that the image size may vary, so we pass it as a list
+ instead of a batched tensor.
+ """
+
+ image_bounds: torch.Tensor
+ """
+ Shape: `(batch_size * num_images, 2)`
+
+ This should be in `(start, stop)` format.
+ """
+
+ tgt_sizes: torch.Tensor
+ """
+ Shape: `(batch_size * num_images, 2)`
+
+ This should be in `(height, width)` format.
+ """
+
+
+class MiniCPMVImageEmbeddingInputs(TypedDict):
+ type: Literal["image_embeds"]
+ data: torch.Tensor
+ """
+ Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
+
+ `hidden_size` must match the hidden size of language model backbone.
+ instead of a batched tensor.
+ """
+
+ image_bounds: torch.Tensor
+ """
+ Shape: `(batch_size * num_images, 2)`
+
+ This should be in `(start, stop)` format.
+ """
+
+
+MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs, MiniCPMVImageEmbeddingInputs]
+
+DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
+
+
+class BaseResampler(nn.Module):
+ """
+ A 2D perceiver-resampler network with one cross attention layers by
+ (grid_size**2) learnable queries and 2d sincos pos_emb.
+ Outputs:
+ A tensor with the shape of (grid_size**2, embed_dim)
+ """
+
+ def __init__(
+ self,
+ num_queries: int,
+ embed_dim: int,
+ num_heads: int,
+ kv_dim: Optional[int] = None,
+ norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
+ do_post_projection: bool = True,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+
+ self.num_queries = num_queries
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+
+ self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
+ trunc_normal_(self.query, std=0.02)
+ if kv_dim is not None and kv_dim != embed_dim:
+ self.kv_proj = ReplicatedLinear(
+ kv_dim,
+ embed_dim,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.kv_proj",
+ )
+ else:
+ # Maintain the same return value with ReplicatedLinear.forward
+ self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa
+ nn.Identity()(*args, **kwargs),
+ None,
+ )
+ self.attn = nn.MultiheadAttention(embed_dim, num_heads)
+ self.ln_q = norm_layer(embed_dim)
+ self.ln_kv = norm_layer(embed_dim)
+ self.do_post_projection = do_post_projection
+ self.ln_post = norm_layer(embed_dim) if do_post_projection else None
+ self.proj = (
+ nn.Parameter((embed_dim**-0.5) * torch.randn(embed_dim, embed_dim))
+ if do_post_projection
+ else None
+ )
+
+ def _init_weights(self, m: nn.Module) -> None:
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def _repeat(self, query, N: int):
+ return query.unsqueeze(1).repeat(1, N, 1)
+
+
+class Resampler2_5(BaseResampler):
+
+ def __init__(
+ self,
+ num_queries: int,
+ embed_dim: int,
+ num_heads: int,
+ kv_dim: Optional[int] = None,
+ norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
+ max_size: Tuple[int, int] = (70, 70),
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__(
+ num_queries,
+ embed_dim,
+ num_heads,
+ kv_dim,
+ norm_layer,
+ quant_config=quant_config,
+ prefix=prefix,
+ )
+
+ self.max_size = max_size
+ self._set_2d_pos_cache(self.max_size)
+
+ self.apply(self._init_weights)
+
+ def _set_2d_pos_cache(
+ self, max_size: Tuple[int, int], device: torch.types.Device = "cpu"
+ ) -> None:
+ pos_embed_arr = get_2d_sincos_pos_embed(
+ self.embed_dim, max_size, version=(2, 5)
+ )
+ pos_embed = torch.from_numpy(pos_embed_arr).float().to(device)
+ self.register_buffer("pos_embed", pos_embed, persistent=False)
+
+ def _adjust_pos_cache(
+ self, tgt_sizes: torch.Tensor, device: torch.types.Device
+ ) -> None:
+ max_h = tgt_sizes[:, 0].max().item()
+ max_w = tgt_sizes[:, 1].max().item()
+ assert isinstance(max_h, int) and isinstance(max_w, int)
+
+ if max_h > self.max_size[0] or max_w > self.max_size[1]:
+ self.max_size = (
+ max(max_h, self.max_size[0]),
+ max(max_w, self.max_size[1]),
+ )
+ self._set_2d_pos_cache(self.max_size, device)
+
+ def forward(self, x: torch.Tensor, tgt_sizes: torch.Tensor) -> torch.Tensor:
+ assert x.shape[0] == tgt_sizes.shape[0]
+ bs = x.shape[0]
+
+ device = x.device
+ dtype = x.dtype
+
+ patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1]
+
+ self._adjust_pos_cache(tgt_sizes, device=device)
+
+ max_patch_len = patch_len.max().item()
+ assert isinstance(max_patch_len, int)
+
+ key_padding_mask = torch.zeros(
+ (bs, max_patch_len), dtype=torch.bool, device=device
+ )
+
+ pos_embed = []
+ for i in range(bs):
+ tgt_h, tgt_w = tgt_sizes[i].tolist()
+ pos_embed.append(
+ self.pos_embed[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1)).to(dtype)
+ ) # patches * D
+ key_padding_mask[i, patch_len[i] :] = True
+ pos_embed = torch.nn.utils.rnn.pad_sequence(
+ pos_embed, batch_first=True, padding_value=0.0
+ ).permute(
+ 1, 0, 2
+ ) # BLD => L * B * D
+ x, _ = self.kv_proj(x) # B * L * D
+ x = self.ln_kv(x).permute(1, 0, 2) # L * B * D
+
+ q = self.ln_q(self.query) # Q * D
+
+ out = self.attn(
+ self._repeat(q, bs), # Q * B * D
+ x + pos_embed, # L * B * D + L * B * D
+ x,
+ key_padding_mask=key_padding_mask,
+ )[0]
+ # out: Q * B * D
+ x = out.permute(1, 0, 2) # B * Q * D
+
+ x = self.ln_post(x)
+ x = x @ self.proj
+ return x
+
+
+def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
+ version_float = getattr(config, "version", None)
+
+ # The old configs do not include version number
+ # TODO: Remove this after the HF repos are updated
+ if version_float is None:
+ if config.hidden_size == 2304 and config.query_num == 64:
+ return 2, 0
+ return 2, 5
+
+ version_str = str(version_float)
+ return tuple(int(x) for x in version_str.split("."))
+
+
+class MiniCPMVBaseModel(nn.Module):
+ """
+ The abstract class of MiniCPMV can only be inherited, but cannot be
+ instantiated.
+ """
+
+ def __init__(
+ self,
+ *,
+ config: PretrainedConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ ):
+ # multimodal_config = config.model_config.multimodal_config
+ super().__init__()
+ # All MiniCPM-V models disable `tie_word_embeddings` but
+ # `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot
+ # check `tie_word_embeddings` until vLLM integrate MiniCPM-V model
+ # and config class
+ self.config = config
+ # self.multimodal_config = multimodal_config
+
+ self.version = get_version_by_config(self.config)
+ self.llm = self.init_llm(config=config, quant_config=quant_config)
+ self.vpm = self.init_vision_module(config, quant_config)
+ self.vision_dim = (
+ self.vpm.embed_dim
+ if self.version == (2, 0)
+ else self.vpm.embeddings.embed_dim
+ )
+ self.embed_dim = self.config.hidden_size
+
+ self.resampler = self.init_resampler(
+ self.embed_dim, self.vision_dim, quant_config=quant_config
+ )
+
+ self.logits_processor = LogitsProcessor(config)
+
+ @cached_property
+ def sampler(self):
+ if hasattr(self.llm, "sampler"):
+ return self.llm.sampler
+
+ return get_sampler()
+
+ def _get_image_bounds(
+ self,
+ input_ids: torch.Tensor,
+ pad_values: List[int],
+ im_start_id: torch.Tensor,
+ im_end_id: torch.Tensor,
+ slice_start_id: Optional[torch.Tensor] = None,
+ slice_end_id: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """
+ Returns a tensor indicating the bounds (start and end token ids) of the images
+ """
+ # All the images in the batch should share the same special image
+ # bound token ids.
+ start_cond = input_ids == im_start_id[0]
+ end_cond = input_ids == im_end_id[0]
+ if slice_start_id is not None:
+ start_cond |= input_ids == slice_start_id[0]
+ end_cond |= input_ids == slice_end_id[0]
+
+ (image_start_tokens,) = torch.where(start_cond)
+ image_start_tokens += 1
+ (image_end_tokens,) = torch.where(end_cond)
+
+ # the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the images
+ if len(image_start_tokens) != len(image_end_tokens):
+ if (
+ len(image_start_tokens) + 1 == len(image_end_tokens)
+ and input_ids[0] in pad_values
+ and image_end_tokens[0] < image_start_tokens[0]
+ ):
+ image_start_tokens = torch.cat(
+ [
+ torch.tensor([0], device=image_start_tokens.device),
+ image_start_tokens,
+ ]
+ )
+ valid_image_nums = min(len(image_start_tokens), len(image_end_tokens))
+
+ if valid_image_nums == 0:
+ return torch.zeros((0, 2), device=input_ids.device)
+
+ # Filter out pairs where start_token >= end_token
+ valid_pairs = []
+ for i in range(valid_image_nums):
+ start_token = image_start_tokens[i]
+ end_token = image_end_tokens[i]
+ if start_token < end_token:
+ valid_pairs.append((start_token, end_token))
+
+ if not valid_pairs:
+ return torch.zeros((0, 2), device=input_ids.device)
+
+ # Convert valid pairs to tensor
+ valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
+ return valid_pairs_tensor
+
+ def get_embedding(
+ self,
+ input_ids: torch.Tensor,
+ image_inputs: Optional[MiniCPMVImageInputs],
+ forward_batch: ForwardBatch,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids)
+
+ if image_inputs is None: # No image
+ vision_hidden_states = torch.tensor([], device=input_ids.device)
+ else:
+ if image_inputs["type"] == "image_embeds":
+ vision_hidden_states = (
+ image_inputs["data"]
+ .type(vlm_embedding.dtype)
+ .to(vlm_embedding.device)
+ )
+ else:
+ vision_hidden_states = self.get_vision_hidden_states(
+ forward_batch, image_inputs
+ )
+
+ # See NOTE in _parse_and_validate_inputs
+ image_bounds = image_inputs["image_bounds"]
+ if len(image_bounds) > 0:
+ image_indices = torch.stack(
+ [
+ torch.arange(start, end, dtype=torch.long)
+ for start, end in image_bounds.tolist()
+ ]
+ ).to(vlm_embedding.device)
+ vlm_embedding.scatter_(
+ 0,
+ image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]),
+ vision_hidden_states.view(-1, vision_hidden_states.shape[-1]),
+ )
+
+ return vlm_embedding, vision_hidden_states
+
+ def _parse_and_validate_inputs(
+ self,
+ input_ids: torch.Tensor,
+ **kwargs: object,
+ ) -> Optional[MiniCPMVImageInputs]:
+ pixel_values = kwargs.pop("pixel_values", [])
+ tgt_sizes = kwargs.pop("tgt_sizes", [])
+ im_start_id = kwargs.pop("im_start_id", None)
+ im_end_id = kwargs.pop("im_end_id", None)
+ slice_start_id = kwargs.pop("slice_start_id", None)
+ slice_end_id = kwargs.pop("slice_end_id", None)
+ image_embeds = kwargs.pop("image_embeds", None)
+ pad_values = kwargs.pop("pad_values", None)
+
+ if image_embeds is not None:
+ image_bounds = self._get_image_bounds(
+ input_ids=input_ids,
+ pad_values=pad_values,
+ im_start_id=im_start_id,
+ im_end_id=im_end_id,
+ slice_start_id=slice_start_id,
+ slice_end_id=slice_end_id,
+ )
+ if not isinstance(image_embeds, (torch.Tensor, list)):
+ raise ValueError(
+ f"Incorrect type of image embeds. "
+ f"Got type: {type(image_embeds)}"
+ )
+
+ if isinstance(image_embeds, list):
+ image_embeds = torch.concat(image_embeds)
+
+ return MiniCPMVImageEmbeddingInputs(
+ image_bounds=image_bounds,
+ data=image_embeds,
+ type="image_embeds",
+ )
+
+ if not isinstance(pixel_values, (torch.Tensor, list)):
+ raise ValueError(
+ "Incorrect type of pixel values. " f"Got type: {type(pixel_values)}"
+ )
+
+ if not isinstance(tgt_sizes, (torch.Tensor, list)):
+ raise ValueError(
+ "Incorrect type of target sizes. " f"Got type: {type(tgt_sizes)}"
+ )
+
+ if len(pixel_values) != len(tgt_sizes):
+ raise ValueError(
+ "Inconsistent batch lengths, found: "
+ f"{len(pixel_values)} vs. {len(tgt_sizes)}"
+ )
+
+ pixel_values_flat: List[torch.Tensor] = []
+ tgt_sizes_flat: List[torch.Tensor] = []
+ for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
+ if len(pixel_b) != len(tgt_b):
+ raise ValueError(
+ "Inconsistent N lengths, found: " f"{len(pixel_b)} vs {len(tgt_b)}"
+ )
+
+ for pixel_n, tgt_n in zip(pixel_b, tgt_b):
+ pixel_values_flat += pixel_n
+ tgt_sizes_flat += tgt_n
+
+ # NOTE: Input IDs does not contain image tokens during memory profiling,
+ # so we allow it to be empty
+ if len(pixel_values_flat) != len(tgt_sizes_flat):
+ raise ValueError(
+ "Inconsistent flattened lengths, found: "
+ f"{len(pixel_values_flat)} vs. "
+ f"{len(tgt_sizes_flat)}"
+ )
+
+ if len(pixel_values_flat) == 0:
+ return None
+
+ image_bounds = self._get_image_bounds(
+ input_ids=input_ids,
+ pad_values=pad_values,
+ im_start_id=im_start_id,
+ im_end_id=im_end_id,
+ slice_start_id=slice_start_id,
+ slice_end_id=slice_end_id,
+ )
+ return MiniCPMVImagePixelInputs(
+ image_bounds=image_bounds.to(device=input_ids.device),
+ data=pixel_values_flat,
+ tgt_sizes=torch.stack(tgt_sizes_flat),
+ type="pixel_values",
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ forward_batch: ForwardBatch,
+ **kwargs: Any,
+ ) -> torch.Tensor:
+ if forward_batch.image_inputs is not None and forward_batch.image_inputs != [
+ None
+ ]:
+ kwargs.update(
+ {
+ "pixel_values": (
+ None
+ if forward_batch.image_inputs is None
+ else [
+ i.pixel_values
+ for i in forward_batch.image_inputs
+ if i is not None
+ ]
+ ),
+ "tgt_sizes": (
+ None
+ if forward_batch.image_inputs is None
+ else [
+ i.tgt_sizes
+ for i in forward_batch.image_inputs
+ if i is not None
+ ]
+ ),
+ "im_start_id": forward_batch.image_inputs[0].im_start_id,
+ "im_end_id": forward_batch.image_inputs[0].im_end_id,
+ "slice_start_id": forward_batch.image_inputs[0].slice_start_id,
+ "slice_end_id": forward_batch.image_inputs[0].slice_end_id,
+ "pad_values": forward_batch.image_inputs[0].pad_values,
+ }
+ )
+
+ image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs)
+
+ # Clamp input ids. This is because the input_ids for the image tokens are
+ # filled with the hash values of the image for the prefix matching in the radix attention.
+ # There values are useless because their embeddings will be replaced by vision embeddings anyway.
+ input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
+
+ vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs, forward_batch)
+
+ # always pass the input via `inputs_embeds`
+ # to make sure the computation graph is consistent
+ # for `torch.compile` integration
+ input_ids = None
+
+ hidden_states = self.llm.model(
+ input_ids=input_ids,
+ positions=positions,
+ forward_batch=forward_batch,
+ input_embeds=vlm_embeddings,
+ )
+
+ return self.logits_processor(
+ input_ids, hidden_states, self.llm.lm_head, forward_batch
+ )
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[torch.Tensor]:
+ return self.llm.compute_logits(hidden_states, sampling_metadata)
+
+ def sample(
+ self,
+ logits: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[SamplerOutput]:
+ next_tokens = self.sampler(logits, sampling_metadata)
+ return next_tokens
+
+ def get_mm_mapping(self) -> MultiModelKeys:
+ """
+ Get the module prefix in multimodal models
+ """
+ return MultiModelKeys.from_string_field(
+ language_model="llm", connector="resampler", tower_model="vpm"
+ )
+
+ def init_llm(
+ self,
+ config: Qwen2Config,
+ quant_config: Optional[QuantizationConfig] = None,
+ ) -> nn.Module:
+ raise NotImplementedError
+
+ def init_vision_module(
+ self,
+ config: PretrainedConfig,
+ quant_config: Optional[QuantizationConfig],
+ ) -> nn.Module:
+ raise NotImplementedError
+
+ def init_resampler(
+ self,
+ embed_dim: int,
+ vision_dim: int,
+ quant_config: Optional[QuantizationConfig] = None,
+ ) -> nn.Module:
+ raise NotImplementedError
+
+ def get_vision_embedding(
+ self,
+ pixel_values: List[torch.Tensor],
+ patch_attn_mask: Optional[torch.Tensor] = None,
+ tgt_sizes: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ raise NotImplementedError
+
+ def get_vision_hidden_states(
+ self, forward_batch: ForwardBatch, data: MiniCPMVImageInputs
+ ) -> torch.Tensor:
+ raise NotImplementedError
+
+
+class MiniCPMV2_6(MiniCPMVBaseModel):
+ packed_modules_mapping = {
+ "qkv_proj": [
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ ],
+ "gate_up_proj": [
+ "gate_proj",
+ "up_proj",
+ ],
+ }
+ # LoRA specific attributes
+ supported_lora_modules = [
+ # vision encoder
+ "fc1",
+ "fc2",
+ "out_proj",
+ # language model
+ "qkv_proj", # same name with vision encoder
+ "o_proj",
+ "gate_up_proj",
+ "down_proj",
+ # resampler
+ "kv_proj",
+ ]
+
+ # BitandBytes specific attributes
+ bitsandbytes_stacked_params_mapping = {
+ # shard_name, weight_name, index
+ "q_proj": ("qkv_proj", 0),
+ "k_proj": ("qkv_proj", 1),
+ "v_proj": ("qkv_proj", 2),
+ "gate_proj": ("gate_up_proj", 0),
+ "up_proj": ("gate_up_proj", 1),
+ }
+
+ embedding_modules = {}
+ embedding_padding_modules = []
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ ):
+ super().__init__(config=config, quant_config=quant_config)
+ assert self.version == (2, 6)
+
+ def init_llm(
+ self,
+ config: Qwen2Config,
+ quant_config: Optional[QuantizationConfig] = None,
+ ) -> nn.Module:
+ return Qwen2ForCausalLM(config=config, quant_config=quant_config)
+
+ def init_vision_module(
+ self,
+ config: PretrainedConfig,
+ quant_config: Optional[QuantizationConfig],
+ ) -> nn.Module:
+ model = Idefics2VisionTransformer(
+ config=config.vision_config, quant_config=quant_config
+ )
+ if self.config.drop_vision_last_layer:
+ model.encoder.layers = model.encoder.layers[:-1]
+
+ setattr(model, "embed_dim", model.embeddings.embed_dim)
+ setattr(model, "patch_size", model.embeddings.patch_size)
+ return model
+
+ def init_resampler(
+ self,
+ embed_dim: int,
+ vision_dim: int,
+ quant_config: Optional[QuantizationConfig] = None,
+ ) -> nn.Module:
+ with set_default_torch_dtype(torch.float16):
+ # The resampler in 2.6 remains consistent with the one in 2.5.
+ resampler = Resampler2_5(
+ num_queries=self.config.query_num,
+ embed_dim=embed_dim,
+ num_heads=embed_dim // 128,
+ kv_dim=vision_dim,
+ quant_config=quant_config,
+ )
+
+ return resampler.to(device="cuda", dtype=torch.get_default_dtype())
+
+ def get_vision_embedding(
+ self,
+ pixel_values: List[torch.Tensor],
+ patch_attn_mask: Optional[torch.Tensor] = None,
+ tgt_sizes: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ vision_embedding = self.vpm(
+ pixel_values,
+ patch_attention_mask=patch_attn_mask,
+ tgt_sizes=tgt_sizes,
+ )
+ return vision_embedding
+
+ def get_vision_hidden_states(
+ self,
+ forward_batch: ForwardBatch,
+ data: MiniCPMVImageInputs,
+ ) -> torch.Tensor:
+ pixel_values = data["data"]
+ tgt_sizes = data["tgt_sizes"]
+
+ device = self.vpm.embeddings.position_embedding.weight.device
+ dtype = self.vpm.embeddings.position_embedding.weight.dtype
+ all_pixel_values_lst = [
+ i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
+ ]
+
+ max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
+ assert isinstance(max_patches, int)
+
+ all_pixel_values = torch.nn.utils.rnn.pad_sequence(
+ all_pixel_values_lst, batch_first=True, padding_value=0.0
+ )
+ B, L, _ = all_pixel_values.shape
+ all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
+ patch_attn_mask = torch.zeros(
+ (B, 1, max_patches), dtype=torch.bool, device=device
+ )
+ for i in range(B):
+ patch_attn_mask[i, 0, : tgt_sizes[i][0] * tgt_sizes[i][1]] = True
+ vision_embedding = self.vpm(
+ all_pixel_values.type(dtype),
+ forward_batch=forward_batch,
+ patch_attention_mask=patch_attn_mask,
+ tgt_sizes=tgt_sizes,
+ )
+
+ return self.resampler(vision_embedding, tgt_sizes)
+
+ def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
+ if not isinstance(image_inputs.im_start_id, list) or not isinstance(
+ image_inputs.im_end_id, list
+ ):
+ return input_ids
+
+ new_input_ids = []
+ last_idx = 0
+ image_idx = -1
+ image_inputs.image_offsets = []
+
+ # Get all special token IDs
+ im_start_id = (
+ image_inputs.im_start_id[0].item()
+ if isinstance(image_inputs.im_start_id[0], torch.Tensor)
+ else image_inputs.im_start_id[0]
+ )
+ im_end_id = (
+ image_inputs.im_end_id[0].item()
+ if isinstance(image_inputs.im_end_id[0], torch.Tensor)
+ else image_inputs.im_end_id[0]
+ )
+ slice_start_id = (
+ image_inputs.slice_start_id[0].item()
+ if isinstance(image_inputs.slice_start_id[0], torch.Tensor)
+ else image_inputs.slice_start_id[0]
+ )
+ slice_end_id = (
+ image_inputs.slice_end_id[0].item()
+ if isinstance(image_inputs.slice_end_id[0], torch.Tensor)
+ else image_inputs.slice_end_id[0]
+ )
+
+ # Find all start and end positions for both types
+ start_indices = [
+ i
+ for i, x in enumerate(input_ids)
+ if x == im_start_id or x == slice_start_id
+ ]
+ end_indices = [
+ i for i, x in enumerate(input_ids) if x == im_end_id or x == slice_end_id
+ ]
+
+ if len(start_indices) != len(end_indices):
+ return input_ids
+ # Process each region (both image and slice)
+ for start_idx, end_idx in zip(start_indices, end_indices):
+ # Add non-image tokens before this region
+ new_input_ids.extend(
+ input_ids[last_idx : start_idx + 1]
+ ) # include start token
+
+ is_image_start = input_ids[start_idx] == im_start_id
+
+ if is_image_start:
+ image_inputs.image_offsets += [start_idx]
+ image_idx += 1
+
+ num_tokens = end_idx - start_idx - 1 # exclude start and end tokens
+
+ # Generate pad_ids
+ pad_values = [image_inputs.pad_values[image_idx]]
+
+ pad_ids = pad_values * ((num_tokens + len(pad_values)) // len(pad_values))
+ pad_ids = pad_ids[:num_tokens]
+
+ # Add pad_ids
+ new_input_ids.extend(pad_ids)
+
+ # Update last_idx to after end token
+ last_idx = end_idx
+
+ # Add remaining tokens after last region
+ new_input_ids.extend(input_ids[last_idx:])
+ assert len(input_ids) == len(new_input_ids)
+ return new_input_ids
+
+
+_SUPPORT_VERSION = {(2, 6): MiniCPMV2_6}
+
+
+class MiniCPMV:
+ """
+ Different versions of MiniCPMV use different visual encoders and LLMs,
+ which is not conducive to the current integration logic of LoRA and
+ bitsandbytes in vLLM. Therefore, it is necessary to separate them.
+ """
+
+ # Ensure that the LoRA support check passes when the class is not
+ # initialized, but set all these attributes to empty.
+ packed_modules_mapping = {}
+ supported_lora_modules = []
+ embedding_modules = {}
+ embedding_padding_modules = []
+
+ minicpmv: nn.Module
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ ) -> None:
+ super().__init__()
+
+ if not hasattr(config, "version"):
+ version = (2, 6)
+ else:
+ version = str(config.version).split(".")
+ version = tuple([int(x) for x in version])
+ # Dispatch class based on version
+ instance_class = _SUPPORT_VERSION.get(version)
+ if instance_class is None:
+ raise ValueError("Currently, MiniCPMV only supports versions 2.6")
+
+ try:
+ minicpmv = instance_class(config=config, quant_config=quant_config)
+ self.minicpmv = minicpmv
+ except Exception as e:
+ print(f"Failed to instantiate MiniCPMV: {e}")
+ raise e
+ self.config = config
+
+ def __getattr__(self, name):
+ if name == "minicpmv":
+ return None
+ return getattr(self.minicpmv, name)
+
+ def __call__(self, *args, **kwargs):
+ return self.minicpmv(*args, **kwargs)
+
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ ("qkv_proj", "q_proj", "q"),
+ ("qkv_proj", "k_proj", "k"),
+ ("qkv_proj", "v_proj", "v"),
+ ("gate_up_proj", "gate_proj", 0),
+ ("gate_up_proj", "up_proj", 1),
+ ]
+
+ params_dict = dict(self.minicpmv.named_parameters())
+ for name, loaded_weight in weights:
+ if "rotary_emb.inv_freq~" in name or "projector" in name:
+ continue
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
+ # Models trained using ColossalAI may include these tensors in
+ # the checkpoint. Skip them.
+ continue
+ if name.startswith("model.vision_tower") and name not in params_dict:
+ continue
+
+ # adapt to VisionAttention
+ name = name.replace(r"self_attn.out_proj", r"self_attn.proj")
+
+ if "sampler" in name:
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
+ weight_loader(param, loaded_weight)
+ continue
+
+ for param_name, weight_name, shard_id in stacked_params_mapping:
+ # replace the name and load with customized loader
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+ # # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
+ weight_loader(param, loaded_weight)
+
+
+EntryClass = MiniCPMV
diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py
index 935e743bf..0c01ab9e5 100644
--- a/python/sglang/srt/models/qwen2.py
+++ b/python/sglang/srt/models/qwen2.py
@@ -248,6 +248,9 @@ class Qwen2Model(nn.Module):
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.embed_tokens(input_ids)
+
def forward(
self,
input_ids: torch.Tensor,
@@ -296,7 +299,6 @@ class Qwen2Model(nn.Module):
class Qwen2ForCausalLM(nn.Module):
-
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
@@ -334,6 +336,9 @@ class Qwen2ForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.model.get_input_embeddings(input_ids)
+
@torch.no_grad()
def forward(
self,
diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py
index 83912e894..0fb85679f 100644
--- a/python/sglang/srt/models/qwen2_vl.py
+++ b/python/sglang/srt/models/qwen2_vl.py
@@ -37,9 +37,7 @@ from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
from sglang.srt.distributed import parallel_state
from sglang.srt.distributed import utils as dist_utils
from sglang.srt.hf_transformers_utils import get_processor
-from sglang.srt.layers.attention.triton_ops.prefill_attention import (
- context_attention_fwd,
-)
+from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType
@@ -52,6 +50,7 @@ from sglang.srt.models.qwen2 import Qwen2Model
logger = logging.getLogger(__name__)
+
# === Vision Inputs === #
@@ -110,118 +109,6 @@ class Qwen2VisionMLP(nn.Module):
return x
-def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
- if not interleaved:
- x1, x2 = x.chunk(2, dim=-1)
- return torch.cat((-x2, x1), dim=-1)
- else:
- x1, x2 = x[..., ::2], x[..., 1::2]
- return rearrange(
- torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
- )
-
-
-def apply_rotary_emb_torch(
- x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
-) -> torch.Tensor:
- """
- x: (batch_size, seqlen, nheads, headdim)
- cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
- """
- ro_dim = cos.shape[-1] * 2
- assert ro_dim <= x.shape[-1]
- cos = repeat(
- cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
- )
- sin = repeat(
- sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
- )
- return torch.cat(
- [
- x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
- x[..., ro_dim:],
- ],
- dim=-1,
- )
-
-
-def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
- t_ = t.float()
- cos = freqs.cos()
- sin = freqs.sin()
- output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
- return output
-
-
-class Qwen2VisionAttention(nn.Module):
-
- def __init__(
- self,
- embed_dim: Optional[int] = None,
- num_heads: Optional[int] = None,
- projection_size: Optional[int] = None,
- quant_config: Optional[QuantizationConfig] = None,
- ) -> None:
- super().__init__()
- # Per attention head and per partition values.
- world_size = parallel_state.get_tensor_model_parallel_world_size()
- self.hidden_size_per_attention_head = dist_utils.divide(
- projection_size, num_heads
- )
- self.num_attention_heads_per_partition = dist_utils.divide(
- num_heads, world_size
- )
-
- self.qkv = ColumnParallelLinear(
- input_size=embed_dim,
- output_size=3 * projection_size,
- quant_config=quant_config,
- )
- self.proj = RowParallelLinear(
- input_size=projection_size, output_size=embed_dim, quant_config=quant_config
- )
-
- def forward(
- self,
- x: torch.Tensor,
- cu_seqlens: torch.Tensor,
- rotary_pos_emb: torch.Tensor = None,
- ) -> torch.Tensor:
- # [s, b, c] --> [s, b, head * 3 * head_dim]
- x, _ = self.qkv(x)
-
- # [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
- new_x_shape = x.size()[:-1] + (
- self.num_attention_heads_per_partition,
- 3 * self.hidden_size_per_attention_head,
- )
- x = x.view(*new_x_shape)
-
- # [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
- q, k, v = dist_utils.split_tensor_along_last_dim(x, 3)
- batch_size = q.shape[1]
-
- q, k, v = [rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)]
- if rotary_pos_emb is not None:
- q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
- k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
-
- seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
- max_seqlen = (seq_lens).max().item()
- q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
-
- output = torch.empty_like(q)
- context_attention_fwd(
- q, k, v, output, cu_seqlens, seq_lens, max_seqlen, is_causal=False
- )
-
- context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
- context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
-
- output, _ = self.proj(context_layer)
- return output
-
-
class Qwen2VisionBlock(nn.Module):
def __init__(
@@ -240,10 +127,11 @@ class Qwen2VisionBlock(nn.Module):
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
- self.attn = Qwen2VisionAttention(
+ self.attn = VisionAttention(
embed_dim=dim,
num_heads=num_heads,
projection_size=dim,
+ use_qkv_parallel=False,
quant_config=quant_config,
)
self.mlp = Qwen2VisionMLP(
@@ -253,9 +141,13 @@ class Qwen2VisionBlock(nn.Module):
def forward(
self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor
) -> torch.Tensor:
- x = x + self.attn(
- self.norm1(x), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
+ hidden_states = self.norm1(x)
+ hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
+ attn = self.attn(
+ hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
)
+ attn = rearrange(attn, "b s ... -> s b ...")
+ x = x + attn
x = x + self.mlp(self.norm2(x))
return x
@@ -684,10 +576,12 @@ class Qwen2VLForConditionalGeneration(nn.Module):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
+
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
+
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
@@ -696,6 +590,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
weight_loader(param, loaded_weight, shard_id)
break
else:
+
if "visual" in name and "qkv.weight" in name:
visual_num_heads = self.config.vision_config.num_heads
visual_embed_dim = self.config.vision_config.embed_dim
@@ -712,6 +607,11 @@ class Qwen2VLForConditionalGeneration(nn.Module):
loaded_weight = loaded_weight.view(3, visual_num_heads, head_size)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1)
+
+ if "visual" in name:
+ # adapt to VisionAttention
+ name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
+
try:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py
index a41b94301..b3526520c 100644
--- a/python/sglang/srt/server.py
+++ b/python/sglang/srt/server.py
@@ -565,6 +565,7 @@ def launch_engine(
# Assume all schedulers have same scheduler_info
scheduler_info = scheduler_infos[0]
+ tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
def launch_server(
diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py
index 4ba2a6c2c..3e8b95b15 100644
--- a/python/sglang/srt/utils.py
+++ b/python/sglang/srt/utils.py
@@ -451,6 +451,8 @@ def load_image(image_file: Union[str, bytes]):
else:
raise ValueError(f"Invalid image: {image}")
+ # if image_size is None:
+ # image_size = image.size
return image, image_size
diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py
index d3c9b7cab..c1437074f 100644
--- a/python/sglang/test/test_utils.py
+++ b/python/sglang/test/test_utils.py
@@ -406,7 +406,7 @@ def popen_launch_server(
base_url: str,
timeout: float,
api_key: Optional[str] = None,
- other_args: tuple = (),
+ other_args: list[str] = (),
env: Optional[dict] = None,
return_stdout_stderr: Optional[tuple] = None,
):
diff --git a/test/README.md b/test/README.md
index 3d739cc04..868061bbc 100644
--- a/test/README.md
+++ b/test/README.md
@@ -25,7 +25,7 @@ export OPENAI_API_KEY=sk-*****
python3 test_openai_backend.py
# Run a single test
-python3 -m unittest test_openai_backend.TestOpenAIBackend.test_few_shot_qa
+python3 -m unittest test_openai_backend.TestOpenAIServer.test_few_shot_qa
# Run a suite with multiple files
python3 run_suite.py --suite per-commit
diff --git a/test/srt/test_vision_openai_server.py b/test/srt/test_vision_openai_server.py
index e19e6b01d..163b0511e 100644
--- a/test/srt/test_vision_openai_server.py
+++ b/test/srt/test_vision_openai_server.py
@@ -171,7 +171,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
text = response.choices[0].message.content
assert isinstance(text, str)
print(text)
- assert "man" in text or "cab" in text, text
+ assert "man" in text or "cab" in text or "SUV" in text or "taxi" in text, text
assert "logo" in text or '"S"' in text or "SG" in text, text
assert response.id
assert response.created
@@ -444,5 +444,24 @@ class TestMllamaServer(TestOpenAIVisionServer):
pass
+class TestMinicpmvServer(TestOpenAIVisionServer):
+ @classmethod
+ def setUpClass(cls):
+ cls.model = "openbmb/MiniCPM-V-2_6"
+ cls.base_url = DEFAULT_URL_FOR_TEST
+ cls.api_key = "sk-123456"
+ cls.process = popen_launch_server(
+ cls.model,
+ cls.base_url,
+ timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
+ other_args=[
+ "--trust-remote-code",
+ "--chat-template",
+ "minicpmv",
+ ],
+ )
+ cls.base_url += "/v1"
+
+
if __name__ == "__main__":
unittest.main()