[Feature] Support minicpmv v2.6 (#2785)

Co-authored-by: Chayenne <zhaochen20@outlook.com>
Co-authored-by: yizhang2077 <1109276519@qq.com>
This commit is contained in:
Mick
2025-01-19 06:14:19 +08:00
committed by GitHub
parent c2f212d672
commit 3d93f84a00
20 changed files with 1715 additions and 139 deletions

View File

@@ -88,7 +88,6 @@ register_chat_template(
)
)
register_chat_template(
ChatTemplate(
name="claude",
@@ -101,7 +100,6 @@ register_chat_template(
)
)
register_chat_template(
ChatTemplate(
name="chatml",
@@ -116,7 +114,6 @@ register_chat_template(
)
)
register_chat_template(
ChatTemplate(
name="chatml-llava",
@@ -132,7 +129,6 @@ register_chat_template(
)
)
# There is default system prompt for qwen
# reference: https://modelscope.cn/models/qwen/Qwen2-72B-Instruct/file/view/master?fileName=tokenizer_config.json&status=1
# The chat template is: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
@@ -219,6 +215,21 @@ register_chat_template(
)
)
# https://huggingface.co/openbmb/MiniCPM-V-2_6
register_chat_template(
ChatTemplate(
name="minicpmv",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("", " "),
"user": ("user:", " "),
"assistant": ("assistant:", "</s>"),
},
stop_str=("<|im_end|>", "<|endoftext|>"),
image_token="(<image>./</image>)",
)
)
# The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token.
register_chat_template(
ChatTemplate(

View File

@@ -402,6 +402,7 @@ def is_multimodal_model(model_architectures: List[str]):
or "LlavaVidForCausalLM" in model_architectures
or "MllamaForConditionalGeneration" in model_architectures
or "Qwen2VLForConditionalGeneration" in model_architectures
or "MiniCPMV" in model_architectures
):
return True
else:

View File

@@ -452,7 +452,6 @@ def generate_chat_conv(
# Add a blank message for the assistant.
conv.append_message(conv.roles[1], None)
return conv
@@ -555,3 +554,17 @@ register_conv_template(
image_token="<|vision_start|><|image_pad|><|vision_end|>",
)
)
# Reference: https://huggingface.co/openbmb/MiniCPM-V-2_6#usage
register_conv_template(
Conversation(
name="minicpmv",
system_message="You are a helpful assistant",
system_template="<|im_start|>system\n{system_message}.",
roles=("<|im_start|>user", "<|im_start|>assistant"),
sep="<|im_end|>\n",
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
stop_str=("<|im_end|>", "<|endoftext|>"),
image_token="(<image>./</image>)",
)
)

View File

@@ -0,0 +1,204 @@
from __future__ import annotations
from typing import Optional
import torch
import torch.nn as nn
from einops import rearrange, repeat
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
context_attention_fwd,
)
from sglang.srt.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from sglang.srt.layers.quantization import QuantizationConfig
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
)
def apply_rotary_emb_torch(
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
) -> torch.Tensor:
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
sin = repeat(
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
return torch.cat(
[
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
x[..., ro_dim:],
],
dim=-1,
)
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
t_ = t.float()
cos = freqs.cos()
sin = freqs.sin()
output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
return output
class VisionAttention(nn.Module):
"""Multi-headed attention without any cache, mostly used for ViT."""
def __init__(
self,
embed_dim: int,
num_heads: int,
projection_size: int,
use_qkv_parallel: bool,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
world_size = parallel_state.get_tensor_model_parallel_world_size()
self.hidden_size_per_attention_head = dist_utils.divide(
projection_size, num_heads
)
self.num_attention_heads_per_partition = dist_utils.divide(
num_heads, world_size
)
# self.tp_size = get_tensor_model_parallel_world_size()
# num_heads = self.num_heads_per_partition
self.use_qkv_parallel = use_qkv_parallel
if use_qkv_parallel:
self.head_dim = embed_dim // num_heads
self.qkv_proj = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.head_dim,
total_num_heads=num_heads,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
else:
self.qkv_proj = ColumnParallelLinear(
input_size=embed_dim,
output_size=3 * projection_size,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.proj = RowParallelLinear(
input_size=embed_dim,
output_size=embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
def forward(
self,
x: torch.Tensor,
cu_seqlens: Optional[torch.Tensor] = None,
rotary_pos_emb: torch.Tensor = None,
) -> torch.Tensor:
"""
Input shape: [b, s, embed_dim]
Output shape: [s, b, num_heads * head_size]
"""
bsz, s, _ = x.shape
if self.use_qkv_parallel:
# [b, s, embed_dim] --> [b, s, embed_dim]
qkv, _ = self.qkv_proj(x)
q, k, v = qkv.chunk(3, dim=-1)
# [b, s, embed_dim] --> [b * s, num_heads, head_size]
q, k, v = [
x.reshape(
bsz * s, self.num_attention_heads_per_partition, -1
).contiguous()
for x in (q, k, v)
]
else:
# [b, s, embed_dim] --> [s, b, embed_dim]
x = rearrange(x, "b s ... -> s b ...")
# [s, b, embed_dim] --> [s, b, head * 3 * head_dim]
qkv, _ = self.qkv_proj(x)
# [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
new_x_shape = qkv.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
qkv = qkv.view(*new_x_shape)
# [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3)
# [s, b, head, head_dim] --> [b, s, head, head_dim]
q, k, v = [
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
]
if rotary_pos_emb is not None:
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
if self.use_qkv_parallel:
pass
else:
# [b, s, head, head_dim] --> [b * s, head, head_dim]
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
# [b * s, num_heads, head_size]
output = torch.empty_like(q)
seq_lens = (cu_seqlens[1:] - cu_seqlens[:-1]).cuda()
max_seqlen = seq_lens.max().item()
context_attention_fwd(
q,
k,
v,
output,
cu_seqlens.cuda(),
seq_lens,
max_seqlen,
is_causal=False,
)
if self.use_qkv_parallel:
# [b * s, head, head_dim] --> [b, s, head * head_dim]
output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz)
# [b, s, head, head_dim] --> [b, s, head, head_dim]
output, _ = self.proj(output)
else:
# [b * s, head, head_dim] --> [b, s, head, head_dim]
context_layer = rearrange(output, "(b s) ... -> b s ...", b=bsz)
# [s, b, num_heads * head_size]
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
# [s, b, num_heads * head_size] --> [s, b, num_heads * head_size]
output, _ = self.proj(context_layer)
output = output.view(bsz, s, -1)
return output

View File

@@ -127,7 +127,7 @@ class LogitsProcessor(nn.Module):
hidden_states,
lm_head: VocabParallelEmbedding,
logits_metadata: Union[LogitsMetadata, ForwardBatch],
):
) -> LogitsProcessorOutput:
if isinstance(logits_metadata, ForwardBatch):
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)

View File

@@ -56,6 +56,7 @@ class DataParallelController:
def __init__(self, server_args, port_args) -> None:
# Parse args
self.max_total_num_tokens = None
self.server_args = server_args
self.port_args = port_args
self.load_balance_method = LoadBalanceMethod.from_str(
@@ -96,6 +97,8 @@ class DataParallelController:
True,
)
self.max_req_input_len = None
def launch_dp_schedulers(self, server_args, port_args):
base_gpu_id = 0
@@ -189,6 +192,7 @@ class DataParallelController:
scheduler_info.append(scheduler_pipe_readers[i].recv())
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
self.max_req_input_len = scheduler_info[0]["max_req_input_len"]
def round_robin_scheduler(self, req):
self.workers[self.round_robin_counter].send_pyobj(req)
@@ -231,7 +235,11 @@ def run_data_parallel_controller_process(
try:
controller = DataParallelController(server_args, port_args)
pipe_writer.send(
{"status": "ready", "max_total_num_tokens": controller.max_total_num_tokens}
{
"status": "ready",
"max_total_num_tokens": controller.max_total_num_tokens,
"max_req_input_len": controller.max_req_input_len,
}
)
if server_args.node_rank == 0:
controller.event_loop()

View File

@@ -9,6 +9,8 @@ from typing import List, Optional, Union
import numpy as np
import transformers
from decord import VideoReader, cpu
from PIL import Image
from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.mm_utils import expand2square, process_anyres_image
@@ -36,6 +38,7 @@ class BaseImageProcessor(ABC):
def __init__(self, hf_config, server_args, _processor):
self.hf_config = hf_config
self._processor = _processor
self.server_args = server_args
self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor,
@@ -126,7 +129,12 @@ class LlavaImageProcessor(BaseImageProcessor):
)
async def process_images_async(
self, image_data: List[Union[str, bytes]], input_text, request_obj
self,
image_data: List[Union[str, bytes]],
input_text,
request_obj,
*args,
**kwargs,
):
if not image_data:
return None
@@ -229,6 +237,147 @@ class MllamaImageProcessor(BaseImageProcessor):
return image_inputs
class MiniCPMVImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
@staticmethod
def _process_images_task(images, input_text):
result = global_processor.__call__(
text=input_text, images=images, return_tensors="pt"
)
return {
"input_ids": result["input_ids"],
"pixel_values": result["pixel_values"],
"tgt_sizes": result["tgt_sizes"],
}
async def _process_images(self, images, input_text):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
MiniCPMVImageProcessor._process_images_task,
images,
input_text,
)
else:
image_inputs = self._processor(
images=images, text=input_text, return_tensors="pt"
)
return image_inputs
async def process_images_async(
self,
image_data: List[Union[str, bytes]],
input_text,
request_obj,
max_req_input_len,
):
if not image_data:
return None
if not isinstance(image_data, list):
image_data = [image_data]
image_hashes, image_sizes = [], []
raw_images = []
IMAGE_TOKEN = "(<image>./</image>)"
# roughly calculate the max number of frames
# TODO: the process should be applied to all the visual inputs
def calculate_max_num_frames() -> int:
# Model-specific
NUM_TOKEN_PER_FRAME = 330
ret = (max_req_input_len - len(input_text)) // NUM_TOKEN_PER_FRAME
return min(ret, 100)
# if cuda OOM set a smaller number
MAX_NUM_FRAMES = calculate_max_num_frames()
print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}")
def encode_video(video_path):
if not os.path.exists(video_path):
logger.error(f"Video {video_path} does not exist")
return []
if MAX_NUM_FRAMES == 0:
return []
def uniform_sample(l, n):
gap = len(l) / n
idxs = [int(i * gap + gap / 2) for i in range(n)]
return [l[i] for i in idxs]
vr = VideoReader(video_path, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1) # FPS
frame_idx = [i for i in range(0, len(vr), sample_fps)]
if len(frame_idx) > MAX_NUM_FRAMES:
frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
frames = vr.get_batch(frame_idx).asnumpy()
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
return frames
if isinstance(input_text, list):
assert len(input_text) and isinstance(input_text[0], int)
input_text = self._processor.tokenizer.decode(input_text)
# MiniCPMV requires each frame of video as a single image token
text_parts = input_text.split(IMAGE_TOKEN)
new_text_parts = []
for image_index, image in enumerate(image_data):
try:
if isinstance(image, str) and image.startswith("video:"):
path = image[len("video:") :]
frames = encode_video(path)
else:
raw_image, size = load_image(image)
frames = [raw_image]
if len(frames) == 0:
continue
except FileNotFoundError as e:
print(e)
return None
image_sizes += frames[0].size * len(frames)
image_hashes += [hash(image)] * len(frames)
raw_images += frames
new_text_parts.append(text_parts[image_index])
new_text_parts.append(IMAGE_TOKEN * len(frames))
new_text_parts.append(text_parts[-1])
input_text = "".join(new_text_parts)
if len(raw_images) == 0:
return None
res = await self._process_images(images=raw_images, input_text=input_text)
pixel_values = res["pixel_values"]
tgt_sizes = res["tgt_sizes"]
input_ids = res["input_ids"]
# Collect special token ids
tokenizer = self._processor.tokenizer
im_start_id = [tokenizer.im_start_id]
im_end_id = [tokenizer.im_end_id]
if tokenizer.slice_start_id:
slice_start_id = [tokenizer.slice_start_id]
slice_end_id = [tokenizer.slice_end_id]
return {
"input_ids": input_ids.flatten().tolist(),
"pixel_values": pixel_values,
"tgt_sizes": tgt_sizes,
"image_hashes": image_hashes,
"modalities": request_obj.modalities or ["image"],
"im_start_id": im_start_id,
"im_end_id": im_end_id,
"slice_start_id": slice_start_id,
"slice_end_id": slice_end_id,
}
class Qwen2VLImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _image_processor):
self.hf_config = hf_config
@@ -289,7 +438,12 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
return self._process_single_image_task(image_data)
async def process_images_async(
self, image_data: List[Union[str, bytes]], input_text, request_obj
self,
image_data: List[Union[str, bytes]],
input_text,
request_obj,
*args,
**kwargs,
):
if not image_data:
return None
@@ -350,6 +504,8 @@ def get_image_processor(
return MllamaImageProcessor(hf_config, server_args, processor)
elif "Qwen2VLForConditionalGeneration" in hf_config.architectures:
return Qwen2VLImageProcessor(hf_config, server_args, processor.image_processor)
elif "MiniCPMV" in hf_config.architectures:
return MiniCPMVImageProcessor(hf_config, server_args, processor)
else:
return LlavaImageProcessor(hf_config, server_args, processor.image_processor)

View File

@@ -52,7 +52,6 @@ from sglang.srt.server_args import ServerArgs
if TYPE_CHECKING:
from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
# Put some global args for easy access
@@ -68,7 +67,6 @@ global_server_args_dict = {
"device": ServerArgs.device,
}
logger = logging.getLogger(__name__)
@@ -149,6 +147,16 @@ class ImageInputs:
image_grid_thws: List[Tuple[int, int, int]] = None
mrope_position_delta: Optional[torch.Tensor] = None
# MiniCPMV related
# All the images in the batch should share the same special image
# bound token ids.
im_start_id: Optional[torch.Tensor] = None
im_end_id: Optional[torch.Tensor] = None
slice_start_id: Optional[torch.Tensor] = None
slice_end_id: Optional[torch.Tensor] = None
tgt_sizes: Optional[list] = None
@staticmethod
def from_dict(obj: dict):
ret = ImageInputs(
@@ -168,6 +176,11 @@ class ImageInputs:
"aspect_ratio_ids",
"aspect_ratio_mask",
"image_grid_thws",
"im_start_id",
"im_end_id",
"slice_start_id",
"slice_end_id",
"tgt_sizes",
]
for arg in optional_args:
if arg in obj:
@@ -1140,7 +1153,6 @@ class ScheduleBatch:
global bid
bid += 1
return ModelWorkerBatch(
bid=bid,
forward_mode=self.forward_mode,

View File

@@ -274,7 +274,6 @@ class Scheduler:
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
global_server_args_dict.update(worker_global_server_args_dict)
set_random_seed(self.random_seed)
# Print debug info
logger.info(
f"max_total_num_tokens={self.max_total_num_tokens}, "
@@ -1729,7 +1728,11 @@ def run_scheduler_process(
try:
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
pipe_writer.send(
{"status": "ready", "max_total_num_tokens": scheduler.max_total_num_tokens}
{
"status": "ready",
"max_total_num_tokens": scheduler.max_total_num_tokens,
"max_req_input_len": scheduler.max_req_input_len,
}
)
if scheduler.enable_overlap:
scheduler.event_loop_overlap()

View File

@@ -112,6 +112,7 @@ class TokenizerManager:
port_args: PortArgs,
):
# Parse args
self.server_args = server_args
self.enable_metrics = server_args.enable_metrics
self.log_requests = server_args.log_requests
@@ -207,6 +208,8 @@ class TokenizerManager:
self.resume_memory_occupation_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
# Set after scheduler is initialized
self.max_req_input_len = None
# Metrics
if self.enable_metrics:
@@ -281,7 +284,7 @@ class TokenizerManager:
if self.is_generation:
# TODO: also support getting embeddings for multimodal models
image_inputs: Dict = await self.image_processor.process_images_async(
obj.image_data, input_text or input_ids, obj
obj.image_data, input_text or input_ids, obj, self.max_req_input_len
)
if image_inputs and "input_ids" in image_inputs:
input_ids = image_inputs["input_ids"]

View File

@@ -237,7 +237,7 @@ class ModelRunner:
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
if not self.is_draft_worker:
# Only initilzie the distributed environment on the target model worker.
# Only initialize the distributed environment on the target model worker.
init_distributed_environment(
backend=backend,
world_size=self.tp_size,

File diff suppressed because it is too large Load Diff

View File

@@ -248,6 +248,9 @@ class Qwen2Model(nn.Module):
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
@@ -296,7 +299,6 @@ class Qwen2Model(nn.Module):
class Qwen2ForCausalLM(nn.Module):
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
@@ -334,6 +336,9 @@ class Qwen2ForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
@torch.no_grad()
def forward(
self,

View File

@@ -37,9 +37,7 @@ from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
from sglang.srt.distributed import parallel_state
from sglang.srt.distributed import utils as dist_utils
from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
context_attention_fwd,
)
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType
@@ -52,6 +50,7 @@ from sglang.srt.models.qwen2 import Qwen2Model
logger = logging.getLogger(__name__)
# === Vision Inputs === #
@@ -110,118 +109,6 @@ class Qwen2VisionMLP(nn.Module):
return x
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
)
def apply_rotary_emb_torch(
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
) -> torch.Tensor:
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
sin = repeat(
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
return torch.cat(
[
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
x[..., ro_dim:],
],
dim=-1,
)
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
t_ = t.float()
cos = freqs.cos()
sin = freqs.sin()
output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
return output
class Qwen2VisionAttention(nn.Module):
def __init__(
self,
embed_dim: Optional[int] = None,
num_heads: Optional[int] = None,
projection_size: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
# Per attention head and per partition values.
world_size = parallel_state.get_tensor_model_parallel_world_size()
self.hidden_size_per_attention_head = dist_utils.divide(
projection_size, num_heads
)
self.num_attention_heads_per_partition = dist_utils.divide(
num_heads, world_size
)
self.qkv = ColumnParallelLinear(
input_size=embed_dim,
output_size=3 * projection_size,
quant_config=quant_config,
)
self.proj = RowParallelLinear(
input_size=projection_size, output_size=embed_dim, quant_config=quant_config
)
def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor = None,
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
# [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
new_x_shape = x.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
x = x.view(*new_x_shape)
# [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
q, k, v = dist_utils.split_tensor_along_last_dim(x, 3)
batch_size = q.shape[1]
q, k, v = [rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)]
if rotary_pos_emb is not None:
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
max_seqlen = (seq_lens).max().item()
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
output = torch.empty_like(q)
context_attention_fwd(
q, k, v, output, cu_seqlens, seq_lens, max_seqlen, is_causal=False
)
context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
output, _ = self.proj(context_layer)
return output
class Qwen2VisionBlock(nn.Module):
def __init__(
@@ -240,10 +127,11 @@ class Qwen2VisionBlock(nn.Module):
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.attn = Qwen2VisionAttention(
self.attn = VisionAttention(
embed_dim=dim,
num_heads=num_heads,
projection_size=dim,
use_qkv_parallel=False,
quant_config=quant_config,
)
self.mlp = Qwen2VisionMLP(
@@ -253,9 +141,13 @@ class Qwen2VisionBlock(nn.Module):
def forward(
self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
hidden_states = self.norm1(x)
hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
attn = self.attn(
hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
)
attn = rearrange(attn, "b s ... -> s b ...")
x = x + attn
x = x + self.mlp(self.norm2(x))
return x
@@ -684,10 +576,12 @@ class Qwen2VLForConditionalGeneration(nn.Module):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
@@ -696,6 +590,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
weight_loader(param, loaded_weight, shard_id)
break
else:
if "visual" in name and "qkv.weight" in name:
visual_num_heads = self.config.vision_config.num_heads
visual_embed_dim = self.config.vision_config.embed_dim
@@ -712,6 +607,11 @@ class Qwen2VLForConditionalGeneration(nn.Module):
loaded_weight = loaded_weight.view(3, visual_num_heads, head_size)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1)
if "visual" in name:
# adapt to VisionAttention
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
try:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:

View File

@@ -565,6 +565,7 @@ def launch_engine(
# Assume all schedulers have same scheduler_info
scheduler_info = scheduler_infos[0]
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
def launch_server(

View File

@@ -451,6 +451,8 @@ def load_image(image_file: Union[str, bytes]):
else:
raise ValueError(f"Invalid image: {image}")
# if image_size is None:
# image_size = image.size
return image, image_size

View File

@@ -406,7 +406,7 @@ def popen_launch_server(
base_url: str,
timeout: float,
api_key: Optional[str] = None,
other_args: tuple = (),
other_args: list[str] = (),
env: Optional[dict] = None,
return_stdout_stderr: Optional[tuple] = None,
):