[feat] Support session control for vision language models (#2210)

This commit is contained in:
Ying Sheng
2024-11-27 00:03:29 -08:00
committed by GitHub
parent c754652fcd
commit 37c8a5761f
7 changed files with 265 additions and 21 deletions

View File

@@ -131,6 +131,7 @@ class LlavaImageProcessor(BaseImageProcessor):
if not image_data: if not image_data:
return None return None
modalities = request_obj.modalities or ["image"]
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None) aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
grid_pinpoints = ( grid_pinpoints = (
self.hf_config.image_grid_pinpoints self.hf_config.image_grid_pinpoints
@@ -139,9 +140,12 @@ class LlavaImageProcessor(BaseImageProcessor):
else None else None
) )
if isinstance(image_data, str):
image_data = [image_data]
if isinstance(image_data, list) and len(image_data) > 0: if isinstance(image_data, list) and len(image_data) > 0:
if "multi-images" in modalities or "video" in modalities:
# Multiple images # Multiple images
if len(image_data) > 1:
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
pixel_values, image_hashes, image_sizes = [], [], [] pixel_values, image_hashes, image_sizes = [], [], []
res = [] res = []
@@ -166,13 +170,6 @@ class LlavaImageProcessor(BaseImageProcessor):
) )
image_hashes = [image_hash] image_hashes = [image_hash]
image_sizes = [image_size] image_sizes = [image_size]
elif isinstance(image_data, str):
# A single image
pixel_values, image_hash, image_size = await self._process_single_image(
image_data, aspect_ratio, grid_pinpoints
)
image_hashes = [image_hash]
image_sizes = [image_size]
else: else:
raise ValueError(f"Invalid image data: {image_data}") raise ValueError(f"Invalid image data: {image_data}")

View File

@@ -31,6 +31,7 @@ import dataclasses
import logging import logging
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
@@ -167,6 +168,30 @@ class ImageInputs:
return ret return ret
def merge(self, other, vocab_size):
assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
self.image_hashes += other.image_hashes
self.pad_values = [
(self.image_hashes) % vocab_size,
(self.image_hashes >> 16) % vocab_size,
(self.image_hashes >> 32) % vocab_size,
(self.image_hashes >> 64) % vocab_size,
]
optional_args = [
"image_sizes",
"image_offsets",
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
"aspect_ratio_ids",
"aspect_ratio_mask",
"image_grid_thws",
]
for arg in optional_args:
if getattr(self, arg, None) is not None:
setattr(self, arg, getattr(self, arg) + getattr(other, arg))
class Req: class Req:
"""The input and output status of a request.""" """The input and output status of a request."""
@@ -177,6 +202,7 @@ class Req:
origin_input_text: str, origin_input_text: str,
origin_input_ids: Tuple[int], origin_input_ids: Tuple[int],
sampling_params: SamplingParams, sampling_params: SamplingParams,
origin_input_ids_unpadded: Optional[Tuple[int]] = None,
lora_path: Optional[str] = None, lora_path: Optional[str] = None,
input_embeds: Optional[List[List[float]]] = None, input_embeds: Optional[List[List[float]]] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
@@ -184,7 +210,11 @@ class Req:
# Input and output info # Input and output info
self.rid = rid self.rid = rid
self.origin_input_text = origin_input_text self.origin_input_text = origin_input_text
self.origin_input_ids_unpadded = origin_input_ids # Before image padding self.origin_input_ids_unpadded = (
origin_input_ids_unpadded
if origin_input_ids_unpadded
else origin_input_ids # Before image padding
)
self.origin_input_ids = origin_input_ids self.origin_input_ids = origin_input_ids
self.output_ids = [] # Each decode stage's output ids self.output_ids = [] # Each decode stage's output ids
self.fill_ids = None # fill_ids = origin_input_ids + output_ids self.fill_ids = None # fill_ids = origin_input_ids + output_ids
@@ -260,6 +290,12 @@ class Req:
# The number of cached tokens, that were already cached in the KV cache # The number of cached tokens, that were already cached in the KV cache
self.cached_tokens = 0 self.cached_tokens = 0
def extend_image_inputs(self, image_inputs, vocab_size):
if self.image_inputs is None:
self.image_inputs = image_inputs
else:
self.image_inputs.merge(image_inputs, vocab_size)
# whether request reached finished condition # whether request reached finished condition
def finished(self) -> bool: def finished(self) -> bool:
return self.finished_reason is not None return self.finished_reason is not None

View File

@@ -559,12 +559,13 @@ class Scheduler:
# Image inputs # Image inputs
if recv_req.image_inputs is not None: if recv_req.image_inputs is not None:
req.image_inputs = ImageInputs.from_dict( image_inputs = ImageInputs.from_dict(
recv_req.image_inputs, self.model_config.vocab_size recv_req.image_inputs, self.model_config.vocab_size
) )
req.origin_input_ids = self.pad_input_ids_func( req.origin_input_ids = self.pad_input_ids_func(
req.origin_input_ids_unpadded, req.image_inputs req.origin_input_ids, image_inputs
) )
req.extend_image_inputs(image_inputs, self.model_config.vocab_size)
if len(req.origin_input_ids) > self.max_req_input_len: if len(req.origin_input_ids) > self.max_req_input_len:
req.finished_reason = FINISH_ABORT( req.finished_reason = FINISH_ABORT(

View File

@@ -41,16 +41,27 @@ class Session:
] ]
+ req.input_ids + req.input_ids
) )
input_ids_unpadded = (
self.reqs[-1].origin_input_ids_unpadded
+ self.reqs[-1].output_ids[
: self.reqs[-1].sampling_params.max_new_tokens
]
+ req.input_ids
)
else: else:
input_ids = req.input_ids input_ids = req.input_ids
input_ids_unpadded = req.input_ids
new_req = Req( new_req = Req(
req.rid, rid=req.rid,
None, origin_input_text=None,
input_ids, origin_input_ids=input_ids,
req.sampling_params, origin_input_ids_unpadded=input_ids_unpadded,
sampling_params=req.sampling_params,
lora_path=req.lora_path, lora_path=req.lora_path,
session_id=self.session_id, session_id=self.session_id,
) )
if len(self.reqs) > 0:
new_req.image_inputs = self.reqs[-1].image_inputs
new_req.tokenizer = tokenizer new_req.tokenizer = tokenizer
if req.session_rid is not None and len(self.reqs) == 0: if req.session_rid is not None and len(self.reqs) == 0:
new_req.finished_reason = FINISH_ABORT( new_req.finished_reason = FINISH_ABORT(

View File

@@ -49,7 +49,13 @@ class LlavaBaseForCausalLM(nn.Module):
image_sizes, pad_values = image_inputs.image_sizes, image_inputs.pad_values image_sizes, pad_values = image_inputs.image_sizes, image_inputs.pad_values
# hardcode for spatial_unpad + anyres # hardcode for spatial_unpad + anyres
image_aspect_ratio = "anyres" if len(image_sizes) == 1 else "pad" if image_inputs.modalities is not None and (
"multi-images" in image_inputs.modalities
or "video" in image_inputs.modalities
):
image_aspect_ratio = "pad"
else:
image_aspect_ratio = "anyres"
offset_list = [] offset_list = []
for image_s in image_sizes: for image_s in image_sizes:
if len(image_sizes) > 16: if len(image_sizes) > 16:

View File

@@ -36,6 +36,7 @@ suites = {
"test_triton_attention_backend.py", "test_triton_attention_backend.py",
"test_update_weights.py", "test_update_weights.py",
"test_vision_openai_server.py", "test_vision_openai_server.py",
"test_session_control.py",
], ],
"sampling/penaltylib": glob.glob( "sampling/penaltylib": glob.glob(
"sampling/penaltylib/**/test_*.py", recursive=True "sampling/penaltylib/**/test_*.py", recursive=True

View File

@@ -1,7 +1,7 @@
""" """
Usage: Usage:
python3 -m unittest test_session_control.TestSessionControl.test_session_control python3 -m unittest test_session_control.TestSessionControl.test_session_control
python3 -m unittest test_session_control.TestSessionControl.test_session_control_vlm python3 -m unittest test_session_control.TestSessionControlVision.test_session_control
""" """
import unittest import unittest
@@ -61,6 +61,8 @@ class TestSessionControl(unittest.TestCase):
"max_new_tokens": ( "max_new_tokens": (
16 if i > 0 else 0 16 if i > 0 else 0
), # prefill only for the first chunk ), # prefill only for the first chunk
"no_stop_trim": True,
"skip_special_tokens": False,
}, },
}, },
).json() ).json()
@@ -79,6 +81,8 @@ class TestSessionControl(unittest.TestCase):
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,
"max_new_tokens": 16, "max_new_tokens": 16,
"no_stop_trim": True,
"skip_special_tokens": False,
}, },
}, },
).json() ).json()
@@ -93,6 +97,8 @@ class TestSessionControl(unittest.TestCase):
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,
"max_new_tokens": 16, "max_new_tokens": 16,
"no_stop_trim": True,
"skip_special_tokens": False,
}, },
}, },
).json() ).json()
@@ -113,6 +119,8 @@ class TestSessionControl(unittest.TestCase):
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,
"max_new_tokens": 16, "max_new_tokens": 16,
"no_stop_trim": True,
"skip_special_tokens": False,
}, },
}, },
).json() ).json()
@@ -133,13 +141,16 @@ class TestSessionControl(unittest.TestCase):
"max_new_tokens": ( "max_new_tokens": (
16 if i > 0 else 0 16 if i > 0 else 0
), # prefill only for the first chunk ), # prefill only for the first chunk
"no_stop_trim": True,
"skip_special_tokens": False,
}, },
}, },
).json() ).json()
if i > 0: if i > 0:
input_ids += tokenizer.encode(response["text"])[ output_ids = tokenizer.encode(response["text"])
1: if output_ids[0] == tokenizer.bos_token_id:
] # drop the bos token output_ids = output_ids[1:]
input_ids += output_ids
outputs_normal.append(response["text"]) outputs_normal.append(response["text"])
if i == 0: if i == 0:
input_ids_first_req = input_ids.copy() input_ids_first_req = input_ids.copy()
@@ -152,6 +163,187 @@ class TestSessionControl(unittest.TestCase):
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,
"max_new_tokens": 16, "max_new_tokens": 16,
"no_stop_trim": True,
"skip_special_tokens": False,
},
},
).json()
outputs_normal.append(response["text"])
print("outputs from chunked queries with session control:")
print(outputs_from_session)
print("outputs from normal queries:")
print(outputs_normal)
assert outputs_from_session == outputs_normal
class TestSessionControlVision(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = "lmms-lab/llava-onevision-qwen2-7b-ov"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
# other_args={"--disable-radix"},
)
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True)
def test_session_control(self):
text_chunks = [
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n",
"<|im_start|>user\n<image>\nDescribe this image in a very short sentence.<|im_end|>\n<|im_start|>assistant\n",
"<|im_start|>user\n<image>\nIs this image same with the previous image? Answer yes or no.<|im_end|>\n<|im_start|>assistant\n",
"<|im_start|>user\n<image>\nIs this image same with the previous image? Answer yes or no.<|im_end|>\n<|im_start|>assistant\n",
]
image_chunks = [
"https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png",
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
]
assert len(text_chunks) == len(image_chunks) + 1
tokenizer = get_tokenizer(self.model)
text_input_ids = [tokenizer.encode(x) for x in text_chunks]
# 1. using session control
session_id = requests.post(
self.base_url + "/open_session",
json={"capacity_of_str_len": 1000},
).json()
rid = None
first_rid = None
outputs_from_session = []
for i in range(len(text_input_ids)):
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": text_input_ids[i],
"image_data": image_chunks[i - 1] if i > 0 else None,
"modalities": ["multi-images"],
"session": [session_id, rid],
"sampling_params": {
"temperature": 0,
"max_new_tokens": (
16 if i > 0 else 0
), # prefill only for the first chunk
"no_stop_trim": True,
"skip_special_tokens": False,
},
},
).json()
rid = response["meta_info"]["id"]
if i == 0:
first_rid = rid
if i > 0:
outputs_from_session.append(response["text"])
# backtrack to the first request and regenerate
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": text_input_ids[-1],
"image_data": image_chunks[-1:],
"modalities": ["multi-images"],
"session": [session_id, first_rid],
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
"no_stop_trim": True,
"skip_special_tokens": False,
},
},
).json()
outputs_from_session.append(response["text"])
# query with a non-existing rid (the last one should be disappeared becuase of backtrack), should see abort
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": text_input_ids[-1],
"image_data": image_chunks[-1:],
"modalities": ["multi-images"],
"session": [session_id, rid],
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
"no_stop_trim": True,
"skip_special_tokens": False,
},
},
).json()
assert response["meta_info"]["finish_reason"]["type"] == "abort"
ret = requests.post(
self.base_url + "/close_session",
json={"session_id": session_id},
)
assert ret.status_code == 200
# send a request to a closed session, should see abort
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": text_input_ids[-1],
"session": [session_id, first_rid],
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
"no_stop_trim": True,
"skip_special_tokens": False,
},
},
).json()
assert response["meta_info"]["finish_reason"]["type"] == "abort"
# 2. not use session control
input_ids_first_req = None
input_ids = []
outputs_normal = []
for i in range(len(text_input_ids)):
input_ids += text_input_ids[i]
image_data = image_chunks[:i] if i > 0 else None
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": input_ids,
"image_data": image_data,
"modalities": ["multi-images"],
"sampling_params": {
"temperature": 0,
"max_new_tokens": (
16 if i > 0 else 0
), # prefill only for the first chunk
"no_stop_trim": True,
"skip_special_tokens": False,
},
},
).json()
if i > 0:
output_ids = tokenizer.encode(response["text"])
if output_ids[0] == tokenizer.bos_token_id:
output_ids = output_ids[1:]
input_ids += output_ids
outputs_normal.append(response["text"])
if i == 0:
input_ids_first_req = input_ids.copy()
input_ids_first_req += text_input_ids[-1]
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": input_ids_first_req,
"image_data": image_chunks[-1:],
"modalities": ["multi-images"],
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
"no_stop_trim": True,
"skip_special_tokens": False,
}, },
}, },
).json() ).json()