refactor: move image processors to separate files (#4229)

This commit is contained in:
Mick
2025-03-12 03:35:35 +08:00
committed by GitHub
parent 0f2a2e3c19
commit ff2ce0b86f
22 changed files with 1085 additions and 955 deletions

View File

@@ -41,7 +41,6 @@ from torch import nn
from torch.nn.init import trunc_normal_
from transformers import PretrainedConfig
from sglang.srt.distributed import divide, get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import get_act_fn
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.linear import (
@@ -51,6 +50,9 @@ from sglang.srt.layers.linear import (
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.multi_modality_padding import (
MultiModalityDataPaddingPatternTokenPairs,
)
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.utils import set_default_torch_dtype
@@ -186,19 +188,16 @@ class Idefics2EncoderLayer(nn.Module):
) -> None:
super().__init__()
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
num_heads_per_partition = divide(self.num_heads, tp_size)
self.self_attn = VisionAttention(
embed_dim=config.hidden_size,
num_heads=num_heads_per_partition,
num_heads=self.num_heads,
projection_size=config.intermediate_size,
use_qkv_parallel=True,
quant_config=quant_config,
dropout=config.attention_dropout,
use_context_forward=False,
use_full_precision_softmax=True,
softmax_in_single_precision=True,
flatten_batch=False,
prefix=add_prefix("self_attn", prefix),
)
@@ -708,21 +707,21 @@ class MiniCPMVBaseModel(nn.Module):
self,
input_ids: torch.Tensor,
pad_values: List[int],
im_start_id: torch.Tensor,
im_end_id: torch.Tensor,
slice_start_id: Optional[torch.Tensor] = None,
slice_end_id: Optional[torch.Tensor] = None,
im_start_id: int,
im_end_id: int,
slice_start_id: Optional[int] = None,
slice_end_id: Optional[int] = None,
) -> torch.Tensor:
"""
Returns a tensor indicating the bounds (start and end token ids) of the images
"""
# All the images in the batch should share the same special image
# bound token ids.
start_cond = input_ids == im_start_id[0]
end_cond = input_ids == im_end_id[0]
start_cond = input_ids == im_start_id
end_cond = input_ids == im_end_id
if slice_start_id is not None:
start_cond |= input_ids == slice_start_id[0]
end_cond |= input_ids == slice_end_id[0]
start_cond |= input_ids == slice_start_id
end_cond |= input_ids == slice_end_id
(image_start_tokens,) = torch.where(start_cond)
image_start_tokens += 1
@@ -733,6 +732,8 @@ class MiniCPMVBaseModel(nn.Module):
if (
len(image_start_tokens) + 1 == len(image_end_tokens)
and input_ids[0] in pad_values
and len(image_start_tokens) != 0
and len(image_end_tokens) != 0
and image_end_tokens[0] < image_start_tokens[0]
):
image_start_tokens = torch.cat(
@@ -897,9 +898,12 @@ class MiniCPMVBaseModel(nn.Module):
forward_batch: ForwardBatch,
**kwargs: Any,
) -> torch.Tensor:
if forward_batch.image_inputs is not None and forward_batch.image_inputs != [
None
]:
if (
forward_batch.image_inputs is not None
and len(forward_batch.image_inputs) > 0
and forward_batch.image_inputs[0] is not None
):
# TODO: bath
kwargs.update(
{
"pixel_values": (
@@ -1135,81 +1139,16 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
return self.resampler(vision_embedding, tgt_sizes)
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
if not isinstance(image_inputs.im_start_id, list) or not isinstance(
image_inputs.im_end_id, list
):
return input_ids
new_input_ids = []
last_idx = 0
image_idx = -1
image_inputs.image_offsets = []
# Get all special token IDs
im_start_id = (
image_inputs.im_start_id[0].item()
if isinstance(image_inputs.im_start_id[0], torch.Tensor)
else image_inputs.im_start_id[0]
)
im_end_id = (
image_inputs.im_end_id[0].item()
if isinstance(image_inputs.im_end_id[0], torch.Tensor)
else image_inputs.im_end_id[0]
)
slice_start_id = (
image_inputs.slice_start_id[0].item()
if isinstance(image_inputs.slice_start_id[0], torch.Tensor)
else image_inputs.slice_start_id[0]
)
slice_end_id = (
image_inputs.slice_end_id[0].item()
if isinstance(image_inputs.slice_end_id[0], torch.Tensor)
else image_inputs.slice_end_id[0]
)
im_start_id: int = image_inputs.im_start_id
im_end_id: int = image_inputs.im_end_id
slice_start_id: int = image_inputs.slice_start_id
slice_end_id: int = image_inputs.slice_end_id
# Find all start and end positions for both types
start_indices = [
i
for i, x in enumerate(input_ids)
if x == im_start_id or x == slice_start_id
]
end_indices = [
i for i, x in enumerate(input_ids) if x == im_end_id or x == slice_end_id
]
media_token_pairs = [(im_start_id, im_end_id), (slice_start_id, slice_end_id)]
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
if len(start_indices) != len(end_indices):
return input_ids
# Process each region (both image and slice)
for start_idx, end_idx in zip(start_indices, end_indices):
# Add non-image tokens before this region
new_input_ids.extend(
input_ids[last_idx : start_idx + 1]
) # include start token
is_image_start = input_ids[start_idx] == im_start_id
if is_image_start:
image_inputs.image_offsets += [start_idx]
image_idx += 1
num_tokens = end_idx - start_idx - 1 # exclude start and end tokens
# Generate pad_ids
pad_values = [image_inputs.pad_values[image_idx]]
pad_ids = pad_values * ((num_tokens + len(pad_values)) // len(pad_values))
pad_ids = pad_ids[:num_tokens]
# Add pad_ids
new_input_ids.extend(pad_ids)
# Update last_idx to after end token
last_idx = end_idx
# Add remaining tokens after last region
new_input_ids.extend(input_ids[last_idx:])
assert len(input_ids) == len(new_input_ids)
return new_input_ids
return pattern.pad_input_tokens(input_ids, image_inputs)
_SUPPORT_VERSION = {(2, 6): MiniCPMV2_6}