[feat] Enable chunked prefill for llava-onevision (#2412)
This commit is contained in:
@@ -129,6 +129,7 @@ class ImageInputs:
|
|||||||
image_hashes: Optional[list] = None
|
image_hashes: Optional[list] = None
|
||||||
image_sizes: Optional[list] = None
|
image_sizes: Optional[list] = None
|
||||||
image_offsets: Optional[list] = None
|
image_offsets: Optional[list] = None
|
||||||
|
image_pad_len: Optional[list] = None
|
||||||
pad_values: Optional[list] = None
|
pad_values: Optional[list] = None
|
||||||
modalities: Optional[list] = None
|
modalities: Optional[list] = None
|
||||||
num_image_tokens: Optional[int] = None
|
num_image_tokens: Optional[int] = None
|
||||||
@@ -181,6 +182,7 @@ class ImageInputs:
|
|||||||
optional_args = [
|
optional_args = [
|
||||||
"image_sizes",
|
"image_sizes",
|
||||||
"image_offsets",
|
"image_offsets",
|
||||||
|
"image_pad_len",
|
||||||
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
|
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
|
||||||
"aspect_ratio_ids",
|
"aspect_ratio_ids",
|
||||||
"aspect_ratio_mask",
|
"aspect_ratio_mask",
|
||||||
|
|||||||
@@ -111,17 +111,20 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.is_multimodal:
|
if self.is_multimodal:
|
||||||
server_args.chunked_prefill_size = -1
|
|
||||||
self.mem_fraction_static *= 0.95
|
self.mem_fraction_static *= 0.95
|
||||||
logger.info(
|
if self.model_config.hf_config.architectures == [
|
||||||
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
|
"MllamaForConditionalGeneration"
|
||||||
f"and turn off chunked prefill "
|
]:
|
||||||
f"because this is a multimodal model."
|
logger.info("Automatically turn off --chunked-prefill-size for mllama.")
|
||||||
)
|
server_args.chunked_prefill_size = -1
|
||||||
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
|
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
|
||||||
if self.model_config.hf_config.architectures == [
|
if self.model_config.hf_config.architectures == [
|
||||||
"Qwen2VLForConditionalGeneration"
|
"Qwen2VLForConditionalGeneration"
|
||||||
]:
|
]:
|
||||||
|
logger.info(
|
||||||
|
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
|
||||||
|
)
|
||||||
|
server_args.chunked_prefill_size = -1
|
||||||
server_args.disable_radix_cache = True
|
server_args.disable_radix_cache = True
|
||||||
|
|
||||||
# Global vars
|
# Global vars
|
||||||
|
|||||||
@@ -57,6 +57,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
|||||||
else:
|
else:
|
||||||
image_aspect_ratio = "anyres"
|
image_aspect_ratio = "anyres"
|
||||||
offset_list = []
|
offset_list = []
|
||||||
|
image_inputs.image_pad_len = []
|
||||||
for image_idx, image_s in enumerate(image_sizes):
|
for image_idx, image_s in enumerate(image_sizes):
|
||||||
if len(image_sizes) > 16:
|
if len(image_sizes) > 16:
|
||||||
# 2x2 pooling with stride 2
|
# 2x2 pooling with stride 2
|
||||||
@@ -103,6 +104,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
|||||||
+ input_ids[offset + 1 :]
|
+ input_ids[offset + 1 :]
|
||||||
)
|
)
|
||||||
offset_list.append(offset)
|
offset_list.append(offset)
|
||||||
|
image_inputs.image_pad_len.append(new_image_feature_len)
|
||||||
|
|
||||||
image_inputs.image_offsets = offset_list
|
image_inputs.image_offsets = offset_list
|
||||||
return input_ids
|
return input_ids
|
||||||
@@ -134,6 +136,14 @@ class LlavaBaseForCausalLM(nn.Module):
|
|||||||
image_inputs = forward_batch.image_inputs
|
image_inputs = forward_batch.image_inputs
|
||||||
|
|
||||||
if forward_batch.forward_mode.is_extend():
|
if forward_batch.forward_mode.is_extend():
|
||||||
|
# Clamp input ids. This is because the input_ids for the image tokens are
|
||||||
|
# filled with the hash values of the image for the prefix matching in the radix attention.
|
||||||
|
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
|
||||||
|
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
|
||||||
|
|
||||||
|
# Embed text inputs
|
||||||
|
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||||
|
|
||||||
# Got List[List[str]] extend it to List[str]
|
# Got List[List[str]] extend it to List[str]
|
||||||
# The length of the List should be equal to batch size
|
# The length of the List should be equal to batch size
|
||||||
modalities_list = []
|
modalities_list = []
|
||||||
@@ -142,18 +152,12 @@ class LlavaBaseForCausalLM(nn.Module):
|
|||||||
if im and im.modalities is not None:
|
if im and im.modalities is not None:
|
||||||
modalities_list.extend(im.modalities)
|
modalities_list.extend(im.modalities)
|
||||||
if im and im.image_offsets:
|
if im and im.image_offsets:
|
||||||
max_image_offset.append(max(im.image_offsets))
|
max_image_offset.append(
|
||||||
|
np.max(np.array(im.image_offsets) + np.array(im.image_pad_len))
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
max_image_offset.append(-1)
|
max_image_offset.append(-1)
|
||||||
|
|
||||||
# Clamp input ids. This is because the input_ids for the image tokens are
|
|
||||||
# filled with the hash values of the image for the prefix matching in the radix attention.
|
|
||||||
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
|
|
||||||
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
|
|
||||||
|
|
||||||
# Embed text inputs
|
|
||||||
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
|
||||||
|
|
||||||
start_positions = positions[forward_batch.extend_start_loc].cpu().numpy()
|
start_positions = positions[forward_batch.extend_start_loc].cpu().numpy()
|
||||||
need_vision = start_positions <= np.array(max_image_offset)
|
need_vision = start_positions <= np.array(max_image_offset)
|
||||||
|
|
||||||
@@ -350,6 +354,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
|||||||
|
|
||||||
# Fill in the placeholder for the image
|
# Fill in the placeholder for the image
|
||||||
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
||||||
|
extend_seq_lens = forward_batch.extend_seq_lens.cpu().numpy()
|
||||||
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
|
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
|
||||||
pt = 0
|
pt = 0
|
||||||
for i in range(bs):
|
for i in range(bs):
|
||||||
@@ -357,18 +362,36 @@ class LlavaBaseForCausalLM(nn.Module):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
start_idx = extend_start_loc_cpu[i]
|
start_idx = extend_start_loc_cpu[i]
|
||||||
|
seq_len = extend_seq_lens[i]
|
||||||
prefix_len = prefix_lens_cpu[i]
|
prefix_len = prefix_lens_cpu[i]
|
||||||
|
|
||||||
# Multiple images
|
# Multiple images
|
||||||
for j, image_offset in enumerate(image_inputs[i].image_offsets):
|
for image_idx, image_offset in enumerate(
|
||||||
if image_offset < prefix_len:
|
image_inputs[i].image_offsets
|
||||||
|
):
|
||||||
|
if (
|
||||||
|
image_offset + image_inputs[i].image_pad_len[image_idx]
|
||||||
|
<= prefix_len
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
|
if image_offset >= prefix_len + seq_len:
|
||||||
|
break
|
||||||
|
|
||||||
tmp_image_feature = image_features[pt][j]
|
tmp_image_feature = image_features[pt][image_idx]
|
||||||
pad_len = tmp_image_feature.shape[0]
|
pad_len = tmp_image_feature.shape[0]
|
||||||
|
|
||||||
left_idx = start_idx + (image_offset - prefix_len)
|
input_offset = image_offset - prefix_len
|
||||||
right_idx = start_idx + (image_offset - prefix_len) + pad_len
|
left_idx = start_idx + input_offset
|
||||||
|
right_idx = left_idx + pad_len
|
||||||
|
assert right_idx > start_idx
|
||||||
|
if input_offset < 0:
|
||||||
|
left_idx = start_idx
|
||||||
|
tmp_image_feature = tmp_image_feature[-input_offset:]
|
||||||
|
if right_idx > start_idx + seq_len:
|
||||||
|
tmp_image_feature = tmp_image_feature[
|
||||||
|
: start_idx + seq_len - right_idx
|
||||||
|
]
|
||||||
|
right_idx = start_idx + seq_len
|
||||||
try:
|
try:
|
||||||
input_embeds[left_idx:right_idx] = tmp_image_feature
|
input_embeds[left_idx:right_idx] = tmp_image_feature
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ suites = {
|
|||||||
"test_triton_attention_kernels.py",
|
"test_triton_attention_kernels.py",
|
||||||
"test_triton_attention_backend.py",
|
"test_triton_attention_backend.py",
|
||||||
"test_update_weights_from_disk.py",
|
"test_update_weights_from_disk.py",
|
||||||
|
"test_vision_chunked_prefill.py",
|
||||||
"test_vision_openai_server.py",
|
"test_vision_openai_server.py",
|
||||||
"test_session_control.py",
|
"test_session_control.py",
|
||||||
],
|
],
|
||||||
|
|||||||
173
test/srt/test_vision_chunked_prefill.py
Normal file
173
test/srt/test_vision_chunked_prefill.py
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
python3 -m unittest test_vision_chunked_prefill.TestVisionChunkedPrefill.test_chunked_prefill
|
||||||
|
"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import requests
|
||||||
|
from decord import VideoReader, cpu
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestVisionChunkedPrefill(unittest.TestCase):
|
||||||
|
def prepare_video_messages(self, video_path, max_frames_num=8):
|
||||||
|
vr = VideoReader(video_path, ctx=cpu(0))
|
||||||
|
total_frame_num = len(vr)
|
||||||
|
uniform_sampled_frames = np.linspace(
|
||||||
|
0, total_frame_num - 1, max_frames_num, dtype=int
|
||||||
|
)
|
||||||
|
frame_idx = uniform_sampled_frames.tolist()
|
||||||
|
frames = vr.get_batch(frame_idx).asnumpy()
|
||||||
|
|
||||||
|
base64_frames = []
|
||||||
|
for frame in frames:
|
||||||
|
pil_img = Image.fromarray(frame)
|
||||||
|
buff = io.BytesIO()
|
||||||
|
pil_img.save(buff, format="JPEG")
|
||||||
|
base64_str = base64.b64encode(buff.getvalue()).decode("utf-8")
|
||||||
|
base64_frames.append(base64_str)
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": []}]
|
||||||
|
frame_format = {
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": "data:image/jpeg;base64,{}"},
|
||||||
|
"modalities": "video",
|
||||||
|
}
|
||||||
|
|
||||||
|
for base64_frame in base64_frames:
|
||||||
|
frame_format["image_url"]["url"] = "data:image/jpeg;base64,{}".format(
|
||||||
|
base64_frame
|
||||||
|
)
|
||||||
|
messages[0]["content"].append(frame_format.copy())
|
||||||
|
|
||||||
|
prompt = {"type": "text", "text": "Please describe the video briefly."}
|
||||||
|
messages[0]["content"].append(prompt)
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def get_prompt_from_messages(self, messages):
|
||||||
|
text = (
|
||||||
|
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
||||||
|
"<|im_start|>user\n"
|
||||||
|
)
|
||||||
|
image_data = []
|
||||||
|
for content in messages[0]["content"]:
|
||||||
|
if content["type"] == "image_url":
|
||||||
|
text += "<image>\n"
|
||||||
|
image_data.append(content["image_url"]["url"])
|
||||||
|
text += "Please describe the video briefly.<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
return text, image_data
|
||||||
|
|
||||||
|
def generate(self, text, image_data):
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url + "/generate",
|
||||||
|
json={
|
||||||
|
"text": text,
|
||||||
|
"image_data": image_data,
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 32,
|
||||||
|
"no_stop_trim": True,
|
||||||
|
"skip_special_tokens": False,
|
||||||
|
},
|
||||||
|
"modalities": ["multi-images"],
|
||||||
|
},
|
||||||
|
).json()
|
||||||
|
return response["text"]
|
||||||
|
|
||||||
|
def generate_for_video(self, batch, num_frame) -> Union[str, list[str]]:
|
||||||
|
# prepare the video input about Steven introducing ipod nano
|
||||||
|
url = "https://raw.githubusercontent.com/evolvinglmms-lab/sglang/dev/onevision_local/assets/jobs.mp4"
|
||||||
|
cache_dir = os.path.expanduser("~/.cache")
|
||||||
|
file_path = os.path.join(cache_dir, "jobs.mp4")
|
||||||
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
response = requests.get(url)
|
||||||
|
response.raise_for_status()
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
f.write(response.content)
|
||||||
|
|
||||||
|
if not batch:
|
||||||
|
assert isinstance(num_frame, int)
|
||||||
|
messages = self.prepare_video_messages(file_path, max_frames_num=num_frame)
|
||||||
|
text, image_data = self.get_prompt_from_messages(messages)
|
||||||
|
return self.generate(text, image_data)
|
||||||
|
else:
|
||||||
|
assert isinstance(num_frame, list)
|
||||||
|
func_args = []
|
||||||
|
for max_frames_num in num_frame:
|
||||||
|
messages = self.prepare_video_messages(
|
||||||
|
file_path,
|
||||||
|
max_frames_num=max_frames_num,
|
||||||
|
)
|
||||||
|
text, image_data = self.get_prompt_from_messages(messages)
|
||||||
|
func_args.append((text, image_data))
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||||
|
responses = list(executor.map(lambda p: self.generate(*p), func_args))
|
||||||
|
|
||||||
|
return responses
|
||||||
|
|
||||||
|
def run_generate(self, chunked_prefill_size, batch, num_frame):
|
||||||
|
# launch server
|
||||||
|
model = "lmms-lab/llava-onevision-qwen2-7b-ov"
|
||||||
|
# model = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
||||||
|
self.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
process = popen_launch_server(
|
||||||
|
model,
|
||||||
|
self.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[
|
||||||
|
"--chunked-prefill-size",
|
||||||
|
f"{chunked_prefill_size}",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
return self.generate_for_video(batch, num_frame)
|
||||||
|
finally:
|
||||||
|
kill_process_tree(process.pid)
|
||||||
|
|
||||||
|
def test_chunked_prefill(self):
|
||||||
|
output_chunked = self.run_generate(
|
||||||
|
chunked_prefill_size=1024, batch=False, num_frame=1
|
||||||
|
)
|
||||||
|
output_no_chunked = self.run_generate(
|
||||||
|
chunked_prefill_size=-1, batch=False, num_frame=1
|
||||||
|
)
|
||||||
|
|
||||||
|
print("output with chunked prefill:")
|
||||||
|
print(output_chunked)
|
||||||
|
print("output without chunked prefill:")
|
||||||
|
print(output_no_chunked)
|
||||||
|
assert output_chunked == output_no_chunked
|
||||||
|
|
||||||
|
output_chunked = self.run_generate(
|
||||||
|
chunked_prefill_size=1024, batch=True, num_frame=[2, 6, 8, 10]
|
||||||
|
)
|
||||||
|
output_no_chunked = self.run_generate(
|
||||||
|
chunked_prefill_size=-1, batch=True, num_frame=[2, 6, 8, 10]
|
||||||
|
)
|
||||||
|
|
||||||
|
print("output with chunked prefill:")
|
||||||
|
print(output_chunked)
|
||||||
|
print("output without chunked prefill:")
|
||||||
|
print(output_no_chunked)
|
||||||
|
assert output_chunked == output_no_chunked
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user