refactor: multimodal data (#4754)
This commit is contained in:
@@ -11,7 +11,11 @@ from sglang.srt.configs.deepseekvl2 import (
|
||||
)
|
||||
from sglang.srt.layers.linear import ReplicatedLinear
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
||||
from sglang.srt.managers.mm_utils import (
|
||||
MultiModalityDataPaddingPatternImageTokens,
|
||||
general_mm_embed_routine,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM
|
||||
@@ -150,7 +154,6 @@ class DeepseekVL2MlpProjector(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
# todo
|
||||
class DeepseekVL2ForCausalLM(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@@ -215,32 +218,15 @@ class DeepseekVL2ForCausalLM(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
**kwargs: object,
|
||||
):
|
||||
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||
if (
|
||||
forward_batch.forward_mode.is_extend()
|
||||
and forward_batch.contains_image_inputs()
|
||||
):
|
||||
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
||||
extend_seq_lens_cpu = forward_batch.extend_seq_lens.cpu().numpy()
|
||||
for idx, image in enumerate(forward_batch.mm_inputs):
|
||||
if image is None:
|
||||
continue
|
||||
start_idx = extend_start_loc_cpu[idx]
|
||||
end_idx = start_idx + extend_seq_lens_cpu[idx]
|
||||
images_emb_mask = image.images_emb_mask.to(device="cuda")
|
||||
image_features = self.get_image_feature(image)
|
||||
input_embeds[start_idx:end_idx] = input_embeds[
|
||||
start_idx:end_idx
|
||||
].masked_scatter(images_emb_mask.unsqueeze(-1), image_features)
|
||||
|
||||
outputs = self.language_model.forward(
|
||||
hs = general_mm_embed_routine(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
forward_batch=forward_batch,
|
||||
input_embeds=input_embeds,
|
||||
image_data_embedding_func=self.get_image_feature,
|
||||
language_model=self.language_model,
|
||||
)
|
||||
|
||||
return outputs
|
||||
return hs
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
@@ -263,94 +249,109 @@ class DeepseekVL2ForCausalLM(nn.Module):
|
||||
weights_loader(param, loaded_weight)
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
||||
return input_ids
|
||||
helper = MultiModalityDataPaddingPatternImageTokens(
|
||||
image_token_id=image_inputs.im_token_id
|
||||
)
|
||||
return helper.pad_input_tokens(input_ids, image_inputs)
|
||||
|
||||
def get_image_feature(self, image_input: MultimodalInputs):
|
||||
pixel_values = image_input.pixel_values.type(
|
||||
next(self.vision.parameters()).dtype
|
||||
).to(device=next(self.vision.parameters()).device)
|
||||
image_feature = self.vision.forward_features(pixel_values)
|
||||
images_embeds = self.projector(image_feature)
|
||||
_, hw, n_dim = images_embeds.shape
|
||||
h = w = int(hw**0.5)
|
||||
tile_index = 0
|
||||
def get_image_feature(self, items: List[MultimodalDataItem]):
|
||||
|
||||
images_spatial_crop = torch.cat(
|
||||
[item.image_spatial_crop for item in items], dim=0
|
||||
)
|
||||
|
||||
assert images_spatial_crop.dim() == 3
|
||||
|
||||
# TODO: can it be batched ?
|
||||
images_in_this_batch = []
|
||||
images_spatial_crop = image_input.image_spatial_crop
|
||||
for jdx in range(images_spatial_crop.shape[1]):
|
||||
num_width_tiles, num_height_tiles = images_spatial_crop[0, jdx]
|
||||
if num_width_tiles == 0 or num_height_tiles == 0:
|
||||
break
|
||||
num_tiles_in_image = num_width_tiles * num_height_tiles
|
||||
|
||||
# [hw, D]
|
||||
global_features = images_embeds[tile_index]
|
||||
|
||||
# [num_height_tiles * num_width_tiles, hw, D]
|
||||
local_features = images_embeds[
|
||||
tile_index + 1 : tile_index + 1 + num_tiles_in_image
|
||||
]
|
||||
tile_index += num_tiles_in_image + 1
|
||||
|
||||
# format global and local features
|
||||
# ----------------- global view add newline -----------------
|
||||
# [hw, D] -> [h, w, D]
|
||||
global_features = global_features.view(h, w, n_dim)
|
||||
|
||||
# [D] -> [h, 1, D]
|
||||
new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h)
|
||||
|
||||
# cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
|
||||
global_features = torch.cat([global_features, new_lines_in_global], dim=1)
|
||||
|
||||
# [h, w + 1, D] -> [h * (w + 1), D]
|
||||
global_features = global_features.view(-1, n_dim)
|
||||
|
||||
# ----------------- local view add newline -----------------
|
||||
# [num_height_tiles * num_width_tiles, h * w, D] ->
|
||||
# [num_height_tiles * h, num_width_tiles * w, D]
|
||||
local_features = rearrange(
|
||||
local_features,
|
||||
"(th tw) (h w) d -> (th h) (tw w) d",
|
||||
th=num_height_tiles,
|
||||
tw=num_width_tiles,
|
||||
h=h,
|
||||
w=w,
|
||||
)
|
||||
|
||||
# [D] -> [num_height_tiles * h, 1, D]
|
||||
new_lines_in_local = repeat(
|
||||
self.image_newline,
|
||||
"d -> (th h) 1 d",
|
||||
th=num_height_tiles,
|
||||
h=h,
|
||||
)
|
||||
|
||||
# [num_height_tiles * h, num_width_tiles * w + 1, D]
|
||||
local_features = torch.cat([local_features, new_lines_in_local], dim=1)
|
||||
|
||||
# [num_height_tiles * h, num_width_tiles * w + 1, D]
|
||||
# --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
|
||||
local_features = local_features.view(-1, n_dim)
|
||||
|
||||
# merge global and local tiles
|
||||
if self.global_view_pos == "head":
|
||||
global_local_features = torch.cat(
|
||||
[
|
||||
global_features,
|
||||
self.view_seperator[None, :],
|
||||
local_features,
|
||||
]
|
||||
for item in items:
|
||||
assert item.pixel_values.dim() == 4
|
||||
image_feature = self.vision.forward_features(
|
||||
item.pixel_values.type(next(self.vision.parameters()).dtype).to(
|
||||
device=next(self.vision.parameters()).device
|
||||
)
|
||||
else:
|
||||
global_local_features = torch.cat(
|
||||
[
|
||||
local_features,
|
||||
self.view_seperator[None, :],
|
||||
global_features,
|
||||
]
|
||||
)
|
||||
images_embeds = self.projector(image_feature)
|
||||
_, hw, n_dim = images_embeds.shape
|
||||
h = w = int(hw**0.5)
|
||||
tile_index = 0
|
||||
for jdx in range(item.image_spatial_crop.shape[1]):
|
||||
num_width_tiles, num_height_tiles = item.image_spatial_crop[0, jdx]
|
||||
if num_width_tiles == 0 or num_height_tiles == 0:
|
||||
break
|
||||
num_tiles_in_image = num_width_tiles * num_height_tiles
|
||||
|
||||
# [hw, D]
|
||||
global_features = images_embeds[tile_index]
|
||||
|
||||
# [num_height_tiles * num_width_tiles, hw, D]
|
||||
local_features = images_embeds[
|
||||
tile_index + 1 : tile_index + 1 + num_tiles_in_image
|
||||
]
|
||||
tile_index += num_tiles_in_image + 1
|
||||
|
||||
# format global and local features
|
||||
# ----------------- global view add newline -----------------
|
||||
# [hw, D] -> [h, w, D]
|
||||
global_features = global_features.view(h, w, n_dim)
|
||||
|
||||
# [D] -> [h, 1, D]
|
||||
new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h)
|
||||
|
||||
# cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
|
||||
global_features = torch.cat(
|
||||
[global_features, new_lines_in_global], dim=1
|
||||
)
|
||||
|
||||
images_in_this_batch.append(global_local_features)
|
||||
# [h, w + 1, D] -> [h * (w + 1), D]
|
||||
global_features = global_features.view(-1, n_dim)
|
||||
|
||||
# ----------------- local view add newline -----------------
|
||||
# [num_height_tiles * num_width_tiles, h * w, D] ->
|
||||
# [num_height_tiles * h, num_width_tiles * w, D]
|
||||
local_features = rearrange(
|
||||
local_features,
|
||||
"(th tw) (h w) d -> (th h) (tw w) d",
|
||||
th=num_height_tiles,
|
||||
tw=num_width_tiles,
|
||||
h=h,
|
||||
w=w,
|
||||
)
|
||||
|
||||
# [D] -> [num_height_tiles * h, 1, D]
|
||||
new_lines_in_local = repeat(
|
||||
self.image_newline,
|
||||
"d -> (th h) 1 d",
|
||||
th=num_height_tiles,
|
||||
h=h,
|
||||
)
|
||||
|
||||
# [num_height_tiles * h, num_width_tiles * w + 1, D]
|
||||
local_features = torch.cat([local_features, new_lines_in_local], dim=1)
|
||||
|
||||
# [num_height_tiles * h, num_width_tiles * w + 1, D]
|
||||
# --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
|
||||
local_features = local_features.view(-1, n_dim)
|
||||
|
||||
# merge global and local tiles
|
||||
if self.global_view_pos == "head":
|
||||
global_local_features = torch.cat(
|
||||
[
|
||||
global_features,
|
||||
self.view_seperator[None, :],
|
||||
local_features,
|
||||
]
|
||||
)
|
||||
else:
|
||||
global_local_features = torch.cat(
|
||||
[
|
||||
local_features,
|
||||
self.view_seperator[None, :],
|
||||
global_features,
|
||||
]
|
||||
)
|
||||
|
||||
images_in_this_batch.append(global_local_features)
|
||||
|
||||
return torch.cat(images_in_this_batch, dim=0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user