refactor: move image processors to separate files (#4229)
This commit is contained in:
@@ -47,6 +47,9 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from sglang.srt.managers.multi_modality_padding import (
|
||||
MultiModalityDataPaddingPatternTokenPairs,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
@@ -121,12 +124,12 @@ class Qwen2_5_VisionBlock(nn.Module):
|
||||
self.norm2 = Qwen2RMSNorm(dim, eps=1e-6)
|
||||
if attn_implementation == "sdpa":
|
||||
use_context_forward = False
|
||||
use_full_precision_softmax = False
|
||||
softmax_in_single_precision = False
|
||||
elif attn_implementation == "flash_attention_2":
|
||||
use_full_precision_softmax = False
|
||||
softmax_in_single_precision = False
|
||||
use_context_forward = True
|
||||
elif attn_implementation == "eager":
|
||||
use_full_precision_softmax = True
|
||||
softmax_in_single_precision = True
|
||||
use_context_forward = False
|
||||
|
||||
self.attn = VisionAttention(
|
||||
@@ -135,7 +138,7 @@ class Qwen2_5_VisionBlock(nn.Module):
|
||||
projection_size=dim,
|
||||
use_qkv_parallel=False,
|
||||
use_context_forward=use_context_forward,
|
||||
use_full_precision_softmax=use_full_precision_softmax,
|
||||
softmax_in_single_precision=softmax_in_single_precision,
|
||||
flatten_batch=True,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
@@ -149,12 +152,17 @@ class Qwen2_5_VisionBlock(nn.Module):
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
position_embeddings: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.norm1(x)
|
||||
hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
|
||||
attn = self.attn(
|
||||
hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
attn = rearrange(attn, "b s ... -> s b ...")
|
||||
x = x + attn
|
||||
@@ -443,6 +451,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
)
|
||||
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
|
||||
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
||||
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||
position_embeddings = (emb.cos(), emb.sin())
|
||||
|
||||
# compute cu_seqlens
|
||||
cu_seqlens = torch.repeat_interleave(
|
||||
@@ -457,7 +467,9 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
cu_seqlens_now = cu_seqlens
|
||||
else:
|
||||
cu_seqlens_now = cu_window_seqlens
|
||||
x = blk(x, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb)
|
||||
x = blk(
|
||||
x, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings
|
||||
)
|
||||
|
||||
# adapter
|
||||
x = self.merger(x)
|
||||
@@ -522,50 +534,14 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
return num_image_tokens
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
||||
new_input_ids = []
|
||||
last_idx = 0
|
||||
image_idx = -1
|
||||
image_inputs.image_offsets = []
|
||||
|
||||
# Get all special token IDs
|
||||
im_start_id = image_inputs.im_start_id
|
||||
im_end_id = image_inputs.im_end_id
|
||||
im_start_id: int = image_inputs.im_start_id
|
||||
im_end_id: int = image_inputs.im_end_id
|
||||
|
||||
# Find all start and end positions for both types
|
||||
start_indices = [i for i, x in enumerate(input_ids) if x == im_start_id]
|
||||
end_indices = [i for i, x in enumerate(input_ids) if x == im_end_id]
|
||||
media_token_pairs = [(im_start_id, im_end_id)]
|
||||
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
||||
|
||||
if len(start_indices) != len(end_indices):
|
||||
return input_ids
|
||||
# Process each region (both image and slice)
|
||||
for start_idx, end_idx in zip(start_indices, end_indices):
|
||||
# Add non-image tokens before this region
|
||||
new_input_ids.extend(input_ids[last_idx : start_idx + 1])
|
||||
|
||||
is_image_start = input_ids[start_idx] == im_start_id
|
||||
|
||||
if is_image_start:
|
||||
image_inputs.image_offsets += [start_idx]
|
||||
image_idx += 1
|
||||
|
||||
num_tokens = end_idx - start_idx - 1 # exclude start and end tokens
|
||||
|
||||
# Generate pad_ids
|
||||
pad_values = [image_inputs.pad_values[image_idx]]
|
||||
|
||||
pad_ids = pad_values * ((num_tokens + len(pad_values)) // len(pad_values))
|
||||
pad_ids = pad_ids[:num_tokens]
|
||||
|
||||
# Add pad_ids
|
||||
new_input_ids.extend(pad_ids)
|
||||
|
||||
# Update last_idx to after end token
|
||||
last_idx = end_idx
|
||||
|
||||
# Add remaining tokens after last region
|
||||
new_input_ids.extend(input_ids[last_idx:])
|
||||
assert len(input_ids) == len(new_input_ids)
|
||||
return new_input_ids
|
||||
return pattern.pad_input_tokens(input_ids, image_inputs)
|
||||
|
||||
def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor:
|
||||
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
||||
@@ -629,7 +605,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
||||
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
|
||||
for i, image in enumerate(forward_batch.image_inputs):
|
||||
if image is None:
|
||||
if image is None or image.pixel_values is None:
|
||||
continue
|
||||
start_idx = extend_start_loc_cpu[i]
|
||||
prefix_len = prefix_lens_cpu[i]
|
||||
|
||||
Reference in New Issue
Block a user