[Feature] Support minicpmv v2.6 (#2785)
Co-authored-by: Chayenne <zhaochen20@outlook.com> Co-authored-by: yizhang2077 <1109276519@qq.com>
This commit is contained in:
@@ -88,7 +88,6 @@ register_chat_template(
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="claude",
|
||||
@@ -101,7 +100,6 @@ register_chat_template(
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="chatml",
|
||||
@@ -116,7 +114,6 @@ register_chat_template(
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="chatml-llava",
|
||||
@@ -132,7 +129,6 @@ register_chat_template(
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# There is default system prompt for qwen
|
||||
# reference: https://modelscope.cn/models/qwen/Qwen2-72B-Instruct/file/view/master?fileName=tokenizer_config.json&status=1
|
||||
# The chat template is: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
|
||||
@@ -219,6 +215,21 @@ register_chat_template(
|
||||
)
|
||||
)
|
||||
|
||||
# https://huggingface.co/openbmb/MiniCPM-V-2_6
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="minicpmv",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("", " "),
|
||||
"user": ("user:", " "),
|
||||
"assistant": ("assistant:", "</s>"),
|
||||
},
|
||||
stop_str=("<|im_end|>", "<|endoftext|>"),
|
||||
image_token="(<image>./</image>)",
|
||||
)
|
||||
)
|
||||
|
||||
# The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token.
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
|
||||
@@ -402,6 +402,7 @@ def is_multimodal_model(model_architectures: List[str]):
|
||||
or "LlavaVidForCausalLM" in model_architectures
|
||||
or "MllamaForConditionalGeneration" in model_architectures
|
||||
or "Qwen2VLForConditionalGeneration" in model_architectures
|
||||
or "MiniCPMV" in model_architectures
|
||||
):
|
||||
return True
|
||||
else:
|
||||
|
||||
@@ -452,7 +452,6 @@ def generate_chat_conv(
|
||||
|
||||
# Add a blank message for the assistant.
|
||||
conv.append_message(conv.roles[1], None)
|
||||
|
||||
return conv
|
||||
|
||||
|
||||
@@ -555,3 +554,17 @@ register_conv_template(
|
||||
image_token="<|vision_start|><|image_pad|><|vision_end|>",
|
||||
)
|
||||
)
|
||||
|
||||
# Reference: https://huggingface.co/openbmb/MiniCPM-V-2_6#usage
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="minicpmv",
|
||||
system_message="You are a helpful assistant",
|
||||
system_template="<|im_start|>system\n{system_message}.",
|
||||
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
||||
sep="<|im_end|>\n",
|
||||
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
|
||||
stop_str=("<|im_end|>", "<|endoftext|>"),
|
||||
image_token="(<image>./</image>)",
|
||||
)
|
||||
)
|
||||
|
||||
204
python/sglang/srt/layers/attention/vision.py
Normal file
204
python/sglang/srt/layers/attention/vision.py
Normal file
@@ -0,0 +1,204 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange, repeat
|
||||
from vllm.distributed import parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
|
||||
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
||||
context_attention_fwd,
|
||||
)
|
||||
from sglang.srt.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.quantization import QuantizationConfig
|
||||
|
||||
|
||||
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
|
||||
if not interleaved:
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
else:
|
||||
x1, x2 = x[..., ::2], x[..., 1::2]
|
||||
return rearrange(
|
||||
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
|
||||
)
|
||||
|
||||
|
||||
def apply_rotary_emb_torch(
|
||||
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
x: (batch_size, seqlen, nheads, headdim)
|
||||
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
|
||||
"""
|
||||
ro_dim = cos.shape[-1] * 2
|
||||
assert ro_dim <= x.shape[-1]
|
||||
cos = repeat(
|
||||
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
||||
)
|
||||
sin = repeat(
|
||||
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
||||
)
|
||||
return torch.cat(
|
||||
[
|
||||
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
|
||||
x[..., ro_dim:],
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
||||
t_ = t.float()
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
|
||||
return output
|
||||
|
||||
|
||||
class VisionAttention(nn.Module):
|
||||
"""Multi-headed attention without any cache, mostly used for ViT."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
projection_size: int,
|
||||
use_qkv_parallel: bool,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
||||
|
||||
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||
projection_size, num_heads
|
||||
)
|
||||
self.num_attention_heads_per_partition = dist_utils.divide(
|
||||
num_heads, world_size
|
||||
)
|
||||
# self.tp_size = get_tensor_model_parallel_world_size()
|
||||
# num_heads = self.num_heads_per_partition
|
||||
self.use_qkv_parallel = use_qkv_parallel
|
||||
if use_qkv_parallel:
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size=embed_dim,
|
||||
head_size=self.head_dim,
|
||||
total_num_heads=num_heads,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
else:
|
||||
self.qkv_proj = ColumnParallelLinear(
|
||||
input_size=embed_dim,
|
||||
output_size=3 * projection_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.proj = RowParallelLinear(
|
||||
input_size=embed_dim,
|
||||
output_size=embed_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
rotary_pos_emb: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Input shape: [b, s, embed_dim]
|
||||
Output shape: [s, b, num_heads * head_size]
|
||||
"""
|
||||
|
||||
bsz, s, _ = x.shape
|
||||
if self.use_qkv_parallel:
|
||||
# [b, s, embed_dim] --> [b, s, embed_dim]
|
||||
qkv, _ = self.qkv_proj(x)
|
||||
q, k, v = qkv.chunk(3, dim=-1)
|
||||
|
||||
# [b, s, embed_dim] --> [b * s, num_heads, head_size]
|
||||
q, k, v = [
|
||||
x.reshape(
|
||||
bsz * s, self.num_attention_heads_per_partition, -1
|
||||
).contiguous()
|
||||
for x in (q, k, v)
|
||||
]
|
||||
else:
|
||||
# [b, s, embed_dim] --> [s, b, embed_dim]
|
||||
x = rearrange(x, "b s ... -> s b ...")
|
||||
# [s, b, embed_dim] --> [s, b, head * 3 * head_dim]
|
||||
qkv, _ = self.qkv_proj(x)
|
||||
# [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
|
||||
new_x_shape = qkv.size()[:-1] + (
|
||||
self.num_attention_heads_per_partition,
|
||||
3 * self.hidden_size_per_attention_head,
|
||||
)
|
||||
qkv = qkv.view(*new_x_shape)
|
||||
|
||||
# [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
|
||||
q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3)
|
||||
|
||||
# [s, b, head, head_dim] --> [b, s, head, head_dim]
|
||||
q, k, v = [
|
||||
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
|
||||
]
|
||||
|
||||
if rotary_pos_emb is not None:
|
||||
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
|
||||
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
|
||||
|
||||
if self.use_qkv_parallel:
|
||||
pass
|
||||
else:
|
||||
# [b, s, head, head_dim] --> [b * s, head, head_dim]
|
||||
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
|
||||
|
||||
# [b * s, num_heads, head_size]
|
||||
output = torch.empty_like(q)
|
||||
|
||||
seq_lens = (cu_seqlens[1:] - cu_seqlens[:-1]).cuda()
|
||||
max_seqlen = seq_lens.max().item()
|
||||
|
||||
context_attention_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
output,
|
||||
cu_seqlens.cuda(),
|
||||
seq_lens,
|
||||
max_seqlen,
|
||||
is_causal=False,
|
||||
)
|
||||
|
||||
if self.use_qkv_parallel:
|
||||
|
||||
# [b * s, head, head_dim] --> [b, s, head * head_dim]
|
||||
output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz)
|
||||
|
||||
# [b, s, head, head_dim] --> [b, s, head, head_dim]
|
||||
output, _ = self.proj(output)
|
||||
else:
|
||||
# [b * s, head, head_dim] --> [b, s, head, head_dim]
|
||||
context_layer = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||
|
||||
# [s, b, num_heads * head_size]
|
||||
context_layer = rearrange(
|
||||
context_layer, "b s h d -> s b (h d)"
|
||||
).contiguous()
|
||||
|
||||
# [s, b, num_heads * head_size] --> [s, b, num_heads * head_size]
|
||||
output, _ = self.proj(context_layer)
|
||||
|
||||
output = output.view(bsz, s, -1)
|
||||
|
||||
return output
|
||||
@@ -127,7 +127,7 @@ class LogitsProcessor(nn.Module):
|
||||
hidden_states,
|
||||
lm_head: VocabParallelEmbedding,
|
||||
logits_metadata: Union[LogitsMetadata, ForwardBatch],
|
||||
):
|
||||
) -> LogitsProcessorOutput:
|
||||
if isinstance(logits_metadata, ForwardBatch):
|
||||
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
|
||||
|
||||
|
||||
@@ -56,6 +56,7 @@ class DataParallelController:
|
||||
|
||||
def __init__(self, server_args, port_args) -> None:
|
||||
# Parse args
|
||||
self.max_total_num_tokens = None
|
||||
self.server_args = server_args
|
||||
self.port_args = port_args
|
||||
self.load_balance_method = LoadBalanceMethod.from_str(
|
||||
@@ -96,6 +97,8 @@ class DataParallelController:
|
||||
True,
|
||||
)
|
||||
|
||||
self.max_req_input_len = None
|
||||
|
||||
def launch_dp_schedulers(self, server_args, port_args):
|
||||
base_gpu_id = 0
|
||||
|
||||
@@ -189,6 +192,7 @@ class DataParallelController:
|
||||
scheduler_info.append(scheduler_pipe_readers[i].recv())
|
||||
|
||||
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
|
||||
self.max_req_input_len = scheduler_info[0]["max_req_input_len"]
|
||||
|
||||
def round_robin_scheduler(self, req):
|
||||
self.workers[self.round_robin_counter].send_pyobj(req)
|
||||
@@ -231,7 +235,11 @@ def run_data_parallel_controller_process(
|
||||
try:
|
||||
controller = DataParallelController(server_args, port_args)
|
||||
pipe_writer.send(
|
||||
{"status": "ready", "max_total_num_tokens": controller.max_total_num_tokens}
|
||||
{
|
||||
"status": "ready",
|
||||
"max_total_num_tokens": controller.max_total_num_tokens,
|
||||
"max_req_input_len": controller.max_req_input_len,
|
||||
}
|
||||
)
|
||||
if server_args.node_rank == 0:
|
||||
controller.event_loop()
|
||||
|
||||
@@ -9,6 +9,8 @@ from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import transformers
|
||||
from decord import VideoReader, cpu
|
||||
from PIL import Image
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_processor
|
||||
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
||||
@@ -36,6 +38,7 @@ class BaseImageProcessor(ABC):
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
self.hf_config = hf_config
|
||||
self._processor = _processor
|
||||
self.server_args = server_args
|
||||
|
||||
self.executor = concurrent.futures.ProcessPoolExecutor(
|
||||
initializer=init_global_processor,
|
||||
@@ -126,7 +129,12 @@ class LlavaImageProcessor(BaseImageProcessor):
|
||||
)
|
||||
|
||||
async def process_images_async(
|
||||
self, image_data: List[Union[str, bytes]], input_text, request_obj
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_text,
|
||||
request_obj,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
@@ -229,6 +237,147 @@ class MllamaImageProcessor(BaseImageProcessor):
|
||||
return image_inputs
|
||||
|
||||
|
||||
class MiniCPMVImageProcessor(BaseImageProcessor):
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
|
||||
@staticmethod
|
||||
def _process_images_task(images, input_text):
|
||||
result = global_processor.__call__(
|
||||
text=input_text, images=images, return_tensors="pt"
|
||||
)
|
||||
return {
|
||||
"input_ids": result["input_ids"],
|
||||
"pixel_values": result["pixel_values"],
|
||||
"tgt_sizes": result["tgt_sizes"],
|
||||
}
|
||||
|
||||
async def _process_images(self, images, input_text):
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
image_inputs = await loop.run_in_executor(
|
||||
self.executor,
|
||||
MiniCPMVImageProcessor._process_images_task,
|
||||
images,
|
||||
input_text,
|
||||
)
|
||||
else:
|
||||
image_inputs = self._processor(
|
||||
images=images, text=input_text, return_tensors="pt"
|
||||
)
|
||||
|
||||
return image_inputs
|
||||
|
||||
async def process_images_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
if not isinstance(image_data, list):
|
||||
image_data = [image_data]
|
||||
|
||||
image_hashes, image_sizes = [], []
|
||||
raw_images = []
|
||||
IMAGE_TOKEN = "(<image>./</image>)"
|
||||
|
||||
# roughly calculate the max number of frames
|
||||
# TODO: the process should be applied to all the visual inputs
|
||||
def calculate_max_num_frames() -> int:
|
||||
# Model-specific
|
||||
NUM_TOKEN_PER_FRAME = 330
|
||||
|
||||
ret = (max_req_input_len - len(input_text)) // NUM_TOKEN_PER_FRAME
|
||||
return min(ret, 100)
|
||||
|
||||
# if cuda OOM set a smaller number
|
||||
MAX_NUM_FRAMES = calculate_max_num_frames()
|
||||
print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}")
|
||||
|
||||
def encode_video(video_path):
|
||||
if not os.path.exists(video_path):
|
||||
logger.error(f"Video {video_path} does not exist")
|
||||
return []
|
||||
|
||||
if MAX_NUM_FRAMES == 0:
|
||||
return []
|
||||
|
||||
def uniform_sample(l, n):
|
||||
gap = len(l) / n
|
||||
idxs = [int(i * gap + gap / 2) for i in range(n)]
|
||||
return [l[i] for i in idxs]
|
||||
|
||||
vr = VideoReader(video_path, ctx=cpu(0))
|
||||
sample_fps = round(vr.get_avg_fps() / 1) # FPS
|
||||
frame_idx = [i for i in range(0, len(vr), sample_fps)]
|
||||
if len(frame_idx) > MAX_NUM_FRAMES:
|
||||
frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
|
||||
frames = vr.get_batch(frame_idx).asnumpy()
|
||||
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
|
||||
return frames
|
||||
|
||||
if isinstance(input_text, list):
|
||||
assert len(input_text) and isinstance(input_text[0], int)
|
||||
input_text = self._processor.tokenizer.decode(input_text)
|
||||
|
||||
# MiniCPMV requires each frame of video as a single image token
|
||||
text_parts = input_text.split(IMAGE_TOKEN)
|
||||
new_text_parts = []
|
||||
|
||||
for image_index, image in enumerate(image_data):
|
||||
try:
|
||||
if isinstance(image, str) and image.startswith("video:"):
|
||||
path = image[len("video:") :]
|
||||
frames = encode_video(path)
|
||||
else:
|
||||
raw_image, size = load_image(image)
|
||||
frames = [raw_image]
|
||||
if len(frames) == 0:
|
||||
continue
|
||||
except FileNotFoundError as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
image_sizes += frames[0].size * len(frames)
|
||||
image_hashes += [hash(image)] * len(frames)
|
||||
raw_images += frames
|
||||
new_text_parts.append(text_parts[image_index])
|
||||
new_text_parts.append(IMAGE_TOKEN * len(frames))
|
||||
|
||||
new_text_parts.append(text_parts[-1])
|
||||
input_text = "".join(new_text_parts)
|
||||
if len(raw_images) == 0:
|
||||
return None
|
||||
res = await self._process_images(images=raw_images, input_text=input_text)
|
||||
pixel_values = res["pixel_values"]
|
||||
tgt_sizes = res["tgt_sizes"]
|
||||
input_ids = res["input_ids"]
|
||||
|
||||
# Collect special token ids
|
||||
tokenizer = self._processor.tokenizer
|
||||
im_start_id = [tokenizer.im_start_id]
|
||||
im_end_id = [tokenizer.im_end_id]
|
||||
if tokenizer.slice_start_id:
|
||||
slice_start_id = [tokenizer.slice_start_id]
|
||||
slice_end_id = [tokenizer.slice_end_id]
|
||||
|
||||
return {
|
||||
"input_ids": input_ids.flatten().tolist(),
|
||||
"pixel_values": pixel_values,
|
||||
"tgt_sizes": tgt_sizes,
|
||||
"image_hashes": image_hashes,
|
||||
"modalities": request_obj.modalities or ["image"],
|
||||
"im_start_id": im_start_id,
|
||||
"im_end_id": im_end_id,
|
||||
"slice_start_id": slice_start_id,
|
||||
"slice_end_id": slice_end_id,
|
||||
}
|
||||
|
||||
|
||||
class Qwen2VLImageProcessor(BaseImageProcessor):
|
||||
def __init__(self, hf_config, server_args, _image_processor):
|
||||
self.hf_config = hf_config
|
||||
@@ -289,7 +438,12 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
|
||||
return self._process_single_image_task(image_data)
|
||||
|
||||
async def process_images_async(
|
||||
self, image_data: List[Union[str, bytes]], input_text, request_obj
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_text,
|
||||
request_obj,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
@@ -350,6 +504,8 @@ def get_image_processor(
|
||||
return MllamaImageProcessor(hf_config, server_args, processor)
|
||||
elif "Qwen2VLForConditionalGeneration" in hf_config.architectures:
|
||||
return Qwen2VLImageProcessor(hf_config, server_args, processor.image_processor)
|
||||
elif "MiniCPMV" in hf_config.architectures:
|
||||
return MiniCPMVImageProcessor(hf_config, server_args, processor)
|
||||
else:
|
||||
return LlavaImageProcessor(hf_config, server_args, processor.image_processor)
|
||||
|
||||
|
||||
@@ -52,7 +52,6 @@ from sglang.srt.server_args import ServerArgs
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
|
||||
|
||||
|
||||
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||
|
||||
# Put some global args for easy access
|
||||
@@ -68,7 +67,6 @@ global_server_args_dict = {
|
||||
"device": ServerArgs.device,
|
||||
}
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -149,6 +147,16 @@ class ImageInputs:
|
||||
image_grid_thws: List[Tuple[int, int, int]] = None
|
||||
mrope_position_delta: Optional[torch.Tensor] = None
|
||||
|
||||
# MiniCPMV related
|
||||
# All the images in the batch should share the same special image
|
||||
# bound token ids.
|
||||
im_start_id: Optional[torch.Tensor] = None
|
||||
im_end_id: Optional[torch.Tensor] = None
|
||||
slice_start_id: Optional[torch.Tensor] = None
|
||||
slice_end_id: Optional[torch.Tensor] = None
|
||||
|
||||
tgt_sizes: Optional[list] = None
|
||||
|
||||
@staticmethod
|
||||
def from_dict(obj: dict):
|
||||
ret = ImageInputs(
|
||||
@@ -168,6 +176,11 @@ class ImageInputs:
|
||||
"aspect_ratio_ids",
|
||||
"aspect_ratio_mask",
|
||||
"image_grid_thws",
|
||||
"im_start_id",
|
||||
"im_end_id",
|
||||
"slice_start_id",
|
||||
"slice_end_id",
|
||||
"tgt_sizes",
|
||||
]
|
||||
for arg in optional_args:
|
||||
if arg in obj:
|
||||
@@ -1140,7 +1153,6 @@ class ScheduleBatch:
|
||||
|
||||
global bid
|
||||
bid += 1
|
||||
|
||||
return ModelWorkerBatch(
|
||||
bid=bid,
|
||||
forward_mode=self.forward_mode,
|
||||
|
||||
@@ -274,7 +274,6 @@ class Scheduler:
|
||||
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
|
||||
global_server_args_dict.update(worker_global_server_args_dict)
|
||||
set_random_seed(self.random_seed)
|
||||
|
||||
# Print debug info
|
||||
logger.info(
|
||||
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
||||
@@ -1729,7 +1728,11 @@ def run_scheduler_process(
|
||||
try:
|
||||
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
|
||||
pipe_writer.send(
|
||||
{"status": "ready", "max_total_num_tokens": scheduler.max_total_num_tokens}
|
||||
{
|
||||
"status": "ready",
|
||||
"max_total_num_tokens": scheduler.max_total_num_tokens,
|
||||
"max_req_input_len": scheduler.max_req_input_len,
|
||||
}
|
||||
)
|
||||
if scheduler.enable_overlap:
|
||||
scheduler.event_loop_overlap()
|
||||
|
||||
@@ -112,6 +112,7 @@ class TokenizerManager:
|
||||
port_args: PortArgs,
|
||||
):
|
||||
# Parse args
|
||||
|
||||
self.server_args = server_args
|
||||
self.enable_metrics = server_args.enable_metrics
|
||||
self.log_requests = server_args.log_requests
|
||||
@@ -207,6 +208,8 @@ class TokenizerManager:
|
||||
self.resume_memory_occupation_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
# Set after scheduler is initialized
|
||||
self.max_req_input_len = None
|
||||
|
||||
# Metrics
|
||||
if self.enable_metrics:
|
||||
@@ -281,7 +284,7 @@ class TokenizerManager:
|
||||
if self.is_generation:
|
||||
# TODO: also support getting embeddings for multimodal models
|
||||
image_inputs: Dict = await self.image_processor.process_images_async(
|
||||
obj.image_data, input_text or input_ids, obj
|
||||
obj.image_data, input_text or input_ids, obj, self.max_req_input_len
|
||||
)
|
||||
if image_inputs and "input_ids" in image_inputs:
|
||||
input_ids = image_inputs["input_ids"]
|
||||
|
||||
@@ -237,7 +237,7 @@ class ModelRunner:
|
||||
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
|
||||
|
||||
if not self.is_draft_worker:
|
||||
# Only initilzie the distributed environment on the target model worker.
|
||||
# Only initialize the distributed environment on the target model worker.
|
||||
init_distributed_environment(
|
||||
backend=backend,
|
||||
world_size=self.tp_size,
|
||||
|
||||
1238
python/sglang/srt/models/minicpmv.py
Normal file
1238
python/sglang/srt/models/minicpmv.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -248,6 +248,9 @@ class Qwen2Model(nn.Module):
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -296,7 +299,6 @@ class Qwen2Model(nn.Module):
|
||||
|
||||
|
||||
class Qwen2ForCausalLM(nn.Module):
|
||||
|
||||
# BitandBytes specific attributes
|
||||
default_bitsandbytes_target_modules = [
|
||||
".gate_proj.",
|
||||
@@ -334,6 +336,9 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -37,9 +37,7 @@ from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
|
||||
from sglang.srt.distributed import parallel_state
|
||||
from sglang.srt.distributed import utils as dist_utils
|
||||
from sglang.srt.hf_transformers_utils import get_processor
|
||||
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
||||
context_attention_fwd,
|
||||
)
|
||||
from sglang.srt.layers.attention.vision import VisionAttention
|
||||
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||
@@ -52,6 +50,7 @@ from sglang.srt.models.qwen2 import Qwen2Model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# === Vision Inputs === #
|
||||
|
||||
|
||||
@@ -110,118 +109,6 @@ class Qwen2VisionMLP(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
|
||||
if not interleaved:
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
else:
|
||||
x1, x2 = x[..., ::2], x[..., 1::2]
|
||||
return rearrange(
|
||||
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
|
||||
)
|
||||
|
||||
|
||||
def apply_rotary_emb_torch(
|
||||
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
x: (batch_size, seqlen, nheads, headdim)
|
||||
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
|
||||
"""
|
||||
ro_dim = cos.shape[-1] * 2
|
||||
assert ro_dim <= x.shape[-1]
|
||||
cos = repeat(
|
||||
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
||||
)
|
||||
sin = repeat(
|
||||
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
||||
)
|
||||
return torch.cat(
|
||||
[
|
||||
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
|
||||
x[..., ro_dim:],
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
||||
t_ = t.float()
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
|
||||
return output
|
||||
|
||||
|
||||
class Qwen2VisionAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: Optional[int] = None,
|
||||
num_heads: Optional[int] = None,
|
||||
projection_size: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# Per attention head and per partition values.
|
||||
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
||||
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||
projection_size, num_heads
|
||||
)
|
||||
self.num_attention_heads_per_partition = dist_utils.divide(
|
||||
num_heads, world_size
|
||||
)
|
||||
|
||||
self.qkv = ColumnParallelLinear(
|
||||
input_size=embed_dim,
|
||||
output_size=3 * projection_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.proj = RowParallelLinear(
|
||||
input_size=projection_size, output_size=embed_dim, quant_config=quant_config
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
||||
x, _ = self.qkv(x)
|
||||
|
||||
# [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
|
||||
new_x_shape = x.size()[:-1] + (
|
||||
self.num_attention_heads_per_partition,
|
||||
3 * self.hidden_size_per_attention_head,
|
||||
)
|
||||
x = x.view(*new_x_shape)
|
||||
|
||||
# [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
|
||||
q, k, v = dist_utils.split_tensor_along_last_dim(x, 3)
|
||||
batch_size = q.shape[1]
|
||||
|
||||
q, k, v = [rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)]
|
||||
if rotary_pos_emb is not None:
|
||||
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
|
||||
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
|
||||
|
||||
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
max_seqlen = (seq_lens).max().item()
|
||||
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
|
||||
|
||||
output = torch.empty_like(q)
|
||||
context_attention_fwd(
|
||||
q, k, v, output, cu_seqlens, seq_lens, max_seqlen, is_causal=False
|
||||
)
|
||||
|
||||
context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
|
||||
context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
|
||||
|
||||
output, _ = self.proj(context_layer)
|
||||
return output
|
||||
|
||||
|
||||
class Qwen2VisionBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@@ -240,10 +127,11 @@ class Qwen2VisionBlock(nn.Module):
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
|
||||
self.attn = Qwen2VisionAttention(
|
||||
self.attn = VisionAttention(
|
||||
embed_dim=dim,
|
||||
num_heads=num_heads,
|
||||
projection_size=dim,
|
||||
use_qkv_parallel=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.mlp = Qwen2VisionMLP(
|
||||
@@ -253,9 +141,13 @@ class Qwen2VisionBlock(nn.Module):
|
||||
def forward(
|
||||
self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
x = x + self.attn(
|
||||
self.norm1(x), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
|
||||
hidden_states = self.norm1(x)
|
||||
hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
|
||||
attn = self.attn(
|
||||
hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
|
||||
)
|
||||
attn = rearrange(attn, "b s ... -> s b ...")
|
||||
x = x + attn
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
@@ -684,10 +576,12 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
@@ -696,6 +590,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
|
||||
if "visual" in name and "qkv.weight" in name:
|
||||
visual_num_heads = self.config.vision_config.num_heads
|
||||
visual_embed_dim = self.config.vision_config.embed_dim
|
||||
@@ -712,6 +607,11 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
loaded_weight = loaded_weight.view(3, visual_num_heads, head_size)
|
||||
loaded_weight = loaded_weight.transpose(0, 1)
|
||||
loaded_weight = loaded_weight.reshape(-1)
|
||||
|
||||
if "visual" in name:
|
||||
# adapt to VisionAttention
|
||||
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
|
||||
|
||||
try:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
|
||||
@@ -565,6 +565,7 @@ def launch_engine(
|
||||
|
||||
# Assume all schedulers have same scheduler_info
|
||||
scheduler_info = scheduler_infos[0]
|
||||
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
|
||||
|
||||
|
||||
def launch_server(
|
||||
|
||||
@@ -451,6 +451,8 @@ def load_image(image_file: Union[str, bytes]):
|
||||
else:
|
||||
raise ValueError(f"Invalid image: {image}")
|
||||
|
||||
# if image_size is None:
|
||||
# image_size = image.size
|
||||
return image, image_size
|
||||
|
||||
|
||||
|
||||
@@ -406,7 +406,7 @@ def popen_launch_server(
|
||||
base_url: str,
|
||||
timeout: float,
|
||||
api_key: Optional[str] = None,
|
||||
other_args: tuple = (),
|
||||
other_args: list[str] = (),
|
||||
env: Optional[dict] = None,
|
||||
return_stdout_stderr: Optional[tuple] = None,
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user