From 0abbf289a8acd01cafd182da8d6a5cc0fccb6953 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 3 Nov 2024 12:25:39 -0800 Subject: [PATCH] Unify the model type checking (#1905) --- .../quick_start/local_example_llava_next.py | 2 +- python/sglang/bench_latency.py | 4 +- python/sglang/lang/chat_template.py | 87 ++++++++++++------- python/sglang/srt/configs/model_config.py | 64 +++++++++++--- python/sglang/srt/managers/image_processor.py | 2 +- python/sglang/srt/managers/scheduler.py | 16 ++-- .../sglang/srt/managers/tokenizer_manager.py | 37 +++----- python/sglang/srt/managers/tp_worker.py | 10 +-- .../sglang/srt/model_executor/model_runner.py | 23 ++--- python/sglang/srt/models/llama.py | 2 + python/sglang/srt/models/qwen2_vl.py | 8 +- python/sglang/srt/utils.py | 50 ----------- test/srt/test_cache_report.py | 1 - 13 files changed, 146 insertions(+), 160 deletions(-) diff --git a/examples/frontend_language/quick_start/local_example_llava_next.py b/examples/frontend_language/quick_start/local_example_llava_next.py index fc5a1d04c..c941a549e 100644 --- a/examples/frontend_language/quick_start/local_example_llava_next.py +++ b/examples/frontend_language/quick_start/local_example_llava_next.py @@ -50,7 +50,7 @@ if __name__ == "__main__": mp.set_start_method("spawn", force=True) runtime = sgl.Runtime(model_path="lmms-lab/llama3-llava-next-8b") - runtime.endpoint.chat_template = get_chat_template("llama-3-instruct") + runtime.endpoint.chat_template = get_chat_template("llama-3-instruct-llava") # Or you can use the 72B model # runtime = sgl.Runtime(model_path="lmms-lab/llava-next-72b", tp_size=8) diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index d97b641ea..841ecf56d 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -129,9 +129,9 @@ def load_model(server_args, port_args, tp_rank): model_config = ModelConfig( server_args.model_path, - server_args.trust_remote_code, + trust_remote_code=server_args.trust_remote_code, context_length=server_args.context_length, - model_override_args=json.loads(server_args.json_model_override_args), + model_override_args=server_args.json_model_override_args, ) model_runner = ModelRunner( model_config=model_config, diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py index ca5a7a261..3e5ac8dd5 100644 --- a/python/sglang/lang/chat_template.py +++ b/python/sglang/lang/chat_template.py @@ -116,6 +116,23 @@ register_chat_template( ) ) + +register_chat_template( + ChatTemplate( + name="chatml-llava", + default_system_prompt="You are a helpful assistant.", + role_prefix_and_suffix={ + "system": ("<|im_start|>system\n", "<|im_end|>\n"), + "user": ("<|im_start|>user\n", "<|im_end|>\n"), + "assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"), + }, + style=ChatTemplateStyle.PLAIN, + stop_str=("<|im_end|>",), + image_token="\n", + ) +) + + # There is default system prompt for qwen # reference: https://modelscope.cn/models/qwen/Qwen2-72B-Instruct/file/view/master?fileName=tokenizer_config.json&status=1 # The chat template is: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" @@ -149,22 +166,6 @@ register_chat_template( ) ) - -register_chat_template( - ChatTemplate( - name="chatml-llava", - default_system_prompt="You are a helpful assistant.", - role_prefix_and_suffix={ - "system": ("<|im_start|>system\n", "<|im_end|>\n"), - "user": ("<|im_start|>user\n", "<|im_end|>\n"), - "assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"), - }, - style=ChatTemplateStyle.PLAIN, - stop_str=("<|im_end|>",), - image_token="\n", - ) -) - # Reference: https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#prompt-template register_chat_template( ChatTemplate( @@ -182,21 +183,6 @@ register_chat_template( ) ) -# Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1 -register_chat_template( - ChatTemplate( - name="yi-1.5", - default_system_prompt=None, - role_prefix_and_suffix={ - "system": ("", ""), - "user": ("<|im_start|>user\n", "<|im_end|>\n<|im_start|>assistant\n"), - "assistant": ("", "<|im_end|>\n"), - }, - style=ChatTemplateStyle.PLAIN, - stop_str=("<|im_end|>",), - ) -) - register_chat_template( ChatTemplate( name="llama-2-chat", @@ -233,6 +219,45 @@ register_chat_template( ) ) +# The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token. +register_chat_template( + ChatTemplate( + name="llama-3-instruct-llava", + default_system_prompt=None, + role_prefix_and_suffix={ + "system": ( + "<|start_header_id|>system<|end_header_id|>\n\n", + "<|eot_id|>", + ), + "user": ( + "<|start_header_id|>user<|end_header_id|>\n\n", + "<|eot_id|>", + ), + "assistant": ( + "<|start_header_id|>assistant<|end_header_id|>\n\n", + "<|eot_id|>", + ), + }, + stop_str=("<|eot_id|>",), + image_token="\n", + ) +) + +# Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1 +register_chat_template( + ChatTemplate( + name="yi-1.5", + default_system_prompt=None, + role_prefix_and_suffix={ + "system": ("", ""), + "user": ("<|im_start|>user\n", "<|im_end|>\n<|im_start|>assistant\n"), + "assistant": ("", "<|im_end|>\n"), + }, + style=ChatTemplateStyle.PLAIN, + stop_str=("<|im_end|>",), + ) +) + # Reference: https://github.com/01-ai/Yi/tree/main/VL#major-difference-with-llava register_chat_template( ChatTemplate( diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index a74d240b4..c37cfefbd 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -13,10 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. """ +import json import logging import os from enum import IntEnum, auto -from typing import Optional +from typing import List, Optional from transformers import PretrainedConfig @@ -38,18 +39,24 @@ class ModelConfig: revision: Optional[str] = None, context_length: Optional[int] = None, model_override_args: Optional[dict] = None, + is_embedding: Optional[bool] = None ) -> None: - self.path = path - self.trust_remote_code = trust_remote_code - self.revision = revision - self.model_override_args = model_override_args + # Parse args + self.model_override_args = json.loads(model_override_args) self.hf_config = get_config( - self.path, - trust_remote_code, - revision, - model_override_args=model_override_args, + path, + trust_remote_code=trust_remote_code, + revision=revision, + model_override_args=self.model_override_args, ) self.hf_text_config = get_hf_text_config(self.hf_config) + + # Check model type + self.is_generation = is_generation_model(self.hf_config.architectures, is_embedding) + self.is_multimodal = is_multimodal_model(self.hf_config.architectures) + self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures) + + # Derive context length derived_context_len = get_context_length(self.hf_text_config) allow_long_context = os.environ.get( "SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", None @@ -81,7 +88,7 @@ class ModelConfig: self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads, ) - # FIXME: temporary special judge for deepseek v2 MLA architecture + # FIXME: temporary special judge for MLA architecture if "DeepseekV2ForCausalLM" in self.hf_config.architectures: self.head_dim = 256 self.attention_arch = AttentionArch.MLA @@ -112,8 +119,6 @@ class ModelConfig: self.num_hidden_layers = self.hf_text_config.num_hidden_layers self.vocab_size = self.hf_text_config.vocab_size - self.is_encoder_decoder = self.hf_config.model_type in ["mllama"] - # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289 def get_total_num_kv_heads(self) -> int: """Returns the total number of KV heads.""" @@ -163,7 +168,6 @@ class ModelConfig: # equal to the number of attention heads. return self.hf_text_config.num_attention_heads - # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L328 def get_num_kv_heads(self, tensor_parallel_size) -> int: """Returns the number of KV heads per GPU.""" total_num_kv_heads = self.get_total_num_kv_heads() @@ -192,3 +196,37 @@ def get_hf_text_config(config: PretrainedConfig): return config.text_config else: return config + + +def is_generation_model(model_architectures: List[str], is_embedding: bool = False): + # We have two ways to determine whether a model is a generative model. + # 1. Check the model architectue + # 2. check the `is_embedding` server args + + if ( + "LlamaEmbeddingModel" in model_architectures + or "MistralModel" in model_architectures + or "LlamaForSequenceClassification" in model_architectures + or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures + ): + return False + else: + return not is_embedding + + +def is_multimodal_model(model_architectures: List[str]): + if ( + "LlavaLlamaForCausalLM" in model_architectures + or "LlavaQwenForCausalLM" in model_architectures + or "LlavaMistralForCausalLM" in model_architectures + or "LlavaVidForCausalLM" in model_architectures + or "MllamaForConditionalGeneration" in model_architectures + or "Qwen2VLForConditionalGeneration" in model_architectures + ): + return True + else: + return False + + +def is_encoder_decoder_model(model_architectures: List[str]): + return "MllamaForConditionalGeneration" in model_architectures diff --git a/python/sglang/srt/managers/image_processor.py b/python/sglang/srt/managers/image_processor.py index b24b761a2..2af817319 100644 --- a/python/sglang/srt/managers/image_processor.py +++ b/python/sglang/srt/managers/image_processor.py @@ -180,7 +180,7 @@ class LlavaImageProcessor(BaseImageProcessor): "pixel_values": pixel_values, "image_hashes": image_hashes, "image_sizes": image_sizes, - "modalities": request_obj.modalities, + "modalities": request_obj.modalities or ["image"], } diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index ea98aa696..dd6bba863 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -15,7 +15,6 @@ limitations under the License. """A scheduler that manages a tensor parallel GPU worker.""" -import json import logging import os import threading @@ -23,7 +22,7 @@ import time import warnings from collections import deque from types import SimpleNamespace -from typing import List, Optional, Union +from typing import List, Optional import torch import zmq @@ -68,8 +67,6 @@ from sglang.srt.utils import ( broadcast_pyobj, configure_logger, get_zmq_socket, - is_generation_model, - is_multimodal_model, kill_parent_process, set_random_seed, suppress_other_loggers, @@ -133,15 +130,17 @@ class Scheduler: # Init tokenizer self.model_config = ModelConfig( server_args.model_path, - server_args.trust_remote_code, + trust_remote_code=server_args.trust_remote_code, context_length=server_args.context_length, - model_override_args=json.loads(server_args.json_model_override_args), + model_override_args=server_args.json_model_override_args, + is_embedding=server_args.is_embedding, ) + self.is_generation = self.model_config.is_generation if server_args.skip_tokenizer_init: self.tokenizer = self.processor = None else: - if is_multimodal_model(self.model_config.hf_config.architectures): + if self.model_config.is_multimodal: self.processor = get_processor( server_args.tokenizer_path, tokenizer_mode=server_args.tokenizer_mode, @@ -154,9 +153,6 @@ class Scheduler: tokenizer_mode=server_args.tokenizer_mode, trust_remote_code=server_args.trust_remote_code, ) - self.is_generation = is_generation_model( - self.model_config.hf_config.architectures, self.server_args.is_embedding - ) # Launch a tensor parallel worker if self.enable_overlap: diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 1a9dc5e2b..cc9e2bd26 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -18,7 +18,6 @@ limitations under the License. import asyncio import copy import dataclasses -import json import logging import os import signal @@ -31,12 +30,8 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks -from sglang.srt.hf_transformers_utils import ( - get_config, - get_context_length, - get_processor, - get_tokenizer, -) +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.managers.image_processor import ( get_dummy_image_processor, get_image_processor, @@ -59,12 +54,7 @@ from sglang.srt.managers.io_struct import ( ) from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import ( - get_zmq_socket, - is_generation_model, - is_multimodal_model, - kill_child_process, -) +from sglang.srt.utils import get_zmq_socket, kill_child_process asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -103,18 +93,17 @@ class TokenizerManager: # Read model args self.model_path = server_args.model_path self.served_model_name = server_args.served_model_name - self.hf_config = get_config( - self.model_path, + self.model_config = ModelConfig( + server_args.model_path, trust_remote_code=server_args.trust_remote_code, - model_override_args=json.loads(server_args.json_model_override_args), - ) - self.is_generation = is_generation_model( - self.hf_config.architectures, self.server_args.is_embedding - ) - self.context_len = server_args.context_length or get_context_length( - self.hf_config + context_length=server_args.context_length, + model_override_args=server_args.json_model_override_args, + is_embedding=server_args.is_embedding, ) + self.is_generation = self.model_config.is_generation + self.context_len = self.model_config.context_len + # Create image processor placeholder self.image_processor = get_dummy_image_processor() @@ -122,7 +111,7 @@ class TokenizerManager: if server_args.skip_tokenizer_init: self.tokenizer = self.processor = None else: - if is_multimodal_model(self.hf_config.architectures): + if self.model_config.is_multimodal: self.processor = get_processor( server_args.tokenizer_path, tokenizer_mode=server_args.tokenizer_mode, @@ -133,7 +122,7 @@ class TokenizerManager: # We want to parallelize the image pre-processing so we create an executor for it self.image_processor = get_image_processor( - self.hf_config, server_args, self.processor + self.model_config.hf_config, server_args, self.processor ) else: self.tokenizer = get_tokenizer( diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 561bfd77c..8bec1a18c 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -15,7 +15,6 @@ limitations under the License. """A tensor parallel worker.""" -import json import logging from typing import Optional @@ -26,7 +25,7 @@ from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_a from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import broadcast_pyobj, is_multimodal_model, set_random_seed +from sglang.srt.utils import broadcast_pyobj, set_random_seed logger = logging.getLogger(__name__) @@ -48,9 +47,10 @@ class TpModelWorker: # Init model and tokenizer self.model_config = ModelConfig( server_args.model_path, - server_args.trust_remote_code, + trust_remote_code=server_args.trust_remote_code, context_length=server_args.context_length, - model_override_args=json.loads(server_args.json_model_override_args), + model_override_args=server_args.json_model_override_args, + is_embedding=server_args.is_embedding, ) self.model_runner = ModelRunner( model_config=self.model_config, @@ -64,7 +64,7 @@ class TpModelWorker: if server_args.skip_tokenizer_init: self.tokenizer = self.processor = None else: - if is_multimodal_model(self.model_config.hf_config.architectures): + if self.model_config.is_multimodal: self.processor = get_processor( server_args.tokenizer_path, tokenizer_mode=server_args.tokenizer_mode, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 583cbd968..1dde62943 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -59,11 +59,6 @@ from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( enable_show_time_cost, get_available_gpu_memory, - is_attention_free_model, - is_embedding_model, - is_generation_model, - is_multimodal_model, - model_has_inner_state, monkey_patch_vllm_dummy_weight_loader, monkey_patch_vllm_p2p_access_check, ) @@ -93,9 +88,8 @@ class ModelRunner: self.tp_size = tp_size self.dist_port = nccl_port self.server_args = server_args - self.is_multimodal_model = is_multimodal_model( - self.model_config.hf_config.architectures - ) + self.is_generation = model_config.is_generation + self.is_multimodal = model_config.is_multimodal # Model-specific adjustment if ( @@ -119,7 +113,7 @@ class ModelRunner: self.server_args.ds_heavy_channel_type ) - if self.is_multimodal_model: + if self.is_multimodal: logger.warning( "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models." ) @@ -270,9 +264,6 @@ class ModelRunner: if hasattr(self.model, "get_attention_sliding_window_size") else None ) - self.is_generation = is_generation_model( - self.model_config.hf_config.architectures, self.server_args.is_embedding - ) logger.info( f"Load weight end. " @@ -679,7 +670,7 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]: # Monkey patch model loader setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt) -setattr(ModelRegistry, "is_multimodal_model", is_multimodal_model) -setattr(ModelRegistry, "is_attention_free_model", is_attention_free_model) -setattr(ModelRegistry, "model_has_inner_state", model_has_inner_state) -setattr(ModelRegistry, "is_embedding_model", is_embedding_model) +setattr(ModelRegistry, "is_multimodal_model", lambda model_architectures: False) +setattr(ModelRegistry, "is_attention_free_model", lambda model_architectures: False) +setattr(ModelRegistry, "model_has_inner_state", lambda model_architectures: False) +setattr(ModelRegistry, "is_embedding_model", lambda model_architectures: False) diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index ab86a55b9..a9eaa81c1 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -409,11 +409,13 @@ class LlamaForCausalLM(nn.Module): if ( hasattr(self.config, "tie_word_embeddings") and self.config.tie_word_embeddings + and "lm_head.weight" in params_dict ): # Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing param = self.lm_head.weight weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, self.model.embed_tokens.weight) + apply_torchao_config_(self, params_dict, set(["proj.weight"])) diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index 80afee557..59adb2ee7 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -23,7 +23,7 @@ # limitations under the License. """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" from functools import lru_cache, partial -from typing import Iterable, List, Mapping, Optional, Tuple, Type, TypedDict, Union +from typing import Iterable, List, Optional, Tuple, Type, TypedDict import numpy as np import torch @@ -36,7 +36,6 @@ from vllm.distributed import utils as dist_utils from vllm.logger import init_logger from vllm.model_executor.layers.activation import QuickGELU from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import SupportsMultiModal from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig from sglang.srt.hf_transformers_utils import get_processor @@ -486,7 +485,7 @@ class Qwen2VisionTransformer(nn.Module): cached_get_processor = lru_cache(get_processor) -class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal): +class Qwen2VLForConditionalGeneration(nn.Module): def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]): processor = cached_get_processor(self.config._name_or_path) grid_t, grid_h, grid_w = image_grid_thw @@ -536,15 +535,12 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal): def __init__( self, config: Qwen2VLConfig, - multimodal_config: MultiModalConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config - self.multimodal_config = multimodal_config - self.visual = Qwen2VisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 2be3a298e..0c3ae0c5a 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -204,56 +204,6 @@ def is_port_available(port): return False -def is_multimodal_model(model_architectures): - if ( - "LlavaLlamaForCausalLM" in model_architectures - or "LlavaQwenForCausalLM" in model_architectures - or "LlavaMistralForCausalLM" in model_architectures - or "LlavaVidForCausalLM" in model_architectures - or "MllamaForConditionalGeneration" in model_architectures - or "Qwen2VLForConditionalGeneration" in model_architectures - ): - return True - else: - return False - - -def is_attention_free_model(model_architectures): - return False - - -def model_has_inner_state(model_architectures): - return False - - -def is_embedding_model(model_architectures): - if ( - "LlamaEmbeddingModel" in model_architectures - or "MistralModel" in model_architectures - or "LlamaForSequenceClassification" in model_architectures - or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures - ): - return True - else: - return False - - -def is_generation_model(model_architectures, is_embedding: bool = False): - # We have two ways to determine whether a model is a generative model. - # 1. Check the model architectue - # 2. check the `is_embedding` server args - - if ( - "LlamaEmbeddingModel" in model_architectures - or "MistralModel" in model_architectures - or "LlamaForSequenceClassification" in model_architectures - or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures - ): - return False - else: - return not is_embedding - - def decode_video_base64(video_base64): from PIL import Image diff --git a/test/srt/test_cache_report.py b/test/srt/test_cache_report.py index dfc140d58..b790c3ae6 100644 --- a/test/srt/test_cache_report.py +++ b/test/srt/test_cache_report.py @@ -1,5 +1,4 @@ import asyncio -import json import unittest import openai