Unify the model type checking (#1905)
This commit is contained in:
@@ -129,9 +129,9 @@ def load_model(server_args, port_args, tp_rank):
|
||||
|
||||
model_config = ModelConfig(
|
||||
server_args.model_path,
|
||||
server_args.trust_remote_code,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
context_length=server_args.context_length,
|
||||
model_override_args=json.loads(server_args.json_model_override_args),
|
||||
model_override_args=server_args.json_model_override_args,
|
||||
)
|
||||
model_runner = ModelRunner(
|
||||
model_config=model_config,
|
||||
|
||||
@@ -116,6 +116,23 @@ register_chat_template(
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="chatml-llava",
|
||||
default_system_prompt="You are a helpful assistant.",
|
||||
role_prefix_and_suffix={
|
||||
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
|
||||
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
|
||||
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
|
||||
},
|
||||
style=ChatTemplateStyle.PLAIN,
|
||||
stop_str=("<|im_end|>",),
|
||||
image_token="<image>\n",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# There is default system prompt for qwen
|
||||
# reference: https://modelscope.cn/models/qwen/Qwen2-72B-Instruct/file/view/master?fileName=tokenizer_config.json&status=1
|
||||
# The chat template is: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
|
||||
@@ -149,22 +166,6 @@ register_chat_template(
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="chatml-llava",
|
||||
default_system_prompt="You are a helpful assistant.",
|
||||
role_prefix_and_suffix={
|
||||
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
|
||||
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
|
||||
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
|
||||
},
|
||||
style=ChatTemplateStyle.PLAIN,
|
||||
stop_str=("<|im_end|>",),
|
||||
image_token="<image>\n",
|
||||
)
|
||||
)
|
||||
|
||||
# Reference: https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#prompt-template
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
@@ -182,21 +183,6 @@ register_chat_template(
|
||||
)
|
||||
)
|
||||
|
||||
# Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="yi-1.5",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("", ""),
|
||||
"user": ("<|im_start|>user\n", "<|im_end|>\n<|im_start|>assistant\n"),
|
||||
"assistant": ("", "<|im_end|>\n"),
|
||||
},
|
||||
style=ChatTemplateStyle.PLAIN,
|
||||
stop_str=("<|im_end|>",),
|
||||
)
|
||||
)
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="llama-2-chat",
|
||||
@@ -233,6 +219,45 @@ register_chat_template(
|
||||
)
|
||||
)
|
||||
|
||||
# The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token.
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="llama-3-instruct-llava",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": (
|
||||
"<|start_header_id|>system<|end_header_id|>\n\n",
|
||||
"<|eot_id|>",
|
||||
),
|
||||
"user": (
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n",
|
||||
"<|eot_id|>",
|
||||
),
|
||||
"assistant": (
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
"<|eot_id|>",
|
||||
),
|
||||
},
|
||||
stop_str=("<|eot_id|>",),
|
||||
image_token="<image>\n",
|
||||
)
|
||||
)
|
||||
|
||||
# Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="yi-1.5",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("", ""),
|
||||
"user": ("<|im_start|>user\n", "<|im_end|>\n<|im_start|>assistant\n"),
|
||||
"assistant": ("", "<|im_end|>\n"),
|
||||
},
|
||||
style=ChatTemplateStyle.PLAIN,
|
||||
stop_str=("<|im_end|>",),
|
||||
)
|
||||
)
|
||||
|
||||
# Reference: https://github.com/01-ai/Yi/tree/main/VL#major-difference-with-llava
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
|
||||
@@ -13,10 +13,11 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from enum import IntEnum, auto
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
@@ -38,18 +39,24 @@ class ModelConfig:
|
||||
revision: Optional[str] = None,
|
||||
context_length: Optional[int] = None,
|
||||
model_override_args: Optional[dict] = None,
|
||||
is_embedding: Optional[bool] = None
|
||||
) -> None:
|
||||
self.path = path
|
||||
self.trust_remote_code = trust_remote_code
|
||||
self.revision = revision
|
||||
self.model_override_args = model_override_args
|
||||
# Parse args
|
||||
self.model_override_args = json.loads(model_override_args)
|
||||
self.hf_config = get_config(
|
||||
self.path,
|
||||
trust_remote_code,
|
||||
revision,
|
||||
model_override_args=model_override_args,
|
||||
path,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
model_override_args=self.model_override_args,
|
||||
)
|
||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||
|
||||
# Check model type
|
||||
self.is_generation = is_generation_model(self.hf_config.architectures, is_embedding)
|
||||
self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
|
||||
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
|
||||
|
||||
# Derive context length
|
||||
derived_context_len = get_context_length(self.hf_text_config)
|
||||
allow_long_context = os.environ.get(
|
||||
"SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", None
|
||||
@@ -81,7 +88,7 @@ class ModelConfig:
|
||||
self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads,
|
||||
)
|
||||
|
||||
# FIXME: temporary special judge for deepseek v2 MLA architecture
|
||||
# FIXME: temporary special judge for MLA architecture
|
||||
if "DeepseekV2ForCausalLM" in self.hf_config.architectures:
|
||||
self.head_dim = 256
|
||||
self.attention_arch = AttentionArch.MLA
|
||||
@@ -112,8 +119,6 @@ class ModelConfig:
|
||||
self.num_hidden_layers = self.hf_text_config.num_hidden_layers
|
||||
self.vocab_size = self.hf_text_config.vocab_size
|
||||
|
||||
self.is_encoder_decoder = self.hf_config.model_type in ["mllama"]
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
|
||||
def get_total_num_kv_heads(self) -> int:
|
||||
"""Returns the total number of KV heads."""
|
||||
@@ -163,7 +168,6 @@ class ModelConfig:
|
||||
# equal to the number of attention heads.
|
||||
return self.hf_text_config.num_attention_heads
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L328
|
||||
def get_num_kv_heads(self, tensor_parallel_size) -> int:
|
||||
"""Returns the number of KV heads per GPU."""
|
||||
total_num_kv_heads = self.get_total_num_kv_heads()
|
||||
@@ -192,3 +196,37 @@ def get_hf_text_config(config: PretrainedConfig):
|
||||
return config.text_config
|
||||
else:
|
||||
return config
|
||||
|
||||
|
||||
def is_generation_model(model_architectures: List[str], is_embedding: bool = False):
|
||||
# We have two ways to determine whether a model is a generative model.
|
||||
# 1. Check the model architectue
|
||||
# 2. check the `is_embedding` server args
|
||||
|
||||
if (
|
||||
"LlamaEmbeddingModel" in model_architectures
|
||||
or "MistralModel" in model_architectures
|
||||
or "LlamaForSequenceClassification" in model_architectures
|
||||
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
|
||||
):
|
||||
return False
|
||||
else:
|
||||
return not is_embedding
|
||||
|
||||
|
||||
def is_multimodal_model(model_architectures: List[str]):
|
||||
if (
|
||||
"LlavaLlamaForCausalLM" in model_architectures
|
||||
or "LlavaQwenForCausalLM" in model_architectures
|
||||
or "LlavaMistralForCausalLM" in model_architectures
|
||||
or "LlavaVidForCausalLM" in model_architectures
|
||||
or "MllamaForConditionalGeneration" in model_architectures
|
||||
or "Qwen2VLForConditionalGeneration" in model_architectures
|
||||
):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def is_encoder_decoder_model(model_architectures: List[str]):
|
||||
return "MllamaForConditionalGeneration" in model_architectures
|
||||
|
||||
@@ -180,7 +180,7 @@ class LlavaImageProcessor(BaseImageProcessor):
|
||||
"pixel_values": pixel_values,
|
||||
"image_hashes": image_hashes,
|
||||
"image_sizes": image_sizes,
|
||||
"modalities": request_obj.modalities,
|
||||
"modalities": request_obj.modalities or ["image"],
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ limitations under the License.
|
||||
|
||||
"""A scheduler that manages a tensor parallel GPU worker."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
@@ -23,7 +22,7 @@ import time
|
||||
import warnings
|
||||
from collections import deque
|
||||
from types import SimpleNamespace
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import zmq
|
||||
@@ -68,8 +67,6 @@ from sglang.srt.utils import (
|
||||
broadcast_pyobj,
|
||||
configure_logger,
|
||||
get_zmq_socket,
|
||||
is_generation_model,
|
||||
is_multimodal_model,
|
||||
kill_parent_process,
|
||||
set_random_seed,
|
||||
suppress_other_loggers,
|
||||
@@ -133,15 +130,17 @@ class Scheduler:
|
||||
# Init tokenizer
|
||||
self.model_config = ModelConfig(
|
||||
server_args.model_path,
|
||||
server_args.trust_remote_code,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
context_length=server_args.context_length,
|
||||
model_override_args=json.loads(server_args.json_model_override_args),
|
||||
model_override_args=server_args.json_model_override_args,
|
||||
is_embedding=server_args.is_embedding,
|
||||
)
|
||||
self.is_generation = self.model_config.is_generation
|
||||
|
||||
if server_args.skip_tokenizer_init:
|
||||
self.tokenizer = self.processor = None
|
||||
else:
|
||||
if is_multimodal_model(self.model_config.hf_config.architectures):
|
||||
if self.model_config.is_multimodal:
|
||||
self.processor = get_processor(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
@@ -154,9 +153,6 @@ class Scheduler:
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
self.is_generation = is_generation_model(
|
||||
self.model_config.hf_config.architectures, self.server_args.is_embedding
|
||||
)
|
||||
|
||||
# Launch a tensor parallel worker
|
||||
if self.enable_overlap:
|
||||
|
||||
@@ -18,7 +18,6 @@ limitations under the License.
|
||||
import asyncio
|
||||
import copy
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
@@ -31,12 +30,8 @@ import zmq
|
||||
import zmq.asyncio
|
||||
from fastapi import BackgroundTasks
|
||||
|
||||
from sglang.srt.hf_transformers_utils import (
|
||||
get_config,
|
||||
get_context_length,
|
||||
get_processor,
|
||||
get_tokenizer,
|
||||
)
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.managers.image_processor import (
|
||||
get_dummy_image_processor,
|
||||
get_image_processor,
|
||||
@@ -59,12 +54,7 @@ from sglang.srt.managers.io_struct import (
|
||||
)
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
get_zmq_socket,
|
||||
is_generation_model,
|
||||
is_multimodal_model,
|
||||
kill_child_process,
|
||||
)
|
||||
from sglang.srt.utils import get_zmq_socket, kill_child_process
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
@@ -103,18 +93,17 @@ class TokenizerManager:
|
||||
# Read model args
|
||||
self.model_path = server_args.model_path
|
||||
self.served_model_name = server_args.served_model_name
|
||||
self.hf_config = get_config(
|
||||
self.model_path,
|
||||
self.model_config = ModelConfig(
|
||||
server_args.model_path,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
model_override_args=json.loads(server_args.json_model_override_args),
|
||||
)
|
||||
self.is_generation = is_generation_model(
|
||||
self.hf_config.architectures, self.server_args.is_embedding
|
||||
)
|
||||
self.context_len = server_args.context_length or get_context_length(
|
||||
self.hf_config
|
||||
context_length=server_args.context_length,
|
||||
model_override_args=server_args.json_model_override_args,
|
||||
is_embedding=server_args.is_embedding,
|
||||
)
|
||||
|
||||
self.is_generation = self.model_config.is_generation
|
||||
self.context_len = self.model_config.context_len
|
||||
|
||||
# Create image processor placeholder
|
||||
self.image_processor = get_dummy_image_processor()
|
||||
|
||||
@@ -122,7 +111,7 @@ class TokenizerManager:
|
||||
if server_args.skip_tokenizer_init:
|
||||
self.tokenizer = self.processor = None
|
||||
else:
|
||||
if is_multimodal_model(self.hf_config.architectures):
|
||||
if self.model_config.is_multimodal:
|
||||
self.processor = get_processor(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
@@ -133,7 +122,7 @@ class TokenizerManager:
|
||||
|
||||
# We want to parallelize the image pre-processing so we create an executor for it
|
||||
self.image_processor = get_image_processor(
|
||||
self.hf_config, server_args, self.processor
|
||||
self.model_config.hf_config, server_args, self.processor
|
||||
)
|
||||
else:
|
||||
self.tokenizer = get_tokenizer(
|
||||
|
||||
@@ -15,7 +15,6 @@ limitations under the License.
|
||||
|
||||
"""A tensor parallel worker."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
@@ -26,7 +25,7 @@ from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_a
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import broadcast_pyobj, is_multimodal_model, set_random_seed
|
||||
from sglang.srt.utils import broadcast_pyobj, set_random_seed
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -48,9 +47,10 @@ class TpModelWorker:
|
||||
# Init model and tokenizer
|
||||
self.model_config = ModelConfig(
|
||||
server_args.model_path,
|
||||
server_args.trust_remote_code,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
context_length=server_args.context_length,
|
||||
model_override_args=json.loads(server_args.json_model_override_args),
|
||||
model_override_args=server_args.json_model_override_args,
|
||||
is_embedding=server_args.is_embedding,
|
||||
)
|
||||
self.model_runner = ModelRunner(
|
||||
model_config=self.model_config,
|
||||
@@ -64,7 +64,7 @@ class TpModelWorker:
|
||||
if server_args.skip_tokenizer_init:
|
||||
self.tokenizer = self.processor = None
|
||||
else:
|
||||
if is_multimodal_model(self.model_config.hf_config.architectures):
|
||||
if self.model_config.is_multimodal:
|
||||
self.processor = get_processor(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
|
||||
@@ -59,11 +59,6 @@ from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
enable_show_time_cost,
|
||||
get_available_gpu_memory,
|
||||
is_attention_free_model,
|
||||
is_embedding_model,
|
||||
is_generation_model,
|
||||
is_multimodal_model,
|
||||
model_has_inner_state,
|
||||
monkey_patch_vllm_dummy_weight_loader,
|
||||
monkey_patch_vllm_p2p_access_check,
|
||||
)
|
||||
@@ -93,9 +88,8 @@ class ModelRunner:
|
||||
self.tp_size = tp_size
|
||||
self.dist_port = nccl_port
|
||||
self.server_args = server_args
|
||||
self.is_multimodal_model = is_multimodal_model(
|
||||
self.model_config.hf_config.architectures
|
||||
)
|
||||
self.is_generation = model_config.is_generation
|
||||
self.is_multimodal = model_config.is_multimodal
|
||||
|
||||
# Model-specific adjustment
|
||||
if (
|
||||
@@ -119,7 +113,7 @@ class ModelRunner:
|
||||
self.server_args.ds_heavy_channel_type
|
||||
)
|
||||
|
||||
if self.is_multimodal_model:
|
||||
if self.is_multimodal:
|
||||
logger.warning(
|
||||
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
|
||||
)
|
||||
@@ -270,9 +264,6 @@ class ModelRunner:
|
||||
if hasattr(self.model, "get_attention_sliding_window_size")
|
||||
else None
|
||||
)
|
||||
self.is_generation = is_generation_model(
|
||||
self.model_config.hf_config.architectures, self.server_args.is_embedding
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Load weight end. "
|
||||
@@ -679,7 +670,7 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
|
||||
|
||||
# Monkey patch model loader
|
||||
setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)
|
||||
setattr(ModelRegistry, "is_multimodal_model", is_multimodal_model)
|
||||
setattr(ModelRegistry, "is_attention_free_model", is_attention_free_model)
|
||||
setattr(ModelRegistry, "model_has_inner_state", model_has_inner_state)
|
||||
setattr(ModelRegistry, "is_embedding_model", is_embedding_model)
|
||||
setattr(ModelRegistry, "is_multimodal_model", lambda model_architectures: False)
|
||||
setattr(ModelRegistry, "is_attention_free_model", lambda model_architectures: False)
|
||||
setattr(ModelRegistry, "model_has_inner_state", lambda model_architectures: False)
|
||||
setattr(ModelRegistry, "is_embedding_model", lambda model_architectures: False)
|
||||
|
||||
@@ -409,11 +409,13 @@ class LlamaForCausalLM(nn.Module):
|
||||
if (
|
||||
hasattr(self.config, "tie_word_embeddings")
|
||||
and self.config.tie_word_embeddings
|
||||
and "lm_head.weight" in params_dict
|
||||
):
|
||||
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
|
||||
param = self.lm_head.weight
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, self.model.embed_tokens.weight)
|
||||
|
||||
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
|
||||
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
|
||||
from functools import lru_cache, partial
|
||||
from typing import Iterable, List, Mapping, Optional, Tuple, Type, TypedDict, Union
|
||||
from typing import Iterable, List, Optional, Tuple, Type, TypedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -36,7 +36,6 @@ from vllm.distributed import utils as dist_utils
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import QuickGELU
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||
|
||||
from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
|
||||
from sglang.srt.hf_transformers_utils import get_processor
|
||||
@@ -486,7 +485,7 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
cached_get_processor = lru_cache(get_processor)
|
||||
|
||||
|
||||
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
|
||||
processor = cached_get_processor(self.config._name_or_path)
|
||||
grid_t, grid_h, grid_w = image_grid_thw
|
||||
@@ -536,15 +535,12 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen2VLConfig,
|
||||
multimodal_config: MultiModalConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
self.visual = Qwen2VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
|
||||
@@ -204,56 +204,6 @@ def is_port_available(port):
|
||||
return False
|
||||
|
||||
|
||||
def is_multimodal_model(model_architectures):
|
||||
if (
|
||||
"LlavaLlamaForCausalLM" in model_architectures
|
||||
or "LlavaQwenForCausalLM" in model_architectures
|
||||
or "LlavaMistralForCausalLM" in model_architectures
|
||||
or "LlavaVidForCausalLM" in model_architectures
|
||||
or "MllamaForConditionalGeneration" in model_architectures
|
||||
or "Qwen2VLForConditionalGeneration" in model_architectures
|
||||
):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def is_attention_free_model(model_architectures):
|
||||
return False
|
||||
|
||||
|
||||
def model_has_inner_state(model_architectures):
|
||||
return False
|
||||
|
||||
|
||||
def is_embedding_model(model_architectures):
|
||||
if (
|
||||
"LlamaEmbeddingModel" in model_architectures
|
||||
or "MistralModel" in model_architectures
|
||||
or "LlamaForSequenceClassification" in model_architectures
|
||||
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
|
||||
):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def is_generation_model(model_architectures, is_embedding: bool = False):
|
||||
# We have two ways to determine whether a model is a generative model.
|
||||
# 1. Check the model architectue
|
||||
# 2. check the `is_embedding` server args
|
||||
|
||||
if (
|
||||
"LlamaEmbeddingModel" in model_architectures
|
||||
or "MistralModel" in model_architectures
|
||||
or "LlamaForSequenceClassification" in model_architectures
|
||||
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
|
||||
):
|
||||
return False
|
||||
else:
|
||||
return not is_embedding
|
||||
|
||||
|
||||
def decode_video_base64(video_base64):
|
||||
from PIL import Image
|
||||
|
||||
|
||||
Reference in New Issue
Block a user