diff --git a/README.md b/README.md index c7d47d678..c118d6a1a 100644 --- a/README.md +++ b/README.md @@ -231,8 +231,13 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct - `python -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000` - `python -m sglang.launch_server --model-path liuhaotian/llava-v1.6-vicuna-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000` - `python -m sglang.launch_server --model-path liuhaotian/llava-v1.6-34b --tokenizer-path liuhaotian/llava-v1.6-34b-tokenizer --port 30000` + - `python -m sglang.launch_server --model-path lmms-lab/llama3-llava-next-8b --port=30000 --host=127.0.0.1 --tp-size=1 --chat-template=llava_llama_3` + - `python -m sglang.launch_server --model-path lmms-lab/llava-next-72b --port=30000 --host="127.0.0.1" --tp-size=8 --chat-template=chatml-llava` - LLaVA-NeXT-Video - see [examples/usage/llava_video](examples/usage/llava_video) +- [LLaVA-OneVision](https://arxiv.org/abs/2408.03326) + - `python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --host=127.0.0.1 --tp-size=8 --chat-template=chatml-llava --chunked-prefill-size=16384` + - see [test/srt/test_llava_onevision_openai_server.py](test/srt/test_llava_onevision_openai_server.py) - Yi-VL - see [srt_example_yi_vl.py](examples/quick_start/srt_example_yi_vl.py). - StableLM diff --git a/examples/usage/llava/http_llava_onevision_test.py b/examples/usage/llava/http_llava_onevision_test.py new file mode 100644 index 000000000..c32d52981 --- /dev/null +++ b/examples/usage/llava/http_llava_onevision_test.py @@ -0,0 +1,211 @@ +import base64 +import io +import os +import sys +import time + +import numpy as np +import openai +import requests +from decord import VideoReader, cpu +from PIL import Image + +# pip install httpx==0.23.3 +# pip install decord +# pip install protobuf==3.20.0 + + +def download_video(url, cache_dir): + file_path = os.path.join(cache_dir, "jobs.mp4") + os.makedirs(cache_dir, exist_ok=True) + + response = requests.get(url) + response.raise_for_status() + + with open(file_path, "wb") as f: + f.write(response.content) + + print(f"File downloaded and saved to: {file_path}") + return file_path + + +def create_openai_client(base_url): + return openai.Client(api_key="EMPTY", base_url=base_url) + + +def image_stream_request_test(client): + print("----------------------Image Stream Request Test----------------------") + stream_request = client.chat.completions.create( + model="default", + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png" + }, + }, + { + "type": "text", + "text": "Please describe this image. Please list the benchmarks and the models.", + }, + ], + }, + ], + temperature=0.7, + max_tokens=1024, + stream=True, + ) + stream_response = "" + + for chunk in stream_request: + if chunk.choices[0].delta.content is not None: + content = chunk.choices[0].delta.content + stream_response += content + sys.stdout.write(content) + sys.stdout.flush() + + print("-" * 30) + + +def video_stream_request_test(client, video_path): + print("------------------------Video Stream Request Test----------------------") + messages = prepare_video_messages(video_path) + + start_time = time.time() + video_request = client.chat.completions.create( + model="default", + messages=messages, + temperature=0, + max_tokens=1024, + stream=True, + ) + print("-" * 30) + video_response = "" + + for chunk in video_request: + if chunk.choices[0].delta.content is not None: + content = chunk.choices[0].delta.content + video_response += content + sys.stdout.write(content) + sys.stdout.flush() + print("-" * 30) + + +def image_speed_test(client): + print("----------------------Image Speed Test----------------------") + start_time = time.time() + request = client.chat.completions.create( + model="default", + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png" + }, + }, + { + "type": "text", + "text": "Please describe this image. Please list the benchmarks and the models.", + }, + ], + }, + ], + temperature=0, + max_tokens=1024, + ) + end_time = time.time() + response = request.choices[0].message.content + print(response) + print("-" * 30) + print_speed_test_results(request, start_time, end_time) + + +def video_speed_test(client, video_path): + print("------------------------Video Speed Test------------------------") + messages = prepare_video_messages(video_path) + + start_time = time.time() + video_request = client.chat.completions.create( + model="default", + messages=messages, + temperature=0, + max_tokens=1024, + ) + end_time = time.time() + video_response = video_request.choices[0].message.content + print(video_response) + print("-" * 30) + print_speed_test_results(video_request, start_time, end_time) + + +def prepare_video_messages(video_path): + max_frames_num = 32 + vr = VideoReader(video_path, ctx=cpu(0)) + total_frame_num = len(vr) + uniform_sampled_frames = np.linspace( + 0, total_frame_num - 1, max_frames_num, dtype=int + ) + frame_idx = uniform_sampled_frames.tolist() + frames = vr.get_batch(frame_idx).asnumpy() + + base64_frames = [] + for frame in frames: + pil_img = Image.fromarray(frame) + buff = io.BytesIO() + pil_img.save(buff, format="JPEG") + base64_str = base64.b64encode(buff.getvalue()).decode("utf-8") + base64_frames.append(base64_str) + + messages = [{"role": "user", "content": []}] + frame_format = { + "type": "image_url", + "image_url": {"url": "data:image/jpeg;base64,{}"}, + } + + for base64_frame in base64_frames: + frame_format["image_url"]["url"] = "data:image/jpeg;base64,{}".format( + base64_frame + ) + messages[0]["content"].append(frame_format.copy()) + + prompt = {"type": "text", "text": "Please describe the video in detail."} + messages[0]["content"].append(prompt) + + return messages + + +def print_speed_test_results(request, start_time, end_time): + total_tokens = request.usage.total_tokens + completion_tokens = request.usage.completion_tokens + prompt_tokens = request.usage.prompt_tokens + + print(f"Total tokens: {total_tokens}") + print(f"Completion tokens: {completion_tokens}") + print(f"Prompt tokens: {prompt_tokens}") + print(f"Time taken: {end_time - start_time} seconds") + print(f"Token per second: {total_tokens / (end_time - start_time)}") + print(f"Completion token per second: {completion_tokens / (end_time - start_time)}") + print(f"Prompt token per second: {prompt_tokens / (end_time - start_time)}") + + +def main(): + url = "https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4" + cache_dir = os.path.expanduser("~/.cache") + video_path = download_video(url, cache_dir) + + client = create_openai_client("http://127.0.0.1:30000/v1") + + image_stream_request_test(client) + video_stream_request_test(client, video_path) + image_speed_test(client) + video_speed_test(client, video_path) + + +if __name__ == "__main__": + main() diff --git a/examples/usage/llava_video/srt_example_llava_v.py b/examples/usage/llava_video/srt_example_llava_v.py index 27ba862d3..7421dfcdf 100644 --- a/examples/usage/llava_video/srt_example_llava_v.py +++ b/examples/usage/llava_video/srt_example_llava_v.py @@ -121,6 +121,20 @@ def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size= if __name__ == "__main__": + url = "https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4" + + cache_dir = os.path.expanduser("~/.cache") + file_path = os.path.join(cache_dir, "jobs.mp4") + + os.makedirs(cache_dir, exist_ok=True) + + response = requests.get(url) + response.raise_for_status() # Raise an exception for bad responses + + with open(file_path, "wb") as f: + f.write(response.content) + + print(f"File downloaded and saved to: {file_path}") # Create the parser parser = argparse.ArgumentParser( description="Run video processing with specified port." @@ -148,7 +162,7 @@ if __name__ == "__main__": parser.add_argument( "--video-dir", type=str, - default="./videos/Q98Z4OTh8RwmDonc.mp4", + default=os.path.expanduser("~/.cache/jobs.mp4"), help="The directory or path for the processed video files.", ) parser.add_argument( diff --git a/examples/usage/llava_video/videos/Q98Z4OTh8RwmDonc.mp4 b/examples/usage/llava_video/videos/Q98Z4OTh8RwmDonc.mp4 deleted file mode 100644 index 32d912dbf..000000000 Binary files a/examples/usage/llava_video/videos/Q98Z4OTh8RwmDonc.mp4 and /dev/null differ diff --git a/python/pyproject.toml b/python/pyproject.toml index 7ba4b4c6b..6b1b032fd 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ ] [project.optional-dependencies] -srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular", +srt = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "packaging", "pillow", "psutil", "pydantic", "python-multipart", "torch", "uvicorn", "uvloop", "zmq", "vllm==0.5.4", "outlines>=0.0.44"] diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py index bfde4bbdb..92f717127 100644 --- a/python/sglang/lang/chat_template.py +++ b/python/sglang/lang/chat_template.py @@ -137,7 +137,7 @@ register_chat_template( register_chat_template( ChatTemplate( name="chatml-llava", - default_system_prompt="Answer the questions.", + default_system_prompt="You are a helpful assistant.", role_prefix_and_suffix={ "system": ("<|im_start|>system\n", "<|im_end|>\n"), "user": ("<|im_start|>user\n", "<|im_end|>\n"), @@ -145,7 +145,7 @@ register_chat_template( }, style=ChatTemplateStyle.PLAIN, stop_str=("<|im_end|>",), - image_token=" \n", + image_token="\n", ) ) @@ -322,12 +322,17 @@ def match_chat_ml(model_path: str): if "tinyllama" in model_path: return get_chat_template("chatml") # Now the suffix for qwen2 chat model is "instruct" - if "qwen" in model_path and ("chat" in model_path or "instruct" in model_path): + if ( + "qwen" in model_path + and ("chat" in model_path or "instruct" in model_path) + and ("llava" not in model_path) + ): return get_chat_template("qwen") if ( "llava-v1.6-34b" in model_path or "llava-v1.6-yi-34b" in model_path or "llava-next-video-34b" in model_path + or "llava-onevision-qwen2" in model_path ): return get_chat_template("chatml-llava") diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index 5ee121697..d5ca32770 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -34,6 +34,7 @@ class SeparatorStyle(IntEnum): NO_COLON_TWO = auto() ADD_NEW_LINE_SINGLE = auto() LLAMA2 = auto() + LLAMA3 = auto() CHATGLM = auto() CHATML = auto() CHATINTERN = auto() @@ -137,6 +138,20 @@ class Conversation: else: ret += role + ":" return ret + elif self.sep_style == SeparatorStyle.LLAMA3: + ret = "<|begin_of_text|>" + if self.system_message: + ret += system_prompt + else: + ret += "" + for i, (role, message) in enumerate(self.messages): + if message: + ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" + ret += f"{message.strip()}<|eot_id|>" + else: + ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" + # print(ret) + return ret elif self.sep_style == SeparatorStyle.LLAMA2: seps = [self.sep, self.sep2] if self.system_message: @@ -379,12 +394,23 @@ def generate_chat_conv( conv.append_message(conv.roles[0], message.content) else: real_content = "" + # calculate number of image_url + num_image_url = 0 + for content in message.content: + if content.type == "image_url": + num_image_url += 1 + if num_image_url > 1: + image_token = "" + else: + image_token = "\n" for content in message.content: if content.type == "text": + if num_image_url > 16: + real_content += "\n" # for video real_content += content.text elif content.type == "image_url": # NOTE: Only works for llava - real_content += "\n" + real_content += image_token conv.append_image(content.image_url.url) conv.append_message(conv.roles[0], real_content) elif msg_role == "assistant": @@ -425,6 +451,18 @@ register_conv_template( ) ) +register_conv_template( + Conversation( + name="chatml-llava", + system_template="<|im_start|>system\n{system_message}", + system_message="You are a helpful assistant.", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_str=["<|endoftext|>", "<|im_end|>"], + ) +) + register_conv_template( Conversation( name="vicuna_v1.1", @@ -437,6 +475,17 @@ register_conv_template( ) ) +register_conv_template( + Conversation( + name="llava_llama_3", + system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.", + system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>", + roles=("user", "assistant"), + sep_style=SeparatorStyle.LLAMA3, + sep="", + stop_str=["<|end_of_text|>", "<|eot_id|>"], + ) +) # Reference: https://github.com/InternLM/lmdeploy/blob/387bf54b4f124e72aab30ae9755f562e435d3d01/lmdeploy/model.py#L425-L442 register_conv_template( Conversation( diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 328519cb2..2d604d287 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -131,11 +131,49 @@ class TokenizerManager: self.model_update_lock = asyncio.Lock() self.model_update_result = None - async def get_pixel_values(self, image_data): - aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None) - grid_pinpoints = ( - self.hf_config.image_grid_pinpoints if aspect_ratio == "anyres" else None + async def get_pixel_values(self, image_data, aspect_ratio=None): + aspect_ratio = ( + getattr(self.hf_config, "image_aspect_ratio", None) + if aspect_ratio is None + else aspect_ratio ) + grid_pinpoints = ( + self.hf_config.image_grid_pinpoints + if hasattr(self.hf_config, "image_grid_pinpoints") + and "anyres" in aspect_ratio + else None + ) + + if isinstance(image_data, list) and len(image_data) > 0: + pixel_values, image_hash, image_size = [], [], [] + if len(image_data) > 1: + aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres + for img_data in image_data: + pixel_v, image_h, image_s = await self._process_single_image( + img_data, aspect_ratio, grid_pinpoints + ) + pixel_values.append(pixel_v) + image_hash.append(image_h) + image_size.append(image_s) + pixel_values = np.stack(pixel_values, axis=0) + else: + pixel_values, image_hash, image_size = await self._process_single_image( + image_data[0], aspect_ratio, grid_pinpoints + ) + image_hash = [image_hash] + image_size = [image_size] + elif isinstance(image_data, str): + pixel_values, image_hash, image_size = await self._process_single_image( + image_data, aspect_ratio, grid_pinpoints + ) + image_hash = [image_hash] + image_size = [image_size] + else: + pixel_values, image_hash, image_size = None, None, None + + return pixel_values, image_hash, image_size + + async def _process_single_image(self, image_data, aspect_ratio, grid_pinpoints): if self.executor is not None: loop = asyncio.get_event_loop() return await loop.run_in_executor( @@ -194,8 +232,8 @@ class TokenizerManager: ) if self.is_generation: - pixel_values, image_hash, image_size = await self._get_pixel_values( - obj.image_data if not_use_index else obj.image_data[index] + pixel_values, image_hash, image_size = await self.get_pixel_values( + obj.image_data ) return_logprob = ( obj.return_logprob if not_use_index else obj.return_logprob[index] @@ -704,7 +742,7 @@ def get_pixel_values( tuple(int(x * 255) for x in processor.image_processor.image_mean), ) pixel_values = processor.image_processor(image)["pixel_values"][0] - elif image_aspect_ratio == "anyres": + elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio: pixel_values = process_anyres_image( image, processor.image_processor, image_grid_pinpoints ) diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 41f908301..fa79f8492 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -322,11 +322,16 @@ class ModelTpServer: if self.model_runner.is_generation: req.pixel_values = recv_req.pixel_values if req.pixel_values is not None: + image_hash = ( + hash(tuple(recv_req.image_hash)) + if isinstance(recv_req.image_hash, list) + else recv_req.image_hash + ) req.pad_value = [ - (recv_req.image_hash) % self.model_config.vocab_size, - (recv_req.image_hash >> 16) % self.model_config.vocab_size, - (recv_req.image_hash >> 32) % self.model_config.vocab_size, - (recv_req.image_hash >> 64) % self.model_config.vocab_size, + (image_hash) % self.model_config.vocab_size, + (image_hash >> 16) % self.model_config.vocab_size, + (image_hash >> 32) % self.model_config.vocab_size, + (image_hash >> 64) % self.model_config.vocab_size, ] req.image_size = recv_req.image_size ( diff --git a/python/sglang/srt/mm_utils.py b/python/sglang/srt/mm_utils.py index e09c8215c..7918f3f71 100644 --- a/python/sglang/srt/mm_utils.py +++ b/python/sglang/srt/mm_utils.py @@ -13,10 +13,25 @@ See the License for the specific language governing permissions and limitations under the License. """ -# Source: https://github.com/haotian-liu/LLaVA/blob/main/llava/mm_utils.py +# Source: https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/llava/mm_utils.py +""" +Utilities for multi-modal models. + +This python file mainly contains utilities that were used in the +image processing logic of llava-next including operations such as +anyres and anyres_max + +Currently supports the anyres and anyres_max operation for CLIP and +SigLip. For more information, you may refer to the paper or the blog + +LLaVA-NeXT : https://llava-vl.github.io/blog/2024-01-30-llava-next/ +LLaVA-Onevision : https://arxiv.org/pdf/2408.03326 + +""" import ast import base64 import math +import re from io import BytesIO import numpy as np @@ -40,10 +55,13 @@ def select_best_resolution(original_size, possible_resolutions): min_wasted_resolution = float("inf") for width, height in possible_resolutions: + # Calculate the downscaled size to keep the aspect ratio scale = min(width / original_width, height / original_height) downscaled_width, downscaled_height = int(original_width * scale), int( original_height * scale ) + + # Calculate effective and wasted resolutions effective_resolution = min( downscaled_width * downscaled_height, original_width * original_height ) @@ -129,6 +147,26 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): Returns: tuple: The shape of the image patch grid in the format (width, height). """ + if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints: + assert patch_size in [ + 224, + 336, + 384, + 448, + 512, + ], "patch_size should be in [224, 336, 384, 448, 512]" + # Use regex to extract the range from the input string + matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints) + range_start = tuple(map(int, matches[0])) + range_end = tuple(map(int, matches[-1])) + # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1]) + grid_pinpoints = [ + (i, j) + for i in range(range_start[0], range_end[0] + 1) + for j in range(range_start[1], range_end[1] + 1) + ] + # Multiply all elements by patch_size + grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints] if type(grid_pinpoints) is list: possible_resolutions = grid_pinpoints else: @@ -149,6 +187,31 @@ def process_anyres_image(image, processor, grid_pinpoints): Returns: np.array: An np array containing the processed image patches. """ + if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints: + try: + patch_size = processor.size[0] + except Exception as e: + patch_size = processor.size["shortest_edge"] + assert patch_size in [ + 224, + 336, + 384, + 448, + 512, + ], "patch_size should be in [224, 336, 384, 448, 512]" + # Use regex to extract the range from the input string + matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints) + range_start = tuple(map(int, matches[0])) + range_end = tuple(map(int, matches[-1])) + # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1]) + grid_pinpoints = [ + (i, j) + for i in range(range_start[0], range_end[0] + 1) + for j in range(range_start[1], range_end[1] + 1) + ] + # Multiply all elements by patch_size + grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints] + if type(grid_pinpoints) is list: possible_resolutions = grid_pinpoints else: @@ -156,15 +219,24 @@ def process_anyres_image(image, processor, grid_pinpoints): best_resolution = select_best_resolution(image.size, possible_resolutions) image_padded = resize_and_pad_image(image, best_resolution) - patches = divide_to_patches(image_padded, processor.crop_size["height"]) - - image_original_resize = image.resize( - (processor.size["shortest_edge"], processor.size["shortest_edge"]) + # For Siglip processor, only have size but no crop size + crop_size = ( + processor.crop_size["height"] + if "crop_size" in processor.__dict__ + else processor.size["height"] ) + shortest_edge = ( + processor.size["shortest_edge"] + if "shortest_edge" in processor.size + else processor.size["height"] + ) + patches = divide_to_patches(image_padded, crop_size) + + image_original_resize = image.resize((shortest_edge, shortest_edge)) image_patches = [image_original_resize] + patches image_patches = [ - processor.preprocess(image_patch)["pixel_values"][0] + processor.preprocess(image_patch.convert("RGB"))["pixel_values"][0] for image_patch in image_patches ] return np.stack(image_patches, axis=0) @@ -255,7 +327,7 @@ def process_images(images, image_processor, model_cfg): ) image = image_processor.preprocess(image)["pixel_values"][0] new_images.append(image) - elif image_aspect_ratio == "anyres": + elif "anyres" in image_aspect_ratio: for image in images: image = process_anyres_image( image, image_processor, model_cfg.image_grid_pinpoints diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index bac0a0537..98daeaece 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -88,14 +88,19 @@ class InputMetadata: reqs = batch.reqs self.pixel_values = [r.pixel_values for r in reqs] self.image_sizes = [r.image_size for r in reqs] - self.image_offsets = [ - ( - (r.image_offset - batch.prefix_lens_cpu[i]) - if r.image_offset is not None - else 0 - ) - for i, r in enumerate(reqs) - ] + self.image_offsets = [] + for r in reqs: + if isinstance(r.image_offset, list): + self.image_offsets.append( + [ + (image_offset - len(r.prefix_indices)) + for image_offset in r.image_offset + ] + ) + elif isinstance(r.image_offset, int): + self.image_offsets.append(r.image_offset - len(r.prefix_indices)) + elif r.image_offset is None: + self.image_offsets.append(0) def compute_positions(self, batch: ScheduleBatch): position_ids_offsets = batch.position_ids_offsets diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index a885a6e59..76a0630fc 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -15,6 +15,8 @@ limitations under the License. """Inference-only LLaVa model compatible with HuggingFace weights.""" +import math +import re from typing import Iterable, List, Optional, Tuple import numpy as np @@ -26,6 +28,8 @@ from transformers import ( LlavaConfig, MistralConfig, Qwen2Config, + SiglipVisionConfig, + SiglipVisionModel, ) from transformers.models.llava.modeling_llava import LlavaMultiModalProjector from vllm.config import CacheConfig @@ -63,34 +67,61 @@ class LlavaLlamaForCausalLM(nn.Module): ) def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None): - new_image_feature_len = self.image_feature_len - # now only support spatial_unpad + anyres - if self.mm_patch_merge_type.startswith("spatial"): - height = width = self.num_patches_per_side - if pt_shape[0] > 1: - if self.image_aspect_ratio == "anyres": - num_patch_width, num_patch_height = get_anyres_image_grid_shape( - image_size, - self.image_grid_pinpoints, - self.vision_tower.config.image_size, - ) - if "unpad" in self.mm_patch_merge_type: - h = num_patch_height * height - w = num_patch_width * width - new_h, new_w = unpad_image_shape(h, w, image_size) - new_image_feature_len += new_h * (new_w + 1) - pad_ids = pad_value * ( - (new_image_feature_len + len(pad_value)) // len(pad_value) - ) - offset = input_ids.index(self.config.image_token_index) - # old_len + pad_len - 1, because we need to remove image_token_id - new_input_ids = ( - input_ids[:offset] - + pad_ids[:new_image_feature_len] - + input_ids[offset + 1 :] - ) - return new_input_ids, offset + # hardcode for spatial_unpad + anyres + image_aspect_ratio = "anyres" if len(image_size) == 1 else "pad" + offset_list = [] + for image_s in image_size: + if len(image_size) > 16: + # 2x2 pooling with stride 2 + new_image_feature_len = ( + math.ceil(self.image_size / self.patch_size / 2) ** 2 + ) + else: + new_image_feature_len = self.image_feature_len # multiimage + + height = width = self.num_patches_per_side + if "anyres" in image_aspect_ratio: + num_patch_width, num_patch_height = get_anyres_image_grid_shape( + image_s, + self.image_grid_pinpoints, + self.vision_tower.config.image_size, + ) + h = num_patch_height * height + w = num_patch_width * width + new_h, new_w = unpad_image_shape(h, w, image_s) + + if "anyres_max" in self.config.image_aspect_ratio: + matched_anyres_max_num_patches = re.match( + r"anyres_max_(\d+)", self.config.image_aspect_ratio + ) + if matched_anyres_max_num_patches: + max_num_patches = int(matched_anyres_max_num_patches.group(1)) + # times = math.sqrt(h * w / (max_num_patches * unit**2)) + times = math.sqrt( + new_h * new_w / (max_num_patches * self.image_feature_len) + ) + if times > 1.1: + new_h = int(new_h // times) + new_w = int(new_w // times) + new_image_feature_len += new_h * (new_w + 1) + + pad_ids = pad_value * ( + (new_image_feature_len + len(pad_value)) // len(pad_value) + ) + # print("calculated new_image_feature_len: ", new_image_feature_len) + try: + offset = input_ids.index(self.config.image_token_index) + except ValueError: + offset = 0 + # old_len + pad_len - 1, because we need to remove image_token_id + input_ids = ( + input_ids[:offset] + + pad_ids[:new_image_feature_len] + + input_ids[offset + 1 :] + ) + offset_list.append(offset) + return input_ids, offset_list def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor: image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) @@ -124,7 +155,6 @@ class LlavaLlamaForCausalLM(nn.Module): # Embed text input input_embeds = self.language_model.model.embed_tokens(input_ids) - # Embed vision input need_vision = ( (positions[input_metadata.extend_start_loc] < self.image_feature_len) @@ -163,27 +193,73 @@ class LlavaLlamaForCausalLM(nn.Module): if self.mm_patch_merge_type.startswith("spatial"): new_image_features = [] + height = width = self.num_patches_per_side for image_idx, image_feature in enumerate(image_features): - if image_feature.shape[0] > 1: + if len(image_sizes[image_idx]) == 1: + image_aspect_ratio = ( + self.config.image_aspect_ratio + ) # single image + else: + image_aspect_ratio = "pad" # multi image + # image_aspect_ratio = ( + # "anyres" if len(image_sizes[image_idx]) == 1 else "pad" + # ) + if ( + image_feature.shape[0] > 1 + and "anyres" in image_aspect_ratio + ): base_image_feature = image_feature[0] image_feature = image_feature[1:] - height = width = self.num_patches_per_side assert height * width == base_image_feature.shape[0] - if self.image_aspect_ratio == "anyres": - ( - num_patch_width, - num_patch_height, - ) = get_anyres_image_grid_shape( - image_sizes[image_idx], - self.image_grid_pinpoints, - self.vision_tower.config.image_size, + + if "anyres_max" in image_aspect_ratio: + matched_anyres_max_num_patches = re.match( + r"anyres_max_(\d+)", image_aspect_ratio ) + if matched_anyres_max_num_patches: + max_num_patches = int( + matched_anyres_max_num_patches.group(1) + ) + + if ( + image_aspect_ratio == "anyres" + or "anyres_max" in image_aspect_ratio + ): + vision_tower_image_size = self.image_size + try: + num_patch_width, num_patch_height = ( + get_anyres_image_grid_shape( + image_sizes[image_idx][0], + self.config.image_grid_pinpoints, + vision_tower_image_size, + ) + ) + except Exception as e: + print(f"Error: {e}") + num_patch_width, num_patch_height = 2, 2 image_feature = image_feature.view( num_patch_height, num_patch_width, height, width, -1 ) else: - raise NotImplementedError() + image_feature = image_feature.view( + 2, 2, height, width, -1 + ) + + # ( + # num_patch_width, + # num_patch_height, + # ) = get_anyres_image_grid_shape( + # image_sizes[image_idx][0], + # self.image_grid_pinpoints, + # self.vision_tower.config.image_size, + # ) + + # image_feature = image_feature.view( + # num_patch_height, num_patch_width, height, width, -1 + # ) + if "unpad" in self.mm_patch_merge_type: + unit = image_feature.shape[2] image_feature = image_feature.permute( 4, 0, 2, 1, 3 ).contiguous() @@ -191,8 +267,23 @@ class LlavaLlamaForCausalLM(nn.Module): 2, 3 ) image_feature = unpad_image( - image_feature, image_sizes[image_idx] + image_feature, image_sizes[image_idx][0] ) + if ( + "anyres_max" in image_aspect_ratio + and matched_anyres_max_num_patches + ): + c, h, w = image_feature.shape + times = math.sqrt( + h * w / (max_num_patches * unit**2) + ) + if times > 1.1: + image_feature = image_feature[None] + image_feature = nn.functional.interpolate( + image_feature, + [int(h // times), int(w // times)], + mode="bilinear", + )[0] image_feature = torch.cat( ( image_feature, @@ -213,16 +304,31 @@ class LlavaLlamaForCausalLM(nn.Module): image_feature = torch.cat( (base_image_feature, image_feature), dim=0 ) + image_feature = image_feature.unsqueeze(0) else: - image_feature = image_feature[0] - if "unpad" in self.mm_patch_merge_type: - image_feature = torch.cat( - ( - image_feature, - self.language_model.model.image_newline[None], - ), - dim=0, + if image_feature.shape[0] > 16: # video + # 2x2 pooling + num_of_frames = image_feature.shape[0] + image_feature = image_feature.view( + num_of_frames, height, width, -1 ) + image_feature = image_feature.permute( + 0, 3, 1, 2 + ).contiguous() # N, C, H, W + height, weight = image_feature.shape[2:] + scaled_shape = [ + math.ceil(height / 2), + math.ceil(weight / 2), + ] + image_feature = nn.functional.interpolate( + image_feature, size=scaled_shape, mode="bilinear" + ) + image_feature = ( + image_feature.flatten(2) + .transpose(1, 2) + .contiguous() + ) # N, C, H*W + new_image_features.append(image_feature) image_features = new_image_features @@ -233,21 +339,22 @@ class LlavaLlamaForCausalLM(nn.Module): continue start_idx = extend_start_loc_cpu[i] - pad_len, pad_dim = image_features[pt].shape # 576, 4096 + pad_dim = image_features[pt].shape[-1] # 576, 4096 dim = input_embeds.shape[1] assert ( pad_dim == dim ), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim) # Fill in the placeholder for the image try: - input_embeds[ - start_idx - + image_offsets[i] : start_idx - + image_offsets[i] - + pad_len - ] = image_features[pt] + for j, image_off in enumerate(image_offsets[i]): + # print("actual image_features length: ", image_features[pt][j].shape[0]) + pad_len = image_features[pt][j].shape[0] + input_embeds[ + start_idx + image_off : start_idx + image_off + pad_len + ] = image_features[pt][j] except RuntimeError as e: print(f"RuntimeError in llava image encoding: {e}") + print(image_features[pt].shape) print(input_embeds.shape) print(start_idx, image_offsets[i]) pt += 1 @@ -262,9 +369,16 @@ class LlavaLlamaForCausalLM(nn.Module): # load clip vision model by cfg['mm_vision_tower']: # huggingface_name or path_of_clip_relative_to_llava_model_dir vision_path = self.config.mm_vision_tower - self.vision_tower = CLIPVisionModel.from_pretrained( - vision_path, torch_dtype=torch.float16 - ).cuda() + if "clip" in vision_path: + self.vision_tower = CLIPVisionModel.from_pretrained( + vision_path, torch_dtype=torch.float16 + ).cuda() + elif "siglip" in vision_path: + self.vision_tower = SiglipVisionModel.from_pretrained( + vision_path, torch_dtype=torch.float16 + ).cuda() + # Siglip needs all feature tokens + self.config.mm_vision_select_feature = "full" self.vision_tower.eval() self.vision_feature_layer = self.config.mm_vision_select_layer @@ -276,8 +390,11 @@ class LlavaLlamaForCausalLM(nn.Module): self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square") self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None) - self.image_feature_len = int((self.image_size / self.patch_size) ** 2) - if self.vision_feature_select_strategy == "patch": + self.image_feature_len = int((self.image_size // self.patch_size) ** 2) + if ( + self.vision_feature_select_strategy == "patch" + or self.vision_feature_select_strategy == "full" + ): pass elif self.vision_feature_select_strategy == "cls_patch": self.image_feature_len += 1 diff --git a/test/srt/test_vision_openai_server.py b/test/srt/test_vision_openai_server.py index c599d8b36..3e858dfa7 100644 --- a/test/srt/test_vision_openai_server.py +++ b/test/srt/test_vision_openai_server.py @@ -1,17 +1,27 @@ +import base64 +import io import json +import os +import sys +import time import unittest +import numpy as np import openai +import requests +from decord import VideoReader, cpu +from PIL import Image from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.utils import kill_child_process from sglang.test.test_utils import DEFAULT_URL_FOR_UNIT_TEST, popen_launch_server +# python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --tokenizer-path lmms-lab/llavanext-qwen-siglip-tokenizer --port=30000 --host=127.0.0.1 --tp-size=8 --chat-template=chatml-llava --chunked-prefill-size=16384 class TestOpenAIVisionServer(unittest.TestCase): @classmethod def setUpClass(cls): - cls.model = "liuhaotian/llava-v1.6-vicuna-7b" + cls.model = "lmms-lab/llava-onevision-qwen2-0.5b-ov" cls.base_url = DEFAULT_URL_FOR_UNIT_TEST cls.api_key = "sk-123456" cls.process = popen_launch_server( @@ -21,9 +31,11 @@ class TestOpenAIVisionServer(unittest.TestCase): api_key=cls.api_key, other_args=[ "--chat-template", - "vicuna_v1.1", + "chatml-llava", "--tokenizer-path", - "llava-hf/llava-1.5-7b-hf", + "lmms-lab/llavanext-qwen-siglip-tokenizer", + "--chunked-prefill-size", + "16384", "--log-requests", ], ) @@ -68,6 +80,81 @@ class TestOpenAIVisionServer(unittest.TestCase): assert response.usage.completion_tokens > 0 assert response.usage.total_tokens > 0 + def prepare_video_messages(self, video_path): + max_frames_num = 32 + vr = VideoReader(video_path, ctx=cpu(0)) + total_frame_num = len(vr) + uniform_sampled_frames = np.linspace( + 0, total_frame_num - 1, max_frames_num, dtype=int + ) + frame_idx = uniform_sampled_frames.tolist() + frames = vr.get_batch(frame_idx).asnumpy() + + base64_frames = [] + for frame in frames: + pil_img = Image.fromarray(frame) + buff = io.BytesIO() + pil_img.save(buff, format="JPEG") + base64_str = base64.b64encode(buff.getvalue()).decode("utf-8") + base64_frames.append(base64_str) + + messages = [{"role": "user", "content": []}] + frame_format = { + "type": "image_url", + "image_url": {"url": "data:image/jpeg;base64,{}"}, + } + + for base64_frame in base64_frames: + frame_format["image_url"]["url"] = "data:image/jpeg;base64,{}".format( + base64_frame + ) + messages[0]["content"].append(frame_format.copy()) + + prompt = {"type": "text", "text": "Please describe the video in detail."} + messages[0]["content"].append(prompt) + + return messages + + def test_video_chat_completion(self): + url = "https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4" + cache_dir = os.path.expanduser("~/.cache") + file_path = os.path.join(cache_dir, "jobs.mp4") + os.makedirs(cache_dir, exist_ok=True) + + if not os.path.exists(file_path): + response = requests.get(url) + response.raise_for_status() + + with open(file_path, "wb") as f: + f.write(response.content) + + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + + messages = self.prepare_video_messages(file_path) + + start_time = time.time() + video_request = client.chat.completions.create( + model="default", + messages=messages, + temperature=0, + max_tokens=1024, + stream=True, + ) + print("-" * 30) + video_response = "" + + for chunk in video_request: + if chunk.choices[0].delta.content is not None: + content = chunk.choices[0].delta.content + video_response += content + sys.stdout.write(content) + sys.stdout.flush() + print("-" * 30) + + # Add assertions to validate the video response + self.assertIsNotNone(video_response) + self.assertGreater(len(video_response), 0) + def test_regex(self): client = openai.Client(api_key=self.api_key, base_url=self.base_url)