Unify the model type checking (#1905)
This commit is contained in:
@@ -18,7 +18,6 @@ limitations under the License.
|
||||
import asyncio
|
||||
import copy
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
@@ -31,12 +30,8 @@ import zmq
|
||||
import zmq.asyncio
|
||||
from fastapi import BackgroundTasks
|
||||
|
||||
from sglang.srt.hf_transformers_utils import (
|
||||
get_config,
|
||||
get_context_length,
|
||||
get_processor,
|
||||
get_tokenizer,
|
||||
)
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.managers.image_processor import (
|
||||
get_dummy_image_processor,
|
||||
get_image_processor,
|
||||
@@ -59,12 +54,7 @@ from sglang.srt.managers.io_struct import (
|
||||
)
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
get_zmq_socket,
|
||||
is_generation_model,
|
||||
is_multimodal_model,
|
||||
kill_child_process,
|
||||
)
|
||||
from sglang.srt.utils import get_zmq_socket, kill_child_process
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
@@ -103,18 +93,17 @@ class TokenizerManager:
|
||||
# Read model args
|
||||
self.model_path = server_args.model_path
|
||||
self.served_model_name = server_args.served_model_name
|
||||
self.hf_config = get_config(
|
||||
self.model_path,
|
||||
self.model_config = ModelConfig(
|
||||
server_args.model_path,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
model_override_args=json.loads(server_args.json_model_override_args),
|
||||
)
|
||||
self.is_generation = is_generation_model(
|
||||
self.hf_config.architectures, self.server_args.is_embedding
|
||||
)
|
||||
self.context_len = server_args.context_length or get_context_length(
|
||||
self.hf_config
|
||||
context_length=server_args.context_length,
|
||||
model_override_args=server_args.json_model_override_args,
|
||||
is_embedding=server_args.is_embedding,
|
||||
)
|
||||
|
||||
self.is_generation = self.model_config.is_generation
|
||||
self.context_len = self.model_config.context_len
|
||||
|
||||
# Create image processor placeholder
|
||||
self.image_processor = get_dummy_image_processor()
|
||||
|
||||
@@ -122,7 +111,7 @@ class TokenizerManager:
|
||||
if server_args.skip_tokenizer_init:
|
||||
self.tokenizer = self.processor = None
|
||||
else:
|
||||
if is_multimodal_model(self.hf_config.architectures):
|
||||
if self.model_config.is_multimodal:
|
||||
self.processor = get_processor(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
@@ -133,7 +122,7 @@ class TokenizerManager:
|
||||
|
||||
# We want to parallelize the image pre-processing so we create an executor for it
|
||||
self.image_processor = get_image_processor(
|
||||
self.hf_config, server_args, self.processor
|
||||
self.model_config.hf_config, server_args, self.processor
|
||||
)
|
||||
else:
|
||||
self.tokenizer = get_tokenizer(
|
||||
|
||||
Reference in New Issue
Block a user