[Fix] Address remaining issues of supporting MiniCPMV (#2977)
This commit is contained in:
@@ -166,6 +166,12 @@ def _fwd_kernel(
|
||||
def context_attention_fwd(
|
||||
q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True
|
||||
):
|
||||
"""
|
||||
q, k, v: [b * s, head, head_dim]
|
||||
b_start_loc: [b]
|
||||
b_seq_len: [b]
|
||||
out: [b * s, head, head_dim]
|
||||
"""
|
||||
if is_cuda_available and CUDA_CAPABILITY[0] > 8:
|
||||
BLOCK = 128
|
||||
else:
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from sglang.srt.distributed import parallel_state
|
||||
@@ -63,7 +64,20 @@ def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.T
|
||||
|
||||
|
||||
class VisionAttention(nn.Module):
|
||||
"""Multi-headed attention without any cache, mostly used for ViT."""
|
||||
r"""
|
||||
Multi-headed attention without any cache, mostly used for ViT.
|
||||
|
||||
|
||||
Args:
|
||||
use_qkv_parallel (bool, optional): If True, use QKV-parallel attention.
|
||||
use_context_forward (bool, default to True):
|
||||
if ``True``, a flash_attn style attention will be applied
|
||||
Otherwise, a full-sequence attention will be applied.
|
||||
use_full_precision_softmax (bool, default to False):
|
||||
if ``True``, the softmax will be performed in full-precision
|
||||
Otherwise, it will be performed in half-precision
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -72,25 +86,39 @@ class VisionAttention(nn.Module):
|
||||
projection_size: int,
|
||||
use_qkv_parallel: bool,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
dropout: float = 0.0,
|
||||
use_context_forward: bool = True,
|
||||
use_full_precision_softmax: bool = False,
|
||||
flatten_batch: bool = False,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.use_context_forward = use_context_forward
|
||||
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
||||
|
||||
self.dropout = dropout
|
||||
self.head_size = embed_dim // num_heads
|
||||
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||
projection_size, num_heads
|
||||
)
|
||||
self.num_attention_heads_per_partition = dist_utils.divide(
|
||||
num_heads, world_size
|
||||
)
|
||||
# self.tp_size = get_tensor_model_parallel_world_size()
|
||||
# num_heads = self.num_heads_per_partition
|
||||
|
||||
if self.use_context_forward:
|
||||
self.qkv_backend = VisionTritonAttention()
|
||||
else:
|
||||
self.qkv_backend = VisionSdpaAttention(
|
||||
head_size=self.head_size,
|
||||
dropout=dropout,
|
||||
flatten_batch=flatten_batch,
|
||||
use_full_precision_softmax=use_full_precision_softmax,
|
||||
)
|
||||
|
||||
self.use_qkv_parallel = use_qkv_parallel
|
||||
if use_qkv_parallel:
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size=embed_dim,
|
||||
head_size=self.head_dim,
|
||||
head_size=self.head_size,
|
||||
total_num_heads=num_heads,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
@@ -114,12 +142,15 @@ class VisionAttention(nn.Module):
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
rotary_pos_emb: torch.Tensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
x: [b, s, embed_dim]
|
||||
cu_seqlens: [b]
|
||||
Returns:
|
||||
[s, b, num_heads * head]
|
||||
"""
|
||||
Input shape: [b, s, embed_dim]
|
||||
Output shape: [s, b, num_heads * head_size]
|
||||
"""
|
||||
|
||||
bsz, s, _ = x.shape
|
||||
if self.use_qkv_parallel:
|
||||
# [b, s, embed_dim] --> [b, s, embed_dim]
|
||||
@@ -136,19 +167,19 @@ class VisionAttention(nn.Module):
|
||||
else:
|
||||
# [b, s, embed_dim] --> [s, b, embed_dim]
|
||||
x = rearrange(x, "b s ... -> s b ...")
|
||||
# [s, b, embed_dim] --> [s, b, head * 3 * head_dim]
|
||||
# [s, b, embed_dim] --> [s, b, head * 3 * head_size]
|
||||
qkv, _ = self.qkv_proj(x)
|
||||
# [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
|
||||
# [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size]
|
||||
new_x_shape = qkv.size()[:-1] + (
|
||||
self.num_attention_heads_per_partition,
|
||||
3 * self.hidden_size_per_attention_head,
|
||||
)
|
||||
qkv = qkv.view(*new_x_shape)
|
||||
|
||||
# [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
|
||||
# [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size]
|
||||
q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3)
|
||||
|
||||
# [s, b, head, head_dim] --> [b, s, head, head_dim]
|
||||
# [s, b, head, head_size] --> [b, s, head, head_size]
|
||||
q, k, v = [
|
||||
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
|
||||
]
|
||||
@@ -160,45 +191,217 @@ class VisionAttention(nn.Module):
|
||||
if self.use_qkv_parallel:
|
||||
pass
|
||||
else:
|
||||
# [b, s, head, head_dim] --> [b * s, head, head_dim]
|
||||
# [b, s, head, head_size] --> [b * s, head, head_size]
|
||||
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
|
||||
|
||||
# [b * s, num_heads, head_size]
|
||||
output = self.qkv_backend.forward(q, k, v, bsz, cu_seqlens, attention_mask)
|
||||
|
||||
if self.use_qkv_parallel:
|
||||
# [b * s, h, head_size] --> [b, s, h * head_size]
|
||||
output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz)
|
||||
|
||||
# [b, s, h * head_size] --> [b, s, h * head_size]
|
||||
output, _ = self.proj(output)
|
||||
else:
|
||||
# [b * s, h, head_size] --> [s, b, h * head_size]
|
||||
context_layer = rearrange(
|
||||
output, "(b s) h d -> s b (h d)", b=bsz, s=s
|
||||
).contiguous()
|
||||
|
||||
# [s, b, h * head_size] --> [s, b, h * head_size]
|
||||
output, _ = self.proj(context_layer)
|
||||
|
||||
# [s, b, h * head_size] --> [b, s, h * head_size]
|
||||
output = output.view(bsz, s, -1)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class VisionSdpaAttention(nn.Module):
|
||||
r"""
|
||||
Scaled Dot Product Attention inner product
|
||||
|
||||
"""
|
||||
|
||||
# TODO: Should it be released after used?
|
||||
_mask_cache = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
dropout: float = 0.0,
|
||||
flatten_batch: bool = False,
|
||||
use_full_precision_softmax: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.head_size = head_size
|
||||
self.flatten_batch = flatten_batch
|
||||
self.use_full_precision_softmax = use_full_precision_softmax
|
||||
self.dropout = dropout
|
||||
|
||||
def generate_patch_attention_mask(
|
||||
self,
|
||||
s: int,
|
||||
bsz: int,
|
||||
device,
|
||||
cu_seqlens: Optional[torch.Tensor],
|
||||
flatten_batch: bool = False,
|
||||
dtype=torch.bfloat16,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
|
||||
|
||||
When `flatten_batch` is True:
|
||||
- All sequences in the batch are flattened into a single dimension
|
||||
- `s` represents the total number of tokens across all sequences in the batch
|
||||
- Returns a unified mask of shape `(1, 1, s, s)`
|
||||
|
||||
When `flatten_batch` is False:
|
||||
- Each sequence has its own attention mask
|
||||
- `s` represents the maximum sequence length in the batch
|
||||
- Returns separate masks of shape `(b, 1, s, s)`
|
||||
|
||||
Args:
|
||||
flatten_batch: (bool):
|
||||
If True, treats all sequences in the batch as a single flattened sequence
|
||||
If False, generates separate masks for each sequence
|
||||
|
||||
Returns:
|
||||
Tensor of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
|
||||
"""
|
||||
|
||||
cache_key = (s, bsz, flatten_batch, tuple(cu_seqlens.cpu().tolist()))
|
||||
|
||||
if cache_key in VisionSdpaAttention._mask_cache:
|
||||
cached_mask = VisionSdpaAttention._mask_cache[cache_key]
|
||||
# print(f"cache hit for key: {cache_key}")
|
||||
return cached_mask.to(device=device, dtype=dtype)
|
||||
|
||||
if cu_seqlens is None:
|
||||
raise ValueError("Internal Error: cu_seqlens cannot be None")
|
||||
|
||||
if flatten_batch:
|
||||
mask = torch.zeros([1, s, s], device=device, dtype=torch.bool)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
start = cu_seqlens[i - 1]
|
||||
end = cu_seqlens[i]
|
||||
mask[
|
||||
...,
|
||||
start:end,
|
||||
start:end,
|
||||
] = True
|
||||
else:
|
||||
# [1, 1, 1, s]
|
||||
row_indices = torch.arange(s, device=device).view(1, 1, 1, s)
|
||||
# [1, 1, s, 1]
|
||||
col_indices = torch.arange(s, device=device).view(1, 1, s, 1)
|
||||
# [b, 1, 1, 1]
|
||||
seq_lens = (
|
||||
(cu_seqlens[1:] - cu_seqlens[:-1]).to(device=device).view(-1, 1, 1, 1)
|
||||
)
|
||||
|
||||
mask = (row_indices < seq_lens) & (col_indices < seq_lens)
|
||||
|
||||
# Convert to attention mask format (False -> 0, True -> -inf)
|
||||
mask = (~mask).to(dtype) * torch.finfo(dtype).min
|
||||
|
||||
VisionSdpaAttention._mask_cache[cache_key] = mask
|
||||
|
||||
return mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
bsz: int,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
cu_seqlens: [b]
|
||||
Returns:
|
||||
[b * s, h, head_size]
|
||||
"""
|
||||
|
||||
s = q.shape[0] // bsz
|
||||
|
||||
# [b, 1, s, s]
|
||||
if attention_mask is None:
|
||||
attention_mask = self.generate_patch_attention_mask(
|
||||
s, bsz, q.device, cu_seqlens, self.flatten_batch, q.dtype
|
||||
)
|
||||
q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]]
|
||||
# [b, 1, s]
|
||||
if self.use_full_precision_softmax:
|
||||
scale = self.head_size**-0.5
|
||||
k_transposed = rearrange(k, "b h s d -> b h d s")
|
||||
attn_weights = torch.matmul(q, k_transposed) * scale
|
||||
del k, k_transposed
|
||||
attn_weights = attn_weights + attention_mask
|
||||
del attention_mask
|
||||
# full-precision
|
||||
attn_weights = nn.functional.softmax(
|
||||
attn_weights, dim=-1, dtype=torch.float32
|
||||
).to(q.dtype)
|
||||
attn_weights = nn.functional.dropout(
|
||||
attn_weights, p=self.dropout, training=False
|
||||
)
|
||||
output = torch.matmul(attn_weights, v)
|
||||
del attn_weights, v
|
||||
else:
|
||||
# SDPA
|
||||
# [b, h, s, head_size]
|
||||
output = F.scaled_dot_product_attention(
|
||||
q, k, v, attention_mask, dropout_p=self.dropout
|
||||
)
|
||||
|
||||
# [b, h, s, head_size] --> [b * s, h, head_size]
|
||||
output = rearrange(output, "b h s d -> (b s) h d")
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class VisionTritonAttention(nn.Module):
|
||||
"""
|
||||
Triton-implemented attention without a causal mask
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
_bsz: int,
|
||||
cu_seqlens: Optional[torch.Tensor],
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
cu_seqlens: [b]
|
||||
Returns:
|
||||
[b * s, h, head_size]
|
||||
"""
|
||||
|
||||
# [b * s, head, head_size]
|
||||
output = torch.empty_like(q)
|
||||
|
||||
seq_lens = (cu_seqlens[1:] - cu_seqlens[:-1]).cuda()
|
||||
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
max_seqlen = seq_lens.max().item()
|
||||
|
||||
context_attention_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
output,
|
||||
cu_seqlens.cuda(),
|
||||
seq_lens,
|
||||
seq_lens.cuda(),
|
||||
max_seqlen,
|
||||
is_causal=False,
|
||||
)
|
||||
|
||||
if self.use_qkv_parallel:
|
||||
|
||||
# [b * s, head, head_dim] --> [b, s, head * head_dim]
|
||||
output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz)
|
||||
|
||||
# [b, s, head, head_dim] --> [b, s, head, head_dim]
|
||||
output, _ = self.proj(output)
|
||||
else:
|
||||
# [b * s, head, head_dim] --> [b, s, head, head_dim]
|
||||
context_layer = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||
|
||||
# [s, b, num_heads * head_size]
|
||||
context_layer = rearrange(
|
||||
context_layer, "b s h d -> s b (h d)"
|
||||
).contiguous()
|
||||
|
||||
# [s, b, num_heads * head_size] --> [s, b, num_heads * head_size]
|
||||
output, _ = self.proj(context_layer)
|
||||
|
||||
output = output.view(bsz, s, -1)
|
||||
|
||||
return output
|
||||
|
||||
@@ -240,6 +240,7 @@ class MllamaImageProcessor(BaseImageProcessor):
|
||||
class MiniCPMVImageProcessor(BaseImageProcessor):
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.IMAGE_TOKEN = "(<image>./</image>)"
|
||||
|
||||
@staticmethod
|
||||
def _process_images_task(images, input_text):
|
||||
@@ -271,7 +272,7 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
||||
async def process_images_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_text,
|
||||
input_ids,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
):
|
||||
@@ -282,28 +283,49 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
||||
image_data = [image_data]
|
||||
|
||||
image_hashes, image_sizes = [], []
|
||||
raw_images = []
|
||||
IMAGE_TOKEN = "(<image>./</image>)"
|
||||
all_frames = []
|
||||
|
||||
# roughly calculate the max number of frames
|
||||
# TODO: the process should be applied to all the visual inputs
|
||||
# roughly calculate the max number of frames under the max_req_input_len limit
|
||||
def calculate_max_num_frames() -> int:
|
||||
# Model-specific
|
||||
NUM_TOKEN_PER_FRAME = 330
|
||||
|
||||
ret = (max_req_input_len - len(input_text)) // NUM_TOKEN_PER_FRAME
|
||||
ret = (max_req_input_len - len(input_ids)) // NUM_TOKEN_PER_FRAME
|
||||
return min(ret, 100)
|
||||
|
||||
# if cuda OOM set a smaller number
|
||||
MAX_NUM_FRAMES = calculate_max_num_frames()
|
||||
print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}")
|
||||
|
||||
def encode_video(video_path):
|
||||
# print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}")
|
||||
|
||||
def get_estimated_frames_list():
|
||||
"""
|
||||
estimate the total frame count from all visual input
|
||||
"""
|
||||
# Before processing inputs
|
||||
estimated_frames_list = []
|
||||
for image in image_data:
|
||||
if isinstance(image, str) and image.startswith("video:"):
|
||||
path = image[len("video:") :]
|
||||
# Estimate frames for the video
|
||||
vr = VideoReader(path, ctx=cpu(0))
|
||||
num_frames = len(vr)
|
||||
else:
|
||||
# For images, each contributes one frame
|
||||
num_frames = 1
|
||||
estimated_frames_list.append(num_frames)
|
||||
|
||||
return estimated_frames_list
|
||||
|
||||
estimated_frames_list = get_estimated_frames_list()
|
||||
total_frame_count = sum(estimated_frames_list)
|
||||
scaling_factor = min(1.0, MAX_NUM_FRAMES / total_frame_count)
|
||||
|
||||
def encode_video(video_path, frame_count_limit=None):
|
||||
if not os.path.exists(video_path):
|
||||
logger.error(f"Video {video_path} does not exist")
|
||||
return []
|
||||
|
||||
if MAX_NUM_FRAMES == 0:
|
||||
if frame_count_limit == 0:
|
||||
return []
|
||||
|
||||
def uniform_sample(l, n):
|
||||
@@ -314,45 +336,63 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
||||
vr = VideoReader(video_path, ctx=cpu(0))
|
||||
sample_fps = round(vr.get_avg_fps() / 1) # FPS
|
||||
frame_idx = [i for i in range(0, len(vr), sample_fps)]
|
||||
if len(frame_idx) > MAX_NUM_FRAMES:
|
||||
frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
|
||||
if frame_count_limit is not None and len(frame_idx) > frame_count_limit:
|
||||
frame_idx = uniform_sample(frame_idx, frame_count_limit)
|
||||
frames = vr.get_batch(frame_idx).asnumpy()
|
||||
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
|
||||
return frames
|
||||
|
||||
if isinstance(input_text, list):
|
||||
assert len(input_text) and isinstance(input_text[0], int)
|
||||
input_text = self._processor.tokenizer.decode(input_text)
|
||||
|
||||
if isinstance(input_ids, list):
|
||||
assert len(input_ids) and isinstance(input_ids[0], int)
|
||||
input_text = self._processor.tokenizer.decode(input_ids)
|
||||
else:
|
||||
input_text = input_ids
|
||||
# MiniCPMV requires each frame of video as a single image token
|
||||
text_parts = input_text.split(IMAGE_TOKEN)
|
||||
text_parts = input_text.split(self.IMAGE_TOKEN)
|
||||
new_text_parts = []
|
||||
|
||||
for image_index, image in enumerate(image_data):
|
||||
try:
|
||||
if isinstance(image, str) and image.startswith("video:"):
|
||||
path = image[len("video:") :]
|
||||
frames = encode_video(path)
|
||||
else:
|
||||
raw_image, size = load_image(image)
|
||||
frames = [raw_image]
|
||||
if len(frames) == 0:
|
||||
continue
|
||||
except FileNotFoundError as e:
|
||||
print(e)
|
||||
return None
|
||||
# Process each input with allocated frames
|
||||
for image_index, (image, estimated_frames) in enumerate(
|
||||
zip(image_data, estimated_frames_list)
|
||||
):
|
||||
if len(all_frames) >= MAX_NUM_FRAMES:
|
||||
frames_to_process = 0
|
||||
else:
|
||||
frames_to_process = max(1, int(estimated_frames * scaling_factor))
|
||||
|
||||
if frames_to_process == 0:
|
||||
frames = []
|
||||
else:
|
||||
try:
|
||||
if isinstance(image, str) and image.startswith("video:"):
|
||||
path = image[len("video:") :]
|
||||
frames = encode_video(path, frame_count_limit=frames_to_process)
|
||||
else:
|
||||
raw_image, _size = load_image(image)
|
||||
frames = [raw_image]
|
||||
if len(frames) == 0:
|
||||
continue
|
||||
except FileNotFoundError as e:
|
||||
print(e)
|
||||
return None
|
||||
image_sizes += frames[0].size * len(frames)
|
||||
image_hashes += [hash(image)] * len(frames)
|
||||
all_frames += frames
|
||||
|
||||
assert frames_to_process == len(frames)
|
||||
|
||||
image_sizes += frames[0].size * len(frames)
|
||||
image_hashes += [hash(image)] * len(frames)
|
||||
raw_images += frames
|
||||
new_text_parts.append(text_parts[image_index])
|
||||
new_text_parts.append(IMAGE_TOKEN * len(frames))
|
||||
|
||||
if frames_to_process != 0:
|
||||
new_text_parts.append(self.IMAGE_TOKEN * len(frames))
|
||||
|
||||
new_text_parts.append(text_parts[-1])
|
||||
|
||||
input_text = "".join(new_text_parts)
|
||||
if len(raw_images) == 0:
|
||||
|
||||
if len(all_frames) == 0:
|
||||
return None
|
||||
res = await self._process_images(images=raw_images, input_text=input_text)
|
||||
res = await self._process_images(images=all_frames, input_text=input_text)
|
||||
pixel_values = res["pixel_values"]
|
||||
tgt_sizes = res["tgt_sizes"]
|
||||
input_ids = res["input_ids"]
|
||||
@@ -364,7 +404,6 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
||||
if tokenizer.slice_start_id:
|
||||
slice_start_id = [tokenizer.slice_start_id]
|
||||
slice_end_id = [tokenizer.slice_end_id]
|
||||
|
||||
return {
|
||||
"input_ids": input_ids.flatten().tolist(),
|
||||
"pixel_values": pixel_values,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2023 The SGLang team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
@@ -20,7 +20,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
|
||||
from functools import cached_property, partial
|
||||
from functools import partial
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
@@ -33,16 +33,13 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.types
|
||||
from PIL import Image
|
||||
from torch import nn
|
||||
from torch.nn.init import trunc_normal_
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.model_executor.layers.resampler import get_2d_sincos_pos_embed
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
|
||||
from sglang.srt.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from sglang.srt.layers.activation import get_act_fn
|
||||
@@ -63,6 +60,88 @@ from sglang.srt.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
|
||||
RawImageType = Union[Image.Image, torch.Tensor]
|
||||
|
||||
|
||||
# sin/cos positional embedding helpers are adapted from:
|
||||
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
|
||||
def get_1d_sincos_pos_embed_from_grid(
|
||||
embed_dim: int, pos: np.ndarray, version: Tuple[int, int] = (2, 0)
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (M,) / (H, W)
|
||||
out: (M, D) / (H, W, D)
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = np.arange(embed_dim // 2, dtype=np.float32)
|
||||
omega /= embed_dim / 2.0
|
||||
omega = 1.0 / 10000**omega # (D/2,)
|
||||
|
||||
if version == (2, 0):
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
else:
|
||||
out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product
|
||||
emb_sin = np.sin(out) # (H, W, D/2)
|
||||
emb_cos = np.cos(out) # (H, W, D/2)
|
||||
emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(
|
||||
embed_dim: int, grid: np.ndarray, version: Tuple[int, int] = (2, 0)
|
||||
) -> torch.Tensor:
|
||||
assert embed_dim % 2 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid(
|
||||
embed_dim // 2, grid[0], version
|
||||
) # (H*W, D/2) or (H, W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(
|
||||
embed_dim // 2, grid[1], version
|
||||
) # (H*W, D/2) or (H, W, D/2)
|
||||
|
||||
if version == (2, 0):
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
else:
|
||||
emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed(
|
||||
embed_dim: int,
|
||||
grid_size: Union[int, Tuple[int, int]],
|
||||
cls_token: bool = False,
|
||||
version: Tuple[int, int] = (2, 0),
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
grid_size: int of the grid height and width
|
||||
return:
|
||||
pos_embed: [grid_size*grid_size, embed_dim] or
|
||||
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
"""
|
||||
if isinstance(grid_size, int):
|
||||
grid_h_size, grid_w_size = grid_size, grid_size
|
||||
else:
|
||||
grid_h_size, grid_w_size = grid_size[0], grid_size[1]
|
||||
|
||||
grid_h = np.arange(grid_h_size, dtype=np.float32)
|
||||
grid_w = np.arange(grid_w_size, dtype=np.float32)
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0)
|
||||
assert isinstance(grid, np.ndarray) and grid.shape == (2, grid_h_size, grid_w_size)
|
||||
|
||||
if version == (2, 0):
|
||||
grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
|
||||
if cls_token:
|
||||
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
||||
else:
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
|
||||
return pos_embed
|
||||
|
||||
|
||||
class Idefics2VisionMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@@ -116,6 +195,10 @@ class Idefics2EncoderLayer(nn.Module):
|
||||
projection_size=config.intermediate_size,
|
||||
use_qkv_parallel=True,
|
||||
quant_config=quant_config,
|
||||
dropout=config.attention_dropout,
|
||||
use_context_forward=False,
|
||||
use_full_precision_softmax=True,
|
||||
flatten_batch=False,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
@@ -126,7 +209,6 @@ class Idefics2EncoderLayer(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@@ -136,11 +218,8 @@ class Idefics2EncoderLayer(nn.Module):
|
||||
"""
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
# , forward_batch=forward_batch
|
||||
)
|
||||
hidden_states = self.self_attn(hidden_states, cu_seqlens=cu_seqlens)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm2(hidden_states)
|
||||
@@ -181,7 +260,6 @@ class Idefics2Encoder(nn.Module):
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
@@ -195,7 +273,8 @@ class Idefics2Encoder(nn.Module):
|
||||
hidden_states = inputs_embeds
|
||||
for encoder_layer in self.layers:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states, cu_seqlens=cu_seqlens, forward_batch=forward_batch
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
hidden_states = layer_outputs
|
||||
return hidden_states
|
||||
@@ -232,19 +311,14 @@ class Idefics2VisionEmbeddings(nn.Module):
|
||||
self.num_positions = self.num_patches
|
||||
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
||||
|
||||
def forward(
|
||||
def get_position_ids(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
patch_attention_mask: torch.BoolTensor,
|
||||
tgt_sizes: Optional[torch.IntTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
):
|
||||
batch_size, _, max_im_h, max_im_w = pixel_values.shape
|
||||
target_dtype = self.patch_embedding.weight.dtype
|
||||
pixel_values = pixel_values.to(
|
||||
device=self.patch_embedding.weight.device, dtype=target_dtype
|
||||
)
|
||||
patch_embeds = self.patch_embedding(pixel_values)
|
||||
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
||||
|
||||
max_nb_patches_h, max_nb_patches_w = (
|
||||
max_im_h // self.patch_size,
|
||||
max_im_w // self.patch_size,
|
||||
@@ -277,6 +351,24 @@ class Idefics2VisionEmbeddings(nn.Module):
|
||||
).flatten()
|
||||
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
|
||||
position_ids = position_ids.to(self.position_embedding.weight.device)
|
||||
return position_ids
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
patch_attention_mask: torch.BoolTensor,
|
||||
tgt_sizes: Optional[torch.IntTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
target_dtype = self.patch_embedding.weight.dtype
|
||||
pixel_values = pixel_values.to(
|
||||
device=self.patch_embedding.weight.device, dtype=target_dtype
|
||||
)
|
||||
patch_embeds = self.patch_embedding(pixel_values)
|
||||
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
||||
position_ids = self.get_position_ids(
|
||||
pixel_values, patch_attention_mask, tgt_sizes
|
||||
)
|
||||
|
||||
embeddings = embeddings + self.position_embedding(position_ids)
|
||||
return embeddings
|
||||
|
||||
@@ -287,7 +379,6 @@ class Idefics2VisionTransformer(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -302,8 +393,6 @@ class Idefics2VisionTransformer(nn.Module):
|
||||
|
||||
def compute_cu_seqlens(self, tgt_sizes: torch.Tensor) -> torch.Tensor:
|
||||
patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] # shape: (batch_size,)
|
||||
|
||||
# 做 prefix sum 来得到 cu_seqlens,注意在最前面插一个 0 作为 offset
|
||||
cu_seqlens = torch.cat(
|
||||
[
|
||||
torch.tensor([0], device=patch_len.device, dtype=torch.int32),
|
||||
@@ -316,19 +405,18 @@ class Idefics2VisionTransformer(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
pixel_values,
|
||||
forward_batch: ForwardBatch,
|
||||
patch_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
tgt_sizes: Optional[torch.IntTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embeddings(
|
||||
pixel_values=pixel_values,
|
||||
patch_attention_mask=patch_attention_mask,
|
||||
# forward_batch=forward_batch,
|
||||
tgt_sizes=tgt_sizes,
|
||||
)
|
||||
cu_seqlens = self.compute_cu_seqlens(tgt_sizes)
|
||||
encoder_outputs = self.encoder(
|
||||
hidden_states, cu_seqlens=cu_seqlens, forward_batch=forward_batch
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
last_hidden_state = self.post_layernorm(encoder_outputs)
|
||||
return last_hidden_state
|
||||
@@ -573,14 +661,12 @@ class MiniCPMVBaseModel(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
# multimodal_config = config.model_config.multimodal_config
|
||||
super().__init__()
|
||||
# All MiniCPM-V models disable `tie_word_embeddings` but
|
||||
# `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot
|
||||
# check `tie_word_embeddings` until vLLM integrate MiniCPM-V model
|
||||
# check `tie_word_embeddings` until SGLang integrate MiniCPM-V model
|
||||
# and config class
|
||||
self.config = config
|
||||
# self.multimodal_config = multimodal_config
|
||||
|
||||
self.version = get_version_by_config(self.config)
|
||||
self.llm = self.init_llm(config=config, quant_config=quant_config)
|
||||
@@ -598,13 +684,6 @@ class MiniCPMVBaseModel(nn.Module):
|
||||
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
@cached_property
|
||||
def sampler(self):
|
||||
if hasattr(self.llm, "sampler"):
|
||||
return self.llm.sampler
|
||||
|
||||
return get_sampler()
|
||||
|
||||
def _get_image_bounds(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -666,7 +745,6 @@ class MiniCPMVBaseModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
image_inputs: Optional[MiniCPMVImageInputs],
|
||||
forward_batch: ForwardBatch,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids)
|
||||
|
||||
@@ -680,10 +758,7 @@ class MiniCPMVBaseModel(nn.Module):
|
||||
.to(vlm_embedding.device)
|
||||
)
|
||||
else:
|
||||
vision_hidden_states = self.get_vision_hidden_states(
|
||||
forward_batch, image_inputs
|
||||
)
|
||||
|
||||
vision_hidden_states = self.get_vision_hidden_states(image_inputs)
|
||||
# See NOTE in _parse_and_validate_inputs
|
||||
image_bounds = image_inputs["image_bounds"]
|
||||
if len(image_bounds) > 0:
|
||||
@@ -693,6 +768,7 @@ class MiniCPMVBaseModel(nn.Module):
|
||||
for start, end in image_bounds.tolist()
|
||||
]
|
||||
).to(vlm_embedding.device)
|
||||
|
||||
vlm_embedding.scatter_(
|
||||
0,
|
||||
image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]),
|
||||
@@ -839,7 +915,7 @@ class MiniCPMVBaseModel(nn.Module):
|
||||
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
|
||||
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
|
||||
|
||||
vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs, forward_batch)
|
||||
vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs)
|
||||
|
||||
# always pass the input via `inputs_embeds`
|
||||
# to make sure the computation graph is consistent
|
||||
@@ -857,29 +933,6 @@ class MiniCPMVBaseModel(nn.Module):
|
||||
input_ids, hidden_states, self.llm.lm_head, forward_batch
|
||||
)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
return self.llm.compute_logits(hidden_states, sampling_metadata)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""
|
||||
Get the module prefix in multimodal models
|
||||
"""
|
||||
return MultiModelKeys.from_string_field(
|
||||
language_model="llm", connector="resampler", tower_model="vpm"
|
||||
)
|
||||
|
||||
def init_llm(
|
||||
self,
|
||||
config: Qwen2Config,
|
||||
@@ -910,9 +963,7 @@ class MiniCPMVBaseModel(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_vision_hidden_states(
|
||||
self, forward_batch: ForwardBatch, data: MiniCPMVImageInputs
|
||||
) -> torch.Tensor:
|
||||
def get_vision_hidden_states(self, data: MiniCPMVImageInputs) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -1019,7 +1070,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
|
||||
|
||||
def get_vision_hidden_states(
|
||||
self,
|
||||
forward_batch: ForwardBatch,
|
||||
data: MiniCPMVImageInputs,
|
||||
) -> torch.Tensor:
|
||||
pixel_values = data["data"]
|
||||
@@ -1042,15 +1092,18 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
|
||||
patch_attn_mask = torch.zeros(
|
||||
(B, 1, max_patches), dtype=torch.bool, device=device
|
||||
)
|
||||
for i in range(B):
|
||||
patch_attn_mask[i, 0, : tgt_sizes[i][0] * tgt_sizes[i][1]] = True
|
||||
|
||||
tgt_sizes_tensor = tgt_sizes.clone().to(device=patch_attn_mask.device)
|
||||
mask_shapes = tgt_sizes_tensor[:, 0] * tgt_sizes_tensor[:, 1]
|
||||
patch_attn_mask[:, 0, :] = torch.arange(
|
||||
patch_attn_mask.size(2), device=patch_attn_mask.device
|
||||
).unsqueeze(0) < mask_shapes.unsqueeze(1)
|
||||
|
||||
vision_embedding = self.vpm(
|
||||
all_pixel_values.type(dtype),
|
||||
forward_batch=forward_batch,
|
||||
patch_attention_mask=patch_attn_mask,
|
||||
tgt_sizes=tgt_sizes,
|
||||
)
|
||||
|
||||
return self.resampler(vision_embedding, tgt_sizes)
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
||||
@@ -1138,7 +1191,7 @@ class MiniCPMV:
|
||||
"""
|
||||
Different versions of MiniCPMV use different visual encoders and LLMs,
|
||||
which is not conducive to the current integration logic of LoRA and
|
||||
bitsandbytes in vLLM. Therefore, it is necessary to separate them.
|
||||
bitsandbytes in SGLang. Therefore, it is necessary to separate them.
|
||||
"""
|
||||
|
||||
# Ensure that the LoRA support check passes when the class is not
|
||||
|
||||
@@ -17,6 +17,7 @@ from transformers.models.mllama.modeling_mllama import (
|
||||
import sglang.srt.distributed.parallel_state as ps
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||
from sglang.srt.layers.activation import get_act_fn
|
||||
from sglang.srt.layers.attention.vision import VisionAttention
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
@@ -145,61 +146,6 @@ class MllamaPrecomputedPositionEmbedding(nn.Module):
|
||||
return hidden_state
|
||||
|
||||
|
||||
class MllamaVisionSdpaAttention(nn.Module):
|
||||
def __init__(self, config: config_mllama.MllamaVisionConfig):
|
||||
super().__init__()
|
||||
|
||||
model_parallel_size = get_tensor_model_parallel_world_size()
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.attention_heads
|
||||
self.head_dim = config.hidden_size // config.attention_heads
|
||||
self.num_local_heads = self.num_heads // model_parallel_size
|
||||
self.q_size = self.num_local_heads * self.head_dim
|
||||
self.kv_size = self.num_local_heads * self.head_dim
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
self.embed_dim,
|
||||
self.head_dim,
|
||||
self.num_heads,
|
||||
bias=False,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.num_heads * self.head_dim,
|
||||
self.embed_dim,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_state: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_state)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q = q.view(
|
||||
q.shape[0], q.shape[1], self.num_local_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
k = k.view(
|
||||
k.shape[0], k.shape[1], self.num_local_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
v = v.view(
|
||||
v.shape[0], v.shape[1], self.num_local_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
# TODO: remove padding in image encoder
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attention_mask, dropout_p=0.0
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(
|
||||
attn_output.shape[0], attn_output.shape[1], -1
|
||||
)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class MllamaVisionMLP(nn.Module):
|
||||
def __init__(self, config, quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
@@ -237,7 +183,17 @@ class MllamaVisionEncoderLayer(nn.Module):
|
||||
self.is_gated = is_gated
|
||||
self.intermediate_size = config.intermediate_size
|
||||
|
||||
self.self_attn = MllamaVisionSdpaAttention(config)
|
||||
self.self_attn = VisionAttention(
|
||||
self.hidden_size,
|
||||
self.num_attention_heads,
|
||||
self.hidden_size,
|
||||
use_qkv_parallel=True,
|
||||
quant_config=None,
|
||||
dropout=0.0,
|
||||
use_context_forward=False,
|
||||
use_full_precision_softmax=False,
|
||||
flatten_batch=False,
|
||||
)
|
||||
self.mlp = MllamaVisionMLP(config)
|
||||
|
||||
self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps)
|
||||
@@ -992,6 +948,10 @@ class MllamaForConditionalGeneration(nn.Module):
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
if "vision_model" in name:
|
||||
# adapt to VisionAttention
|
||||
name = name.replace("self_attn.o_proj", "self_attn.proj")
|
||||
|
||||
param = params_dict.pop(name)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
@@ -249,7 +249,10 @@ class Qwen2Model(nn.Module):
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
if hasattr(self.config, "scale_emb"):
|
||||
return self.embed_tokens(input_ids) * self.config.scale_emb
|
||||
else:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -30,12 +30,10 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from einops import rearrange
|
||||
from vllm.model_executor.layers.activation import QuickGELU
|
||||
|
||||
from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
|
||||
from sglang.srt.distributed import parallel_state
|
||||
from sglang.srt.distributed import utils as dist_utils
|
||||
from sglang.srt.hf_transformers_utils import get_processor
|
||||
from sglang.srt.layers.attention.vision import VisionAttention
|
||||
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
||||
@@ -118,6 +116,7 @@ class Qwen2VisionBlock(nn.Module):
|
||||
mlp_ratio: float,
|
||||
act_layer: Type[nn.Module] = QuickGELU,
|
||||
norm_layer: Type[nn.Module] = None,
|
||||
attn_implementation: Optional[str] = "sdpa",
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -126,12 +125,24 @@ class Qwen2VisionBlock(nn.Module):
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
if attn_implementation == "sdpa":
|
||||
use_context_forward = False
|
||||
use_full_precision_softmax = False
|
||||
elif attn_implementation == "flash_attention_2":
|
||||
use_full_precision_softmax = False
|
||||
use_context_forward = True
|
||||
elif attn_implementation == "eager":
|
||||
use_full_precision_softmax = True
|
||||
use_context_forward = False
|
||||
|
||||
self.attn = VisionAttention(
|
||||
embed_dim=dim,
|
||||
num_heads=num_heads,
|
||||
projection_size=dim,
|
||||
use_qkv_parallel=False,
|
||||
use_context_forward=use_context_forward,
|
||||
use_full_precision_softmax=use_full_precision_softmax,
|
||||
flatten_batch=True,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.mlp = Qwen2VisionMLP(
|
||||
@@ -286,7 +297,6 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
|
||||
head_dim = embed_dim // num_heads
|
||||
self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2)
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
Qwen2VisionBlock(
|
||||
@@ -294,6 +304,7 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
norm_layer=norm_layer,
|
||||
attn_implementation="sdpa",
|
||||
quant_config=quant_config,
|
||||
)
|
||||
for _ in range(depth)
|
||||
@@ -482,10 +493,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
opensource models), the shape will be `(3, seq_len)`,
|
||||
otherwise it will be `(seq_len,).
|
||||
(Use input_metadata.mrope_positions to replace it)
|
||||
pixel_values: Pixel values to be fed to a model.
|
||||
`None` if no images are passed.
|
||||
image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
|
||||
`None` if no images are passed.
|
||||
"""
|
||||
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
||||
positions = forward_batch.mrope_positions
|
||||
@@ -540,15 +547,18 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
num_image_tokens = self.calculate_num_image_tokens(
|
||||
image_grid_thws[idx]
|
||||
)
|
||||
|
||||
left_idx = start_idx + (image_offset - prefix_len)
|
||||
right_idx = (
|
||||
start_idx + (image_offset - prefix_len) + num_image_tokens
|
||||
)
|
||||
|
||||
inputs_embeds[left_idx:right_idx] = image_embeds[
|
||||
image_embeds_offset : image_embeds_offset + num_image_tokens
|
||||
]
|
||||
image_embeds_offset += num_image_tokens
|
||||
|
||||
input_ids = None
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
|
||||
@@ -444,8 +444,6 @@ def load_image(image_file: Union[str, bytes]):
|
||||
else:
|
||||
raise ValueError(f"Invalid image: {image}")
|
||||
|
||||
# if image_size is None:
|
||||
# image_size = image.size
|
||||
return image, image_size
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user