refactor: bug fixes and refactor for vlm (#4661)
This commit is contained in:
@@ -26,7 +26,6 @@ import logging
|
||||
from functools import lru_cache, partial
|
||||
from typing import Iterable, List, Optional, Tuple, Type
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@@ -54,14 +53,15 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from sglang.srt.managers.multi_modality_padding import (
|
||||
from sglang.srt.managers.mm_utils import (
|
||||
MultiModalityDataPaddingPatternTokenPairs,
|
||||
general_mm_embed_routine,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.qwen2 import Qwen2Model
|
||||
from sglang.srt.models.qwen2_vl import Qwen2VLImageInputs, Qwen2VLVideoInputs
|
||||
from sglang.srt.models.qwen2_vl import Qwen2VLVideoInputs
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -326,13 +326,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
)
|
||||
|
||||
def get_window_index(self, grid_thw):
|
||||
window_index: list = []
|
||||
cu_window_seqlens: list = [0]
|
||||
window_index_id = 0
|
||||
vit_merger_window_size = (
|
||||
self.window_size // self.spatial_merge_size // self.patch_size
|
||||
)
|
||||
|
||||
window_index: list = []
|
||||
for grid_t, grid_h, grid_w in grid_thw:
|
||||
llm_grid_h, llm_grid_w = (
|
||||
grid_h // self.spatial_merge_size,
|
||||
@@ -369,7 +368,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
|
||||
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
|
||||
window_index = torch.cat(window_index, dim=0)
|
||||
|
||||
return window_index, cu_window_seqlens
|
||||
|
||||
@property
|
||||
@@ -382,8 +380,10 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
|
||||
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
|
||||
pos_ids = []
|
||||
for t, h, w in grid_thw:
|
||||
for i in range(grid_thw.size(0)):
|
||||
t, h, w = grid_thw[i].tolist()
|
||||
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
||||
|
||||
hpos_ids = hpos_ids.reshape(
|
||||
h // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
@@ -402,6 +402,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
)
|
||||
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
||||
wpos_ids = wpos_ids.flatten()
|
||||
|
||||
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
||||
pos_ids = torch.cat(pos_ids, dim=0)
|
||||
max_grid_size = grid_thw[:, 1:].max()
|
||||
@@ -443,9 +444,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
position_embeddings = (emb.cos(), emb.sin())
|
||||
|
||||
# compute cu_seqlens
|
||||
cu_seqlens = torch.repeat_interleave(
|
||||
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||
).cumsum(dim=0, dtype=torch.int32)
|
||||
cu_seqlens = torch.cat(
|
||||
[
|
||||
torch.tensor([0], device=grid_thw.device),
|
||||
(grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).cumsum(dim=0),
|
||||
]
|
||||
)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
||||
|
||||
# transformers
|
||||
@@ -509,18 +513,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
|
||||
processor = cached_get_processor(self.config._name_or_path)
|
||||
grid_t, grid_h, grid_w = image_grid_thw
|
||||
num_image_tokens = (
|
||||
grid_t
|
||||
* grid_h
|
||||
* grid_w
|
||||
// processor.image_processor.merge_size
|
||||
// processor.image_processor.merge_size
|
||||
)
|
||||
return num_image_tokens
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
||||
# Get all special token IDs
|
||||
im_start_id: int = image_inputs.im_start_id
|
||||
@@ -531,9 +523,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
|
||||
return pattern.pad_input_tokens(input_ids, image_inputs)
|
||||
|
||||
def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor:
|
||||
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_input["image_grid_thw"])
|
||||
def get_image_feature(self, image_input: ImageInputs) -> torch.Tensor:
|
||||
pixel_values = image_input.pixel_values.type(self.visual.dtype)
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_input.image_grid_thws)
|
||||
return image_embeds
|
||||
|
||||
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
|
||||
@@ -543,6 +535,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
)
|
||||
return video_embeds
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -565,86 +560,26 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
||||
positions = forward_batch.mrope_positions
|
||||
|
||||
image_inputs = None
|
||||
if forward_batch.image_inputs is not None:
|
||||
image_inputs = [
|
||||
img for img in forward_batch.image_inputs if img is not None
|
||||
]
|
||||
|
||||
if (
|
||||
if not (
|
||||
forward_batch.forward_mode.is_decode()
|
||||
or image_inputs is None
|
||||
or len(image_inputs) == 0
|
||||
or not forward_batch.contains_image_inputs()
|
||||
):
|
||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||
else:
|
||||
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
||||
assert positions.ndim == 2 and positions.size(0) == 3, (
|
||||
"multimodal section rotary embedding requires "
|
||||
f"(3, seq_len) positions, but got {positions.size()}"
|
||||
)
|
||||
|
||||
# Clamp input ids. This is because the input_ids for the image tokens are
|
||||
# filled with the hash values of the image for the prefix matching in the radix attention.
|
||||
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
|
||||
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
|
||||
# [B, s, hidden_size]
|
||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
||||
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
|
||||
for i, image in enumerate(forward_batch.image_inputs):
|
||||
if image is None or image.pixel_values is None:
|
||||
continue
|
||||
start_idx = extend_start_loc_cpu[i]
|
||||
prefix_len = prefix_lens_cpu[i]
|
||||
|
||||
pixel_values = image.pixel_values.to(device="cuda")
|
||||
|
||||
image_grid_thws = torch.tensor(
|
||||
np.array(image.image_grid_thws), device="cuda"
|
||||
)
|
||||
image_offsets = image.image_offsets
|
||||
image_input = Qwen2VLImageInputs(
|
||||
pixel_values=pixel_values, image_grid_thw=image_grid_thws
|
||||
)
|
||||
image_embeds = self._process_image_input(image_input)
|
||||
|
||||
image_embeds_offset = 0
|
||||
for idx, image_offset in enumerate(image_offsets):
|
||||
if image_offset < prefix_len:
|
||||
continue
|
||||
num_image_tokens = self.calculate_num_image_tokens(
|
||||
image_grid_thws[idx]
|
||||
)
|
||||
|
||||
left_idx = start_idx + (image_offset - prefix_len)
|
||||
right_idx = left_idx + num_image_tokens
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
hidden_size = image_embeds.shape[-1]
|
||||
|
||||
if hidden_size % tp_size != 0:
|
||||
padding_size = tp_size - (hidden_size % tp_size)
|
||||
image_embeds = F.pad(image_embeds, (0, padding_size))
|
||||
inputs_embeds = F.pad(inputs_embeds, (0, padding_size))
|
||||
|
||||
hidden_chunk_size = image_embeds.shape[-1] // tp_size
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
start_dim = rank * hidden_chunk_size
|
||||
end_dim = (rank + 1) * hidden_chunk_size
|
||||
inputs_embeds[left_idx:right_idx, ..., start_dim:end_dim] = (
|
||||
image_embeds[
|
||||
image_embeds_offset : image_embeds_offset
|
||||
+ num_image_tokens,
|
||||
...,
|
||||
start_dim:end_dim,
|
||||
]
|
||||
)
|
||||
image_embeds_offset += num_image_tokens
|
||||
inputs_embeds = general_mm_embed_routine(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
forward_batch=forward_batch,
|
||||
embed_tokens=self.get_input_embeddings(),
|
||||
image_embedding_func=self.get_image_feature,
|
||||
)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
input_ids=None,
|
||||
positions=positions,
|
||||
forward_batch=forward_batch,
|
||||
input_embeds=inputs_embeds,
|
||||
|
||||
Reference in New Issue
Block a user