[Fix] Address remaining issues of supporting MiniCPMV (#2977)
This commit is contained in:
@@ -78,6 +78,7 @@ Another valuable resource is the [vLLM Models Directory](https://github.com/vllm
|
|||||||
To port a model from vLLM to SGLang, you can compare these two files [SGLang Llama Implementation](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/llama.py) and [vLLM Llama Implementation](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py). This comparison will help you understand how to convert a model implementation from vLLM to SGLang. The major difference is the replacement of Attention with RadixAttention. The other parts are almost identical. Specifically,
|
To port a model from vLLM to SGLang, you can compare these two files [SGLang Llama Implementation](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/llama.py) and [vLLM Llama Implementation](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py). This comparison will help you understand how to convert a model implementation from vLLM to SGLang. The major difference is the replacement of Attention with RadixAttention. The other parts are almost identical. Specifically,
|
||||||
- Replace vllm's `Attention` with `RadixAttention`. Note that you need to pass `layer_id` all the way to `RadixAttention`.
|
- Replace vllm's `Attention` with `RadixAttention`. Note that you need to pass `layer_id` all the way to `RadixAttention`.
|
||||||
- Replace vllm's `LogitsProcessor` with SGLang's `LogitsProcessor`.
|
- Replace vllm's `LogitsProcessor` with SGLang's `LogitsProcessor`.
|
||||||
|
- Replace Multi-headed `Attention` of ViT with SGLang's `VisionAttention`.
|
||||||
- Replace other vLLM layers with SGLang layers (e.g., `RMSNorm`, `SiluAndMul`).
|
- Replace other vLLM layers with SGLang layers (e.g., `RMSNorm`, `SiluAndMul`).
|
||||||
- Remove `Sample`.
|
- Remove `Sample`.
|
||||||
- Change `forward()` functions, and add `forward_batch`.
|
- Change `forward()` functions, and add `forward_batch`.
|
||||||
|
|||||||
@@ -166,6 +166,12 @@ def _fwd_kernel(
|
|||||||
def context_attention_fwd(
|
def context_attention_fwd(
|
||||||
q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True
|
q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
q, k, v: [b * s, head, head_dim]
|
||||||
|
b_start_loc: [b]
|
||||||
|
b_seq_len: [b]
|
||||||
|
out: [b * s, head, head_dim]
|
||||||
|
"""
|
||||||
if is_cuda_available and CUDA_CAPABILITY[0] > 8:
|
if is_cuda_available and CUDA_CAPABILITY[0] > 8:
|
||||||
BLOCK = 128
|
BLOCK = 128
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
from sglang.srt.distributed import parallel_state
|
from sglang.srt.distributed import parallel_state
|
||||||
@@ -63,7 +64,20 @@ def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.T
|
|||||||
|
|
||||||
|
|
||||||
class VisionAttention(nn.Module):
|
class VisionAttention(nn.Module):
|
||||||
"""Multi-headed attention without any cache, mostly used for ViT."""
|
r"""
|
||||||
|
Multi-headed attention without any cache, mostly used for ViT.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_qkv_parallel (bool, optional): If True, use QKV-parallel attention.
|
||||||
|
use_context_forward (bool, default to True):
|
||||||
|
if ``True``, a flash_attn style attention will be applied
|
||||||
|
Otherwise, a full-sequence attention will be applied.
|
||||||
|
use_full_precision_softmax (bool, default to False):
|
||||||
|
if ``True``, the softmax will be performed in full-precision
|
||||||
|
Otherwise, it will be performed in half-precision
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -72,25 +86,39 @@ class VisionAttention(nn.Module):
|
|||||||
projection_size: int,
|
projection_size: int,
|
||||||
use_qkv_parallel: bool,
|
use_qkv_parallel: bool,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
use_context_forward: bool = True,
|
||||||
|
use_full_precision_softmax: bool = False,
|
||||||
|
flatten_batch: bool = False,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.use_context_forward = use_context_forward
|
||||||
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
||||||
|
self.dropout = dropout
|
||||||
|
self.head_size = embed_dim // num_heads
|
||||||
self.hidden_size_per_attention_head = dist_utils.divide(
|
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||||
projection_size, num_heads
|
projection_size, num_heads
|
||||||
)
|
)
|
||||||
self.num_attention_heads_per_partition = dist_utils.divide(
|
self.num_attention_heads_per_partition = dist_utils.divide(
|
||||||
num_heads, world_size
|
num_heads, world_size
|
||||||
)
|
)
|
||||||
# self.tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
# num_heads = self.num_heads_per_partition
|
if self.use_context_forward:
|
||||||
|
self.qkv_backend = VisionTritonAttention()
|
||||||
|
else:
|
||||||
|
self.qkv_backend = VisionSdpaAttention(
|
||||||
|
head_size=self.head_size,
|
||||||
|
dropout=dropout,
|
||||||
|
flatten_batch=flatten_batch,
|
||||||
|
use_full_precision_softmax=use_full_precision_softmax,
|
||||||
|
)
|
||||||
|
|
||||||
self.use_qkv_parallel = use_qkv_parallel
|
self.use_qkv_parallel = use_qkv_parallel
|
||||||
if use_qkv_parallel:
|
if use_qkv_parallel:
|
||||||
self.head_dim = embed_dim // num_heads
|
|
||||||
self.qkv_proj = QKVParallelLinear(
|
self.qkv_proj = QKVParallelLinear(
|
||||||
hidden_size=embed_dim,
|
hidden_size=embed_dim,
|
||||||
head_size=self.head_dim,
|
head_size=self.head_size,
|
||||||
total_num_heads=num_heads,
|
total_num_heads=num_heads,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.qkv_proj",
|
prefix=f"{prefix}.qkv_proj",
|
||||||
@@ -114,12 +142,15 @@ class VisionAttention(nn.Module):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
cu_seqlens: Optional[torch.Tensor] = None,
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
rotary_pos_emb: torch.Tensor = None,
|
rotary_pos_emb: torch.Tensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
x: [b, s, embed_dim]
|
||||||
|
cu_seqlens: [b]
|
||||||
|
Returns:
|
||||||
|
[s, b, num_heads * head]
|
||||||
"""
|
"""
|
||||||
Input shape: [b, s, embed_dim]
|
|
||||||
Output shape: [s, b, num_heads * head_size]
|
|
||||||
"""
|
|
||||||
|
|
||||||
bsz, s, _ = x.shape
|
bsz, s, _ = x.shape
|
||||||
if self.use_qkv_parallel:
|
if self.use_qkv_parallel:
|
||||||
# [b, s, embed_dim] --> [b, s, embed_dim]
|
# [b, s, embed_dim] --> [b, s, embed_dim]
|
||||||
@@ -136,19 +167,19 @@ class VisionAttention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
# [b, s, embed_dim] --> [s, b, embed_dim]
|
# [b, s, embed_dim] --> [s, b, embed_dim]
|
||||||
x = rearrange(x, "b s ... -> s b ...")
|
x = rearrange(x, "b s ... -> s b ...")
|
||||||
# [s, b, embed_dim] --> [s, b, head * 3 * head_dim]
|
# [s, b, embed_dim] --> [s, b, head * 3 * head_size]
|
||||||
qkv, _ = self.qkv_proj(x)
|
qkv, _ = self.qkv_proj(x)
|
||||||
# [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
|
# [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size]
|
||||||
new_x_shape = qkv.size()[:-1] + (
|
new_x_shape = qkv.size()[:-1] + (
|
||||||
self.num_attention_heads_per_partition,
|
self.num_attention_heads_per_partition,
|
||||||
3 * self.hidden_size_per_attention_head,
|
3 * self.hidden_size_per_attention_head,
|
||||||
)
|
)
|
||||||
qkv = qkv.view(*new_x_shape)
|
qkv = qkv.view(*new_x_shape)
|
||||||
|
|
||||||
# [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
|
# [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size]
|
||||||
q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3)
|
q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3)
|
||||||
|
|
||||||
# [s, b, head, head_dim] --> [b, s, head, head_dim]
|
# [s, b, head, head_size] --> [b, s, head, head_size]
|
||||||
q, k, v = [
|
q, k, v = [
|
||||||
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
|
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
|
||||||
]
|
]
|
||||||
@@ -160,45 +191,217 @@ class VisionAttention(nn.Module):
|
|||||||
if self.use_qkv_parallel:
|
if self.use_qkv_parallel:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
# [b, s, head, head_dim] --> [b * s, head, head_dim]
|
# [b, s, head, head_size] --> [b * s, head, head_size]
|
||||||
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
|
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
|
||||||
|
|
||||||
# [b * s, num_heads, head_size]
|
output = self.qkv_backend.forward(q, k, v, bsz, cu_seqlens, attention_mask)
|
||||||
|
|
||||||
|
if self.use_qkv_parallel:
|
||||||
|
# [b * s, h, head_size] --> [b, s, h * head_size]
|
||||||
|
output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz)
|
||||||
|
|
||||||
|
# [b, s, h * head_size] --> [b, s, h * head_size]
|
||||||
|
output, _ = self.proj(output)
|
||||||
|
else:
|
||||||
|
# [b * s, h, head_size] --> [s, b, h * head_size]
|
||||||
|
context_layer = rearrange(
|
||||||
|
output, "(b s) h d -> s b (h d)", b=bsz, s=s
|
||||||
|
).contiguous()
|
||||||
|
|
||||||
|
# [s, b, h * head_size] --> [s, b, h * head_size]
|
||||||
|
output, _ = self.proj(context_layer)
|
||||||
|
|
||||||
|
# [s, b, h * head_size] --> [b, s, h * head_size]
|
||||||
|
output = output.view(bsz, s, -1)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class VisionSdpaAttention(nn.Module):
|
||||||
|
r"""
|
||||||
|
Scaled Dot Product Attention inner product
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO: Should it be released after used?
|
||||||
|
_mask_cache = {}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
head_size: int,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
flatten_batch: bool = False,
|
||||||
|
use_full_precision_softmax: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.head_size = head_size
|
||||||
|
self.flatten_batch = flatten_batch
|
||||||
|
self.use_full_precision_softmax = use_full_precision_softmax
|
||||||
|
self.dropout = dropout
|
||||||
|
|
||||||
|
def generate_patch_attention_mask(
|
||||||
|
self,
|
||||||
|
s: int,
|
||||||
|
bsz: int,
|
||||||
|
device,
|
||||||
|
cu_seqlens: Optional[torch.Tensor],
|
||||||
|
flatten_batch: bool = False,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
r"""
|
||||||
|
Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
|
||||||
|
|
||||||
|
When `flatten_batch` is True:
|
||||||
|
- All sequences in the batch are flattened into a single dimension
|
||||||
|
- `s` represents the total number of tokens across all sequences in the batch
|
||||||
|
- Returns a unified mask of shape `(1, 1, s, s)`
|
||||||
|
|
||||||
|
When `flatten_batch` is False:
|
||||||
|
- Each sequence has its own attention mask
|
||||||
|
- `s` represents the maximum sequence length in the batch
|
||||||
|
- Returns separate masks of shape `(b, 1, s, s)`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
flatten_batch: (bool):
|
||||||
|
If True, treats all sequences in the batch as a single flattened sequence
|
||||||
|
If False, generates separate masks for each sequence
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
cache_key = (s, bsz, flatten_batch, tuple(cu_seqlens.cpu().tolist()))
|
||||||
|
|
||||||
|
if cache_key in VisionSdpaAttention._mask_cache:
|
||||||
|
cached_mask = VisionSdpaAttention._mask_cache[cache_key]
|
||||||
|
# print(f"cache hit for key: {cache_key}")
|
||||||
|
return cached_mask.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
if cu_seqlens is None:
|
||||||
|
raise ValueError("Internal Error: cu_seqlens cannot be None")
|
||||||
|
|
||||||
|
if flatten_batch:
|
||||||
|
mask = torch.zeros([1, s, s], device=device, dtype=torch.bool)
|
||||||
|
for i in range(1, len(cu_seqlens)):
|
||||||
|
start = cu_seqlens[i - 1]
|
||||||
|
end = cu_seqlens[i]
|
||||||
|
mask[
|
||||||
|
...,
|
||||||
|
start:end,
|
||||||
|
start:end,
|
||||||
|
] = True
|
||||||
|
else:
|
||||||
|
# [1, 1, 1, s]
|
||||||
|
row_indices = torch.arange(s, device=device).view(1, 1, 1, s)
|
||||||
|
# [1, 1, s, 1]
|
||||||
|
col_indices = torch.arange(s, device=device).view(1, 1, s, 1)
|
||||||
|
# [b, 1, 1, 1]
|
||||||
|
seq_lens = (
|
||||||
|
(cu_seqlens[1:] - cu_seqlens[:-1]).to(device=device).view(-1, 1, 1, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
mask = (row_indices < seq_lens) & (col_indices < seq_lens)
|
||||||
|
|
||||||
|
# Convert to attention mask format (False -> 0, True -> -inf)
|
||||||
|
mask = (~mask).to(dtype) * torch.finfo(dtype).min
|
||||||
|
|
||||||
|
VisionSdpaAttention._mask_cache[cache_key] = mask
|
||||||
|
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
bsz: int,
|
||||||
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
cu_seqlens: [b]
|
||||||
|
Returns:
|
||||||
|
[b * s, h, head_size]
|
||||||
|
"""
|
||||||
|
|
||||||
|
s = q.shape[0] // bsz
|
||||||
|
|
||||||
|
# [b, 1, s, s]
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = self.generate_patch_attention_mask(
|
||||||
|
s, bsz, q.device, cu_seqlens, self.flatten_batch, q.dtype
|
||||||
|
)
|
||||||
|
q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]]
|
||||||
|
# [b, 1, s]
|
||||||
|
if self.use_full_precision_softmax:
|
||||||
|
scale = self.head_size**-0.5
|
||||||
|
k_transposed = rearrange(k, "b h s d -> b h d s")
|
||||||
|
attn_weights = torch.matmul(q, k_transposed) * scale
|
||||||
|
del k, k_transposed
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
del attention_mask
|
||||||
|
# full-precision
|
||||||
|
attn_weights = nn.functional.softmax(
|
||||||
|
attn_weights, dim=-1, dtype=torch.float32
|
||||||
|
).to(q.dtype)
|
||||||
|
attn_weights = nn.functional.dropout(
|
||||||
|
attn_weights, p=self.dropout, training=False
|
||||||
|
)
|
||||||
|
output = torch.matmul(attn_weights, v)
|
||||||
|
del attn_weights, v
|
||||||
|
else:
|
||||||
|
# SDPA
|
||||||
|
# [b, h, s, head_size]
|
||||||
|
output = F.scaled_dot_product_attention(
|
||||||
|
q, k, v, attention_mask, dropout_p=self.dropout
|
||||||
|
)
|
||||||
|
|
||||||
|
# [b, h, s, head_size] --> [b * s, h, head_size]
|
||||||
|
output = rearrange(output, "b h s d -> (b s) h d")
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class VisionTritonAttention(nn.Module):
|
||||||
|
"""
|
||||||
|
Triton-implemented attention without a causal mask
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
_bsz: int,
|
||||||
|
cu_seqlens: Optional[torch.Tensor],
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
cu_seqlens: [b]
|
||||||
|
Returns:
|
||||||
|
[b * s, h, head_size]
|
||||||
|
"""
|
||||||
|
|
||||||
|
# [b * s, head, head_size]
|
||||||
output = torch.empty_like(q)
|
output = torch.empty_like(q)
|
||||||
|
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||||
seq_lens = (cu_seqlens[1:] - cu_seqlens[:-1]).cuda()
|
|
||||||
max_seqlen = seq_lens.max().item()
|
max_seqlen = seq_lens.max().item()
|
||||||
|
|
||||||
context_attention_fwd(
|
context_attention_fwd(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
output,
|
output,
|
||||||
cu_seqlens.cuda(),
|
cu_seqlens.cuda(),
|
||||||
seq_lens,
|
seq_lens.cuda(),
|
||||||
max_seqlen,
|
max_seqlen,
|
||||||
is_causal=False,
|
is_causal=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.use_qkv_parallel:
|
|
||||||
|
|
||||||
# [b * s, head, head_dim] --> [b, s, head * head_dim]
|
|
||||||
output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz)
|
|
||||||
|
|
||||||
# [b, s, head, head_dim] --> [b, s, head, head_dim]
|
|
||||||
output, _ = self.proj(output)
|
|
||||||
else:
|
|
||||||
# [b * s, head, head_dim] --> [b, s, head, head_dim]
|
|
||||||
context_layer = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
|
||||||
|
|
||||||
# [s, b, num_heads * head_size]
|
|
||||||
context_layer = rearrange(
|
|
||||||
context_layer, "b s h d -> s b (h d)"
|
|
||||||
).contiguous()
|
|
||||||
|
|
||||||
# [s, b, num_heads * head_size] --> [s, b, num_heads * head_size]
|
|
||||||
output, _ = self.proj(context_layer)
|
|
||||||
|
|
||||||
output = output.view(bsz, s, -1)
|
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -240,6 +240,7 @@ class MllamaImageProcessor(BaseImageProcessor):
|
|||||||
class MiniCPMVImageProcessor(BaseImageProcessor):
|
class MiniCPMVImageProcessor(BaseImageProcessor):
|
||||||
def __init__(self, hf_config, server_args, _processor):
|
def __init__(self, hf_config, server_args, _processor):
|
||||||
super().__init__(hf_config, server_args, _processor)
|
super().__init__(hf_config, server_args, _processor)
|
||||||
|
self.IMAGE_TOKEN = "(<image>./</image>)"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _process_images_task(images, input_text):
|
def _process_images_task(images, input_text):
|
||||||
@@ -271,7 +272,7 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
|||||||
async def process_images_async(
|
async def process_images_async(
|
||||||
self,
|
self,
|
||||||
image_data: List[Union[str, bytes]],
|
image_data: List[Union[str, bytes]],
|
||||||
input_text,
|
input_ids,
|
||||||
request_obj,
|
request_obj,
|
||||||
max_req_input_len,
|
max_req_input_len,
|
||||||
):
|
):
|
||||||
@@ -282,28 +283,49 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
|||||||
image_data = [image_data]
|
image_data = [image_data]
|
||||||
|
|
||||||
image_hashes, image_sizes = [], []
|
image_hashes, image_sizes = [], []
|
||||||
raw_images = []
|
all_frames = []
|
||||||
IMAGE_TOKEN = "(<image>./</image>)"
|
|
||||||
|
|
||||||
# roughly calculate the max number of frames
|
# roughly calculate the max number of frames under the max_req_input_len limit
|
||||||
# TODO: the process should be applied to all the visual inputs
|
|
||||||
def calculate_max_num_frames() -> int:
|
def calculate_max_num_frames() -> int:
|
||||||
# Model-specific
|
# Model-specific
|
||||||
NUM_TOKEN_PER_FRAME = 330
|
NUM_TOKEN_PER_FRAME = 330
|
||||||
|
|
||||||
ret = (max_req_input_len - len(input_text)) // NUM_TOKEN_PER_FRAME
|
ret = (max_req_input_len - len(input_ids)) // NUM_TOKEN_PER_FRAME
|
||||||
return min(ret, 100)
|
return min(ret, 100)
|
||||||
|
|
||||||
# if cuda OOM set a smaller number
|
|
||||||
MAX_NUM_FRAMES = calculate_max_num_frames()
|
MAX_NUM_FRAMES = calculate_max_num_frames()
|
||||||
print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}")
|
|
||||||
|
|
||||||
def encode_video(video_path):
|
# print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}")
|
||||||
|
|
||||||
|
def get_estimated_frames_list():
|
||||||
|
"""
|
||||||
|
estimate the total frame count from all visual input
|
||||||
|
"""
|
||||||
|
# Before processing inputs
|
||||||
|
estimated_frames_list = []
|
||||||
|
for image in image_data:
|
||||||
|
if isinstance(image, str) and image.startswith("video:"):
|
||||||
|
path = image[len("video:") :]
|
||||||
|
# Estimate frames for the video
|
||||||
|
vr = VideoReader(path, ctx=cpu(0))
|
||||||
|
num_frames = len(vr)
|
||||||
|
else:
|
||||||
|
# For images, each contributes one frame
|
||||||
|
num_frames = 1
|
||||||
|
estimated_frames_list.append(num_frames)
|
||||||
|
|
||||||
|
return estimated_frames_list
|
||||||
|
|
||||||
|
estimated_frames_list = get_estimated_frames_list()
|
||||||
|
total_frame_count = sum(estimated_frames_list)
|
||||||
|
scaling_factor = min(1.0, MAX_NUM_FRAMES / total_frame_count)
|
||||||
|
|
||||||
|
def encode_video(video_path, frame_count_limit=None):
|
||||||
if not os.path.exists(video_path):
|
if not os.path.exists(video_path):
|
||||||
logger.error(f"Video {video_path} does not exist")
|
logger.error(f"Video {video_path} does not exist")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
if MAX_NUM_FRAMES == 0:
|
if frame_count_limit == 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def uniform_sample(l, n):
|
def uniform_sample(l, n):
|
||||||
@@ -314,45 +336,63 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
|||||||
vr = VideoReader(video_path, ctx=cpu(0))
|
vr = VideoReader(video_path, ctx=cpu(0))
|
||||||
sample_fps = round(vr.get_avg_fps() / 1) # FPS
|
sample_fps = round(vr.get_avg_fps() / 1) # FPS
|
||||||
frame_idx = [i for i in range(0, len(vr), sample_fps)]
|
frame_idx = [i for i in range(0, len(vr), sample_fps)]
|
||||||
if len(frame_idx) > MAX_NUM_FRAMES:
|
if frame_count_limit is not None and len(frame_idx) > frame_count_limit:
|
||||||
frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
|
frame_idx = uniform_sample(frame_idx, frame_count_limit)
|
||||||
frames = vr.get_batch(frame_idx).asnumpy()
|
frames = vr.get_batch(frame_idx).asnumpy()
|
||||||
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
|
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
|
||||||
return frames
|
return frames
|
||||||
|
|
||||||
if isinstance(input_text, list):
|
if isinstance(input_ids, list):
|
||||||
assert len(input_text) and isinstance(input_text[0], int)
|
assert len(input_ids) and isinstance(input_ids[0], int)
|
||||||
input_text = self._processor.tokenizer.decode(input_text)
|
input_text = self._processor.tokenizer.decode(input_ids)
|
||||||
|
else:
|
||||||
|
input_text = input_ids
|
||||||
# MiniCPMV requires each frame of video as a single image token
|
# MiniCPMV requires each frame of video as a single image token
|
||||||
text_parts = input_text.split(IMAGE_TOKEN)
|
text_parts = input_text.split(self.IMAGE_TOKEN)
|
||||||
new_text_parts = []
|
new_text_parts = []
|
||||||
|
|
||||||
for image_index, image in enumerate(image_data):
|
# Process each input with allocated frames
|
||||||
|
for image_index, (image, estimated_frames) in enumerate(
|
||||||
|
zip(image_data, estimated_frames_list)
|
||||||
|
):
|
||||||
|
if len(all_frames) >= MAX_NUM_FRAMES:
|
||||||
|
frames_to_process = 0
|
||||||
|
else:
|
||||||
|
frames_to_process = max(1, int(estimated_frames * scaling_factor))
|
||||||
|
|
||||||
|
if frames_to_process == 0:
|
||||||
|
frames = []
|
||||||
|
else:
|
||||||
try:
|
try:
|
||||||
if isinstance(image, str) and image.startswith("video:"):
|
if isinstance(image, str) and image.startswith("video:"):
|
||||||
path = image[len("video:") :]
|
path = image[len("video:") :]
|
||||||
frames = encode_video(path)
|
frames = encode_video(path, frame_count_limit=frames_to_process)
|
||||||
else:
|
else:
|
||||||
raw_image, size = load_image(image)
|
raw_image, _size = load_image(image)
|
||||||
frames = [raw_image]
|
frames = [raw_image]
|
||||||
if len(frames) == 0:
|
if len(frames) == 0:
|
||||||
continue
|
continue
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
print(e)
|
print(e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
image_sizes += frames[0].size * len(frames)
|
image_sizes += frames[0].size * len(frames)
|
||||||
image_hashes += [hash(image)] * len(frames)
|
image_hashes += [hash(image)] * len(frames)
|
||||||
raw_images += frames
|
all_frames += frames
|
||||||
|
|
||||||
|
assert frames_to_process == len(frames)
|
||||||
|
|
||||||
new_text_parts.append(text_parts[image_index])
|
new_text_parts.append(text_parts[image_index])
|
||||||
new_text_parts.append(IMAGE_TOKEN * len(frames))
|
|
||||||
|
if frames_to_process != 0:
|
||||||
|
new_text_parts.append(self.IMAGE_TOKEN * len(frames))
|
||||||
|
|
||||||
new_text_parts.append(text_parts[-1])
|
new_text_parts.append(text_parts[-1])
|
||||||
|
|
||||||
input_text = "".join(new_text_parts)
|
input_text = "".join(new_text_parts)
|
||||||
if len(raw_images) == 0:
|
|
||||||
|
if len(all_frames) == 0:
|
||||||
return None
|
return None
|
||||||
res = await self._process_images(images=raw_images, input_text=input_text)
|
res = await self._process_images(images=all_frames, input_text=input_text)
|
||||||
pixel_values = res["pixel_values"]
|
pixel_values = res["pixel_values"]
|
||||||
tgt_sizes = res["tgt_sizes"]
|
tgt_sizes = res["tgt_sizes"]
|
||||||
input_ids = res["input_ids"]
|
input_ids = res["input_ids"]
|
||||||
@@ -364,7 +404,6 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
|||||||
if tokenizer.slice_start_id:
|
if tokenizer.slice_start_id:
|
||||||
slice_start_id = [tokenizer.slice_start_id]
|
slice_start_id = [tokenizer.slice_start_id]
|
||||||
slice_end_id = [tokenizer.slice_end_id]
|
slice_end_id = [tokenizer.slice_end_id]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids.flatten().tolist(),
|
"input_ids": input_ids.flatten().tolist(),
|
||||||
"pixel_values": pixel_values,
|
"pixel_values": pixel_values,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
# Adapted from
|
# Adapted from
|
||||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||||
# Copyright 2023 The vLLM team.
|
# Copyright 2023 The SGLang team.
|
||||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
@@ -20,7 +20,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
|
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
|
||||||
from functools import cached_property, partial
|
from functools import partial
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
@@ -33,16 +33,13 @@ from typing import (
|
|||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.types
|
import torch.types
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn.init import trunc_normal_
|
from torch.nn.init import trunc_normal_
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
from vllm.model_executor.layers.resampler import get_2d_sincos_pos_embed
|
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
|
||||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
||||||
|
|
||||||
from sglang.srt.distributed import divide, get_tensor_model_parallel_world_size
|
from sglang.srt.distributed import divide, get_tensor_model_parallel_world_size
|
||||||
from sglang.srt.layers.activation import get_act_fn
|
from sglang.srt.layers.activation import get_act_fn
|
||||||
@@ -63,6 +60,88 @@ from sglang.srt.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
|
|||||||
RawImageType = Union[Image.Image, torch.Tensor]
|
RawImageType = Union[Image.Image, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
# sin/cos positional embedding helpers are adapted from:
|
||||||
|
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
|
||||||
|
def get_1d_sincos_pos_embed_from_grid(
|
||||||
|
embed_dim: int, pos: np.ndarray, version: Tuple[int, int] = (2, 0)
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
embed_dim: output dimension for each position
|
||||||
|
pos: a list of positions to be encoded: size (M,) / (H, W)
|
||||||
|
out: (M, D) / (H, W, D)
|
||||||
|
"""
|
||||||
|
assert embed_dim % 2 == 0
|
||||||
|
omega = np.arange(embed_dim // 2, dtype=np.float32)
|
||||||
|
omega /= embed_dim / 2.0
|
||||||
|
omega = 1.0 / 10000**omega # (D/2,)
|
||||||
|
|
||||||
|
if version == (2, 0):
|
||||||
|
pos = pos.reshape(-1) # (M,)
|
||||||
|
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
||||||
|
emb_sin = np.sin(out) # (M, D/2)
|
||||||
|
emb_cos = np.cos(out) # (M, D/2)
|
||||||
|
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||||
|
else:
|
||||||
|
out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product
|
||||||
|
emb_sin = np.sin(out) # (H, W, D/2)
|
||||||
|
emb_cos = np.cos(out) # (H, W, D/2)
|
||||||
|
emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
def get_2d_sincos_pos_embed_from_grid(
|
||||||
|
embed_dim: int, grid: np.ndarray, version: Tuple[int, int] = (2, 0)
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert embed_dim % 2 == 0
|
||||||
|
|
||||||
|
# use half of dimensions to encode grid_h
|
||||||
|
emb_h = get_1d_sincos_pos_embed_from_grid(
|
||||||
|
embed_dim // 2, grid[0], version
|
||||||
|
) # (H*W, D/2) or (H, W, D/2)
|
||||||
|
emb_w = get_1d_sincos_pos_embed_from_grid(
|
||||||
|
embed_dim // 2, grid[1], version
|
||||||
|
) # (H*W, D/2) or (H, W, D/2)
|
||||||
|
|
||||||
|
if version == (2, 0):
|
||||||
|
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||||
|
else:
|
||||||
|
emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
def get_2d_sincos_pos_embed(
|
||||||
|
embed_dim: int,
|
||||||
|
grid_size: Union[int, Tuple[int, int]],
|
||||||
|
cls_token: bool = False,
|
||||||
|
version: Tuple[int, int] = (2, 0),
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
grid_size: int of the grid height and width
|
||||||
|
return:
|
||||||
|
pos_embed: [grid_size*grid_size, embed_dim] or
|
||||||
|
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||||
|
"""
|
||||||
|
if isinstance(grid_size, int):
|
||||||
|
grid_h_size, grid_w_size = grid_size, grid_size
|
||||||
|
else:
|
||||||
|
grid_h_size, grid_w_size = grid_size[0], grid_size[1]
|
||||||
|
|
||||||
|
grid_h = np.arange(grid_h_size, dtype=np.float32)
|
||||||
|
grid_w = np.arange(grid_w_size, dtype=np.float32)
|
||||||
|
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||||
|
grid = np.stack(grid, axis=0)
|
||||||
|
assert isinstance(grid, np.ndarray) and grid.shape == (2, grid_h_size, grid_w_size)
|
||||||
|
|
||||||
|
if version == (2, 0):
|
||||||
|
grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
|
||||||
|
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
|
||||||
|
if cls_token:
|
||||||
|
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
||||||
|
else:
|
||||||
|
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
|
||||||
|
return pos_embed
|
||||||
|
|
||||||
|
|
||||||
class Idefics2VisionMLP(nn.Module):
|
class Idefics2VisionMLP(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -116,6 +195,10 @@ class Idefics2EncoderLayer(nn.Module):
|
|||||||
projection_size=config.intermediate_size,
|
projection_size=config.intermediate_size,
|
||||||
use_qkv_parallel=True,
|
use_qkv_parallel=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
dropout=config.attention_dropout,
|
||||||
|
use_context_forward=False,
|
||||||
|
use_full_precision_softmax=True,
|
||||||
|
flatten_batch=False,
|
||||||
prefix=f"{prefix}.self_attn",
|
prefix=f"{prefix}.self_attn",
|
||||||
)
|
)
|
||||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||||
@@ -126,7 +209,6 @@ class Idefics2EncoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -136,11 +218,8 @@ class Idefics2EncoderLayer(nn.Module):
|
|||||||
"""
|
"""
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.layer_norm1(hidden_states)
|
hidden_states = self.layer_norm1(hidden_states)
|
||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(hidden_states, cu_seqlens=cu_seqlens)
|
||||||
hidden_states,
|
|
||||||
cu_seqlens=cu_seqlens,
|
|
||||||
# , forward_batch=forward_batch
|
|
||||||
)
|
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.layer_norm2(hidden_states)
|
hidden_states = self.layer_norm2(hidden_states)
|
||||||
@@ -181,7 +260,6 @@ class Idefics2Encoder(nn.Module):
|
|||||||
self,
|
self,
|
||||||
inputs_embeds: torch.Tensor,
|
inputs_embeds: torch.Tensor,
|
||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@@ -195,7 +273,8 @@ class Idefics2Encoder(nn.Module):
|
|||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
for encoder_layer in self.layers:
|
for encoder_layer in self.layers:
|
||||||
layer_outputs = encoder_layer(
|
layer_outputs = encoder_layer(
|
||||||
hidden_states, cu_seqlens=cu_seqlens, forward_batch=forward_batch
|
hidden_states,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
)
|
)
|
||||||
hidden_states = layer_outputs
|
hidden_states = layer_outputs
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@@ -232,19 +311,14 @@ class Idefics2VisionEmbeddings(nn.Module):
|
|||||||
self.num_positions = self.num_patches
|
self.num_positions = self.num_patches
|
||||||
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
||||||
|
|
||||||
def forward(
|
def get_position_ids(
|
||||||
self,
|
self,
|
||||||
pixel_values: torch.FloatTensor,
|
pixel_values: torch.FloatTensor,
|
||||||
patch_attention_mask: torch.BoolTensor,
|
patch_attention_mask: torch.BoolTensor,
|
||||||
tgt_sizes: Optional[torch.IntTensor] = None,
|
tgt_sizes: Optional[torch.IntTensor] = None,
|
||||||
) -> torch.Tensor:
|
):
|
||||||
batch_size, _, max_im_h, max_im_w = pixel_values.shape
|
batch_size, _, max_im_h, max_im_w = pixel_values.shape
|
||||||
target_dtype = self.patch_embedding.weight.dtype
|
|
||||||
pixel_values = pixel_values.to(
|
|
||||||
device=self.patch_embedding.weight.device, dtype=target_dtype
|
|
||||||
)
|
|
||||||
patch_embeds = self.patch_embedding(pixel_values)
|
|
||||||
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
|
||||||
max_nb_patches_h, max_nb_patches_w = (
|
max_nb_patches_h, max_nb_patches_w = (
|
||||||
max_im_h // self.patch_size,
|
max_im_h // self.patch_size,
|
||||||
max_im_w // self.patch_size,
|
max_im_w // self.patch_size,
|
||||||
@@ -277,6 +351,24 @@ class Idefics2VisionEmbeddings(nn.Module):
|
|||||||
).flatten()
|
).flatten()
|
||||||
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
|
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
|
||||||
position_ids = position_ids.to(self.position_embedding.weight.device)
|
position_ids = position_ids.to(self.position_embedding.weight.device)
|
||||||
|
return position_ids
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.FloatTensor,
|
||||||
|
patch_attention_mask: torch.BoolTensor,
|
||||||
|
tgt_sizes: Optional[torch.IntTensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
target_dtype = self.patch_embedding.weight.dtype
|
||||||
|
pixel_values = pixel_values.to(
|
||||||
|
device=self.patch_embedding.weight.device, dtype=target_dtype
|
||||||
|
)
|
||||||
|
patch_embeds = self.patch_embedding(pixel_values)
|
||||||
|
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
||||||
|
position_ids = self.get_position_ids(
|
||||||
|
pixel_values, patch_attention_mask, tgt_sizes
|
||||||
|
)
|
||||||
|
|
||||||
embeddings = embeddings + self.position_embedding(position_ids)
|
embeddings = embeddings + self.position_embedding(position_ids)
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
@@ -287,7 +379,6 @@ class Idefics2VisionTransformer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -302,8 +393,6 @@ class Idefics2VisionTransformer(nn.Module):
|
|||||||
|
|
||||||
def compute_cu_seqlens(self, tgt_sizes: torch.Tensor) -> torch.Tensor:
|
def compute_cu_seqlens(self, tgt_sizes: torch.Tensor) -> torch.Tensor:
|
||||||
patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] # shape: (batch_size,)
|
patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] # shape: (batch_size,)
|
||||||
|
|
||||||
# 做 prefix sum 来得到 cu_seqlens,注意在最前面插一个 0 作为 offset
|
|
||||||
cu_seqlens = torch.cat(
|
cu_seqlens = torch.cat(
|
||||||
[
|
[
|
||||||
torch.tensor([0], device=patch_len.device, dtype=torch.int32),
|
torch.tensor([0], device=patch_len.device, dtype=torch.int32),
|
||||||
@@ -316,19 +405,18 @@ class Idefics2VisionTransformer(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
pixel_values,
|
pixel_values,
|
||||||
forward_batch: ForwardBatch,
|
|
||||||
patch_attention_mask: Optional[torch.BoolTensor] = None,
|
patch_attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
tgt_sizes: Optional[torch.IntTensor] = None,
|
tgt_sizes: Optional[torch.IntTensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.embeddings(
|
hidden_states = self.embeddings(
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
patch_attention_mask=patch_attention_mask,
|
patch_attention_mask=patch_attention_mask,
|
||||||
# forward_batch=forward_batch,
|
|
||||||
tgt_sizes=tgt_sizes,
|
tgt_sizes=tgt_sizes,
|
||||||
)
|
)
|
||||||
cu_seqlens = self.compute_cu_seqlens(tgt_sizes)
|
cu_seqlens = self.compute_cu_seqlens(tgt_sizes)
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
hidden_states, cu_seqlens=cu_seqlens, forward_batch=forward_batch
|
hidden_states,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
)
|
)
|
||||||
last_hidden_state = self.post_layernorm(encoder_outputs)
|
last_hidden_state = self.post_layernorm(encoder_outputs)
|
||||||
return last_hidden_state
|
return last_hidden_state
|
||||||
@@ -573,14 +661,12 @@ class MiniCPMVBaseModel(nn.Module):
|
|||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
# multimodal_config = config.model_config.multimodal_config
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# All MiniCPM-V models disable `tie_word_embeddings` but
|
# All MiniCPM-V models disable `tie_word_embeddings` but
|
||||||
# `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot
|
# `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot
|
||||||
# check `tie_word_embeddings` until vLLM integrate MiniCPM-V model
|
# check `tie_word_embeddings` until SGLang integrate MiniCPM-V model
|
||||||
# and config class
|
# and config class
|
||||||
self.config = config
|
self.config = config
|
||||||
# self.multimodal_config = multimodal_config
|
|
||||||
|
|
||||||
self.version = get_version_by_config(self.config)
|
self.version = get_version_by_config(self.config)
|
||||||
self.llm = self.init_llm(config=config, quant_config=quant_config)
|
self.llm = self.init_llm(config=config, quant_config=quant_config)
|
||||||
@@ -598,13 +684,6 @@ class MiniCPMVBaseModel(nn.Module):
|
|||||||
|
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def sampler(self):
|
|
||||||
if hasattr(self.llm, "sampler"):
|
|
||||||
return self.llm.sampler
|
|
||||||
|
|
||||||
return get_sampler()
|
|
||||||
|
|
||||||
def _get_image_bounds(
|
def _get_image_bounds(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@@ -666,7 +745,6 @@ class MiniCPMVBaseModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
image_inputs: Optional[MiniCPMVImageInputs],
|
image_inputs: Optional[MiniCPMVImageInputs],
|
||||||
forward_batch: ForwardBatch,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids)
|
vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
@@ -680,10 +758,7 @@ class MiniCPMVBaseModel(nn.Module):
|
|||||||
.to(vlm_embedding.device)
|
.to(vlm_embedding.device)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
vision_hidden_states = self.get_vision_hidden_states(
|
vision_hidden_states = self.get_vision_hidden_states(image_inputs)
|
||||||
forward_batch, image_inputs
|
|
||||||
)
|
|
||||||
|
|
||||||
# See NOTE in _parse_and_validate_inputs
|
# See NOTE in _parse_and_validate_inputs
|
||||||
image_bounds = image_inputs["image_bounds"]
|
image_bounds = image_inputs["image_bounds"]
|
||||||
if len(image_bounds) > 0:
|
if len(image_bounds) > 0:
|
||||||
@@ -693,6 +768,7 @@ class MiniCPMVBaseModel(nn.Module):
|
|||||||
for start, end in image_bounds.tolist()
|
for start, end in image_bounds.tolist()
|
||||||
]
|
]
|
||||||
).to(vlm_embedding.device)
|
).to(vlm_embedding.device)
|
||||||
|
|
||||||
vlm_embedding.scatter_(
|
vlm_embedding.scatter_(
|
||||||
0,
|
0,
|
||||||
image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]),
|
image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]),
|
||||||
@@ -839,7 +915,7 @@ class MiniCPMVBaseModel(nn.Module):
|
|||||||
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
|
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
|
||||||
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
|
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
|
||||||
|
|
||||||
vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs, forward_batch)
|
vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs)
|
||||||
|
|
||||||
# always pass the input via `inputs_embeds`
|
# always pass the input via `inputs_embeds`
|
||||||
# to make sure the computation graph is consistent
|
# to make sure the computation graph is consistent
|
||||||
@@ -857,29 +933,6 @@ class MiniCPMVBaseModel(nn.Module):
|
|||||||
input_ids, hidden_states, self.llm.lm_head, forward_batch
|
input_ids, hidden_states, self.llm.lm_head, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
def compute_logits(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
sampling_metadata: SamplingMetadata,
|
|
||||||
) -> Optional[torch.Tensor]:
|
|
||||||
return self.llm.compute_logits(hidden_states, sampling_metadata)
|
|
||||||
|
|
||||||
def sample(
|
|
||||||
self,
|
|
||||||
logits: torch.Tensor,
|
|
||||||
sampling_metadata: SamplingMetadata,
|
|
||||||
) -> Optional[SamplerOutput]:
|
|
||||||
next_tokens = self.sampler(logits, sampling_metadata)
|
|
||||||
return next_tokens
|
|
||||||
|
|
||||||
def get_mm_mapping(self) -> MultiModelKeys:
|
|
||||||
"""
|
|
||||||
Get the module prefix in multimodal models
|
|
||||||
"""
|
|
||||||
return MultiModelKeys.from_string_field(
|
|
||||||
language_model="llm", connector="resampler", tower_model="vpm"
|
|
||||||
)
|
|
||||||
|
|
||||||
def init_llm(
|
def init_llm(
|
||||||
self,
|
self,
|
||||||
config: Qwen2Config,
|
config: Qwen2Config,
|
||||||
@@ -910,9 +963,7 @@ class MiniCPMVBaseModel(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_vision_hidden_states(
|
def get_vision_hidden_states(self, data: MiniCPMVImageInputs) -> torch.Tensor:
|
||||||
self, forward_batch: ForwardBatch, data: MiniCPMVImageInputs
|
|
||||||
) -> torch.Tensor:
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@@ -1019,7 +1070,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
|
|||||||
|
|
||||||
def get_vision_hidden_states(
|
def get_vision_hidden_states(
|
||||||
self,
|
self,
|
||||||
forward_batch: ForwardBatch,
|
|
||||||
data: MiniCPMVImageInputs,
|
data: MiniCPMVImageInputs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
pixel_values = data["data"]
|
pixel_values = data["data"]
|
||||||
@@ -1042,15 +1092,18 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
|
|||||||
patch_attn_mask = torch.zeros(
|
patch_attn_mask = torch.zeros(
|
||||||
(B, 1, max_patches), dtype=torch.bool, device=device
|
(B, 1, max_patches), dtype=torch.bool, device=device
|
||||||
)
|
)
|
||||||
for i in range(B):
|
|
||||||
patch_attn_mask[i, 0, : tgt_sizes[i][0] * tgt_sizes[i][1]] = True
|
tgt_sizes_tensor = tgt_sizes.clone().to(device=patch_attn_mask.device)
|
||||||
|
mask_shapes = tgt_sizes_tensor[:, 0] * tgt_sizes_tensor[:, 1]
|
||||||
|
patch_attn_mask[:, 0, :] = torch.arange(
|
||||||
|
patch_attn_mask.size(2), device=patch_attn_mask.device
|
||||||
|
).unsqueeze(0) < mask_shapes.unsqueeze(1)
|
||||||
|
|
||||||
vision_embedding = self.vpm(
|
vision_embedding = self.vpm(
|
||||||
all_pixel_values.type(dtype),
|
all_pixel_values.type(dtype),
|
||||||
forward_batch=forward_batch,
|
|
||||||
patch_attention_mask=patch_attn_mask,
|
patch_attention_mask=patch_attn_mask,
|
||||||
tgt_sizes=tgt_sizes,
|
tgt_sizes=tgt_sizes,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.resampler(vision_embedding, tgt_sizes)
|
return self.resampler(vision_embedding, tgt_sizes)
|
||||||
|
|
||||||
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
||||||
@@ -1138,7 +1191,7 @@ class MiniCPMV:
|
|||||||
"""
|
"""
|
||||||
Different versions of MiniCPMV use different visual encoders and LLMs,
|
Different versions of MiniCPMV use different visual encoders and LLMs,
|
||||||
which is not conducive to the current integration logic of LoRA and
|
which is not conducive to the current integration logic of LoRA and
|
||||||
bitsandbytes in vLLM. Therefore, it is necessary to separate them.
|
bitsandbytes in SGLang. Therefore, it is necessary to separate them.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Ensure that the LoRA support check passes when the class is not
|
# Ensure that the LoRA support check passes when the class is not
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from transformers.models.mllama.modeling_mllama import (
|
|||||||
import sglang.srt.distributed.parallel_state as ps
|
import sglang.srt.distributed.parallel_state as ps
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||||
from sglang.srt.layers.activation import get_act_fn
|
from sglang.srt.layers.activation import get_act_fn
|
||||||
|
from sglang.srt.layers.attention.vision import VisionAttention
|
||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.linear import (
|
from sglang.srt.layers.linear import (
|
||||||
ColumnParallelLinear,
|
ColumnParallelLinear,
|
||||||
@@ -145,61 +146,6 @@ class MllamaPrecomputedPositionEmbedding(nn.Module):
|
|||||||
return hidden_state
|
return hidden_state
|
||||||
|
|
||||||
|
|
||||||
class MllamaVisionSdpaAttention(nn.Module):
|
|
||||||
def __init__(self, config: config_mllama.MllamaVisionConfig):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
model_parallel_size = get_tensor_model_parallel_world_size()
|
|
||||||
self.embed_dim = config.hidden_size
|
|
||||||
self.num_heads = config.attention_heads
|
|
||||||
self.head_dim = config.hidden_size // config.attention_heads
|
|
||||||
self.num_local_heads = self.num_heads // model_parallel_size
|
|
||||||
self.q_size = self.num_local_heads * self.head_dim
|
|
||||||
self.kv_size = self.num_local_heads * self.head_dim
|
|
||||||
|
|
||||||
self.qkv_proj = QKVParallelLinear(
|
|
||||||
self.embed_dim,
|
|
||||||
self.head_dim,
|
|
||||||
self.num_heads,
|
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
self.o_proj = RowParallelLinear(
|
|
||||||
self.num_heads * self.head_dim,
|
|
||||||
self.embed_dim,
|
|
||||||
bias=False,
|
|
||||||
input_is_parallel=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_state: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
qkv, _ = self.qkv_proj(hidden_state)
|
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
|
||||||
q = q.view(
|
|
||||||
q.shape[0], q.shape[1], self.num_local_heads, self.head_dim
|
|
||||||
).transpose(1, 2)
|
|
||||||
k = k.view(
|
|
||||||
k.shape[0], k.shape[1], self.num_local_heads, self.head_dim
|
|
||||||
).transpose(1, 2)
|
|
||||||
v = v.view(
|
|
||||||
v.shape[0], v.shape[1], self.num_local_heads, self.head_dim
|
|
||||||
).transpose(1, 2)
|
|
||||||
|
|
||||||
# TODO: remove padding in image encoder
|
|
||||||
attn_output = F.scaled_dot_product_attention(
|
|
||||||
q, k, v, attn_mask=attention_mask, dropout_p=0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
||||||
attn_output = attn_output.reshape(
|
|
||||||
attn_output.shape[0], attn_output.shape[1], -1
|
|
||||||
)
|
|
||||||
output, _ = self.o_proj(attn_output)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class MllamaVisionMLP(nn.Module):
|
class MllamaVisionMLP(nn.Module):
|
||||||
def __init__(self, config, quant_config: Optional[QuantizationConfig] = None):
|
def __init__(self, config, quant_config: Optional[QuantizationConfig] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -237,7 +183,17 @@ class MllamaVisionEncoderLayer(nn.Module):
|
|||||||
self.is_gated = is_gated
|
self.is_gated = is_gated
|
||||||
self.intermediate_size = config.intermediate_size
|
self.intermediate_size = config.intermediate_size
|
||||||
|
|
||||||
self.self_attn = MllamaVisionSdpaAttention(config)
|
self.self_attn = VisionAttention(
|
||||||
|
self.hidden_size,
|
||||||
|
self.num_attention_heads,
|
||||||
|
self.hidden_size,
|
||||||
|
use_qkv_parallel=True,
|
||||||
|
quant_config=None,
|
||||||
|
dropout=0.0,
|
||||||
|
use_context_forward=False,
|
||||||
|
use_full_precision_softmax=False,
|
||||||
|
flatten_batch=False,
|
||||||
|
)
|
||||||
self.mlp = MllamaVisionMLP(config)
|
self.mlp = MllamaVisionMLP(config)
|
||||||
|
|
||||||
self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps)
|
self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps)
|
||||||
@@ -992,6 +948,10 @@ class MllamaForConditionalGeneration(nn.Module):
|
|||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
if "vision_model" in name:
|
||||||
|
# adapt to VisionAttention
|
||||||
|
name = name.replace("self_attn.o_proj", "self_attn.proj")
|
||||||
|
|
||||||
param = params_dict.pop(name)
|
param = params_dict.pop(name)
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|||||||
@@ -249,6 +249,9 @@ class Qwen2Model(nn.Module):
|
|||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
if hasattr(self.config, "scale_emb"):
|
||||||
|
return self.embed_tokens(input_ids) * self.config.scale_emb
|
||||||
|
else:
|
||||||
return self.embed_tokens(input_ids)
|
return self.embed_tokens(input_ids)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@@ -30,12 +30,10 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange
|
||||||
from vllm.model_executor.layers.activation import QuickGELU
|
from vllm.model_executor.layers.activation import QuickGELU
|
||||||
|
|
||||||
from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
|
from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
|
||||||
from sglang.srt.distributed import parallel_state
|
|
||||||
from sglang.srt.distributed import utils as dist_utils
|
|
||||||
from sglang.srt.hf_transformers_utils import get_processor
|
from sglang.srt.hf_transformers_utils import get_processor
|
||||||
from sglang.srt.layers.attention.vision import VisionAttention
|
from sglang.srt.layers.attention.vision import VisionAttention
|
||||||
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
||||||
@@ -118,6 +116,7 @@ class Qwen2VisionBlock(nn.Module):
|
|||||||
mlp_ratio: float,
|
mlp_ratio: float,
|
||||||
act_layer: Type[nn.Module] = QuickGELU,
|
act_layer: Type[nn.Module] = QuickGELU,
|
||||||
norm_layer: Type[nn.Module] = None,
|
norm_layer: Type[nn.Module] = None,
|
||||||
|
attn_implementation: Optional[str] = "sdpa",
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -126,12 +125,24 @@ class Qwen2VisionBlock(nn.Module):
|
|||||||
self.norm1 = norm_layer(dim)
|
self.norm1 = norm_layer(dim)
|
||||||
self.norm2 = norm_layer(dim)
|
self.norm2 = norm_layer(dim)
|
||||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||||
|
if attn_implementation == "sdpa":
|
||||||
|
use_context_forward = False
|
||||||
|
use_full_precision_softmax = False
|
||||||
|
elif attn_implementation == "flash_attention_2":
|
||||||
|
use_full_precision_softmax = False
|
||||||
|
use_context_forward = True
|
||||||
|
elif attn_implementation == "eager":
|
||||||
|
use_full_precision_softmax = True
|
||||||
|
use_context_forward = False
|
||||||
|
|
||||||
self.attn = VisionAttention(
|
self.attn = VisionAttention(
|
||||||
embed_dim=dim,
|
embed_dim=dim,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
projection_size=dim,
|
projection_size=dim,
|
||||||
use_qkv_parallel=False,
|
use_qkv_parallel=False,
|
||||||
|
use_context_forward=use_context_forward,
|
||||||
|
use_full_precision_softmax=use_full_precision_softmax,
|
||||||
|
flatten_batch=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.mlp = Qwen2VisionMLP(
|
self.mlp = Qwen2VisionMLP(
|
||||||
@@ -286,7 +297,6 @@ class Qwen2VisionTransformer(nn.Module):
|
|||||||
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
|
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
|
||||||
head_dim = embed_dim // num_heads
|
head_dim = embed_dim // num_heads
|
||||||
self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2)
|
self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2)
|
||||||
|
|
||||||
self.blocks = nn.ModuleList(
|
self.blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
Qwen2VisionBlock(
|
Qwen2VisionBlock(
|
||||||
@@ -294,6 +304,7 @@ class Qwen2VisionTransformer(nn.Module):
|
|||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
mlp_ratio=mlp_ratio,
|
mlp_ratio=mlp_ratio,
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
|
attn_implementation="sdpa",
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
for _ in range(depth)
|
for _ in range(depth)
|
||||||
@@ -482,10 +493,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
opensource models), the shape will be `(3, seq_len)`,
|
opensource models), the shape will be `(3, seq_len)`,
|
||||||
otherwise it will be `(seq_len,).
|
otherwise it will be `(seq_len,).
|
||||||
(Use input_metadata.mrope_positions to replace it)
|
(Use input_metadata.mrope_positions to replace it)
|
||||||
pixel_values: Pixel values to be fed to a model.
|
|
||||||
`None` if no images are passed.
|
|
||||||
image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
|
|
||||||
`None` if no images are passed.
|
|
||||||
"""
|
"""
|
||||||
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
||||||
positions = forward_batch.mrope_positions
|
positions = forward_batch.mrope_positions
|
||||||
@@ -540,15 +547,18 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
num_image_tokens = self.calculate_num_image_tokens(
|
num_image_tokens = self.calculate_num_image_tokens(
|
||||||
image_grid_thws[idx]
|
image_grid_thws[idx]
|
||||||
)
|
)
|
||||||
|
|
||||||
left_idx = start_idx + (image_offset - prefix_len)
|
left_idx = start_idx + (image_offset - prefix_len)
|
||||||
right_idx = (
|
right_idx = (
|
||||||
start_idx + (image_offset - prefix_len) + num_image_tokens
|
start_idx + (image_offset - prefix_len) + num_image_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
inputs_embeds[left_idx:right_idx] = image_embeds[
|
inputs_embeds[left_idx:right_idx] = image_embeds[
|
||||||
image_embeds_offset : image_embeds_offset + num_image_tokens
|
image_embeds_offset : image_embeds_offset + num_image_tokens
|
||||||
]
|
]
|
||||||
image_embeds_offset += num_image_tokens
|
image_embeds_offset += num_image_tokens
|
||||||
|
|
||||||
|
input_ids = None
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
|
|||||||
@@ -444,8 +444,6 @@ def load_image(image_file: Union[str, bytes]):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid image: {image}")
|
raise ValueError(f"Invalid image: {image}")
|
||||||
|
|
||||||
# if image_size is None:
|
|
||||||
# image_size = image.size
|
|
||||||
return image, image_size
|
return image, image_size
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ suites = {
|
|||||||
"test_update_weights_from_disk.py",
|
"test_update_weights_from_disk.py",
|
||||||
"test_update_weights_from_tensor.py",
|
"test_update_weights_from_tensor.py",
|
||||||
"test_vision_chunked_prefill.py",
|
"test_vision_chunked_prefill.py",
|
||||||
|
"test_vision_llm.py",
|
||||||
"test_vision_openai_server.py",
|
"test_vision_openai_server.py",
|
||||||
"test_w8a8_quantization.py",
|
"test_w8a8_quantization.py",
|
||||||
"test_fp8_kvcache.py",
|
"test_fp8_kvcache.py",
|
||||||
@@ -72,7 +73,6 @@ for target_suite_name, target_tests in suites.items():
|
|||||||
tests.remove(target_suite_name)
|
tests.remove(target_suite_name)
|
||||||
tests.extend(target_tests)
|
tests.extend(target_tests)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
arg_parser = argparse.ArgumentParser()
|
arg_parser = argparse.ArgumentParser()
|
||||||
arg_parser.add_argument(
|
arg_parser.add_argument(
|
||||||
|
|||||||
210
test/srt/test_vision_llm.py
Normal file
210
test/srt/test_vision_llm.py
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
"""
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import requests
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import AutoModel, AutoProcessor, AutoTokenizer
|
||||||
|
|
||||||
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
|
from sglang.srt.conversation import generate_chat_conv
|
||||||
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
|
from sglang.srt.openai_api.protocol import ChatCompletionRequest
|
||||||
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
|
||||||
|
MiniCPMV = "openbmb/MiniCPM-V-2_6"
|
||||||
|
|
||||||
|
|
||||||
|
# Test the logits output between HF and SGLang
|
||||||
|
class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.image_url = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
|
||||||
|
cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
cls.model_path = ""
|
||||||
|
cls.chat_template = ""
|
||||||
|
cls.processor = ""
|
||||||
|
response = requests.get(cls.image_url)
|
||||||
|
cls.main_image = Image.open(BytesIO(response.content))
|
||||||
|
|
||||||
|
def compare_outputs(self, sglang_output: torch.Tensor, hf_output: torch.Tensor):
|
||||||
|
# Convert to float32 for numerical stability if needed
|
||||||
|
hf = hf_output.float()
|
||||||
|
sg = sglang_output.float()
|
||||||
|
|
||||||
|
# Basic shape and dtype comparison
|
||||||
|
print("\n=== Basic Properties ===")
|
||||||
|
print(f"Shapes match: {hf.shape == sg.shape}")
|
||||||
|
print(f"HF shape: {hf.shape}, SGLang shape: {sg.shape}")
|
||||||
|
print(f"HF dtype: {hf.dtype}, SGLang dtype: {sg.dtype}")
|
||||||
|
|
||||||
|
# Move tensors to CPU for numpy operations
|
||||||
|
hf_np = hf.cpu().numpy()
|
||||||
|
sg_np = sg.cpu().numpy()
|
||||||
|
|
||||||
|
# Statistical metrics
|
||||||
|
print("\n=== Statistical Metrics ===")
|
||||||
|
print(f"Mean absolute difference: {torch.mean(torch.abs(hf - sg)).item():.6f}")
|
||||||
|
print(f"Max absolute difference: {torch.max(torch.abs(hf - sg)).item():.6f}")
|
||||||
|
print(f"Mean squared error: {torch.mean((hf - sg) ** 2).item():.6f}")
|
||||||
|
print(
|
||||||
|
f"Root mean squared error: {torch.sqrt(torch.mean((hf - sg) ** 2)).item():.6f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cosine similarity (across feature dimension)
|
||||||
|
cos_sim = F.cosine_similarity(hf, sg)
|
||||||
|
print(f"Mean cosine similarity: {torch.mean(cos_sim).item():.6f}")
|
||||||
|
print(f"Min cosine similarity: {torch.min(cos_sim).item():.6f}")
|
||||||
|
|
||||||
|
# Find largest absolute differences
|
||||||
|
print("\n=== Largest Absolute Differences ===")
|
||||||
|
diffs = torch.abs(hf - sg)
|
||||||
|
flat_diffs = diffs.flatten()
|
||||||
|
|
||||||
|
# Get indices of top 10 differences
|
||||||
|
top_k = 10
|
||||||
|
top_values, top_flat_indices = torch.topk(flat_diffs, top_k)
|
||||||
|
|
||||||
|
# Convert flat indices to multidimensional indices
|
||||||
|
top_indices = np.unravel_index(top_flat_indices.cpu().numpy(), diffs.shape)
|
||||||
|
|
||||||
|
print(f"\nTop {top_k} largest absolute differences:")
|
||||||
|
print(
|
||||||
|
"Index".ljust(30)
|
||||||
|
+ "Difference".ljust(15)
|
||||||
|
+ "HF Value".ljust(15)
|
||||||
|
+ "SGLang Value"
|
||||||
|
)
|
||||||
|
print("-" * 75)
|
||||||
|
|
||||||
|
for i in range(top_k):
|
||||||
|
# Get the index tuple for this difference
|
||||||
|
idx = tuple(dim[i] for dim in top_indices)
|
||||||
|
diff_val = top_values[i].item()
|
||||||
|
hf_val = hf[idx].item()
|
||||||
|
sg_val = sg[idx].item()
|
||||||
|
|
||||||
|
# Format the index tuple and values
|
||||||
|
idx_str = str(idx)
|
||||||
|
print(f"{idx_str:<30}{diff_val:<15.6f}{hf_val:<15.6f}{sg_val:.6f}")
|
||||||
|
|
||||||
|
np.testing.assert_allclose(hf_np, sg_np)
|
||||||
|
|
||||||
|
def get_processor_output(self):
|
||||||
|
json_str = f"""
|
||||||
|
{{
|
||||||
|
"model": "{self.model_path}",
|
||||||
|
"messages": [
|
||||||
|
{{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {{
|
||||||
|
"url": "{self.image_url}"
|
||||||
|
}}
|
||||||
|
}},
|
||||||
|
{{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Whats in this picture?"
|
||||||
|
}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
req = ChatCompletionRequest.model_validate_json(json_str)
|
||||||
|
|
||||||
|
conv = generate_chat_conv(req, template_name=self.chat_template)
|
||||||
|
|
||||||
|
text = conv.get_prompt()
|
||||||
|
|
||||||
|
# Process inputs using processor
|
||||||
|
# FIXME: the formal arguments may differ
|
||||||
|
inputs = self.processor(
|
||||||
|
text=[text],
|
||||||
|
images=[self.main_image],
|
||||||
|
return_tensors="pt",
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def get_sglang_model(self):
|
||||||
|
model_runner = ModelRunner(
|
||||||
|
model_config=ModelConfig(self.model_path, model_override_args="{}"),
|
||||||
|
mem_fraction_static=0.8,
|
||||||
|
gpu_id=0,
|
||||||
|
tp_rank=0,
|
||||||
|
tp_size=1,
|
||||||
|
nccl_port=12435,
|
||||||
|
server_args=ServerArgs(
|
||||||
|
model_path=self.model_path,
|
||||||
|
disable_cuda_graph=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return model_runner.model
|
||||||
|
|
||||||
|
|
||||||
|
class TestMiniCPMVLogits(VisionLLMLogitsBase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
super().setUpClass()
|
||||||
|
cls.model_path = MiniCPMV
|
||||||
|
cls.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
cls.model_path, trust_remote_code=True
|
||||||
|
)
|
||||||
|
cls.processor = AutoProcessor.from_pretrained(
|
||||||
|
cls.model_path, trust_remote_code=True
|
||||||
|
)
|
||||||
|
cls.chat_template = "minicpmv"
|
||||||
|
|
||||||
|
cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
cls.model = AutoModel.from_pretrained(
|
||||||
|
cls.model_path, torch_dtype=torch.bfloat16, trust_remote_code=True
|
||||||
|
).eval()
|
||||||
|
cls.model.to(cls.device)
|
||||||
|
|
||||||
|
async def test_encode_output(self):
|
||||||
|
inputs = self.get_processor_output()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
model_inputs = {
|
||||||
|
"input_ids": inputs.input_ids,
|
||||||
|
"image_bound": inputs.image_bound,
|
||||||
|
"pixel_values": inputs.pixel_values,
|
||||||
|
"tgt_sizes": inputs.tgt_sizes,
|
||||||
|
}
|
||||||
|
(hf_output, _) = self.model.get_vllm_embedding(
|
||||||
|
model_inputs,
|
||||||
|
)
|
||||||
|
hf_output = hf_output.squeeze(0)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
model = self.get_sglang_model()
|
||||||
|
input_ids = inputs["input_ids"].to(self.device).flatten()
|
||||||
|
image_inputs = model._parse_and_validate_inputs(
|
||||||
|
input_ids=input_ids,
|
||||||
|
**{
|
||||||
|
"pixel_values": [inputs["pixel_values"]],
|
||||||
|
"tgt_sizes": [inputs["tgt_sizes"]],
|
||||||
|
"im_start_id": [self.tokenizer.im_start_id],
|
||||||
|
"im_end_id": [self.tokenizer.im_end_id],
|
||||||
|
"slice_start_id": [self.tokenizer.slice_start_id],
|
||||||
|
"slice_end_id": [self.tokenizer.slice_end_id],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
(sglang_output, _) = model.get_embedding(
|
||||||
|
input_ids=input_ids, image_inputs=image_inputs
|
||||||
|
)
|
||||||
|
|
||||||
|
self.compare_outputs(sglang_output, hf_output)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -180,7 +180,9 @@ class TestOpenAIVisionServer(unittest.TestCase):
|
|||||||
assert response.usage.total_tokens > 0
|
assert response.usage.total_tokens > 0
|
||||||
|
|
||||||
def prepare_video_messages(self, video_path):
|
def prepare_video_messages(self, video_path):
|
||||||
max_frames_num = 32
|
# the memory consumed by the Vision Attention varies a lot, e.g. blocked qkv vs full-sequence sdpa
|
||||||
|
# the size of the video embeds differs from the `modality` argument when preprocessed
|
||||||
|
max_frames_num = 12
|
||||||
vr = VideoReader(video_path, ctx=cpu(0))
|
vr = VideoReader(video_path, ctx=cpu(0))
|
||||||
total_frame_num = len(vr)
|
total_frame_num = len(vr)
|
||||||
uniform_sampled_frames = np.linspace(
|
uniform_sampled_frames = np.linspace(
|
||||||
|
|||||||
Reference in New Issue
Block a user