refactor: move image processors to separate files (#4229)
This commit is contained in:
@@ -41,7 +41,6 @@ from torch import nn
|
||||
from torch.nn.init import trunc_normal_
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from sglang.srt.layers.activation import get_act_fn
|
||||
from sglang.srt.layers.attention.vision import VisionAttention
|
||||
from sglang.srt.layers.linear import (
|
||||
@@ -51,6 +50,9 @@ from sglang.srt.layers.linear import (
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.managers.multi_modality_padding import (
|
||||
MultiModalityDataPaddingPatternTokenPairs,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
||||
@@ -186,19 +188,16 @@ class Idefics2EncoderLayer(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
self.num_heads = config.num_attention_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
num_heads_per_partition = divide(self.num_heads, tp_size)
|
||||
self.self_attn = VisionAttention(
|
||||
embed_dim=config.hidden_size,
|
||||
num_heads=num_heads_per_partition,
|
||||
num_heads=self.num_heads,
|
||||
projection_size=config.intermediate_size,
|
||||
use_qkv_parallel=True,
|
||||
quant_config=quant_config,
|
||||
dropout=config.attention_dropout,
|
||||
use_context_forward=False,
|
||||
use_full_precision_softmax=True,
|
||||
softmax_in_single_precision=True,
|
||||
flatten_batch=False,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
@@ -708,21 +707,21 @@ class MiniCPMVBaseModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
pad_values: List[int],
|
||||
im_start_id: torch.Tensor,
|
||||
im_end_id: torch.Tensor,
|
||||
slice_start_id: Optional[torch.Tensor] = None,
|
||||
slice_end_id: Optional[torch.Tensor] = None,
|
||||
im_start_id: int,
|
||||
im_end_id: int,
|
||||
slice_start_id: Optional[int] = None,
|
||||
slice_end_id: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Returns a tensor indicating the bounds (start and end token ids) of the images
|
||||
"""
|
||||
# All the images in the batch should share the same special image
|
||||
# bound token ids.
|
||||
start_cond = input_ids == im_start_id[0]
|
||||
end_cond = input_ids == im_end_id[0]
|
||||
start_cond = input_ids == im_start_id
|
||||
end_cond = input_ids == im_end_id
|
||||
if slice_start_id is not None:
|
||||
start_cond |= input_ids == slice_start_id[0]
|
||||
end_cond |= input_ids == slice_end_id[0]
|
||||
start_cond |= input_ids == slice_start_id
|
||||
end_cond |= input_ids == slice_end_id
|
||||
|
||||
(image_start_tokens,) = torch.where(start_cond)
|
||||
image_start_tokens += 1
|
||||
@@ -733,6 +732,8 @@ class MiniCPMVBaseModel(nn.Module):
|
||||
if (
|
||||
len(image_start_tokens) + 1 == len(image_end_tokens)
|
||||
and input_ids[0] in pad_values
|
||||
and len(image_start_tokens) != 0
|
||||
and len(image_end_tokens) != 0
|
||||
and image_end_tokens[0] < image_start_tokens[0]
|
||||
):
|
||||
image_start_tokens = torch.cat(
|
||||
@@ -897,9 +898,12 @@ class MiniCPMVBaseModel(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
**kwargs: Any,
|
||||
) -> torch.Tensor:
|
||||
if forward_batch.image_inputs is not None and forward_batch.image_inputs != [
|
||||
None
|
||||
]:
|
||||
if (
|
||||
forward_batch.image_inputs is not None
|
||||
and len(forward_batch.image_inputs) > 0
|
||||
and forward_batch.image_inputs[0] is not None
|
||||
):
|
||||
# TODO: bath
|
||||
kwargs.update(
|
||||
{
|
||||
"pixel_values": (
|
||||
@@ -1135,81 +1139,16 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
|
||||
return self.resampler(vision_embedding, tgt_sizes)
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
||||
if not isinstance(image_inputs.im_start_id, list) or not isinstance(
|
||||
image_inputs.im_end_id, list
|
||||
):
|
||||
return input_ids
|
||||
|
||||
new_input_ids = []
|
||||
last_idx = 0
|
||||
image_idx = -1
|
||||
image_inputs.image_offsets = []
|
||||
|
||||
# Get all special token IDs
|
||||
im_start_id = (
|
||||
image_inputs.im_start_id[0].item()
|
||||
if isinstance(image_inputs.im_start_id[0], torch.Tensor)
|
||||
else image_inputs.im_start_id[0]
|
||||
)
|
||||
im_end_id = (
|
||||
image_inputs.im_end_id[0].item()
|
||||
if isinstance(image_inputs.im_end_id[0], torch.Tensor)
|
||||
else image_inputs.im_end_id[0]
|
||||
)
|
||||
slice_start_id = (
|
||||
image_inputs.slice_start_id[0].item()
|
||||
if isinstance(image_inputs.slice_start_id[0], torch.Tensor)
|
||||
else image_inputs.slice_start_id[0]
|
||||
)
|
||||
slice_end_id = (
|
||||
image_inputs.slice_end_id[0].item()
|
||||
if isinstance(image_inputs.slice_end_id[0], torch.Tensor)
|
||||
else image_inputs.slice_end_id[0]
|
||||
)
|
||||
im_start_id: int = image_inputs.im_start_id
|
||||
im_end_id: int = image_inputs.im_end_id
|
||||
slice_start_id: int = image_inputs.slice_start_id
|
||||
slice_end_id: int = image_inputs.slice_end_id
|
||||
|
||||
# Find all start and end positions for both types
|
||||
start_indices = [
|
||||
i
|
||||
for i, x in enumerate(input_ids)
|
||||
if x == im_start_id or x == slice_start_id
|
||||
]
|
||||
end_indices = [
|
||||
i for i, x in enumerate(input_ids) if x == im_end_id or x == slice_end_id
|
||||
]
|
||||
media_token_pairs = [(im_start_id, im_end_id), (slice_start_id, slice_end_id)]
|
||||
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
||||
|
||||
if len(start_indices) != len(end_indices):
|
||||
return input_ids
|
||||
# Process each region (both image and slice)
|
||||
for start_idx, end_idx in zip(start_indices, end_indices):
|
||||
# Add non-image tokens before this region
|
||||
new_input_ids.extend(
|
||||
input_ids[last_idx : start_idx + 1]
|
||||
) # include start token
|
||||
|
||||
is_image_start = input_ids[start_idx] == im_start_id
|
||||
|
||||
if is_image_start:
|
||||
image_inputs.image_offsets += [start_idx]
|
||||
image_idx += 1
|
||||
|
||||
num_tokens = end_idx - start_idx - 1 # exclude start and end tokens
|
||||
|
||||
# Generate pad_ids
|
||||
pad_values = [image_inputs.pad_values[image_idx]]
|
||||
|
||||
pad_ids = pad_values * ((num_tokens + len(pad_values)) // len(pad_values))
|
||||
pad_ids = pad_ids[:num_tokens]
|
||||
|
||||
# Add pad_ids
|
||||
new_input_ids.extend(pad_ids)
|
||||
|
||||
# Update last_idx to after end token
|
||||
last_idx = end_idx
|
||||
|
||||
# Add remaining tokens after last region
|
||||
new_input_ids.extend(input_ids[last_idx:])
|
||||
assert len(input_ids) == len(new_input_ids)
|
||||
return new_input_ids
|
||||
return pattern.pad_input_tokens(input_ids, image_inputs)
|
||||
|
||||
|
||||
_SUPPORT_VERSION = {(2, 6): MiniCPMV2_6}
|
||||
|
||||
Reference in New Issue
Block a user