Unify the model type checking (#1905)
This commit is contained in:
@@ -50,7 +50,7 @@ if __name__ == "__main__":
|
|||||||
mp.set_start_method("spawn", force=True)
|
mp.set_start_method("spawn", force=True)
|
||||||
|
|
||||||
runtime = sgl.Runtime(model_path="lmms-lab/llama3-llava-next-8b")
|
runtime = sgl.Runtime(model_path="lmms-lab/llama3-llava-next-8b")
|
||||||
runtime.endpoint.chat_template = get_chat_template("llama-3-instruct")
|
runtime.endpoint.chat_template = get_chat_template("llama-3-instruct-llava")
|
||||||
|
|
||||||
# Or you can use the 72B model
|
# Or you can use the 72B model
|
||||||
# runtime = sgl.Runtime(model_path="lmms-lab/llava-next-72b", tp_size=8)
|
# runtime = sgl.Runtime(model_path="lmms-lab/llava-next-72b", tp_size=8)
|
||||||
|
|||||||
@@ -129,9 +129,9 @@ def load_model(server_args, port_args, tp_rank):
|
|||||||
|
|
||||||
model_config = ModelConfig(
|
model_config = ModelConfig(
|
||||||
server_args.model_path,
|
server_args.model_path,
|
||||||
server_args.trust_remote_code,
|
trust_remote_code=server_args.trust_remote_code,
|
||||||
context_length=server_args.context_length,
|
context_length=server_args.context_length,
|
||||||
model_override_args=json.loads(server_args.json_model_override_args),
|
model_override_args=server_args.json_model_override_args,
|
||||||
)
|
)
|
||||||
model_runner = ModelRunner(
|
model_runner = ModelRunner(
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
|
|||||||
@@ -116,6 +116,23 @@ register_chat_template(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_chat_template(
|
||||||
|
ChatTemplate(
|
||||||
|
name="chatml-llava",
|
||||||
|
default_system_prompt="You are a helpful assistant.",
|
||||||
|
role_prefix_and_suffix={
|
||||||
|
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
|
||||||
|
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
|
||||||
|
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
|
||||||
|
},
|
||||||
|
style=ChatTemplateStyle.PLAIN,
|
||||||
|
stop_str=("<|im_end|>",),
|
||||||
|
image_token="<image>\n",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# There is default system prompt for qwen
|
# There is default system prompt for qwen
|
||||||
# reference: https://modelscope.cn/models/qwen/Qwen2-72B-Instruct/file/view/master?fileName=tokenizer_config.json&status=1
|
# reference: https://modelscope.cn/models/qwen/Qwen2-72B-Instruct/file/view/master?fileName=tokenizer_config.json&status=1
|
||||||
# The chat template is: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
|
# The chat template is: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
|
||||||
@@ -149,22 +166,6 @@ register_chat_template(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_chat_template(
|
|
||||||
ChatTemplate(
|
|
||||||
name="chatml-llava",
|
|
||||||
default_system_prompt="You are a helpful assistant.",
|
|
||||||
role_prefix_and_suffix={
|
|
||||||
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
|
|
||||||
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
|
|
||||||
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
|
|
||||||
},
|
|
||||||
style=ChatTemplateStyle.PLAIN,
|
|
||||||
stop_str=("<|im_end|>",),
|
|
||||||
image_token="<image>\n",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Reference: https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#prompt-template
|
# Reference: https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#prompt-template
|
||||||
register_chat_template(
|
register_chat_template(
|
||||||
ChatTemplate(
|
ChatTemplate(
|
||||||
@@ -182,21 +183,6 @@ register_chat_template(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1
|
|
||||||
register_chat_template(
|
|
||||||
ChatTemplate(
|
|
||||||
name="yi-1.5",
|
|
||||||
default_system_prompt=None,
|
|
||||||
role_prefix_and_suffix={
|
|
||||||
"system": ("", ""),
|
|
||||||
"user": ("<|im_start|>user\n", "<|im_end|>\n<|im_start|>assistant\n"),
|
|
||||||
"assistant": ("", "<|im_end|>\n"),
|
|
||||||
},
|
|
||||||
style=ChatTemplateStyle.PLAIN,
|
|
||||||
stop_str=("<|im_end|>",),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
register_chat_template(
|
register_chat_template(
|
||||||
ChatTemplate(
|
ChatTemplate(
|
||||||
name="llama-2-chat",
|
name="llama-2-chat",
|
||||||
@@ -233,6 +219,45 @@ register_chat_template(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token.
|
||||||
|
register_chat_template(
|
||||||
|
ChatTemplate(
|
||||||
|
name="llama-3-instruct-llava",
|
||||||
|
default_system_prompt=None,
|
||||||
|
role_prefix_and_suffix={
|
||||||
|
"system": (
|
||||||
|
"<|start_header_id|>system<|end_header_id|>\n\n",
|
||||||
|
"<|eot_id|>",
|
||||||
|
),
|
||||||
|
"user": (
|
||||||
|
"<|start_header_id|>user<|end_header_id|>\n\n",
|
||||||
|
"<|eot_id|>",
|
||||||
|
),
|
||||||
|
"assistant": (
|
||||||
|
"<|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||||
|
"<|eot_id|>",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
stop_str=("<|eot_id|>",),
|
||||||
|
image_token="<image>\n",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1
|
||||||
|
register_chat_template(
|
||||||
|
ChatTemplate(
|
||||||
|
name="yi-1.5",
|
||||||
|
default_system_prompt=None,
|
||||||
|
role_prefix_and_suffix={
|
||||||
|
"system": ("", ""),
|
||||||
|
"user": ("<|im_start|>user\n", "<|im_end|>\n<|im_start|>assistant\n"),
|
||||||
|
"assistant": ("", "<|im_end|>\n"),
|
||||||
|
},
|
||||||
|
style=ChatTemplateStyle.PLAIN,
|
||||||
|
stop_str=("<|im_end|>",),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Reference: https://github.com/01-ai/Yi/tree/main/VL#major-difference-with-llava
|
# Reference: https://github.com/01-ai/Yi/tree/main/VL#major-difference-with-llava
|
||||||
register_chat_template(
|
register_chat_template(
|
||||||
ChatTemplate(
|
ChatTemplate(
|
||||||
|
|||||||
@@ -13,10 +13,11 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from enum import IntEnum, auto
|
from enum import IntEnum, auto
|
||||||
from typing import Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
@@ -38,18 +39,24 @@ class ModelConfig:
|
|||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
context_length: Optional[int] = None,
|
context_length: Optional[int] = None,
|
||||||
model_override_args: Optional[dict] = None,
|
model_override_args: Optional[dict] = None,
|
||||||
|
is_embedding: Optional[bool] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
self.path = path
|
# Parse args
|
||||||
self.trust_remote_code = trust_remote_code
|
self.model_override_args = json.loads(model_override_args)
|
||||||
self.revision = revision
|
|
||||||
self.model_override_args = model_override_args
|
|
||||||
self.hf_config = get_config(
|
self.hf_config = get_config(
|
||||||
self.path,
|
path,
|
||||||
trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
revision,
|
revision=revision,
|
||||||
model_override_args=model_override_args,
|
model_override_args=self.model_override_args,
|
||||||
)
|
)
|
||||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||||
|
|
||||||
|
# Check model type
|
||||||
|
self.is_generation = is_generation_model(self.hf_config.architectures, is_embedding)
|
||||||
|
self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
|
||||||
|
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
|
||||||
|
|
||||||
|
# Derive context length
|
||||||
derived_context_len = get_context_length(self.hf_text_config)
|
derived_context_len = get_context_length(self.hf_text_config)
|
||||||
allow_long_context = os.environ.get(
|
allow_long_context = os.environ.get(
|
||||||
"SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", None
|
"SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", None
|
||||||
@@ -81,7 +88,7 @@ class ModelConfig:
|
|||||||
self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads,
|
self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads,
|
||||||
)
|
)
|
||||||
|
|
||||||
# FIXME: temporary special judge for deepseek v2 MLA architecture
|
# FIXME: temporary special judge for MLA architecture
|
||||||
if "DeepseekV2ForCausalLM" in self.hf_config.architectures:
|
if "DeepseekV2ForCausalLM" in self.hf_config.architectures:
|
||||||
self.head_dim = 256
|
self.head_dim = 256
|
||||||
self.attention_arch = AttentionArch.MLA
|
self.attention_arch = AttentionArch.MLA
|
||||||
@@ -112,8 +119,6 @@ class ModelConfig:
|
|||||||
self.num_hidden_layers = self.hf_text_config.num_hidden_layers
|
self.num_hidden_layers = self.hf_text_config.num_hidden_layers
|
||||||
self.vocab_size = self.hf_text_config.vocab_size
|
self.vocab_size = self.hf_text_config.vocab_size
|
||||||
|
|
||||||
self.is_encoder_decoder = self.hf_config.model_type in ["mllama"]
|
|
||||||
|
|
||||||
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
|
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
|
||||||
def get_total_num_kv_heads(self) -> int:
|
def get_total_num_kv_heads(self) -> int:
|
||||||
"""Returns the total number of KV heads."""
|
"""Returns the total number of KV heads."""
|
||||||
@@ -163,7 +168,6 @@ class ModelConfig:
|
|||||||
# equal to the number of attention heads.
|
# equal to the number of attention heads.
|
||||||
return self.hf_text_config.num_attention_heads
|
return self.hf_text_config.num_attention_heads
|
||||||
|
|
||||||
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L328
|
|
||||||
def get_num_kv_heads(self, tensor_parallel_size) -> int:
|
def get_num_kv_heads(self, tensor_parallel_size) -> int:
|
||||||
"""Returns the number of KV heads per GPU."""
|
"""Returns the number of KV heads per GPU."""
|
||||||
total_num_kv_heads = self.get_total_num_kv_heads()
|
total_num_kv_heads = self.get_total_num_kv_heads()
|
||||||
@@ -192,3 +196,37 @@ def get_hf_text_config(config: PretrainedConfig):
|
|||||||
return config.text_config
|
return config.text_config
|
||||||
else:
|
else:
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def is_generation_model(model_architectures: List[str], is_embedding: bool = False):
|
||||||
|
# We have two ways to determine whether a model is a generative model.
|
||||||
|
# 1. Check the model architectue
|
||||||
|
# 2. check the `is_embedding` server args
|
||||||
|
|
||||||
|
if (
|
||||||
|
"LlamaEmbeddingModel" in model_architectures
|
||||||
|
or "MistralModel" in model_architectures
|
||||||
|
or "LlamaForSequenceClassification" in model_architectures
|
||||||
|
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return not is_embedding
|
||||||
|
|
||||||
|
|
||||||
|
def is_multimodal_model(model_architectures: List[str]):
|
||||||
|
if (
|
||||||
|
"LlavaLlamaForCausalLM" in model_architectures
|
||||||
|
or "LlavaQwenForCausalLM" in model_architectures
|
||||||
|
or "LlavaMistralForCausalLM" in model_architectures
|
||||||
|
or "LlavaVidForCausalLM" in model_architectures
|
||||||
|
or "MllamaForConditionalGeneration" in model_architectures
|
||||||
|
or "Qwen2VLForConditionalGeneration" in model_architectures
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_encoder_decoder_model(model_architectures: List[str]):
|
||||||
|
return "MllamaForConditionalGeneration" in model_architectures
|
||||||
|
|||||||
@@ -180,7 +180,7 @@ class LlavaImageProcessor(BaseImageProcessor):
|
|||||||
"pixel_values": pixel_values,
|
"pixel_values": pixel_values,
|
||||||
"image_hashes": image_hashes,
|
"image_hashes": image_hashes,
|
||||||
"image_sizes": image_sizes,
|
"image_sizes": image_sizes,
|
||||||
"modalities": request_obj.modalities,
|
"modalities": request_obj.modalities or ["image"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ limitations under the License.
|
|||||||
|
|
||||||
"""A scheduler that manages a tensor parallel GPU worker."""
|
"""A scheduler that manages a tensor parallel GPU worker."""
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
@@ -23,7 +22,7 @@ import time
|
|||||||
import warnings
|
import warnings
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import zmq
|
import zmq
|
||||||
@@ -68,8 +67,6 @@ from sglang.srt.utils import (
|
|||||||
broadcast_pyobj,
|
broadcast_pyobj,
|
||||||
configure_logger,
|
configure_logger,
|
||||||
get_zmq_socket,
|
get_zmq_socket,
|
||||||
is_generation_model,
|
|
||||||
is_multimodal_model,
|
|
||||||
kill_parent_process,
|
kill_parent_process,
|
||||||
set_random_seed,
|
set_random_seed,
|
||||||
suppress_other_loggers,
|
suppress_other_loggers,
|
||||||
@@ -133,15 +130,17 @@ class Scheduler:
|
|||||||
# Init tokenizer
|
# Init tokenizer
|
||||||
self.model_config = ModelConfig(
|
self.model_config = ModelConfig(
|
||||||
server_args.model_path,
|
server_args.model_path,
|
||||||
server_args.trust_remote_code,
|
trust_remote_code=server_args.trust_remote_code,
|
||||||
context_length=server_args.context_length,
|
context_length=server_args.context_length,
|
||||||
model_override_args=json.loads(server_args.json_model_override_args),
|
model_override_args=server_args.json_model_override_args,
|
||||||
|
is_embedding=server_args.is_embedding,
|
||||||
)
|
)
|
||||||
|
self.is_generation = self.model_config.is_generation
|
||||||
|
|
||||||
if server_args.skip_tokenizer_init:
|
if server_args.skip_tokenizer_init:
|
||||||
self.tokenizer = self.processor = None
|
self.tokenizer = self.processor = None
|
||||||
else:
|
else:
|
||||||
if is_multimodal_model(self.model_config.hf_config.architectures):
|
if self.model_config.is_multimodal:
|
||||||
self.processor = get_processor(
|
self.processor = get_processor(
|
||||||
server_args.tokenizer_path,
|
server_args.tokenizer_path,
|
||||||
tokenizer_mode=server_args.tokenizer_mode,
|
tokenizer_mode=server_args.tokenizer_mode,
|
||||||
@@ -154,9 +153,6 @@ class Scheduler:
|
|||||||
tokenizer_mode=server_args.tokenizer_mode,
|
tokenizer_mode=server_args.tokenizer_mode,
|
||||||
trust_remote_code=server_args.trust_remote_code,
|
trust_remote_code=server_args.trust_remote_code,
|
||||||
)
|
)
|
||||||
self.is_generation = is_generation_model(
|
|
||||||
self.model_config.hf_config.architectures, self.server_args.is_embedding
|
|
||||||
)
|
|
||||||
|
|
||||||
# Launch a tensor parallel worker
|
# Launch a tensor parallel worker
|
||||||
if self.enable_overlap:
|
if self.enable_overlap:
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ limitations under the License.
|
|||||||
import asyncio
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
@@ -31,12 +30,8 @@ import zmq
|
|||||||
import zmq.asyncio
|
import zmq.asyncio
|
||||||
from fastapi import BackgroundTasks
|
from fastapi import BackgroundTasks
|
||||||
|
|
||||||
from sglang.srt.hf_transformers_utils import (
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
get_config,
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||||
get_context_length,
|
|
||||||
get_processor,
|
|
||||||
get_tokenizer,
|
|
||||||
)
|
|
||||||
from sglang.srt.managers.image_processor import (
|
from sglang.srt.managers.image_processor import (
|
||||||
get_dummy_image_processor,
|
get_dummy_image_processor,
|
||||||
get_image_processor,
|
get_image_processor,
|
||||||
@@ -59,12 +54,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import get_zmq_socket, kill_child_process
|
||||||
get_zmq_socket,
|
|
||||||
is_generation_model,
|
|
||||||
is_multimodal_model,
|
|
||||||
kill_child_process,
|
|
||||||
)
|
|
||||||
|
|
||||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||||
|
|
||||||
@@ -103,18 +93,17 @@ class TokenizerManager:
|
|||||||
# Read model args
|
# Read model args
|
||||||
self.model_path = server_args.model_path
|
self.model_path = server_args.model_path
|
||||||
self.served_model_name = server_args.served_model_name
|
self.served_model_name = server_args.served_model_name
|
||||||
self.hf_config = get_config(
|
self.model_config = ModelConfig(
|
||||||
self.model_path,
|
server_args.model_path,
|
||||||
trust_remote_code=server_args.trust_remote_code,
|
trust_remote_code=server_args.trust_remote_code,
|
||||||
model_override_args=json.loads(server_args.json_model_override_args),
|
context_length=server_args.context_length,
|
||||||
)
|
model_override_args=server_args.json_model_override_args,
|
||||||
self.is_generation = is_generation_model(
|
is_embedding=server_args.is_embedding,
|
||||||
self.hf_config.architectures, self.server_args.is_embedding
|
|
||||||
)
|
|
||||||
self.context_len = server_args.context_length or get_context_length(
|
|
||||||
self.hf_config
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.is_generation = self.model_config.is_generation
|
||||||
|
self.context_len = self.model_config.context_len
|
||||||
|
|
||||||
# Create image processor placeholder
|
# Create image processor placeholder
|
||||||
self.image_processor = get_dummy_image_processor()
|
self.image_processor = get_dummy_image_processor()
|
||||||
|
|
||||||
@@ -122,7 +111,7 @@ class TokenizerManager:
|
|||||||
if server_args.skip_tokenizer_init:
|
if server_args.skip_tokenizer_init:
|
||||||
self.tokenizer = self.processor = None
|
self.tokenizer = self.processor = None
|
||||||
else:
|
else:
|
||||||
if is_multimodal_model(self.hf_config.architectures):
|
if self.model_config.is_multimodal:
|
||||||
self.processor = get_processor(
|
self.processor = get_processor(
|
||||||
server_args.tokenizer_path,
|
server_args.tokenizer_path,
|
||||||
tokenizer_mode=server_args.tokenizer_mode,
|
tokenizer_mode=server_args.tokenizer_mode,
|
||||||
@@ -133,7 +122,7 @@ class TokenizerManager:
|
|||||||
|
|
||||||
# We want to parallelize the image pre-processing so we create an executor for it
|
# We want to parallelize the image pre-processing so we create an executor for it
|
||||||
self.image_processor = get_image_processor(
|
self.image_processor = get_image_processor(
|
||||||
self.hf_config, server_args, self.processor
|
self.model_config.hf_config, server_args, self.processor
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.tokenizer = get_tokenizer(
|
self.tokenizer = get_tokenizer(
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ limitations under the License.
|
|||||||
|
|
||||||
"""A tensor parallel worker."""
|
"""A tensor parallel worker."""
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -26,7 +25,7 @@ from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_a
|
|||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import broadcast_pyobj, is_multimodal_model, set_random_seed
|
from sglang.srt.utils import broadcast_pyobj, set_random_seed
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -48,9 +47,10 @@ class TpModelWorker:
|
|||||||
# Init model and tokenizer
|
# Init model and tokenizer
|
||||||
self.model_config = ModelConfig(
|
self.model_config = ModelConfig(
|
||||||
server_args.model_path,
|
server_args.model_path,
|
||||||
server_args.trust_remote_code,
|
trust_remote_code=server_args.trust_remote_code,
|
||||||
context_length=server_args.context_length,
|
context_length=server_args.context_length,
|
||||||
model_override_args=json.loads(server_args.json_model_override_args),
|
model_override_args=server_args.json_model_override_args,
|
||||||
|
is_embedding=server_args.is_embedding,
|
||||||
)
|
)
|
||||||
self.model_runner = ModelRunner(
|
self.model_runner = ModelRunner(
|
||||||
model_config=self.model_config,
|
model_config=self.model_config,
|
||||||
@@ -64,7 +64,7 @@ class TpModelWorker:
|
|||||||
if server_args.skip_tokenizer_init:
|
if server_args.skip_tokenizer_init:
|
||||||
self.tokenizer = self.processor = None
|
self.tokenizer = self.processor = None
|
||||||
else:
|
else:
|
||||||
if is_multimodal_model(self.model_config.hf_config.architectures):
|
if self.model_config.is_multimodal:
|
||||||
self.processor = get_processor(
|
self.processor = get_processor(
|
||||||
server_args.tokenizer_path,
|
server_args.tokenizer_path,
|
||||||
tokenizer_mode=server_args.tokenizer_mode,
|
tokenizer_mode=server_args.tokenizer_mode,
|
||||||
|
|||||||
@@ -59,11 +59,6 @@ from sglang.srt.server_args import ServerArgs
|
|||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
enable_show_time_cost,
|
enable_show_time_cost,
|
||||||
get_available_gpu_memory,
|
get_available_gpu_memory,
|
||||||
is_attention_free_model,
|
|
||||||
is_embedding_model,
|
|
||||||
is_generation_model,
|
|
||||||
is_multimodal_model,
|
|
||||||
model_has_inner_state,
|
|
||||||
monkey_patch_vllm_dummy_weight_loader,
|
monkey_patch_vllm_dummy_weight_loader,
|
||||||
monkey_patch_vllm_p2p_access_check,
|
monkey_patch_vllm_p2p_access_check,
|
||||||
)
|
)
|
||||||
@@ -93,9 +88,8 @@ class ModelRunner:
|
|||||||
self.tp_size = tp_size
|
self.tp_size = tp_size
|
||||||
self.dist_port = nccl_port
|
self.dist_port = nccl_port
|
||||||
self.server_args = server_args
|
self.server_args = server_args
|
||||||
self.is_multimodal_model = is_multimodal_model(
|
self.is_generation = model_config.is_generation
|
||||||
self.model_config.hf_config.architectures
|
self.is_multimodal = model_config.is_multimodal
|
||||||
)
|
|
||||||
|
|
||||||
# Model-specific adjustment
|
# Model-specific adjustment
|
||||||
if (
|
if (
|
||||||
@@ -119,7 +113,7 @@ class ModelRunner:
|
|||||||
self.server_args.ds_heavy_channel_type
|
self.server_args.ds_heavy_channel_type
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.is_multimodal_model:
|
if self.is_multimodal:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
|
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
|
||||||
)
|
)
|
||||||
@@ -270,9 +264,6 @@ class ModelRunner:
|
|||||||
if hasattr(self.model, "get_attention_sliding_window_size")
|
if hasattr(self.model, "get_attention_sliding_window_size")
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
self.is_generation = is_generation_model(
|
|
||||||
self.model_config.hf_config.architectures, self.server_args.is_embedding
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Load weight end. "
|
f"Load weight end. "
|
||||||
@@ -679,7 +670,7 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
|
|||||||
|
|
||||||
# Monkey patch model loader
|
# Monkey patch model loader
|
||||||
setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)
|
setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)
|
||||||
setattr(ModelRegistry, "is_multimodal_model", is_multimodal_model)
|
setattr(ModelRegistry, "is_multimodal_model", lambda model_architectures: False)
|
||||||
setattr(ModelRegistry, "is_attention_free_model", is_attention_free_model)
|
setattr(ModelRegistry, "is_attention_free_model", lambda model_architectures: False)
|
||||||
setattr(ModelRegistry, "model_has_inner_state", model_has_inner_state)
|
setattr(ModelRegistry, "model_has_inner_state", lambda model_architectures: False)
|
||||||
setattr(ModelRegistry, "is_embedding_model", is_embedding_model)
|
setattr(ModelRegistry, "is_embedding_model", lambda model_architectures: False)
|
||||||
|
|||||||
@@ -409,11 +409,13 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
if (
|
if (
|
||||||
hasattr(self.config, "tie_word_embeddings")
|
hasattr(self.config, "tie_word_embeddings")
|
||||||
and self.config.tie_word_embeddings
|
and self.config.tie_word_embeddings
|
||||||
|
and "lm_head.weight" in params_dict
|
||||||
):
|
):
|
||||||
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
|
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
|
||||||
param = self.lm_head.weight
|
param = self.lm_head.weight
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
weight_loader(param, self.model.embed_tokens.weight)
|
weight_loader(param, self.model.embed_tokens.weight)
|
||||||
|
|
||||||
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
|
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
|
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
|
||||||
from functools import lru_cache, partial
|
from functools import lru_cache, partial
|
||||||
from typing import Iterable, List, Mapping, Optional, Tuple, Type, TypedDict, Union
|
from typing import Iterable, List, Optional, Tuple, Type, TypedDict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -36,7 +36,6 @@ from vllm.distributed import utils as dist_utils
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.activation import QuickGELU
|
from vllm.model_executor.layers.activation import QuickGELU
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
|
||||||
|
|
||||||
from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
|
from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
|
||||||
from sglang.srt.hf_transformers_utils import get_processor
|
from sglang.srt.hf_transformers_utils import get_processor
|
||||||
@@ -486,7 +485,7 @@ class Qwen2VisionTransformer(nn.Module):
|
|||||||
cached_get_processor = lru_cache(get_processor)
|
cached_get_processor = lru_cache(get_processor)
|
||||||
|
|
||||||
|
|
||||||
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
|
class Qwen2VLForConditionalGeneration(nn.Module):
|
||||||
def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
|
def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
|
||||||
processor = cached_get_processor(self.config._name_or_path)
|
processor = cached_get_processor(self.config._name_or_path)
|
||||||
grid_t, grid_h, grid_w = image_grid_thw
|
grid_t, grid_h, grid_w = image_grid_thw
|
||||||
@@ -536,15 +535,12 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: Qwen2VLConfig,
|
config: Qwen2VLConfig,
|
||||||
multimodal_config: MultiModalConfig,
|
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.multimodal_config = multimodal_config
|
|
||||||
|
|
||||||
self.visual = Qwen2VisionTransformer(
|
self.visual = Qwen2VisionTransformer(
|
||||||
config.vision_config,
|
config.vision_config,
|
||||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||||
|
|||||||
@@ -204,56 +204,6 @@ def is_port_available(port):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_multimodal_model(model_architectures):
|
|
||||||
if (
|
|
||||||
"LlavaLlamaForCausalLM" in model_architectures
|
|
||||||
or "LlavaQwenForCausalLM" in model_architectures
|
|
||||||
or "LlavaMistralForCausalLM" in model_architectures
|
|
||||||
or "LlavaVidForCausalLM" in model_architectures
|
|
||||||
or "MllamaForConditionalGeneration" in model_architectures
|
|
||||||
or "Qwen2VLForConditionalGeneration" in model_architectures
|
|
||||||
):
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def is_attention_free_model(model_architectures):
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def model_has_inner_state(model_architectures):
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def is_embedding_model(model_architectures):
|
|
||||||
if (
|
|
||||||
"LlamaEmbeddingModel" in model_architectures
|
|
||||||
or "MistralModel" in model_architectures
|
|
||||||
or "LlamaForSequenceClassification" in model_architectures
|
|
||||||
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
|
|
||||||
):
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def is_generation_model(model_architectures, is_embedding: bool = False):
|
|
||||||
# We have two ways to determine whether a model is a generative model.
|
|
||||||
# 1. Check the model architectue
|
|
||||||
# 2. check the `is_embedding` server args
|
|
||||||
|
|
||||||
if (
|
|
||||||
"LlamaEmbeddingModel" in model_architectures
|
|
||||||
or "MistralModel" in model_architectures
|
|
||||||
or "LlamaForSequenceClassification" in model_architectures
|
|
||||||
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
|
|
||||||
):
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
return not is_embedding
|
|
||||||
|
|
||||||
|
|
||||||
def decode_video_base64(video_base64):
|
def decode_video_base64(video_base64):
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
|
|||||||
Reference in New Issue
Block a user