diff --git a/python/pyproject.toml b/python/pyproject.toml index 48648656e..4f2035ad1 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ [project.optional-dependencies] srt = ["aiohttp", "fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn", "zmq", "vllm>=0.2.5", "interegular", "lark", "numba", - "pydantic", "diskcache", "cloudpickle"] + "pydantic", "diskcache", "cloudpickle", "pillow"] openai = ["openai>=1.0", "numpy"] anthropic = ["anthropic", "numpy"] all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"] diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/router/model_runner.py index 4914ea2ec..c85ec534d 100644 --- a/python/sglang/srt/managers/router/model_runner.py +++ b/python/sglang/srt/managers/router/model_runner.py @@ -1,10 +1,13 @@ +import importlib import logging from dataclasses import dataclass -from enum import Enum, auto +from functools import lru_cache +from pathlib import Path from typing import List import numpy as np import torch +import sglang from sglang.srt.managers.router.infer_batch import Batch, ForwardMode from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.utils import is_multimodal_model @@ -20,6 +23,32 @@ logger = logging.getLogger("model_runner") global_model_mode: List[str] = [] +@lru_cache() +def import_model_classes(): + model_arch_name_to_cls = {} + for module_path in (Path(sglang.__file__).parent / "srt" / "models").glob("*.py"): + module = importlib.import_module(f"sglang.srt.models.{module_path.stem}") + if hasattr(module, "EntryClass"): + model_arch_name_to_cls[module.EntryClass.__name__] = module.EntryClass + return model_arch_name_to_cls + + +def get_model_cls_by_arch_name(model_arch_names): + model_arch_name_to_cls = import_model_classes() + + model_class = None + for arch in model_arch_names: + if arch in model_arch_name_to_cls: + model_class = model_arch_name_to_cls[arch] + break + else: + raise ValueError( + f"Unsupported architectures: {arch}. " + f"Supported list: {list(model_arch_name_to_cls.keys())}" + ) + return model_class + + @dataclass class InputMetadata: model_runner: "ModelRunner" @@ -237,34 +266,9 @@ class ModelRunner: def load_model(self): """See also vllm/model_executor/model_loader.py::get_model""" - from sglang.srt.models.llama2 import LlamaForCausalLM - from sglang.srt.models.llava import LlavaLlamaForCausalLM - from sglang.srt.models.mixtral import MixtralForCausalLM - from sglang.srt.models.qwen import QWenLMHeadModel - # Select model class architectures = getattr(self.model_config.hf_config, "architectures", []) - - model_class = None - for arch in architectures: - if arch == "LlamaForCausalLM": - model_class = LlamaForCausalLM - break - if arch == "MistralForCausalLM": - model_class = LlamaForCausalLM - break - if arch == "LlavaLlamaForCausalLM": - model_class = LlavaLlamaForCausalLM - break - if arch == "MixtralForCausalLM": - model_class = MixtralForCausalLM - break - if arch == "QWenLMHeadModel": - model_class = QWenLMHeadModel - break - if model_class is None: - raise ValueError(f"Unsupported architectures: {architectures}") - + model_class = get_model_cls_by_arch_name(architectures) logger.info(f"Rank {self.tp_rank}: load weight begin.") # Load weights diff --git a/python/sglang/srt/models/llama2.py b/python/sglang/srt/models/llama2.py index c5690838b..b4ee11d5b 100644 --- a/python/sglang/srt/models/llama2.py +++ b/python/sglang/srt/models/llama2.py @@ -318,3 +318,5 @@ class LlamaForCausalLM(nn.Module): param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + +EntryClass = LlamaForCausalLM diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index b35e902c0..cd3e93cbd 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -330,3 +330,5 @@ def monkey_path_clip_vision_embed_forward(): "forward", clip_vision_embed_forward, ) + +EntryClass = LlavaLlamaForCausalLM diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index 0a82d3dd8..2f376983c 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -376,3 +376,5 @@ class MixtralForCausalLM(nn.Module): param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + +EntryClass = MixtralForCausalLM diff --git a/python/sglang/srt/models/qwen.py b/python/sglang/srt/models/qwen.py index ba59d5bb6..acd9af464 100644 --- a/python/sglang/srt/models/qwen.py +++ b/python/sglang/srt/models/qwen.py @@ -258,3 +258,5 @@ class QWenLMHeadModel(nn.Module): param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + +EntryClass = QWenLMHeadModel