refactor: multimodal data (#4754)

This commit is contained in:
Mick
2025-04-01 00:57:51 +08:00
committed by GitHub
parent c7457191a0
commit 5cb552b1d4
36 changed files with 989 additions and 1138 deletions

View File

@@ -11,7 +11,11 @@ from sglang.srt.configs.deepseekvl2 import (
)
from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.schedule_batch import MultimodalInputs
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternImageTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM
@@ -150,7 +154,6 @@ class DeepseekVL2MlpProjector(nn.Module):
return x
# todo
class DeepseekVL2ForCausalLM(nn.Module):
def __init__(
@@ -215,32 +218,15 @@ class DeepseekVL2ForCausalLM(nn.Module):
forward_batch: ForwardBatch,
**kwargs: object,
):
input_embeds = self.language_model.model.embed_tokens(input_ids)
if (
forward_batch.forward_mode.is_extend()
and forward_batch.contains_image_inputs()
):
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
extend_seq_lens_cpu = forward_batch.extend_seq_lens.cpu().numpy()
for idx, image in enumerate(forward_batch.mm_inputs):
if image is None:
continue
start_idx = extend_start_loc_cpu[idx]
end_idx = start_idx + extend_seq_lens_cpu[idx]
images_emb_mask = image.images_emb_mask.to(device="cuda")
image_features = self.get_image_feature(image)
input_embeds[start_idx:end_idx] = input_embeds[
start_idx:end_idx
].masked_scatter(images_emb_mask.unsqueeze(-1), image_features)
outputs = self.language_model.forward(
hs = general_mm_embed_routine(
input_ids=input_ids,
positions=positions,
forward_batch=forward_batch,
input_embeds=input_embeds,
image_data_embedding_func=self.get_image_feature,
language_model=self.language_model,
)
return outputs
return hs
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
@@ -263,94 +249,109 @@ class DeepseekVL2ForCausalLM(nn.Module):
weights_loader(param, loaded_weight)
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
return input_ids
helper = MultiModalityDataPaddingPatternImageTokens(
image_token_id=image_inputs.im_token_id
)
return helper.pad_input_tokens(input_ids, image_inputs)
def get_image_feature(self, image_input: MultimodalInputs):
pixel_values = image_input.pixel_values.type(
next(self.vision.parameters()).dtype
).to(device=next(self.vision.parameters()).device)
image_feature = self.vision.forward_features(pixel_values)
images_embeds = self.projector(image_feature)
_, hw, n_dim = images_embeds.shape
h = w = int(hw**0.5)
tile_index = 0
def get_image_feature(self, items: List[MultimodalDataItem]):
images_spatial_crop = torch.cat(
[item.image_spatial_crop for item in items], dim=0
)
assert images_spatial_crop.dim() == 3
# TODO: can it be batched ?
images_in_this_batch = []
images_spatial_crop = image_input.image_spatial_crop
for jdx in range(images_spatial_crop.shape[1]):
num_width_tiles, num_height_tiles = images_spatial_crop[0, jdx]
if num_width_tiles == 0 or num_height_tiles == 0:
break
num_tiles_in_image = num_width_tiles * num_height_tiles
# [hw, D]
global_features = images_embeds[tile_index]
# [num_height_tiles * num_width_tiles, hw, D]
local_features = images_embeds[
tile_index + 1 : tile_index + 1 + num_tiles_in_image
]
tile_index += num_tiles_in_image + 1
# format global and local features
# ----------------- global view add newline -----------------
# [hw, D] -> [h, w, D]
global_features = global_features.view(h, w, n_dim)
# [D] -> [h, 1, D]
new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h)
# cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
global_features = torch.cat([global_features, new_lines_in_global], dim=1)
# [h, w + 1, D] -> [h * (w + 1), D]
global_features = global_features.view(-1, n_dim)
# ----------------- local view add newline -----------------
# [num_height_tiles * num_width_tiles, h * w, D] ->
# [num_height_tiles * h, num_width_tiles * w, D]
local_features = rearrange(
local_features,
"(th tw) (h w) d -> (th h) (tw w) d",
th=num_height_tiles,
tw=num_width_tiles,
h=h,
w=w,
)
# [D] -> [num_height_tiles * h, 1, D]
new_lines_in_local = repeat(
self.image_newline,
"d -> (th h) 1 d",
th=num_height_tiles,
h=h,
)
# [num_height_tiles * h, num_width_tiles * w + 1, D]
local_features = torch.cat([local_features, new_lines_in_local], dim=1)
# [num_height_tiles * h, num_width_tiles * w + 1, D]
# --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
local_features = local_features.view(-1, n_dim)
# merge global and local tiles
if self.global_view_pos == "head":
global_local_features = torch.cat(
[
global_features,
self.view_seperator[None, :],
local_features,
]
for item in items:
assert item.pixel_values.dim() == 4
image_feature = self.vision.forward_features(
item.pixel_values.type(next(self.vision.parameters()).dtype).to(
device=next(self.vision.parameters()).device
)
else:
global_local_features = torch.cat(
[
local_features,
self.view_seperator[None, :],
global_features,
]
)
images_embeds = self.projector(image_feature)
_, hw, n_dim = images_embeds.shape
h = w = int(hw**0.5)
tile_index = 0
for jdx in range(item.image_spatial_crop.shape[1]):
num_width_tiles, num_height_tiles = item.image_spatial_crop[0, jdx]
if num_width_tiles == 0 or num_height_tiles == 0:
break
num_tiles_in_image = num_width_tiles * num_height_tiles
# [hw, D]
global_features = images_embeds[tile_index]
# [num_height_tiles * num_width_tiles, hw, D]
local_features = images_embeds[
tile_index + 1 : tile_index + 1 + num_tiles_in_image
]
tile_index += num_tiles_in_image + 1
# format global and local features
# ----------------- global view add newline -----------------
# [hw, D] -> [h, w, D]
global_features = global_features.view(h, w, n_dim)
# [D] -> [h, 1, D]
new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h)
# cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
global_features = torch.cat(
[global_features, new_lines_in_global], dim=1
)
images_in_this_batch.append(global_local_features)
# [h, w + 1, D] -> [h * (w + 1), D]
global_features = global_features.view(-1, n_dim)
# ----------------- local view add newline -----------------
# [num_height_tiles * num_width_tiles, h * w, D] ->
# [num_height_tiles * h, num_width_tiles * w, D]
local_features = rearrange(
local_features,
"(th tw) (h w) d -> (th h) (tw w) d",
th=num_height_tiles,
tw=num_width_tiles,
h=h,
w=w,
)
# [D] -> [num_height_tiles * h, 1, D]
new_lines_in_local = repeat(
self.image_newline,
"d -> (th h) 1 d",
th=num_height_tiles,
h=h,
)
# [num_height_tiles * h, num_width_tiles * w + 1, D]
local_features = torch.cat([local_features, new_lines_in_local], dim=1)
# [num_height_tiles * h, num_width_tiles * w + 1, D]
# --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
local_features = local_features.view(-1, n_dim)
# merge global and local tiles
if self.global_view_pos == "head":
global_local_features = torch.cat(
[
global_features,
self.view_seperator[None, :],
local_features,
]
)
else:
global_local_features = torch.cat(
[
local_features,
self.view_seperator[None, :],
global_features,
]
)
images_in_this_batch.append(global_local_features)
return torch.cat(images_in_this_batch, dim=0)