From bf53bf5142bd3393d495608e58c86f6d8c991664 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 28 Aug 2024 06:33:05 -0700 Subject: [PATCH] [Fix] Fix llava on multi images (#1247) --- README.md | 2 +- .../usage/llava_video/srt_example_llava_v.py | 13 +- python/sglang/launch_server_llavavid.py | 26 +++ python/sglang/srt/hf_transformers_utils.py | 149 ------------------ python/sglang/srt/managers/io_struct.py | 9 +- python/sglang/srt/managers/schedule_batch.py | 10 +- .../sglang/srt/managers/tokenizer_manager.py | 135 +++++++++------- python/sglang/srt/managers/tp_worker.py | 19 ++- .../srt/model_executor/forward_batch_info.py | 30 ++-- .../sglang/srt/model_executor/model_runner.py | 21 ++- python/sglang/srt/models/chatglm.py | 2 +- python/sglang/srt/models/grok.py | 12 +- python/sglang/srt/models/llama2.py | 7 +- .../sglang/srt/models/llama_classification.py | 4 - python/sglang/srt/models/llama_embedding.py | 7 +- python/sglang/srt/models/llava.py | 111 +++++-------- python/sglang/srt/models/llavavid.py | 126 +++++---------- python/sglang/srt/models/qwen2.py | 7 +- python/sglang/srt/models/yivl.py | 9 +- python/sglang/srt/server.py | 8 +- python/sglang/srt/utils.py | 51 +++--- test/srt/test_vision_openai_server.py | 2 - 22 files changed, 272 insertions(+), 488 deletions(-) create mode 100644 python/sglang/launch_server_llavavid.py diff --git a/README.md b/README.md index 223f9624f..9d795ce43 100644 --- a/README.md +++ b/README.md @@ -240,7 +240,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct - Qwen / Qwen 2 / Qwen 2 MoE - DeepSeek / DeepSeek 2 - [LLaVA-OneVision](https://llava-vl.github.io/blog/2024-08-05-llava-onevision/) - - `python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8 --chat-template=chatml-llava --chunked-prefill-size=16384` + - `python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8 --chat-template=chatml-llava` - Query the server with the [OpenAI Vision API](https://platform.openai.com/docs/guides/vision). See examples at [test/srt/test_vision_openai_server.py](test/srt/test_vision_openai_server.py) - LLaVA 1.5 / 1.6 / NeXT - `python -m sglang.launch_server --model-path lmms-lab/llama3-llava-next-8b --port=30000 --tp-size=1 --chat-template=llava_llama_3` diff --git a/examples/frontend_language/usage/llava_video/srt_example_llava_v.py b/examples/frontend_language/usage/llava_video/srt_example_llava_v.py index 085bcea5a..1f2931a5a 100644 --- a/examples/frontend_language/usage/llava_video/srt_example_llava_v.py +++ b/examples/frontend_language/usage/llava_video/srt_example_llava_v.py @@ -184,13 +184,9 @@ if __name__ == "__main__": # Parse the arguments args = parser.parse_args() - cur_port = args.port - cur_chunk = args.chunk_idx - num_chunks = args.num_chunks - num_frames = args.num_frames if "34b" in args.model_path.lower(): @@ -202,7 +198,6 @@ if __name__ == "__main__": exit() model_overide_args = {} - model_overide_args["mm_spatial_pool_stride"] = args.mm_spatial_pool_stride model_overide_args["architectures"] = ["LlavaVidForCausalLM"] model_overide_args["num_frames"] = args.num_frames @@ -235,7 +230,6 @@ if __name__ == "__main__": print(f"chat template: {runtime.endpoint.chat_template.name}") # Run a single request - # try: print("\n========== single ==========\n") root = args.video_dir if os.path.isfile(root): @@ -257,13 +251,10 @@ if __name__ == "__main__": ) # Calculate the average processing time print(f"Average processing time per video: {average_time:.2f} seconds") runtime.shutdown() - # except Exception as e: - # print(e) - runtime.shutdown() - # # # Run a batch of requests + # # Run a batch of requests # print("\n========== batch ==========\n") # if not os.path.exists(args.save_dir): # os.makedirs(args.save_dir) - # batch(args.video_dir,args.save_dir,cur_chunk, num_chunks, num_frames, num_chunks) + # batch(args.video_dir, args.save_dir, cur_chunk, num_chunks, num_frames, num_chunks) # runtime.shutdown() diff --git a/python/sglang/launch_server_llavavid.py b/python/sglang/launch_server_llavavid.py new file mode 100644 index 000000000..797ad07a4 --- /dev/null +++ b/python/sglang/launch_server_llavavid.py @@ -0,0 +1,26 @@ +"""Launch the inference server for Llava-video model.""" + +import argparse + +from sglang.srt.server import ServerArgs, launch_server + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + ServerArgs.add_cli_args(parser) + args = parser.parse_args() + server_args = ServerArgs.from_cli_args(args) + + model_overide_args = {} + model_overide_args["mm_spatial_pool_stride"] = 2 + model_overide_args["architectures"] = ["LlavaVidForCausalLM"] + model_overide_args["num_frames"] = 16 + model_overide_args["model_type"] = "llavavid" + if model_overide_args["num_frames"] == 32: + model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"} + model_overide_args["max_sequence_length"] = 4096 * 2 + model_overide_args["tokenizer_model_max_length"] = 4096 * 2 + model_overide_args["model_max_length"] = 4096 * 2 + if "34b" in args.model_path.lower(): + model_overide_args["image_token_index"] = 64002 + + launch_server(server_args, model_overide_args, None) diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index 4f6e3d071..2be416914 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -119,24 +119,7 @@ def get_tokenizer( tokenizer_revision: Optional[str] = None, **kwargs, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: - if tokenizer_name.endswith(".json"): - return TiktokenTokenizer(tokenizer_name) - - if tokenizer_name.endswith(".model"): - return SentencePieceTokenizer(tokenizer_name) - """Gets a tokenizer for the given model name via Huggingface.""" - if is_multimodal_model(tokenizer_name): - processor = get_processor( - tokenizer_name, - *args, - trust_remote_code=trust_remote_code, - tokenizer_revision=tokenizer_revision, - **kwargs, - ) - tokenizer = processor.tokenizer - return tokenizer - if tokenizer_mode == "slow": if kwargs.get("use_fast", False): raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") @@ -199,135 +182,3 @@ def get_processor( **kwargs, ) return processor - - -class TiktokenTokenizer: - def __init__(self, tokenizer_path): - import tiktoken - from jinja2 import Template - - PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" - - # Read JSON - name = "tmp-json" - with open(tokenizer_path, "rb") as fin: - tok_dict = json.load(fin) - - mergeable_ranks = { - bytes(item["bytes"]): item["token"] for item in tok_dict["regular_tokens"] - } - special_tokens = { - bytes(item["bytes"]).decode(): item["token"] - for item in tok_dict["special_tokens"] - } - assert tok_dict["word_split"] == "V1" - - default_allowed_special = None - - kwargs = { - "name": name, - "pat_str": tok_dict.get("pat_str", PAT_STR_B), - "mergeable_ranks": mergeable_ranks, - "special_tokens": special_tokens, - } - if "default_allowed_special" in tok_dict: - default_allowed_special = set( - [ - bytes(bytes_list).decode() - for bytes_list in tok_dict["default_allowed_special"] - ] - ) - if "vocab_size" in tok_dict: - kwargs["explicit_n_vocab"] = tok_dict["vocab_size"] - - PAD = "<|pad|>" - EOS = "<|eos|>" - SEP = "<|separator|>" - - DEFAULT_CONTROL_TOKENS = {"pad": PAD, "sep": EOS, "eos": SEP} - - tokenizer = tiktoken.Encoding(**kwargs) - tokenizer._default_allowed_special = default_allowed_special or set() - tokenizer._control_tokens = DEFAULT_CONTROL_TOKENS - - def encode_patched( - self, - text: str, - *, - allowed_special: Union[ - Literal["all"], AbstractSet[str] - ] = set(), # noqa: B006 - disallowed_special: Union[Literal["all"], Collection[str]] = "all", - ) -> List[int]: - if isinstance(allowed_special, set): - allowed_special |= self._default_allowed_special - return tiktoken.Encoding.encode( - self, - text, - allowed_special=allowed_special, - disallowed_special=(), - ) - - tokenizer.encode = functools.partial(encode_patched, tokenizer) - - # Convert to HF interface - self.tokenizer = tokenizer - self.eos_token_id = tokenizer._special_tokens[EOS] - self.vocab_size = tokenizer.n_vocab - self.chat_template = Template( - "{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}" - ) - - def encode(self, x, add_special_tokens=False): - return self.tokenizer.encode(x) - - def decode(self, x): - return self.tokenizer.decode(x) - - def batch_decode( - self, batch, skip_special_tokens=True, spaces_between_special_tokens=False - ): - if isinstance(batch[0], int): - batch = [[x] for x in batch] - return self.tokenizer.decode_batch(batch) - - def apply_chat_template(self, messages, tokenize, add_generation_prompt): - ret = self.chat_template.render( - messages=messages, add_generation_prompt=add_generation_prompt - ) - return self.encode(ret) if tokenize else ret - - -class SentencePieceTokenizer: - def __init__(self, tokenizer_path): - import sentencepiece as spm - from jinja2 import Template - - tokenizer = spm.SentencePieceProcessor(model_file=tokenizer_path) - - # Convert to HF interface - self.tokenizer = tokenizer - self.eos_token_id = tokenizer.eos_id() - self.vocab_size = tokenizer.vocab_size() - self.chat_template = Template( - "{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}" - ) - - def encode(self, x, add_special_tokens=False): - return self.tokenizer.encode(x) - - def decode(self, x): - return self.tokenizer.decode(x) - - def batch_decode( - self, batch, skip_special_tokens=True, spaces_between_special_tokens=False - ): - if isinstance(batch[0], int): - batch = [[x] for x in batch] - return self.tokenizer.decode(batch) - - def apply_chat_template(self, messages, tokenize, add_generation_prompt): - ret = self.chat_template.render( - messages=messages, add_generation_prompt=add_generation_prompt - ) - return self.encode(ret) if tokenize else ret diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 56e3d8f79..3f80c64cf 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -55,6 +55,7 @@ class GenerateReqInput: self.text is not None and self.input_ids is not None ): raise ValueError("Either text or input_ids should be provided.") + if ( isinstance(self.sampling_params, dict) and self.sampling_params.get("n", 1) != 1 @@ -161,10 +162,10 @@ class TokenizedGenerateReqInput: input_ids: List[int] # The pixel values for input images pixel_values: List[float] - # The hash of input images - image_hash: int - # The image size - image_size: List[int] + # The hash values of input images + image_hashes: List[int] + # The image sizes + image_sizes: List[List[int]] # The sampling parameters sampling_params: SamplingParams # Whether to return the logprobs diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index f3af821e4..5554170a3 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -121,8 +121,8 @@ class Req: # For vision input self.pixel_values = None - self.image_size = None - self.image_offset = None + self.image_sizes = None + self.image_offsets = None self.pad_value = None # Prefix info @@ -600,12 +600,12 @@ class ScheduleBatch: if req.pixel_values is not None: ( req.origin_input_ids, - req.image_offset, + req.image_offsets, ) = model_runner.model.pad_input_ids( req.origin_input_ids_unpadded, req.pad_value, - req.pixel_values.shape, - req.image_size, + req.pixel_values, + req.image_sizes, ) jump_forward_reqs.append(req) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index c74251947..5ad4152ea 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -23,6 +23,7 @@ import multiprocessing as mp import os from typing import Dict, List, Optional, Tuple, Union +import fastapi import numpy as np import transformers import uvloop @@ -96,21 +97,18 @@ class TokenizerManager: trust_remote_code=server_args.trust_remote_code, model_overide_args=model_overide_args, ) - self.is_generation = is_generation_model( self.hf_config.architectures, self.server_args.is_embedding ) - - if server_args.context_length is not None: - self.context_len = server_args.context_length - else: - self.context_len = get_context_length(self.hf_config) + self.context_len = server_args.context_length or get_context_length( + self.hf_config + ) # Create tokenizer if server_args.skip_tokenizer_init: self.tokenizer = self.processor = None else: - if is_multimodal_model(self.model_path): + if is_multimodal_model(self.hf_config.architectures): self.processor = get_processor( server_args.tokenizer_path, tokenizer_mode=server_args.tokenizer_mode, @@ -118,6 +116,9 @@ class TokenizerManager: ) self.tokenizer = self.processor.tokenizer os.environ["TOKENIZERS_PARALLELISM"] = "false" + + # We want to parallelize the image pre-processing so we + # create an executor for it self.executor = concurrent.futures.ProcessPoolExecutor( initializer=init_global_processor, mp_context=mp.get_context("fork"), @@ -134,12 +135,14 @@ class TokenizerManager: self.to_create_loop = True self.rid_to_state: Dict[str, ReqState] = {} - # for update model weights + # For update model weights self.model_update_lock = asyncio.Lock() self.model_update_result = None async def generate_request( - self, obj: Union[GenerateReqInput, EmbeddingReqInput], request=None + self, + obj: Union[GenerateReqInput, EmbeddingReqInput], + request: Optional[fastapi.Request] = None, ): if self.to_create_loop: self.create_handle_loop() @@ -160,7 +163,7 @@ class TokenizerManager: async def _handle_single_request( self, obj: Union[GenerateReqInput, EmbeddingReqInput], - request, + request: Optional[fastapi.Request] = None, index: Optional[int] = None, is_cache_for_prefill: Optional[bool] = False, ): @@ -182,8 +185,8 @@ class TokenizerManager: ) if self.is_generation: - pixel_values, image_hash, image_size = await self._get_pixel_values( - obj.image_data + pixel_values, image_hashes, image_sizes = await self._get_pixel_values( + obj.image_data if not_use_index else obj.image_data[index] ) return_logprob = ( obj.return_logprob if not_use_index else obj.return_logprob[index] @@ -195,7 +198,6 @@ class TokenizerManager: ) if return_logprob and logprob_start_len == -1: logprob_start_len = len(input_ids) - 1 - top_logprobs_num = ( obj.top_logprobs_num if not_use_index @@ -238,13 +240,14 @@ class TokenizerManager: sampling_params = SamplingParams(**obj.sampling_params[0]) sampling_params.max_new_tokens = 0 - pixel_values, image_hash, image_size = await self._get_pixel_values( + pixel_values, image_hashes, image_sizes = await self._get_pixel_values( obj.image_data[0] ) return_logprob = obj.return_logprob[0] logprob_start_len = obj.logprob_start_len[0] top_logprobs_num = obj.top_logprobs_num[0] + # Send to the controller if self.is_generation: if return_logprob and logprob_start_len == -1: logprob_start_len = len(input_ids) - 1 @@ -253,8 +256,8 @@ class TokenizerManager: input_text, input_ids, pixel_values, - image_hash, - image_size, + image_hashes, + image_sizes, sampling_params, return_logprob, logprob_start_len, @@ -268,24 +271,24 @@ class TokenizerManager: input_ids, sampling_params, ) - self.send_to_router.send_pyobj(tokenized_obj) + # Recv results event = asyncio.Event() state = ReqState([], False, event) self.rid_to_state[rid] = state if not is_cache_for_prefill: - async for response in self._wait_for_response( - event, state, obj, rid, request - ): + async for response in self._wait_for_response(state, obj, rid, request): yield response else: assert self.is_generation - await self._wait_for_cache_prefill_response(event, state, obj, rid, request) + await self._wait_for_cache_prefill_response(state, obj, rid, request) yield input_ids async def _handle_batch_request( - self, obj: Union[GenerateReqInput, EmbeddingReqInput], request + self, + obj: Union[GenerateReqInput, EmbeddingReqInput], + request: Optional[fastapi.Request] = None, ): batch_size = obj.batch_size if self.is_generation: @@ -340,8 +343,8 @@ class TokenizerManager: if self.is_generation: if obj.return_logprob[index] and obj.logprob_start_len[index] == -1: obj.logprob_start_len[index] = len(input_ids) - 1 - pixel_values, image_hash, image_size = await self._get_pixel_values( - obj.image_data[index] + pixel_values, image_hashes, image_sizes = ( + await self._get_pixel_values(obj.image_data[index]) ) tokenized_obj = TokenizedGenerateReqInput( @@ -349,8 +352,8 @@ class TokenizerManager: input_text, input_ids, pixel_values, - image_hash, - image_size, + image_hashes, + image_sizes, sampling_params, obj.return_logprob[index], obj.logprob_start_len[index], @@ -372,7 +375,6 @@ class TokenizerManager: generators.append( self._wait_for_response( - event, state, obj, rid, @@ -388,6 +390,7 @@ class TokenizerManager: tasks = [asyncio.create_task(gen.__anext__()) for gen in generators] output_list = [None] * len(tasks) + # Recv results while tasks: done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) @@ -426,25 +429,18 @@ class TokenizerManager: sampling_params.verify() return sampling_params - async def _get_pixel_values(self, image_data): - if image_data is None: - return None, None, None - else: - return await self._get_pixel_values_internal(image_data) - async def _wait_for_response( self, - event: asyncio.Event, state: ReqState, obj: Union[GenerateReqInput, EmbeddingReqInput], rid: str, - request, - index: int = None, + request: Optional[fastapi.Request] = None, + index: Optional[int] = None, response_index: int = 0, ): while True: try: - await asyncio.wait_for(event.wait(), timeout=4) + await asyncio.wait_for(state.event.wait(), timeout=4) except asyncio.TimeoutError: if request is not None and await request.is_disconnected(): for rid in [obj.rid] if obj.is_single else obj.rid: @@ -478,16 +474,15 @@ class TokenizerManager: yield out break - event.clear() + state.event.clear() yield out async def _wait_for_cache_prefill_response( self, - event: asyncio.Event, state: ReqState, obj: GenerateReqInput, rid: str, - request, + request: Optional[fastapi.Request] = None, ): while True: try: @@ -514,7 +509,9 @@ class TokenizerManager: req = AbortReq(rid) self.send_to_router.send_pyobj(req) - async def update_weights(self, obj: UpdateWeightReqInput, request): + async def update_weights( + self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None + ): if self.to_create_loop: self.create_handle_loop() @@ -659,12 +656,11 @@ class TokenizerManager: ) return top_logprobs - async def _get_pixel_values_internal(self, image_data, aspect_ratio=None): - aspect_ratio = ( - getattr(self.hf_config, "image_aspect_ratio", None) - if aspect_ratio is None - else aspect_ratio - ) + async def _get_pixel_values(self, image_data: List[Union[str, bytes]]): + if not image_data: + return None, None, None + + aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None) grid_pinpoints = ( self.hf_config.image_grid_pinpoints if hasattr(self.hf_config, "image_grid_pinpoints") @@ -673,35 +669,42 @@ class TokenizerManager: ) if isinstance(image_data, list) and len(image_data) > 0: - pixel_values, image_hash, image_size = [], [], [] + # Multiple images if len(image_data) > 1: aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres + pixel_values, image_hashes, image_sizes = [], [], [] for img_data in image_data: pixel_v, image_h, image_s = await self._process_single_image( img_data, aspect_ratio, grid_pinpoints ) pixel_values.append(pixel_v) - image_hash.append(image_h) - image_size.append(image_s) - pixel_values = np.stack(pixel_values, axis=0) + image_hashes.append(image_h) + image_sizes.append(image_s) + + if isinstance(pixel_values[0], np.ndarray): + pixel_values = np.stack(pixel_values, axis=0) else: + # A single image pixel_values, image_hash, image_size = await self._process_single_image( image_data[0], aspect_ratio, grid_pinpoints ) - image_hash = [image_hash] - image_size = [image_size] + image_hashes = [image_hash] + image_sizes = [image_size] elif isinstance(image_data, str): + # A single image pixel_values, image_hash, image_size = await self._process_single_image( image_data, aspect_ratio, grid_pinpoints ) - image_hash = [image_hash] - image_size = [image_size] + image_hashes = [image_hash] + image_sizes = [image_size] else: - pixel_values, image_hash, image_size = None, None, None + raise ValueError(f"Invalid image data: {image_data}") - return pixel_values, image_hash, image_size + return pixel_values, image_hashes, image_sizes - async def _process_single_image(self, image_data, aspect_ratio, grid_pinpoints): + async def _process_single_image( + self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str + ): if self.executor is not None: loop = asyncio.get_event_loop() return await loop.run_in_executor( @@ -732,12 +735,16 @@ def init_global_processor(server_args: ServerArgs): def _process_single_image_task( - image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None + image_data: Union[str, bytes], + image_aspect_ratio: Optional[str] = None, + image_grid_pinpoints: Optional[str] = None, + processor=None, ): try: processor = processor or global_processor image, image_size = load_image(image_data) if image_size is not None: + # It is a video with multiple images image_hash = hash(image_data) pixel_values = processor.image_processor(image)["pixel_values"] for _ in range(len(pixel_values)): @@ -745,6 +752,7 @@ def _process_single_image_task( pixel_values = np.stack(pixel_values, axis=0) return pixel_values, image_hash, image_size else: + # It is an image image_hash = hash(image_data) if image_aspect_ratio == "pad": image = expand2square( @@ -754,13 +762,18 @@ def _process_single_image_task( pixel_values = processor.image_processor(image.convert("RGB"))[ "pixel_values" ][0] - elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio: + elif image_aspect_ratio == "anyres" or ( + image_aspect_ratio is not None and "anyres_max" in image_aspect_ratio + ): pixel_values = process_anyres_image( image, processor.image_processor, image_grid_pinpoints ) else: pixel_values = processor.image_processor(image)["pixel_values"][0] - pixel_values = pixel_values.astype(np.float16) + + if isinstance(pixel_values, np.ndarray): + pixel_values = pixel_values.astype(np.float16) + return pixel_values, image_hash, image.size except Exception: logger.error("Exception in TokenizerManager:\n" + get_exception_traceback()) diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 65daed43b..cd1b58064 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -108,7 +108,7 @@ class ModelTpServer: if server_args.skip_tokenizer_init: self.tokenizer = self.processor = None else: - if is_multimodal_model(server_args.model_path): + if is_multimodal_model(self.model_config.hf_config.architectures): self.processor = get_processor( server_args.tokenizer_path, tokenizer_mode=server_args.tokenizer_mode, @@ -333,26 +333,24 @@ class ModelTpServer: if self.model_runner.is_generation: req.pixel_values = recv_req.pixel_values if req.pixel_values is not None: - image_hash = ( - hash(tuple(recv_req.image_hash)) - if isinstance(recv_req.image_hash, list) - else recv_req.image_hash - ) + # Use image hash as fake token_ids, which is then used + # for prefix matching + image_hash = hash(tuple(recv_req.image_hashes)) req.pad_value = [ (image_hash) % self.model_config.vocab_size, (image_hash >> 16) % self.model_config.vocab_size, (image_hash >> 32) % self.model_config.vocab_size, (image_hash >> 64) % self.model_config.vocab_size, ] - req.image_size = recv_req.image_size + req.image_sizes = recv_req.image_sizes ( req.origin_input_ids, - req.image_offset, + req.image_offsets, ) = self.model_runner.model.pad_input_ids( req.origin_input_ids_unpadded, req.pad_value, - req.pixel_values.shape, - req.image_size, + req.pixel_values, + req.image_sizes, ) req.return_logprob = recv_req.return_logprob req.logprob_start_len = recv_req.logprob_start_len @@ -368,6 +366,7 @@ class ModelTpServer: req.jump_forward_map = self.jump_forward_cache.query( computed_regex_string ) + # Init regex fsm elif req.sampling_params.regex is not None: req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index c107b3bc8..f24cdf6b7 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -16,7 +16,7 @@ limitations under the License. """ModelRunner runs the forward passes of the models.""" from dataclasses import dataclass from enum import IntEnum, auto -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, List import numpy as np import torch @@ -58,6 +58,7 @@ class InputMetadata: # For extend extend_seq_lens: torch.Tensor = None + extend_prefix_lens: torch.Tensor = None extend_start_loc: torch.Tensor = None extend_no_prefix: bool = None @@ -69,8 +70,8 @@ class InputMetadata: # For multimodal pixel_values: List[torch.Tensor] = None - image_sizes: List[List[int]] = None - image_offsets: List[int] = None + image_sizes: List[List[List[int]]] = None + image_offsets: List[List[int]] = None # Trition attention backend triton_max_seq_len: int = 0 @@ -87,20 +88,8 @@ class InputMetadata: def init_multimuldal_info(self, batch: ScheduleBatch): reqs = batch.reqs self.pixel_values = [r.pixel_values for r in reqs] - self.image_sizes = [r.image_size for r in reqs] - self.image_offsets = [] - for r in reqs: - if isinstance(r.image_offset, list): - self.image_offsets.append( - [ - (image_offset - len(r.prefix_indices)) - for image_offset in r.image_offset - ] - ) - elif isinstance(r.image_offset, int): - self.image_offsets.append(r.image_offset - len(r.prefix_indices)) - elif r.image_offset is None: - self.image_offsets.append(0) + self.image_sizes = [r.image_sizes for r in reqs] + self.image_offsets = [r.image_offsets for r in reqs] def compute_positions(self, batch: ScheduleBatch): position_ids_offsets = batch.position_ids_offsets @@ -153,6 +142,7 @@ class InputMetadata: for i, r in enumerate(batch.reqs) ] self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda") + self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda") self.extend_start_loc = torch.zeros_like(self.seq_lens) self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0) self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu) @@ -238,10 +228,10 @@ class InputMetadata: prefix_lens_cpu, flashinfer_use_ragged, ): - if self.forward_mode != ForwardMode.DECODE: - prefix_lens = torch.tensor(prefix_lens_cpu, device="cuda") - else: + if self.forward_mode == ForwardMode.DECODE: prefix_lens = None + else: + prefix_lens = self.extend_prefix_lens update_flashinfer_indices( self.forward_mode, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index abee152d6..8ef47a530 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -50,7 +50,7 @@ from sglang.srt.mem_cache.memory_pool import ( MLATokenToKVPool, ReqToTokenPool, ) -from sglang.srt.model_config import AttentionArch +from sglang.srt.model_config import AttentionArch, ModelConfig from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( @@ -69,7 +69,7 @@ logger = logging.getLogger(__name__) class ModelRunner: def __init__( self, - model_config, + model_config: ModelConfig, mem_fraction_static: float, gpu_id: int, tp_rank: int, @@ -85,7 +85,9 @@ class ModelRunner: self.tp_size = tp_size self.nccl_port = nccl_port self.server_args = server_args - self.is_multimodal_model = is_multimodal_model(self.model_config) + self.is_multimodal_model = is_multimodal_model( + self.model_config.hf_config.architectures + ) global_server_args_dict.update( { "disable_flashinfer": server_args.disable_flashinfer, @@ -95,6 +97,13 @@ class ModelRunner: } ) + if self.is_multimodal_model: + logger.info( + "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models." + ) + server_args.chunked_prefill_size = None + server_args.mem_fraction_static *= 0.95 + min_per_gpu_memory = self.init_torch_distributed() self.load_model() self.init_memory_pool( @@ -507,9 +516,9 @@ class ModelRunner: raise Exception( f"Capture cuda graph failed: {e}\n" "Possible solutions:\n" - "1. disable torch compile by not using --enable-torch-compile\n" - "2. disable cuda graph by --disable-cuda-graph\n" - "3. set --mem-fraction-static to a smaller value\n" + "1. disable cuda graph by --disable-cuda-graph\n" + "2. set --mem-fraction-static to a smaller value\n" + "3. disable torch compile by not using --enable-torch-compile\n" "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n" ) diff --git a/python/sglang/srt/models/chatglm.py b/python/sglang/srt/models/chatglm.py index 0a22f994b..b38b62faf 100644 --- a/python/sglang/srt/models/chatglm.py +++ b/python/sglang/srt/models/chatglm.py @@ -17,7 +17,7 @@ limitations under the License. # Adapted from # https://github.com/THUDM/ChatGLM2-6B """Inference-only ChatGLM model compatible with THUDM weights.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, Optional, Tuple import torch from torch import nn diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index 4a0a08bf8..daf6f25da 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -273,9 +273,9 @@ class Grok1Model(nn.Module): ) -> torch.Tensor: if input_embeds is None: hidden_states = self.embed_tokens(input_ids) + hidden_states.mul_(self.config.embedding_multiplier_scale) else: hidden_states = input_embeds - hidden_states.mul_(self.config.embedding_multiplier_scale) for i in range(len(self.layers)): hidden_states = self.layers[i](positions, hidden_states, input_metadata) @@ -284,7 +284,7 @@ class Grok1Model(nn.Module): return hidden_states -class Grok1ModelForCausalLM(nn.Module): +class Grok1ForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, @@ -415,4 +415,10 @@ def _prepare_presharded_weights( return hf_folder, hf_weights_files, use_safetensors -EntryClass = Grok1ModelForCausalLM +class Grok1ModelForCausalLM(Grok1ForCausalLM): + """An alias for backward-compatbility.""" + + pass + + +EntryClass = [Grok1ForCausalLM, Grok1ModelForCausalLM] diff --git a/python/sglang/srt/models/llama2.py b/python/sglang/srt/models/llama2.py index 9de8d33c5..fe75916a4 100644 --- a/python/sglang/srt/models/llama2.py +++ b/python/sglang/srt/models/llama2.py @@ -357,6 +357,9 @@ class LlamaForCausalLM(nn.Module): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. return + if name.startswith("model.vision_tower") and name not in params_dict: + return + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue @@ -364,8 +367,6 @@ class LlamaForCausalLM(nn.Module): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - if name.startswith("model.vision_tower") and name not in params_dict: - continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -374,8 +375,6 @@ class LlamaForCausalLM(nn.Module): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: return - if name.startswith("model.vision_tower") and name not in params_dict: - return param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/python/sglang/srt/models/llama_classification.py b/python/sglang/srt/models/llama_classification.py index 02224971d..c5effbfc9 100644 --- a/python/sglang/srt/models/llama_classification.py +++ b/python/sglang/srt/models/llama_classification.py @@ -103,8 +103,6 @@ class LlamaForClassification(nn.Module): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - if name.startswith("model.vision_tower") and name not in params_dict: - continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -113,8 +111,6 @@ class LlamaForClassification(nn.Module): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - if name.startswith("model.vision_tower") and name not in params_dict: - continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/python/sglang/srt/models/llama_embedding.py b/python/sglang/srt/models/llama_embedding.py index dfff53cbc..e4e9174f1 100644 --- a/python/sglang/srt/models/llama_embedding.py +++ b/python/sglang/srt/models/llama_embedding.py @@ -57,6 +57,9 @@ class LlamaEmbeddingModel(nn.Module): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. return + if name.startswith("model.vision_tower") and name not in params_dict: + return + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue @@ -64,8 +67,6 @@ class LlamaEmbeddingModel(nn.Module): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - if name.startswith("model.vision_tower") and name not in params_dict: - continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -74,8 +75,6 @@ class LlamaEmbeddingModel(nn.Module): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: return - if name.startswith("model.vision_tower") and name not in params_dict: - return param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index 76a0630fc..bc522bec9 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -28,7 +28,6 @@ from transformers import ( LlavaConfig, MistralConfig, Qwen2Config, - SiglipVisionConfig, SiglipVisionModel, ) from transformers.models.llava.modeling_llava import LlavaMultiModalProjector @@ -66,13 +65,18 @@ class LlavaLlamaForCausalLM(nn.Module): torch.empty(config.text_config.hidden_size, dtype=torch.float16) ) - def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None): - + def pad_input_ids( + self, + input_ids: List[int], + pad_value: List[int], + pixel_values: List, + image_sizes: List[List[int]], + ): # hardcode for spatial_unpad + anyres - image_aspect_ratio = "anyres" if len(image_size) == 1 else "pad" + image_aspect_ratio = "anyres" if len(image_sizes) == 1 else "pad" offset_list = [] - for image_s in image_size: - if len(image_size) > 16: + for image_s in image_sizes: + if len(image_sizes) > 16: # 2x2 pooling with stride 2 new_image_feature_len = ( math.ceil(self.image_size / self.patch_size / 2) ** 2 @@ -153,17 +157,15 @@ class LlavaLlamaForCausalLM(nn.Module): if input_metadata.forward_mode == ForwardMode.EXTEND: bs = input_metadata.batch_size - # Embed text input + # Embed text inputs input_embeds = self.language_model.model.embed_tokens(input_ids) - # Embed vision input - need_vision = ( - (positions[input_metadata.extend_start_loc] < self.image_feature_len) - .cpu() - .numpy() + + # Whether the requests need vision inputs + max_image_offset = np.array( + [max(image_offsets[i]) if image_offsets[i] else -1 for i in range(bs)] ) - # FIXME: We need to substract the length of the system prompt - has_pixel = np.array([pixel_values[i] is not None for i in range(bs)]) - need_vision = need_vision & has_pixel + start_positions = positions[input_metadata.extend_start_loc].cpu().numpy() + need_vision = start_positions <= max_image_offset if need_vision.any(): pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]] @@ -332,31 +334,35 @@ class LlavaLlamaForCausalLM(nn.Module): new_image_features.append(image_feature) image_features = new_image_features + # Fill in the placeholder for the image extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy() + prefix_lens_cpu = input_metadata.extend_prefix_lens.cpu().numpy() pt = 0 for i in range(bs): if not need_vision[i]: continue start_idx = extend_start_loc_cpu[i] - pad_dim = image_features[pt].shape[-1] # 576, 4096 - dim = input_embeds.shape[1] - assert ( - pad_dim == dim - ), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim) - # Fill in the placeholder for the image - try: - for j, image_off in enumerate(image_offsets[i]): - # print("actual image_features length: ", image_features[pt][j].shape[0]) - pad_len = image_features[pt][j].shape[0] - input_embeds[ - start_idx + image_off : start_idx + image_off + pad_len - ] = image_features[pt][j] - except RuntimeError as e: - print(f"RuntimeError in llava image encoding: {e}") - print(image_features[pt].shape) - print(input_embeds.shape) - print(start_idx, image_offsets[i]) + prefix_len = prefix_lens_cpu[i] + + # Multiple images + for j, image_offset in enumerate(image_offsets[i]): + if image_offset < prefix_len: + continue + + tmp_image_feature = image_features[pt][j] + pad_len = tmp_image_feature.shape[0] + + left_idx = start_idx + (image_offset - prefix_len) + right_idx = start_idx + (image_offset - prefix_len) + pad_len + try: + input_embeds[left_idx:right_idx] = tmp_image_feature + except RuntimeError as e: + print(f"RuntimeError in image encoding: {e}") + print(f"{input_embeds.shape=}, {tmp_image_feature.shape=}") + print( + f"{start_idx=}, {image_offset=}, {prefix_len=}, {pad_len=}" + ) pt += 1 return self.language_model( @@ -366,8 +372,9 @@ class LlavaLlamaForCausalLM(nn.Module): return self.language_model(input_ids, positions, input_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - # load clip vision model by cfg['mm_vision_tower']: - # huggingface_name or path_of_clip_relative_to_llava_model_dir + # Load clip vision model by cfg['mm_vision_tower']: + # huggingface_name or path_of_clip_relative_to_llava_model_dir + # We put the initialization here instead of __init__ to allow it being reused by other subclasses. vision_path = self.config.mm_vision_tower if "clip" in vision_path: self.vision_tower = CLIPVisionModel.from_pretrained( @@ -422,8 +429,6 @@ class LlavaLlamaForCausalLM(nn.Module): # load language model self.language_model.load_weights(weights) - monkey_path_clip_vision_embed_forward() - @property def num_patches_per_side(self): return self.image_size // self.patch_size @@ -495,36 +500,4 @@ class LlavaMistralForCausalLM(LlavaLlamaForCausalLM): ) -first_call = True - - -def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: - batch_size = pixel_values.shape[0] - - # Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G. - global first_call - if first_call: - self.patch_embedding.cpu().float() - first_call = False - pixel_values = pixel_values.to(dtype=torch.float32, device="cpu") - patch_embeds = self.patch_embedding(pixel_values).cuda().half() - - patch_embeds = patch_embeds.flatten(2).transpose(1, 2) - - class_embeds = self.class_embedding.expand(batch_size, 1, -1) - embeddings = torch.cat([class_embeds, patch_embeds], dim=1) - embeddings = embeddings + self.position_embedding(self.position_ids) - return embeddings - - -def monkey_path_clip_vision_embed_forward(): - import transformers - - setattr( - transformers.models.clip.modeling_clip.CLIPVisionEmbeddings, - "forward", - clip_vision_embed_forward, - ) - - EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM] diff --git a/python/sglang/srt/models/llavavid.py b/python/sglang/srt/models/llavavid.py index 8b81251d6..44e400ff6 100644 --- a/python/sglang/srt/models/llavavid.py +++ b/python/sglang/srt/models/llavavid.py @@ -26,11 +26,6 @@ from vllm.config import CacheConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from sglang.srt.mm_utils import ( - get_anyres_image_grid_shape, - unpad_image, - unpad_image_shape, -) from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.models.llama2 import LlamaForCausalLM @@ -59,23 +54,14 @@ class LlavaVidForCausalLM(nn.Module): torch.empty(config.text_config.hidden_size, dtype=torch.float16) ) - def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None): + def pad_input_ids( + self, + input_ids: List[int], + pad_value: List[int], + pixel_values: List, + image_sizes: List[List[int]], + ): new_image_feature_len = self.image_feature_len - # now only support spatial_unpad + anyres - # if self.mm_patch_merge_type.startswith("spatial"): - # height = width = self.num_patches_per_side - # if pt_shape[0] > 1: - # if self.image_aspect_ratio == "anyres": - # num_patch_width, num_patch_height = get_anyres_image_grid_shape( - # image_size, - # self.image_grid_pinpoints, - # self.vision_tower.config.image_size, - # ) - # if "unpad" in self.mm_patch_merge_type: - # h = num_patch_height * height - # w = num_patch_width * width - # new_h, new_w = unpad_image_shape(h, w, image_size) - # new_image_feature_len += new_h * (new_w + 1) pad_ids = pad_value * ( (new_image_feature_len + len(pad_value)) // len(pad_value) @@ -87,7 +73,7 @@ class LlavaVidForCausalLM(nn.Module): + pad_ids[:new_image_feature_len] + input_ids[offset + 1 :] ) - return new_input_ids, offset + return new_input_ids, [offset] def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor: image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) @@ -133,22 +119,18 @@ class LlavaVidForCausalLM(nn.Module): if input_metadata.forward_mode == ForwardMode.EXTEND: bs = input_metadata.batch_size - # Embed text input + # Embed text inputs input_embeds = self.language_model.model.embed_tokens(input_ids) - # Embed vision input - need_vision = ( - (positions[input_metadata.extend_start_loc] < self.image_feature_len) - .cpu() - .numpy() + # Whether the requests need vision inputs + max_image_offset = np.array( + [max(image_offsets[i]) if image_offsets[i] else -1 for i in range(bs)] ) - # FIXME: We need to substract the length of the system prompt - has_pixel = np.array([pixel_values[i] is not None for i in range(bs)]) - need_vision = need_vision & has_pixel + start_positions = positions[input_metadata.extend_start_loc].cpu().numpy() + need_vision = start_positions <= max_image_offset if need_vision.any(): pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]] - image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]] ########## Encode Image ######## @@ -183,31 +165,36 @@ class LlavaVidForCausalLM(nn.Module): new_image_features.append(image_feature.flatten(0, 1)) image_features = new_image_features + # Fill in the placeholder for the image extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy() + prefix_lens_cpu = input_metadata.extend_prefix_lens.cpu().numpy() pt = 0 for i in range(bs): if not need_vision[i]: continue start_idx = extend_start_loc_cpu[i] - pad_len, pad_dim = image_features[pt].shape # 576, 4096 - dim = input_embeds.shape[1] - assert ( - pad_dim == dim - ), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim) - # Fill in the placeholder for the image - try: - input_embeds[ - start_idx - + image_offsets[i] : start_idx - + image_offsets[i] - + pad_len - ] = image_features[pt] - except RuntimeError as e: - print(f"RuntimeError in llava image encoding: {e}") - print(input_embeds.shape) - print(start_idx, image_offsets[i]) - pt += 1 + prefix_len = prefix_lens_cpu[i] + + # Multiple images + for image_offset in image_offsets[i]: + if image_offset < prefix_len: + continue + + tmp_image_feature = image_features[pt] + pad_len = tmp_image_feature.shape[0] + + left_idx = start_idx + (image_offset - prefix_len) + right_idx = start_idx + (image_offset - prefix_len) + pad_len + try: + input_embeds[left_idx:right_idx] = tmp_image_feature + except RuntimeError as e: + print(f"RuntimeError in image encoding: {e}") + print(f"{input_embeds.shape=}, {tmp_image_feature.shape=}") + print( + f"{start_idx=}, {image_offset=}, {prefix_len=}, {pad_len=}" + ) + pt += 1 return self.language_model( input_ids, positions, input_metadata, input_embeds=input_embeds @@ -216,8 +203,9 @@ class LlavaVidForCausalLM(nn.Module): return self.language_model(input_ids, positions, input_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - # load clip vision model by cfg['mm_vision_tower']: - # huggingface_name or path_of_clip_relative_to_llava_model_dir + # Load clip vision model by cfg['mm_vision_tower']: + # huggingface_name or path_of_clip_relative_to_llava_model_dir + # We put the initialization here instead of __init__ to allow it being reused by other subclasses. vision_path = self.config.mm_vision_tower self.vision_tower = CLIPVisionModel.from_pretrained( vision_path, torch_dtype=torch.float16 @@ -271,43 +259,9 @@ class LlavaVidForCausalLM(nn.Module): # load language model self.language_model.load_weights(weights) - monkey_path_clip_vision_embed_forward() - @property def num_patches_per_side(self): return self.image_size // self.patch_size -first_call = True - - -def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: - batch_size = pixel_values.shape[0] - - # Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G. - global first_call - if first_call: - self.patch_embedding.cpu().float() - first_call = False - pixel_values = pixel_values.to(dtype=torch.float32, device="cpu") - patch_embeds = self.patch_embedding(pixel_values).cuda().half() - - patch_embeds = patch_embeds.flatten(2).transpose(1, 2) - - class_embeds = self.class_embedding.expand(batch_size, 1, -1) - embeddings = torch.cat([class_embeds, patch_embeds], dim=1) - embeddings = embeddings + self.position_embedding(self.position_ids) - return embeddings - - -def monkey_path_clip_vision_embed_forward(): - import transformers - - setattr( - transformers.models.clip.modeling_clip.CLIPVisionEmbeddings, - "forward", - clip_vision_embed_forward, - ) - - EntryClass = LlavaVidForCausalLM diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index fcf083e1b..a0c54f691 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -312,6 +312,9 @@ class Qwen2ForCausalLM(nn.Module): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue + if name.startswith("model.vision_tower") and name not in params_dict: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue @@ -319,8 +322,6 @@ class Qwen2ForCausalLM(nn.Module): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - if name.startswith("model.vision_tower") and name not in params_dict: - continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -329,8 +330,6 @@ class Qwen2ForCausalLM(nn.Module): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - if name.startswith("model.vision_tower") and name not in params_dict: - continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/python/sglang/srt/models/yivl.py b/python/sglang/srt/models/yivl.py index 11d4cda1c..0f86206d8 100644 --- a/python/sglang/srt/models/yivl.py +++ b/python/sglang/srt/models/yivl.py @@ -24,10 +24,7 @@ from vllm.config import CacheConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from sglang.srt.models.llava import ( - LlavaLlamaForCausalLM, - monkey_path_clip_vision_embed_forward, -) +from sglang.srt.models.llava import LlavaLlamaForCausalLM class YiVLForCausalLM(LlavaLlamaForCausalLM): @@ -50,7 +47,7 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM): self.config._name_or_path, torch_dtype=torch.float16, subfolder=self.vision_tower_subfolder, - ).cuda() + ).to("cuda") self.vision_tower.eval() @@ -94,8 +91,6 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM): # load language model self.language_model.load_weights(weights) - monkey_path_clip_vision_embed_forward() - class YiVLMultiModalProjector(nn.Module): def __init__(self, config: LlavaConfig): diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index f3d1ab0f9..9c36216ed 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -335,12 +335,12 @@ def launch_server( pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False) if server_args.dp_size == 1: - start_process = start_controller_process_single + start_controller_process = start_controller_process_single else: - start_process = start_controller_process_multi + start_controller_process = start_controller_process_multi proc_controller = mp.Process( - target=start_process, + target=start_controller_process, args=(server_args, port_args, pipe_controller_writer, model_overide_args), ) proc_controller.start() @@ -421,7 +421,7 @@ def _set_envs_and_config(server_args: ServerArgs): if not server_args.disable_flashinfer: assert_pkg_version( "flashinfer", - "0.1.6", + "0.1.5", "Please uninstall the old version and " "reinstall the latest version by following the instructions " "at https://docs.flashinfer.ai/installation.html.", diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index a6e710009..b7bb65730 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -26,7 +26,7 @@ import struct import time from importlib.metadata import PackageNotFoundError, version from io import BytesIO -from typing import List, Optional +from typing import List, Optional, Union import numpy as np import psutil @@ -193,35 +193,16 @@ def allocate_init_ports( return ret_ports[0], ret_ports[1:num_ports_needed] -def get_int_token_logit_bias(tokenizer, vocab_size): - """Get the logit bias for integer-only tokens.""" - # a bug when model's vocab size > tokenizer.vocab_size - if tokenizer == None: - return [-1e5] * vocab_size - vocab_size = tokenizer.vocab_size - logit_bias = np.zeros(vocab_size, dtype=np.float32) - for t_id in range(vocab_size): - ss = tokenizer.decode([t_id]).strip() - if not (ss.isdigit() or len(ss) == 0 or t_id == tokenizer.eos_token_id): - logit_bias[t_id] = -1e5 - - return logit_bias - - -def is_multimodal_model(model): - from sglang.srt.model_config import ModelConfig - - if isinstance(model, str): - model = model.lower() - return "llava" in model or "yi-vl" in model or "llava-next" in model - - if isinstance(model, ModelConfig): - model_path = model.path.lower() - return ( - "llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path - ) - - raise ValueError("unrecognized type") +def is_multimodal_model(model_architectures): + if ( + "LlavaLlamaForCausalLM" in model_architectures + or "LlavaQwenForCausalLM" in model_architectures + or "LlavaMistralForCausalLM" in model_architectures + or "LlavaVidForCausalLM" in model_architectures + ): + return True + else: + return False def is_generation_model(model_architectures, is_embedding: bool = False): @@ -317,12 +298,14 @@ def decode_video_base64(video_base64): ) # Return an empty array and size tuple if no frames were found -def load_image(image_file): +def load_image(image_file: Union[str, bytes]): from PIL import Image image = image_size = None - if image_file.startswith("http://") or image_file.startswith("https://"): + if isinstance(image_file, bytes): + image = Image.open(BytesIO(image_file)) + elif image_file.startswith("http://") or image_file.startswith("https://"): timeout = int(os.getenv("REQUEST_TIMEOUT", "3")) response = requests.get(image_file, timeout=timeout) image = Image.open(BytesIO(response.content)) @@ -334,8 +317,10 @@ def load_image(image_file): elif image_file.startswith("video:"): image_file = image_file.replace("video:", "") image, image_size = decode_video_base64(image_file) - else: + elif isinstance(image_file, str): image = Image.open(BytesIO(base64.b64decode(image_file))) + else: + raise ValueError(f"Invalid image: {image}") return image, image_size diff --git a/test/srt/test_vision_openai_server.py b/test/srt/test_vision_openai_server.py index 0003e4776..cf29c0e81 100644 --- a/test/srt/test_vision_openai_server.py +++ b/test/srt/test_vision_openai_server.py @@ -32,8 +32,6 @@ class TestOpenAIVisionServer(unittest.TestCase): other_args=[ "--chat-template", "chatml-llava", - "--chunked-prefill-size", - "16384", # "--log-requests", ], )