refactor: move image processors to separate files (#4229)

This commit is contained in:
Mick
2025-03-12 03:35:35 +08:00
committed by GitHub
parent 0f2a2e3c19
commit ff2ce0b86f
22 changed files with 1085 additions and 955 deletions

View File

@@ -41,7 +41,6 @@ from torch import nn
from torch.nn.init import trunc_normal_
from transformers import PretrainedConfig
from sglang.srt.distributed import divide, get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import get_act_fn
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.linear import (
@@ -51,6 +50,9 @@ from sglang.srt.layers.linear import (
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.multi_modality_padding import (
MultiModalityDataPaddingPatternTokenPairs,
)
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.utils import set_default_torch_dtype
@@ -186,19 +188,16 @@ class Idefics2EncoderLayer(nn.Module):
) -> None:
super().__init__()
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
num_heads_per_partition = divide(self.num_heads, tp_size)
self.self_attn = VisionAttention(
embed_dim=config.hidden_size,
num_heads=num_heads_per_partition,
num_heads=self.num_heads,
projection_size=config.intermediate_size,
use_qkv_parallel=True,
quant_config=quant_config,
dropout=config.attention_dropout,
use_context_forward=False,
use_full_precision_softmax=True,
softmax_in_single_precision=True,
flatten_batch=False,
prefix=add_prefix("self_attn", prefix),
)
@@ -708,21 +707,21 @@ class MiniCPMVBaseModel(nn.Module):
self,
input_ids: torch.Tensor,
pad_values: List[int],
im_start_id: torch.Tensor,
im_end_id: torch.Tensor,
slice_start_id: Optional[torch.Tensor] = None,
slice_end_id: Optional[torch.Tensor] = None,
im_start_id: int,
im_end_id: int,
slice_start_id: Optional[int] = None,
slice_end_id: Optional[int] = None,
) -> torch.Tensor:
"""
Returns a tensor indicating the bounds (start and end token ids) of the images
"""
# All the images in the batch should share the same special image
# bound token ids.
start_cond = input_ids == im_start_id[0]
end_cond = input_ids == im_end_id[0]
start_cond = input_ids == im_start_id
end_cond = input_ids == im_end_id
if slice_start_id is not None:
start_cond |= input_ids == slice_start_id[0]
end_cond |= input_ids == slice_end_id[0]
start_cond |= input_ids == slice_start_id
end_cond |= input_ids == slice_end_id
(image_start_tokens,) = torch.where(start_cond)
image_start_tokens += 1
@@ -733,6 +732,8 @@ class MiniCPMVBaseModel(nn.Module):
if (
len(image_start_tokens) + 1 == len(image_end_tokens)
and input_ids[0] in pad_values
and len(image_start_tokens) != 0
and len(image_end_tokens) != 0
and image_end_tokens[0] < image_start_tokens[0]
):
image_start_tokens = torch.cat(
@@ -897,9 +898,12 @@ class MiniCPMVBaseModel(nn.Module):
forward_batch: ForwardBatch,
**kwargs: Any,
) -> torch.Tensor:
if forward_batch.image_inputs is not None and forward_batch.image_inputs != [
None
]:
if (
forward_batch.image_inputs is not None
and len(forward_batch.image_inputs) > 0
and forward_batch.image_inputs[0] is not None
):
# TODO: bath
kwargs.update(
{
"pixel_values": (
@@ -1135,81 +1139,16 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
return self.resampler(vision_embedding, tgt_sizes)
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
if not isinstance(image_inputs.im_start_id, list) or not isinstance(
image_inputs.im_end_id, list
):
return input_ids
new_input_ids = []
last_idx = 0
image_idx = -1
image_inputs.image_offsets = []
# Get all special token IDs
im_start_id = (
image_inputs.im_start_id[0].item()
if isinstance(image_inputs.im_start_id[0], torch.Tensor)
else image_inputs.im_start_id[0]
)
im_end_id = (
image_inputs.im_end_id[0].item()
if isinstance(image_inputs.im_end_id[0], torch.Tensor)
else image_inputs.im_end_id[0]
)
slice_start_id = (
image_inputs.slice_start_id[0].item()
if isinstance(image_inputs.slice_start_id[0], torch.Tensor)
else image_inputs.slice_start_id[0]
)
slice_end_id = (
image_inputs.slice_end_id[0].item()
if isinstance(image_inputs.slice_end_id[0], torch.Tensor)
else image_inputs.slice_end_id[0]
)
im_start_id: int = image_inputs.im_start_id
im_end_id: int = image_inputs.im_end_id
slice_start_id: int = image_inputs.slice_start_id
slice_end_id: int = image_inputs.slice_end_id
# Find all start and end positions for both types
start_indices = [
i
for i, x in enumerate(input_ids)
if x == im_start_id or x == slice_start_id
]
end_indices = [
i for i, x in enumerate(input_ids) if x == im_end_id or x == slice_end_id
]
media_token_pairs = [(im_start_id, im_end_id), (slice_start_id, slice_end_id)]
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
if len(start_indices) != len(end_indices):
return input_ids
# Process each region (both image and slice)
for start_idx, end_idx in zip(start_indices, end_indices):
# Add non-image tokens before this region
new_input_ids.extend(
input_ids[last_idx : start_idx + 1]
) # include start token
is_image_start = input_ids[start_idx] == im_start_id
if is_image_start:
image_inputs.image_offsets += [start_idx]
image_idx += 1
num_tokens = end_idx - start_idx - 1 # exclude start and end tokens
# Generate pad_ids
pad_values = [image_inputs.pad_values[image_idx]]
pad_ids = pad_values * ((num_tokens + len(pad_values)) // len(pad_values))
pad_ids = pad_ids[:num_tokens]
# Add pad_ids
new_input_ids.extend(pad_ids)
# Update last_idx to after end token
last_idx = end_idx
# Add remaining tokens after last region
new_input_ids.extend(input_ids[last_idx:])
assert len(input_ids) == len(new_input_ids)
return new_input_ids
return pattern.pad_input_tokens(input_ids, image_inputs)
_SUPPORT_VERSION = {(2, 6): MiniCPMV2_6}

View File

@@ -202,7 +202,7 @@ class MllamaVisionEncoderLayer(nn.Module):
quant_config=None,
dropout=0.0,
use_context_forward=False,
use_full_precision_softmax=False,
softmax_in_single_precision=False,
flatten_batch=False,
prefix=add_prefix("self_attn", prefix),
)

View File

@@ -47,6 +47,9 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.multi_modality_padding import (
MultiModalityDataPaddingPatternTokenPairs,
)
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
@@ -121,12 +124,12 @@ class Qwen2_5_VisionBlock(nn.Module):
self.norm2 = Qwen2RMSNorm(dim, eps=1e-6)
if attn_implementation == "sdpa":
use_context_forward = False
use_full_precision_softmax = False
softmax_in_single_precision = False
elif attn_implementation == "flash_attention_2":
use_full_precision_softmax = False
softmax_in_single_precision = False
use_context_forward = True
elif attn_implementation == "eager":
use_full_precision_softmax = True
softmax_in_single_precision = True
use_context_forward = False
self.attn = VisionAttention(
@@ -135,7 +138,7 @@ class Qwen2_5_VisionBlock(nn.Module):
projection_size=dim,
use_qkv_parallel=False,
use_context_forward=use_context_forward,
use_full_precision_softmax=use_full_precision_softmax,
softmax_in_single_precision=softmax_in_single_precision,
flatten_batch=True,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
@@ -149,12 +152,17 @@ class Qwen2_5_VisionBlock(nn.Module):
)
def forward(
self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
position_embeddings: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.norm1(x)
hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
attn = self.attn(
hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
)
attn = rearrange(attn, "b s ... -> s b ...")
x = x + attn
@@ -443,6 +451,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
)
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = (emb.cos(), emb.sin())
# compute cu_seqlens
cu_seqlens = torch.repeat_interleave(
@@ -457,7 +467,9 @@ class Qwen2_5_VisionTransformer(nn.Module):
cu_seqlens_now = cu_seqlens
else:
cu_seqlens_now = cu_window_seqlens
x = blk(x, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb)
x = blk(
x, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings
)
# adapter
x = self.merger(x)
@@ -522,50 +534,14 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
return num_image_tokens
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
new_input_ids = []
last_idx = 0
image_idx = -1
image_inputs.image_offsets = []
# Get all special token IDs
im_start_id = image_inputs.im_start_id
im_end_id = image_inputs.im_end_id
im_start_id: int = image_inputs.im_start_id
im_end_id: int = image_inputs.im_end_id
# Find all start and end positions for both types
start_indices = [i for i, x in enumerate(input_ids) if x == im_start_id]
end_indices = [i for i, x in enumerate(input_ids) if x == im_end_id]
media_token_pairs = [(im_start_id, im_end_id)]
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
if len(start_indices) != len(end_indices):
return input_ids
# Process each region (both image and slice)
for start_idx, end_idx in zip(start_indices, end_indices):
# Add non-image tokens before this region
new_input_ids.extend(input_ids[last_idx : start_idx + 1])
is_image_start = input_ids[start_idx] == im_start_id
if is_image_start:
image_inputs.image_offsets += [start_idx]
image_idx += 1
num_tokens = end_idx - start_idx - 1 # exclude start and end tokens
# Generate pad_ids
pad_values = [image_inputs.pad_values[image_idx]]
pad_ids = pad_values * ((num_tokens + len(pad_values)) // len(pad_values))
pad_ids = pad_ids[:num_tokens]
# Add pad_ids
new_input_ids.extend(pad_ids)
# Update last_idx to after end token
last_idx = end_idx
# Add remaining tokens after last region
new_input_ids.extend(input_ids[last_idx:])
assert len(input_ids) == len(new_input_ids)
return new_input_ids
return pattern.pad_input_tokens(input_ids, image_inputs)
def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor:
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
@@ -629,7 +605,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
for i, image in enumerate(forward_batch.image_inputs):
if image is None:
if image is None or image.pixel_values is None:
continue
start_idx = extend_start_loc_cpu[i]
prefix_len = prefix_lens_cpu[i]

View File

@@ -42,6 +42,9 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.multi_modality_padding import (
MultiModalityDataPaddingPatternTokenPairs,
)
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
@@ -137,12 +140,12 @@ class Qwen2VisionBlock(nn.Module):
mlp_hidden_dim = int(dim * mlp_ratio)
if attn_implementation == "sdpa":
use_context_forward = False
use_full_precision_softmax = False
softmax_in_single_precision = False
elif attn_implementation == "flash_attention_2":
use_full_precision_softmax = False
softmax_in_single_precision = False
use_context_forward = True
elif attn_implementation == "eager":
use_full_precision_softmax = True
softmax_in_single_precision = True
use_context_forward = False
self.attn = VisionAttention(
@@ -151,7 +154,7 @@ class Qwen2VisionBlock(nn.Module):
projection_size=dim,
use_qkv_parallel=False,
use_context_forward=use_context_forward,
use_full_precision_softmax=use_full_precision_softmax,
softmax_in_single_precision=softmax_in_single_precision,
flatten_batch=True,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
@@ -165,12 +168,17 @@ class Qwen2VisionBlock(nn.Module):
)
def forward(
self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
position_embeddings: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.norm1(x)
hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
attn = self.attn(
hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
)
attn = rearrange(attn, "b s ... -> s b ...")
x = x + attn
@@ -392,7 +400,8 @@ class Qwen2VisionTransformer(nn.Module):
# compute position embedding
rotary_pos_emb = self.rot_pos_emb(grid_thw)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = (emb.cos(), emb.sin())
# compute cu_seqlens
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
@@ -402,7 +411,7 @@ class Qwen2VisionTransformer(nn.Module):
# transformers
x = x.unsqueeze(1)
for blk in self.blocks:
x = blk(x, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
# adapter
x = self.merger(x)
@@ -425,40 +434,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
)
return num_image_tokens
# Use grid_t * grid_w * grid_h to pad tokens for each image
# add replaced padding by unique image hash
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
image_grid_thws = image_inputs.image_grid_thws
pad_values = image_inputs.pad_values
image_indices = [
idx
for idx, token in enumerate(input_ids)
if token == self.config.image_token_id
]
image_inputs.image_offsets = []
input_ids_with_image = []
for image_cnt, _ in enumerate(image_grid_thws):
num_image_tokens = self.calculate_num_image_tokens(
image_grid_thws[image_cnt]
)
if image_cnt == 0:
non_image_tokens = input_ids[: image_indices[image_cnt]]
else:
non_image_tokens = input_ids[
image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
]
input_ids_with_image.extend(non_image_tokens)
image_inputs.image_offsets.append(len(input_ids_with_image))
pad_ids = pad_values * (
(num_image_tokens + len(pad_values)) // len(pad_values)
)
input_ids_with_image.extend(pad_ids[:num_image_tokens])
input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :])
return input_ids_with_image
def __init__(
self,
config: Qwen2VLConfig,
@@ -494,6 +469,17 @@ class Qwen2VLForConditionalGeneration(nn.Module):
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
# Use grid_t * grid_w * grid_h to pad tokens for each image
# add replaced padding by unique image hash
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
# Get all special token IDs
im_start_id: int = image_inputs.im_start_id
im_end_id: int = image_inputs.im_end_id
media_token_pairs = [(im_start_id, im_end_id)]
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
return pattern.pad_input_tokens(input_ids, image_inputs)
def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor:
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_input["image_grid_thw"])
@@ -556,12 +542,12 @@ class Qwen2VLForConditionalGeneration(nn.Module):
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
for i, image in enumerate(forward_batch.image_inputs):
if image is None:
if image is None or image.pixel_values is None:
continue
start_idx = extend_start_loc_cpu[i]
prefix_len = prefix_lens_cpu[i]
pixel_values = image.pixel_values.clone()
pixel_values = torch.tensor(image.pixel_values, device="cuda")
image_grid_thws = torch.tensor(
np.array(image.image_grid_thws), device="cuda"
)
@@ -579,15 +565,13 @@ class Qwen2VLForConditionalGeneration(nn.Module):
image_grid_thws[idx]
)
left_idx = start_idx + (image_offset - prefix_len)
right_idx = (
start_idx + (image_offset - prefix_len) + num_image_tokens
)
left_idx = start_idx + (image_offset - prefix_len + 1)
right_idx = left_idx + num_image_tokens
inputs_embeds[left_idx:right_idx] = image_embeds[
image_embeds_offset : image_embeds_offset + num_image_tokens
]
image_embeds_offset += num_image_tokens
input_ids = None
hidden_states = self.model(
input_ids=input_ids,