Unify the model type checking (#1905)

This commit is contained in:
Lianmin Zheng
2024-11-03 12:25:39 -08:00
committed by GitHub
parent c17c578108
commit 0abbf289a8
13 changed files with 146 additions and 160 deletions

View File

@@ -50,7 +50,7 @@ if __name__ == "__main__":
mp.set_start_method("spawn", force=True)
runtime = sgl.Runtime(model_path="lmms-lab/llama3-llava-next-8b")
runtime.endpoint.chat_template = get_chat_template("llama-3-instruct")
runtime.endpoint.chat_template = get_chat_template("llama-3-instruct-llava")
# Or you can use the 72B model
# runtime = sgl.Runtime(model_path="lmms-lab/llava-next-72b", tp_size=8)

View File

@@ -129,9 +129,9 @@ def load_model(server_args, port_args, tp_rank):
model_config = ModelConfig(
server_args.model_path,
server_args.trust_remote_code,
trust_remote_code=server_args.trust_remote_code,
context_length=server_args.context_length,
model_override_args=json.loads(server_args.json_model_override_args),
model_override_args=server_args.json_model_override_args,
)
model_runner = ModelRunner(
model_config=model_config,

View File

@@ -116,6 +116,23 @@ register_chat_template(
)
)
register_chat_template(
ChatTemplate(
name="chatml-llava",
default_system_prompt="You are a helpful assistant.",
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
},
style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>",),
image_token="<image>\n",
)
)
# There is default system prompt for qwen
# reference: https://modelscope.cn/models/qwen/Qwen2-72B-Instruct/file/view/master?fileName=tokenizer_config.json&status=1
# The chat template is: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
@@ -149,22 +166,6 @@ register_chat_template(
)
)
register_chat_template(
ChatTemplate(
name="chatml-llava",
default_system_prompt="You are a helpful assistant.",
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
},
style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>",),
image_token="<image>\n",
)
)
# Reference: https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#prompt-template
register_chat_template(
ChatTemplate(
@@ -182,21 +183,6 @@ register_chat_template(
)
)
# Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1
register_chat_template(
ChatTemplate(
name="yi-1.5",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("", ""),
"user": ("<|im_start|>user\n", "<|im_end|>\n<|im_start|>assistant\n"),
"assistant": ("", "<|im_end|>\n"),
},
style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>",),
)
)
register_chat_template(
ChatTemplate(
name="llama-2-chat",
@@ -233,6 +219,45 @@ register_chat_template(
)
)
# The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token.
register_chat_template(
ChatTemplate(
name="llama-3-instruct-llava",
default_system_prompt=None,
role_prefix_and_suffix={
"system": (
"<|start_header_id|>system<|end_header_id|>\n\n",
"<|eot_id|>",
),
"user": (
"<|start_header_id|>user<|end_header_id|>\n\n",
"<|eot_id|>",
),
"assistant": (
"<|start_header_id|>assistant<|end_header_id|>\n\n",
"<|eot_id|>",
),
},
stop_str=("<|eot_id|>",),
image_token="<image>\n",
)
)
# Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1
register_chat_template(
ChatTemplate(
name="yi-1.5",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("", ""),
"user": ("<|im_start|>user\n", "<|im_end|>\n<|im_start|>assistant\n"),
"assistant": ("", "<|im_end|>\n"),
},
style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>",),
)
)
# Reference: https://github.com/01-ai/Yi/tree/main/VL#major-difference-with-llava
register_chat_template(
ChatTemplate(

View File

@@ -13,10 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import json
import logging
import os
from enum import IntEnum, auto
from typing import Optional
from typing import List, Optional
from transformers import PretrainedConfig
@@ -38,18 +39,24 @@ class ModelConfig:
revision: Optional[str] = None,
context_length: Optional[int] = None,
model_override_args: Optional[dict] = None,
is_embedding: Optional[bool] = None
) -> None:
self.path = path
self.trust_remote_code = trust_remote_code
self.revision = revision
self.model_override_args = model_override_args
# Parse args
self.model_override_args = json.loads(model_override_args)
self.hf_config = get_config(
self.path,
trust_remote_code,
revision,
model_override_args=model_override_args,
path,
trust_remote_code=trust_remote_code,
revision=revision,
model_override_args=self.model_override_args,
)
self.hf_text_config = get_hf_text_config(self.hf_config)
# Check model type
self.is_generation = is_generation_model(self.hf_config.architectures, is_embedding)
self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
# Derive context length
derived_context_len = get_context_length(self.hf_text_config)
allow_long_context = os.environ.get(
"SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", None
@@ -81,7 +88,7 @@ class ModelConfig:
self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads,
)
# FIXME: temporary special judge for deepseek v2 MLA architecture
# FIXME: temporary special judge for MLA architecture
if "DeepseekV2ForCausalLM" in self.hf_config.architectures:
self.head_dim = 256
self.attention_arch = AttentionArch.MLA
@@ -112,8 +119,6 @@ class ModelConfig:
self.num_hidden_layers = self.hf_text_config.num_hidden_layers
self.vocab_size = self.hf_text_config.vocab_size
self.is_encoder_decoder = self.hf_config.model_type in ["mllama"]
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
def get_total_num_kv_heads(self) -> int:
"""Returns the total number of KV heads."""
@@ -163,7 +168,6 @@ class ModelConfig:
# equal to the number of attention heads.
return self.hf_text_config.num_attention_heads
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L328
def get_num_kv_heads(self, tensor_parallel_size) -> int:
"""Returns the number of KV heads per GPU."""
total_num_kv_heads = self.get_total_num_kv_heads()
@@ -192,3 +196,37 @@ def get_hf_text_config(config: PretrainedConfig):
return config.text_config
else:
return config
def is_generation_model(model_architectures: List[str], is_embedding: bool = False):
# We have two ways to determine whether a model is a generative model.
# 1. Check the model architectue
# 2. check the `is_embedding` server args
if (
"LlamaEmbeddingModel" in model_architectures
or "MistralModel" in model_architectures
or "LlamaForSequenceClassification" in model_architectures
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
):
return False
else:
return not is_embedding
def is_multimodal_model(model_architectures: List[str]):
if (
"LlavaLlamaForCausalLM" in model_architectures
or "LlavaQwenForCausalLM" in model_architectures
or "LlavaMistralForCausalLM" in model_architectures
or "LlavaVidForCausalLM" in model_architectures
or "MllamaForConditionalGeneration" in model_architectures
or "Qwen2VLForConditionalGeneration" in model_architectures
):
return True
else:
return False
def is_encoder_decoder_model(model_architectures: List[str]):
return "MllamaForConditionalGeneration" in model_architectures

View File

@@ -180,7 +180,7 @@ class LlavaImageProcessor(BaseImageProcessor):
"pixel_values": pixel_values,
"image_hashes": image_hashes,
"image_sizes": image_sizes,
"modalities": request_obj.modalities,
"modalities": request_obj.modalities or ["image"],
}

View File

@@ -15,7 +15,6 @@ limitations under the License.
"""A scheduler that manages a tensor parallel GPU worker."""
import json
import logging
import os
import threading
@@ -23,7 +22,7 @@ import time
import warnings
from collections import deque
from types import SimpleNamespace
from typing import List, Optional, Union
from typing import List, Optional
import torch
import zmq
@@ -68,8 +67,6 @@ from sglang.srt.utils import (
broadcast_pyobj,
configure_logger,
get_zmq_socket,
is_generation_model,
is_multimodal_model,
kill_parent_process,
set_random_seed,
suppress_other_loggers,
@@ -133,15 +130,17 @@ class Scheduler:
# Init tokenizer
self.model_config = ModelConfig(
server_args.model_path,
server_args.trust_remote_code,
trust_remote_code=server_args.trust_remote_code,
context_length=server_args.context_length,
model_override_args=json.loads(server_args.json_model_override_args),
model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
)
self.is_generation = self.model_config.is_generation
if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None
else:
if is_multimodal_model(self.model_config.hf_config.architectures):
if self.model_config.is_multimodal:
self.processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
@@ -154,9 +153,6 @@ class Scheduler:
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
self.is_generation = is_generation_model(
self.model_config.hf_config.architectures, self.server_args.is_embedding
)
# Launch a tensor parallel worker
if self.enable_overlap:

View File

@@ -18,7 +18,6 @@ limitations under the License.
import asyncio
import copy
import dataclasses
import json
import logging
import os
import signal
@@ -31,12 +30,8 @@ import zmq
import zmq.asyncio
from fastapi import BackgroundTasks
from sglang.srt.hf_transformers_utils import (
get_config,
get_context_length,
get_processor,
get_tokenizer,
)
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.image_processor import (
get_dummy_image_processor,
get_image_processor,
@@ -59,12 +54,7 @@ from sglang.srt.managers.io_struct import (
)
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
get_zmq_socket,
is_generation_model,
is_multimodal_model,
kill_child_process,
)
from sglang.srt.utils import get_zmq_socket, kill_child_process
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@@ -103,18 +93,17 @@ class TokenizerManager:
# Read model args
self.model_path = server_args.model_path
self.served_model_name = server_args.served_model_name
self.hf_config = get_config(
self.model_path,
self.model_config = ModelConfig(
server_args.model_path,
trust_remote_code=server_args.trust_remote_code,
model_override_args=json.loads(server_args.json_model_override_args),
)
self.is_generation = is_generation_model(
self.hf_config.architectures, self.server_args.is_embedding
)
self.context_len = server_args.context_length or get_context_length(
self.hf_config
context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
)
self.is_generation = self.model_config.is_generation
self.context_len = self.model_config.context_len
# Create image processor placeholder
self.image_processor = get_dummy_image_processor()
@@ -122,7 +111,7 @@ class TokenizerManager:
if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None
else:
if is_multimodal_model(self.hf_config.architectures):
if self.model_config.is_multimodal:
self.processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
@@ -133,7 +122,7 @@ class TokenizerManager:
# We want to parallelize the image pre-processing so we create an executor for it
self.image_processor = get_image_processor(
self.hf_config, server_args, self.processor
self.model_config.hf_config, server_args, self.processor
)
else:
self.tokenizer = get_tokenizer(

View File

@@ -15,7 +15,6 @@ limitations under the License.
"""A tensor parallel worker."""
import json
import logging
from typing import Optional
@@ -26,7 +25,7 @@ from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_a
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import broadcast_pyobj, is_multimodal_model, set_random_seed
from sglang.srt.utils import broadcast_pyobj, set_random_seed
logger = logging.getLogger(__name__)
@@ -48,9 +47,10 @@ class TpModelWorker:
# Init model and tokenizer
self.model_config = ModelConfig(
server_args.model_path,
server_args.trust_remote_code,
trust_remote_code=server_args.trust_remote_code,
context_length=server_args.context_length,
model_override_args=json.loads(server_args.json_model_override_args),
model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
)
self.model_runner = ModelRunner(
model_config=self.model_config,
@@ -64,7 +64,7 @@ class TpModelWorker:
if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None
else:
if is_multimodal_model(self.model_config.hf_config.architectures):
if self.model_config.is_multimodal:
self.processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,

View File

@@ -59,11 +59,6 @@ from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
enable_show_time_cost,
get_available_gpu_memory,
is_attention_free_model,
is_embedding_model,
is_generation_model,
is_multimodal_model,
model_has_inner_state,
monkey_patch_vllm_dummy_weight_loader,
monkey_patch_vllm_p2p_access_check,
)
@@ -93,9 +88,8 @@ class ModelRunner:
self.tp_size = tp_size
self.dist_port = nccl_port
self.server_args = server_args
self.is_multimodal_model = is_multimodal_model(
self.model_config.hf_config.architectures
)
self.is_generation = model_config.is_generation
self.is_multimodal = model_config.is_multimodal
# Model-specific adjustment
if (
@@ -119,7 +113,7 @@ class ModelRunner:
self.server_args.ds_heavy_channel_type
)
if self.is_multimodal_model:
if self.is_multimodal:
logger.warning(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
)
@@ -270,9 +264,6 @@ class ModelRunner:
if hasattr(self.model, "get_attention_sliding_window_size")
else None
)
self.is_generation = is_generation_model(
self.model_config.hf_config.architectures, self.server_args.is_embedding
)
logger.info(
f"Load weight end. "
@@ -679,7 +670,7 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
# Monkey patch model loader
setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)
setattr(ModelRegistry, "is_multimodal_model", is_multimodal_model)
setattr(ModelRegistry, "is_attention_free_model", is_attention_free_model)
setattr(ModelRegistry, "model_has_inner_state", model_has_inner_state)
setattr(ModelRegistry, "is_embedding_model", is_embedding_model)
setattr(ModelRegistry, "is_multimodal_model", lambda model_architectures: False)
setattr(ModelRegistry, "is_attention_free_model", lambda model_architectures: False)
setattr(ModelRegistry, "model_has_inner_state", lambda model_architectures: False)
setattr(ModelRegistry, "is_embedding_model", lambda model_architectures: False)

View File

@@ -409,11 +409,13 @@ class LlamaForCausalLM(nn.Module):
if (
hasattr(self.config, "tie_word_embeddings")
and self.config.tie_word_embeddings
and "lm_head.weight" in params_dict
):
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
param = self.lm_head.weight
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, self.model.embed_tokens.weight)
apply_torchao_config_(self, params_dict, set(["proj.weight"]))

View File

@@ -23,7 +23,7 @@
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from functools import lru_cache, partial
from typing import Iterable, List, Mapping, Optional, Tuple, Type, TypedDict, Union
from typing import Iterable, List, Optional, Tuple, Type, TypedDict
import numpy as np
import torch
@@ -36,7 +36,6 @@ from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsMultiModal
from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
from sglang.srt.hf_transformers_utils import get_processor
@@ -486,7 +485,7 @@ class Qwen2VisionTransformer(nn.Module):
cached_get_processor = lru_cache(get_processor)
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
class Qwen2VLForConditionalGeneration(nn.Module):
def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
processor = cached_get_processor(self.config._name_or_path)
grid_t, grid_h, grid_w = image_grid_thw
@@ -536,15 +535,12 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(
self,
config: Qwen2VLConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.multimodal_config = multimodal_config
self.visual = Qwen2VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),

View File

@@ -204,56 +204,6 @@ def is_port_available(port):
return False
def is_multimodal_model(model_architectures):
if (
"LlavaLlamaForCausalLM" in model_architectures
or "LlavaQwenForCausalLM" in model_architectures
or "LlavaMistralForCausalLM" in model_architectures
or "LlavaVidForCausalLM" in model_architectures
or "MllamaForConditionalGeneration" in model_architectures
or "Qwen2VLForConditionalGeneration" in model_architectures
):
return True
else:
return False
def is_attention_free_model(model_architectures):
return False
def model_has_inner_state(model_architectures):
return False
def is_embedding_model(model_architectures):
if (
"LlamaEmbeddingModel" in model_architectures
or "MistralModel" in model_architectures
or "LlamaForSequenceClassification" in model_architectures
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
):
return True
else:
return False
def is_generation_model(model_architectures, is_embedding: bool = False):
# We have two ways to determine whether a model is a generative model.
# 1. Check the model architectue
# 2. check the `is_embedding` server args
if (
"LlamaEmbeddingModel" in model_architectures
or "MistralModel" in model_architectures
or "LlamaForSequenceClassification" in model_architectures
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
):
return False
else:
return not is_embedding
def decode_video_base64(video_base64):
from PIL import Image

View File

@@ -1,5 +1,4 @@
import asyncio
import json
import unittest
import openai