[feat] Support session control for vision language models (#2210)

This commit is contained in:
Ying Sheng
2024-11-27 00:03:29 -08:00
committed by GitHub
parent c754652fcd
commit 37c8a5761f
7 changed files with 265 additions and 21 deletions

View File

@@ -31,6 +31,7 @@ import dataclasses
import logging
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import triton
import triton.language as tl
@@ -167,6 +168,30 @@ class ImageInputs:
return ret
def merge(self, other, vocab_size):
assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
self.image_hashes += other.image_hashes
self.pad_values = [
(self.image_hashes) % vocab_size,
(self.image_hashes >> 16) % vocab_size,
(self.image_hashes >> 32) % vocab_size,
(self.image_hashes >> 64) % vocab_size,
]
optional_args = [
"image_sizes",
"image_offsets",
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
"aspect_ratio_ids",
"aspect_ratio_mask",
"image_grid_thws",
]
for arg in optional_args:
if getattr(self, arg, None) is not None:
setattr(self, arg, getattr(self, arg) + getattr(other, arg))
class Req:
"""The input and output status of a request."""
@@ -177,6 +202,7 @@ class Req:
origin_input_text: str,
origin_input_ids: Tuple[int],
sampling_params: SamplingParams,
origin_input_ids_unpadded: Optional[Tuple[int]] = None,
lora_path: Optional[str] = None,
input_embeds: Optional[List[List[float]]] = None,
session_id: Optional[str] = None,
@@ -184,7 +210,11 @@ class Req:
# Input and output info
self.rid = rid
self.origin_input_text = origin_input_text
self.origin_input_ids_unpadded = origin_input_ids # Before image padding
self.origin_input_ids_unpadded = (
origin_input_ids_unpadded
if origin_input_ids_unpadded
else origin_input_ids # Before image padding
)
self.origin_input_ids = origin_input_ids
self.output_ids = [] # Each decode stage's output ids
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
@@ -260,6 +290,12 @@ class Req:
# The number of cached tokens, that were already cached in the KV cache
self.cached_tokens = 0
def extend_image_inputs(self, image_inputs, vocab_size):
if self.image_inputs is None:
self.image_inputs = image_inputs
else:
self.image_inputs.merge(image_inputs, vocab_size)
# whether request reached finished condition
def finished(self) -> bool:
return self.finished_reason is not None