refactor: bug fixes and refactor for vlm (#4661)

This commit is contained in:
Mick
2025-03-23 13:48:49 +08:00
committed by GitHub
parent ca75741e86
commit 11577cedb7
31 changed files with 770 additions and 735 deletions

View File

@@ -1,34 +1,16 @@
import collections
import itertools
import math
import warnings
from enum import Enum
from functools import partial
from typing import Callable, Iterable, List, Optional, Tuple, Type, Union
from typing import Iterable, List, Optional, Tuple
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import nn
from sglang.srt.configs import DeepseekVL2Config
from sglang.srt.configs.deepseekvl2 import (
DeepseekVL2Config,
DeepseekVL2MlpProjectorConfig,
)
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
ColumnParallelLinear,
LinearBase,
ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
@@ -233,11 +215,11 @@ class DeepseekVL2ForCausalLM(nn.Module):
forward_batch: ForwardBatch,
**kwargs: object,
):
input_embeds = self.language_model.model.embed_tokens(input_ids)
if forward_batch.forward_mode.is_extend() and forward_batch.image_inputs != [
None
]:
if (
forward_batch.forward_mode.is_extend()
and forward_batch.contains_image_inputs()
):
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
extend_seq_lens_cpu = forward_batch.extend_seq_lens.cpu().numpy()
for idx, image in enumerate(forward_batch.image_inputs):
@@ -245,17 +227,11 @@ class DeepseekVL2ForCausalLM(nn.Module):
continue
start_idx = extend_start_loc_cpu[idx]
end_idx = start_idx + extend_seq_lens_cpu[idx]
pixel_values = image.pixel_values.to(
device="cuda", dtype=torch.bfloat16
)
image_seq_mask = image.image_seq_mask.to(device="cuda")
image_spatial_crop = image.image_spatial_crop
input_embeds[start_idx:end_idx] = self.prepare_inputs_embeds(
pixel_values,
image_seq_mask,
image_spatial_crop,
input_embeds[start_idx:end_idx],
)
images_emb_mask = image.images_emb_mask.to(device="cuda")
image_features = self.get_image_feature(image)
input_embeds[start_idx:end_idx] = input_embeds[
start_idx:end_idx
].masked_scatter(images_emb_mask.unsqueeze(-1), image_features)
outputs = self.language_model.forward(
input_ids=input_ids,
@@ -289,20 +265,17 @@ class DeepseekVL2ForCausalLM(nn.Module):
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
return input_ids
def prepare_inputs_embeds(
self,
pixel_values,
images_seq_mask,
images_spatial_crop,
input_embeds,
):
def get_image_feature(self, image_input: ImageInputs):
pixel_values = image_input.pixel_values.type(
next(self.vision.parameters()).dtype
).to(device=next(self.vision.parameters()).device)
image_feature = self.vision.forward_features(pixel_values)
images_embeds = self.projector(image_feature)
_, hw, n_dim = images_embeds.shape
h = w = int(hw**0.5)
tile_index = 0
images_in_this_batch = []
images_spatial_crop = image_input.image_spatial_crop
for jdx in range(images_spatial_crop.shape[1]):
num_width_tiles, num_height_tiles = images_spatial_crop[0, jdx]
if num_width_tiles == 0 or num_height_tiles == 0:
@@ -379,13 +352,7 @@ class DeepseekVL2ForCausalLM(nn.Module):
images_in_this_batch.append(global_local_features)
if len(images_in_this_batch) > 0:
images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
input_embeds.masked_scatter_(
images_seq_mask.unsqueeze(-1), images_in_this_batch
)
return input_embeds
return torch.cat(images_in_this_batch, dim=0)
EntryClass = DeepseekVL2ForCausalLM