refactor: bug fixes and refactor for vlm (#4661)

This commit is contained in:
Mick
2025-03-23 13:48:49 +08:00
committed by GitHub
parent ca75741e86
commit 11577cedb7
31 changed files with 770 additions and 735 deletions

View File

@@ -77,6 +77,7 @@ global_server_args_dict = {
"enable_flashmla": ServerArgs.enable_flashmla,
"disable_radix_cache": ServerArgs.disable_radix_cache,
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
}
logger = logging.getLogger(__name__)
@@ -160,7 +161,8 @@ class ImageInputs:
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
# QWen2-VL related
image_grid_thws: List[Tuple[int, int, int]] = None
# [num_of_images, t, h, w]
image_grid_thws: torch.Tensor = None
mrope_position_delta: Optional[torch.Tensor] = None
# Qwen2-VL video related
video_token_id: Optional[int] = None
@@ -168,7 +170,7 @@ class ImageInputs:
second_per_grid_ts: Optional[List[torch.Tensor]] = None
# deepseek vl2 related
image_seq_mask: Optional[List[torch.Tensor]] = None
images_emb_mask: Optional[List[torch.Tensor]] = None
image_spatial_crop: Optional[List[torch.Tensor]] = None
# The id of the single-image placeholder token
@@ -182,9 +184,6 @@ class ImageInputs:
slice_end_id: Optional[int] = None
tgt_sizes: Optional[list] = None
# denotes the number of valid image tokens in each image
images_emb_mask: Optional[torch.BoolTensor] = None
@staticmethod
def from_dict(obj: dict):
ret = ImageInputs(
@@ -204,7 +203,7 @@ class ImageInputs:
"aspect_ratio_ids",
"aspect_ratio_mask",
"image_grid_thws",
"image_seq_mask",
"images_emb_mask",
"image_spatial_crop",
"im_token_id",
"im_start_id",
@@ -212,20 +211,58 @@ class ImageInputs:
"slice_start_id",
"slice_end_id",
"tgt_sizes",
"images_emb_mask",
]
for arg in optional_args:
if arg in obj:
setattr(ret, arg, obj[arg])
# validate
assert (
isinstance(ret.pixel_values, torch.Tensor)
or isinstance(ret.pixel_values, np.ndarray)
or isinstance(ret.pixel_values, list)
)
return ret
def merge(self, other):
def merge(self, other: ImageInputs):
"""
merge image inputs when requests are being merged
"""
assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
if isinstance(self.pixel_values, list):
# in some rare cases, pixel values are list of patches with different shapes
# e.g. minicpm
self.pixel_values += other.pixel_values
else:
assert (
self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
), f"{self.pixel_values.shape[1:]} vs {other.pixel_values.shape[1:]}"
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
# args would be stacked along first dim
# usually these are already tensors
stack_args = [
# TODO: merge with image_grid_thws, basically the same thing
"tgt_sizes",
"image_spatial_crop",
]
for arg in stack_args:
if getattr(self, arg, None) is None:
setattr(self, arg, getattr(other, arg, None))
elif getattr(other, arg, None) is not None:
# self and other both not None
setattr(
self,
arg,
torch.cat([getattr(self, arg), getattr(other, arg)], dim=0),
)
if self.image_grid_thws is None:
self.image_grid_thws = other.image_grid_thws
elif other.image_grid_thws is not None:
self.image_grid_thws = torch.concat(
[self.image_grid_thws, other.image_grid_thws]
)
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
# Please note that if the `input_ids` is later used in the model forward,
@@ -233,7 +270,7 @@ class ImageInputs:
# errors in cuda kernels. See also llava.py for example.
self.image_hashes += other.image_hashes
self.pad_values = [x % (1 << 30) for x in self.image_hashes]
# args needed to be merged
optional_args = [
"image_sizes",
"image_offsets",
@@ -241,13 +278,13 @@ class ImageInputs:
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
"aspect_ratio_ids",
"aspect_ratio_mask",
"image_grid_thws",
"image_seq_mask",
"image_spatial_crop",
"images_emb_mask",
]
for arg in optional_args:
if getattr(self, arg, None) is not None:
setattr(self, arg, getattr(self, arg) + getattr(other, arg))
self_arg = getattr(self, arg, None)
if self_arg is not None:
setattr(self, arg, self_arg + getattr(other, arg))
# other args would be kept intact
class Req: