[feat] Support session control for vision language models (#2210)
This commit is contained in:
@@ -131,6 +131,7 @@ class LlavaImageProcessor(BaseImageProcessor):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
modalities = request_obj.modalities or ["image"]
|
||||
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
|
||||
grid_pinpoints = (
|
||||
self.hf_config.image_grid_pinpoints
|
||||
@@ -139,9 +140,12 @@ class LlavaImageProcessor(BaseImageProcessor):
|
||||
else None
|
||||
)
|
||||
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
if isinstance(image_data, list) and len(image_data) > 0:
|
||||
# Multiple images
|
||||
if len(image_data) > 1:
|
||||
if "multi-images" in modalities or "video" in modalities:
|
||||
# Multiple images
|
||||
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
|
||||
pixel_values, image_hashes, image_sizes = [], [], []
|
||||
res = []
|
||||
@@ -166,13 +170,6 @@ class LlavaImageProcessor(BaseImageProcessor):
|
||||
)
|
||||
image_hashes = [image_hash]
|
||||
image_sizes = [image_size]
|
||||
elif isinstance(image_data, str):
|
||||
# A single image
|
||||
pixel_values, image_hash, image_size = await self._process_single_image(
|
||||
image_data, aspect_ratio, grid_pinpoints
|
||||
)
|
||||
image_hashes = [image_hash]
|
||||
image_sizes = [image_size]
|
||||
else:
|
||||
raise ValueError(f"Invalid image data: {image_data}")
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ import dataclasses
|
||||
import logging
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
@@ -167,6 +168,30 @@ class ImageInputs:
|
||||
|
||||
return ret
|
||||
|
||||
def merge(self, other, vocab_size):
|
||||
assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
|
||||
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
|
||||
self.image_hashes += other.image_hashes
|
||||
|
||||
self.pad_values = [
|
||||
(self.image_hashes) % vocab_size,
|
||||
(self.image_hashes >> 16) % vocab_size,
|
||||
(self.image_hashes >> 32) % vocab_size,
|
||||
(self.image_hashes >> 64) % vocab_size,
|
||||
]
|
||||
|
||||
optional_args = [
|
||||
"image_sizes",
|
||||
"image_offsets",
|
||||
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
|
||||
"aspect_ratio_ids",
|
||||
"aspect_ratio_mask",
|
||||
"image_grid_thws",
|
||||
]
|
||||
for arg in optional_args:
|
||||
if getattr(self, arg, None) is not None:
|
||||
setattr(self, arg, getattr(self, arg) + getattr(other, arg))
|
||||
|
||||
|
||||
class Req:
|
||||
"""The input and output status of a request."""
|
||||
@@ -177,6 +202,7 @@ class Req:
|
||||
origin_input_text: str,
|
||||
origin_input_ids: Tuple[int],
|
||||
sampling_params: SamplingParams,
|
||||
origin_input_ids_unpadded: Optional[Tuple[int]] = None,
|
||||
lora_path: Optional[str] = None,
|
||||
input_embeds: Optional[List[List[float]]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
@@ -184,7 +210,11 @@ class Req:
|
||||
# Input and output info
|
||||
self.rid = rid
|
||||
self.origin_input_text = origin_input_text
|
||||
self.origin_input_ids_unpadded = origin_input_ids # Before image padding
|
||||
self.origin_input_ids_unpadded = (
|
||||
origin_input_ids_unpadded
|
||||
if origin_input_ids_unpadded
|
||||
else origin_input_ids # Before image padding
|
||||
)
|
||||
self.origin_input_ids = origin_input_ids
|
||||
self.output_ids = [] # Each decode stage's output ids
|
||||
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
|
||||
@@ -260,6 +290,12 @@ class Req:
|
||||
# The number of cached tokens, that were already cached in the KV cache
|
||||
self.cached_tokens = 0
|
||||
|
||||
def extend_image_inputs(self, image_inputs, vocab_size):
|
||||
if self.image_inputs is None:
|
||||
self.image_inputs = image_inputs
|
||||
else:
|
||||
self.image_inputs.merge(image_inputs, vocab_size)
|
||||
|
||||
# whether request reached finished condition
|
||||
def finished(self) -> bool:
|
||||
return self.finished_reason is not None
|
||||
|
||||
@@ -559,12 +559,13 @@ class Scheduler:
|
||||
|
||||
# Image inputs
|
||||
if recv_req.image_inputs is not None:
|
||||
req.image_inputs = ImageInputs.from_dict(
|
||||
image_inputs = ImageInputs.from_dict(
|
||||
recv_req.image_inputs, self.model_config.vocab_size
|
||||
)
|
||||
req.origin_input_ids = self.pad_input_ids_func(
|
||||
req.origin_input_ids_unpadded, req.image_inputs
|
||||
req.origin_input_ids, image_inputs
|
||||
)
|
||||
req.extend_image_inputs(image_inputs, self.model_config.vocab_size)
|
||||
|
||||
if len(req.origin_input_ids) > self.max_req_input_len:
|
||||
req.finished_reason = FINISH_ABORT(
|
||||
|
||||
@@ -41,16 +41,27 @@ class Session:
|
||||
]
|
||||
+ req.input_ids
|
||||
)
|
||||
input_ids_unpadded = (
|
||||
self.reqs[-1].origin_input_ids_unpadded
|
||||
+ self.reqs[-1].output_ids[
|
||||
: self.reqs[-1].sampling_params.max_new_tokens
|
||||
]
|
||||
+ req.input_ids
|
||||
)
|
||||
else:
|
||||
input_ids = req.input_ids
|
||||
input_ids_unpadded = req.input_ids
|
||||
new_req = Req(
|
||||
req.rid,
|
||||
None,
|
||||
input_ids,
|
||||
req.sampling_params,
|
||||
rid=req.rid,
|
||||
origin_input_text=None,
|
||||
origin_input_ids=input_ids,
|
||||
origin_input_ids_unpadded=input_ids_unpadded,
|
||||
sampling_params=req.sampling_params,
|
||||
lora_path=req.lora_path,
|
||||
session_id=self.session_id,
|
||||
)
|
||||
if len(self.reqs) > 0:
|
||||
new_req.image_inputs = self.reqs[-1].image_inputs
|
||||
new_req.tokenizer = tokenizer
|
||||
if req.session_rid is not None and len(self.reqs) == 0:
|
||||
new_req.finished_reason = FINISH_ABORT(
|
||||
|
||||
@@ -49,7 +49,13 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
image_sizes, pad_values = image_inputs.image_sizes, image_inputs.pad_values
|
||||
|
||||
# hardcode for spatial_unpad + anyres
|
||||
image_aspect_ratio = "anyres" if len(image_sizes) == 1 else "pad"
|
||||
if image_inputs.modalities is not None and (
|
||||
"multi-images" in image_inputs.modalities
|
||||
or "video" in image_inputs.modalities
|
||||
):
|
||||
image_aspect_ratio = "pad"
|
||||
else:
|
||||
image_aspect_ratio = "anyres"
|
||||
offset_list = []
|
||||
for image_s in image_sizes:
|
||||
if len(image_sizes) > 16:
|
||||
|
||||
@@ -36,6 +36,7 @@ suites = {
|
||||
"test_triton_attention_backend.py",
|
||||
"test_update_weights.py",
|
||||
"test_vision_openai_server.py",
|
||||
"test_session_control.py",
|
||||
],
|
||||
"sampling/penaltylib": glob.glob(
|
||||
"sampling/penaltylib/**/test_*.py", recursive=True
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 -m unittest test_session_control.TestSessionControl.test_session_control
|
||||
python3 -m unittest test_session_control.TestSessionControl.test_session_control_vlm
|
||||
python3 -m unittest test_session_control.TestSessionControlVision.test_session_control
|
||||
"""
|
||||
|
||||
import unittest
|
||||
@@ -61,6 +61,8 @@ class TestSessionControl(unittest.TestCase):
|
||||
"max_new_tokens": (
|
||||
16 if i > 0 else 0
|
||||
), # prefill only for the first chunk
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
},
|
||||
).json()
|
||||
@@ -79,6 +81,8 @@ class TestSessionControl(unittest.TestCase):
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
},
|
||||
).json()
|
||||
@@ -93,6 +97,8 @@ class TestSessionControl(unittest.TestCase):
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
},
|
||||
).json()
|
||||
@@ -113,6 +119,8 @@ class TestSessionControl(unittest.TestCase):
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
},
|
||||
).json()
|
||||
@@ -133,13 +141,16 @@ class TestSessionControl(unittest.TestCase):
|
||||
"max_new_tokens": (
|
||||
16 if i > 0 else 0
|
||||
), # prefill only for the first chunk
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
},
|
||||
).json()
|
||||
if i > 0:
|
||||
input_ids += tokenizer.encode(response["text"])[
|
||||
1:
|
||||
] # drop the bos token
|
||||
output_ids = tokenizer.encode(response["text"])
|
||||
if output_ids[0] == tokenizer.bos_token_id:
|
||||
output_ids = output_ids[1:]
|
||||
input_ids += output_ids
|
||||
outputs_normal.append(response["text"])
|
||||
if i == 0:
|
||||
input_ids_first_req = input_ids.copy()
|
||||
@@ -152,6 +163,187 @@ class TestSessionControl(unittest.TestCase):
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
},
|
||||
).json()
|
||||
outputs_normal.append(response["text"])
|
||||
|
||||
print("outputs from chunked queries with session control:")
|
||||
print(outputs_from_session)
|
||||
print("outputs from normal queries:")
|
||||
print(outputs_normal)
|
||||
assert outputs_from_session == outputs_normal
|
||||
|
||||
|
||||
class TestSessionControlVision(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "lmms-lab/llava-onevision-qwen2-7b-ov"
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
# other_args={"--disable-radix"},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_child_process(cls.process.pid, include_self=True)
|
||||
|
||||
def test_session_control(self):
|
||||
text_chunks = [
|
||||
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n",
|
||||
"<|im_start|>user\n<image>\nDescribe this image in a very short sentence.<|im_end|>\n<|im_start|>assistant\n",
|
||||
"<|im_start|>user\n<image>\nIs this image same with the previous image? Answer yes or no.<|im_end|>\n<|im_start|>assistant\n",
|
||||
"<|im_start|>user\n<image>\nIs this image same with the previous image? Answer yes or no.<|im_end|>\n<|im_start|>assistant\n",
|
||||
]
|
||||
image_chunks = [
|
||||
"https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png",
|
||||
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
|
||||
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
|
||||
]
|
||||
assert len(text_chunks) == len(image_chunks) + 1
|
||||
tokenizer = get_tokenizer(self.model)
|
||||
text_input_ids = [tokenizer.encode(x) for x in text_chunks]
|
||||
|
||||
# 1. using session control
|
||||
session_id = requests.post(
|
||||
self.base_url + "/open_session",
|
||||
json={"capacity_of_str_len": 1000},
|
||||
).json()
|
||||
rid = None
|
||||
|
||||
first_rid = None
|
||||
outputs_from_session = []
|
||||
for i in range(len(text_input_ids)):
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": text_input_ids[i],
|
||||
"image_data": image_chunks[i - 1] if i > 0 else None,
|
||||
"modalities": ["multi-images"],
|
||||
"session": [session_id, rid],
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": (
|
||||
16 if i > 0 else 0
|
||||
), # prefill only for the first chunk
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
},
|
||||
).json()
|
||||
rid = response["meta_info"]["id"]
|
||||
if i == 0:
|
||||
first_rid = rid
|
||||
if i > 0:
|
||||
outputs_from_session.append(response["text"])
|
||||
|
||||
# backtrack to the first request and regenerate
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": text_input_ids[-1],
|
||||
"image_data": image_chunks[-1:],
|
||||
"modalities": ["multi-images"],
|
||||
"session": [session_id, first_rid],
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
},
|
||||
).json()
|
||||
outputs_from_session.append(response["text"])
|
||||
|
||||
# query with a non-existing rid (the last one should be disappeared becuase of backtrack), should see abort
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": text_input_ids[-1],
|
||||
"image_data": image_chunks[-1:],
|
||||
"modalities": ["multi-images"],
|
||||
"session": [session_id, rid],
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
},
|
||||
).json()
|
||||
assert response["meta_info"]["finish_reason"]["type"] == "abort"
|
||||
|
||||
ret = requests.post(
|
||||
self.base_url + "/close_session",
|
||||
json={"session_id": session_id},
|
||||
)
|
||||
assert ret.status_code == 200
|
||||
|
||||
# send a request to a closed session, should see abort
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": text_input_ids[-1],
|
||||
"session": [session_id, first_rid],
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
},
|
||||
).json()
|
||||
assert response["meta_info"]["finish_reason"]["type"] == "abort"
|
||||
|
||||
# 2. not use session control
|
||||
input_ids_first_req = None
|
||||
input_ids = []
|
||||
outputs_normal = []
|
||||
for i in range(len(text_input_ids)):
|
||||
input_ids += text_input_ids[i]
|
||||
image_data = image_chunks[:i] if i > 0 else None
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": input_ids,
|
||||
"image_data": image_data,
|
||||
"modalities": ["multi-images"],
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": (
|
||||
16 if i > 0 else 0
|
||||
), # prefill only for the first chunk
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
},
|
||||
).json()
|
||||
if i > 0:
|
||||
output_ids = tokenizer.encode(response["text"])
|
||||
if output_ids[0] == tokenizer.bos_token_id:
|
||||
output_ids = output_ids[1:]
|
||||
input_ids += output_ids
|
||||
outputs_normal.append(response["text"])
|
||||
if i == 0:
|
||||
input_ids_first_req = input_ids.copy()
|
||||
|
||||
input_ids_first_req += text_input_ids[-1]
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": input_ids_first_req,
|
||||
"image_data": image_chunks[-1:],
|
||||
"modalities": ["multi-images"],
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
},
|
||||
).json()
|
||||
|
||||
Reference in New Issue
Block a user