refactor: bug fixes and refactor for vlm (#4661)
This commit is contained in:
@@ -1,34 +1,16 @@
|
||||
import collections
|
||||
import itertools
|
||||
import math
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Callable, Iterable, List, Optional, Tuple, Type, Union
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.configs import DeepseekVL2Config
|
||||
from sglang.srt.configs.deepseekvl2 import (
|
||||
DeepseekVL2Config,
|
||||
DeepseekVL2MlpProjectorConfig,
|
||||
)
|
||||
from sglang.srt.layers.attention.vision import VisionAttention
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
LinearBase,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.linear import ReplicatedLinear
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
@@ -233,11 +215,11 @@ class DeepseekVL2ForCausalLM(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
**kwargs: object,
|
||||
):
|
||||
|
||||
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||
if forward_batch.forward_mode.is_extend() and forward_batch.image_inputs != [
|
||||
None
|
||||
]:
|
||||
if (
|
||||
forward_batch.forward_mode.is_extend()
|
||||
and forward_batch.contains_image_inputs()
|
||||
):
|
||||
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
||||
extend_seq_lens_cpu = forward_batch.extend_seq_lens.cpu().numpy()
|
||||
for idx, image in enumerate(forward_batch.image_inputs):
|
||||
@@ -245,17 +227,11 @@ class DeepseekVL2ForCausalLM(nn.Module):
|
||||
continue
|
||||
start_idx = extend_start_loc_cpu[idx]
|
||||
end_idx = start_idx + extend_seq_lens_cpu[idx]
|
||||
pixel_values = image.pixel_values.to(
|
||||
device="cuda", dtype=torch.bfloat16
|
||||
)
|
||||
image_seq_mask = image.image_seq_mask.to(device="cuda")
|
||||
image_spatial_crop = image.image_spatial_crop
|
||||
input_embeds[start_idx:end_idx] = self.prepare_inputs_embeds(
|
||||
pixel_values,
|
||||
image_seq_mask,
|
||||
image_spatial_crop,
|
||||
input_embeds[start_idx:end_idx],
|
||||
)
|
||||
images_emb_mask = image.images_emb_mask.to(device="cuda")
|
||||
image_features = self.get_image_feature(image)
|
||||
input_embeds[start_idx:end_idx] = input_embeds[
|
||||
start_idx:end_idx
|
||||
].masked_scatter(images_emb_mask.unsqueeze(-1), image_features)
|
||||
|
||||
outputs = self.language_model.forward(
|
||||
input_ids=input_ids,
|
||||
@@ -289,20 +265,17 @@ class DeepseekVL2ForCausalLM(nn.Module):
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
||||
return input_ids
|
||||
|
||||
def prepare_inputs_embeds(
|
||||
self,
|
||||
pixel_values,
|
||||
images_seq_mask,
|
||||
images_spatial_crop,
|
||||
input_embeds,
|
||||
):
|
||||
def get_image_feature(self, image_input: ImageInputs):
|
||||
pixel_values = image_input.pixel_values.type(
|
||||
next(self.vision.parameters()).dtype
|
||||
).to(device=next(self.vision.parameters()).device)
|
||||
image_feature = self.vision.forward_features(pixel_values)
|
||||
images_embeds = self.projector(image_feature)
|
||||
_, hw, n_dim = images_embeds.shape
|
||||
h = w = int(hw**0.5)
|
||||
|
||||
tile_index = 0
|
||||
images_in_this_batch = []
|
||||
images_spatial_crop = image_input.image_spatial_crop
|
||||
for jdx in range(images_spatial_crop.shape[1]):
|
||||
num_width_tiles, num_height_tiles = images_spatial_crop[0, jdx]
|
||||
if num_width_tiles == 0 or num_height_tiles == 0:
|
||||
@@ -379,13 +352,7 @@ class DeepseekVL2ForCausalLM(nn.Module):
|
||||
|
||||
images_in_this_batch.append(global_local_features)
|
||||
|
||||
if len(images_in_this_batch) > 0:
|
||||
images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
|
||||
input_embeds.masked_scatter_(
|
||||
images_seq_mask.unsqueeze(-1), images_in_this_batch
|
||||
)
|
||||
|
||||
return input_embeds
|
||||
return torch.cat(images_in_this_batch, dim=0)
|
||||
|
||||
|
||||
EntryClass = DeepseekVL2ForCausalLM
|
||||
|
||||
Reference in New Issue
Block a user