refactor: bug fixes and refactor for vlm (#4661)
This commit is contained in:
@@ -77,6 +77,7 @@ global_server_args_dict = {
|
||||
"enable_flashmla": ServerArgs.enable_flashmla,
|
||||
"disable_radix_cache": ServerArgs.disable_radix_cache,
|
||||
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
|
||||
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
|
||||
}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -160,7 +161,8 @@ class ImageInputs:
|
||||
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
|
||||
|
||||
# QWen2-VL related
|
||||
image_grid_thws: List[Tuple[int, int, int]] = None
|
||||
# [num_of_images, t, h, w]
|
||||
image_grid_thws: torch.Tensor = None
|
||||
mrope_position_delta: Optional[torch.Tensor] = None
|
||||
# Qwen2-VL video related
|
||||
video_token_id: Optional[int] = None
|
||||
@@ -168,7 +170,7 @@ class ImageInputs:
|
||||
second_per_grid_ts: Optional[List[torch.Tensor]] = None
|
||||
|
||||
# deepseek vl2 related
|
||||
image_seq_mask: Optional[List[torch.Tensor]] = None
|
||||
images_emb_mask: Optional[List[torch.Tensor]] = None
|
||||
image_spatial_crop: Optional[List[torch.Tensor]] = None
|
||||
|
||||
# The id of the single-image placeholder token
|
||||
@@ -182,9 +184,6 @@ class ImageInputs:
|
||||
slice_end_id: Optional[int] = None
|
||||
tgt_sizes: Optional[list] = None
|
||||
|
||||
# denotes the number of valid image tokens in each image
|
||||
images_emb_mask: Optional[torch.BoolTensor] = None
|
||||
|
||||
@staticmethod
|
||||
def from_dict(obj: dict):
|
||||
ret = ImageInputs(
|
||||
@@ -204,7 +203,7 @@ class ImageInputs:
|
||||
"aspect_ratio_ids",
|
||||
"aspect_ratio_mask",
|
||||
"image_grid_thws",
|
||||
"image_seq_mask",
|
||||
"images_emb_mask",
|
||||
"image_spatial_crop",
|
||||
"im_token_id",
|
||||
"im_start_id",
|
||||
@@ -212,20 +211,58 @@ class ImageInputs:
|
||||
"slice_start_id",
|
||||
"slice_end_id",
|
||||
"tgt_sizes",
|
||||
"images_emb_mask",
|
||||
]
|
||||
for arg in optional_args:
|
||||
if arg in obj:
|
||||
setattr(ret, arg, obj[arg])
|
||||
|
||||
# validate
|
||||
assert (
|
||||
isinstance(ret.pixel_values, torch.Tensor)
|
||||
or isinstance(ret.pixel_values, np.ndarray)
|
||||
or isinstance(ret.pixel_values, list)
|
||||
)
|
||||
|
||||
return ret
|
||||
|
||||
def merge(self, other):
|
||||
def merge(self, other: ImageInputs):
|
||||
"""
|
||||
merge image inputs when requests are being merged
|
||||
"""
|
||||
assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
|
||||
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
|
||||
if isinstance(self.pixel_values, list):
|
||||
# in some rare cases, pixel values are list of patches with different shapes
|
||||
# e.g. minicpm
|
||||
self.pixel_values += other.pixel_values
|
||||
else:
|
||||
assert (
|
||||
self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
|
||||
), f"{self.pixel_values.shape[1:]} vs {other.pixel_values.shape[1:]}"
|
||||
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
|
||||
|
||||
# args would be stacked along first dim
|
||||
# usually these are already tensors
|
||||
stack_args = [
|
||||
# TODO: merge with image_grid_thws, basically the same thing
|
||||
"tgt_sizes",
|
||||
"image_spatial_crop",
|
||||
]
|
||||
for arg in stack_args:
|
||||
if getattr(self, arg, None) is None:
|
||||
setattr(self, arg, getattr(other, arg, None))
|
||||
elif getattr(other, arg, None) is not None:
|
||||
# self and other both not None
|
||||
setattr(
|
||||
self,
|
||||
arg,
|
||||
torch.cat([getattr(self, arg), getattr(other, arg)], dim=0),
|
||||
)
|
||||
|
||||
if self.image_grid_thws is None:
|
||||
self.image_grid_thws = other.image_grid_thws
|
||||
elif other.image_grid_thws is not None:
|
||||
self.image_grid_thws = torch.concat(
|
||||
[self.image_grid_thws, other.image_grid_thws]
|
||||
)
|
||||
|
||||
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
|
||||
# Please note that if the `input_ids` is later used in the model forward,
|
||||
@@ -233,7 +270,7 @@ class ImageInputs:
|
||||
# errors in cuda kernels. See also llava.py for example.
|
||||
self.image_hashes += other.image_hashes
|
||||
self.pad_values = [x % (1 << 30) for x in self.image_hashes]
|
||||
|
||||
# args needed to be merged
|
||||
optional_args = [
|
||||
"image_sizes",
|
||||
"image_offsets",
|
||||
@@ -241,13 +278,13 @@ class ImageInputs:
|
||||
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
|
||||
"aspect_ratio_ids",
|
||||
"aspect_ratio_mask",
|
||||
"image_grid_thws",
|
||||
"image_seq_mask",
|
||||
"image_spatial_crop",
|
||||
"images_emb_mask",
|
||||
]
|
||||
for arg in optional_args:
|
||||
if getattr(self, arg, None) is not None:
|
||||
setattr(self, arg, getattr(self, arg) + getattr(other, arg))
|
||||
self_arg = getattr(self, arg, None)
|
||||
if self_arg is not None:
|
||||
setattr(self, arg, self_arg + getattr(other, arg))
|
||||
# other args would be kept intact
|
||||
|
||||
|
||||
class Req:
|
||||
|
||||
Reference in New Issue
Block a user