Unify the model type checking (#1905)

This commit is contained in:
Lianmin Zheng
2024-11-03 12:25:39 -08:00
committed by GitHub
parent c17c578108
commit 0abbf289a8
13 changed files with 146 additions and 160 deletions

View File

@@ -18,7 +18,6 @@ limitations under the License.
import asyncio
import copy
import dataclasses
import json
import logging
import os
import signal
@@ -31,12 +30,8 @@ import zmq
import zmq.asyncio
from fastapi import BackgroundTasks
from sglang.srt.hf_transformers_utils import (
get_config,
get_context_length,
get_processor,
get_tokenizer,
)
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.image_processor import (
get_dummy_image_processor,
get_image_processor,
@@ -59,12 +54,7 @@ from sglang.srt.managers.io_struct import (
)
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
get_zmq_socket,
is_generation_model,
is_multimodal_model,
kill_child_process,
)
from sglang.srt.utils import get_zmq_socket, kill_child_process
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@@ -103,18 +93,17 @@ class TokenizerManager:
# Read model args
self.model_path = server_args.model_path
self.served_model_name = server_args.served_model_name
self.hf_config = get_config(
self.model_path,
self.model_config = ModelConfig(
server_args.model_path,
trust_remote_code=server_args.trust_remote_code,
model_override_args=json.loads(server_args.json_model_override_args),
)
self.is_generation = is_generation_model(
self.hf_config.architectures, self.server_args.is_embedding
)
self.context_len = server_args.context_length or get_context_length(
self.hf_config
context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
)
self.is_generation = self.model_config.is_generation
self.context_len = self.model_config.context_len
# Create image processor placeholder
self.image_processor = get_dummy_image_processor()
@@ -122,7 +111,7 @@ class TokenizerManager:
if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None
else:
if is_multimodal_model(self.hf_config.architectures):
if self.model_config.is_multimodal:
self.processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
@@ -133,7 +122,7 @@ class TokenizerManager:
# We want to parallelize the image pre-processing so we create an executor for it
self.image_processor = get_image_processor(
self.hf_config, server_args, self.processor
self.model_config.hf_config, server_args, self.processor
)
else:
self.tokenizer = get_tokenizer(