Dynamic model class loading (#101)
This commit is contained in:
@@ -20,7 +20,7 @@ dependencies = [
|
||||
[project.optional-dependencies]
|
||||
srt = ["aiohttp", "fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn",
|
||||
"zmq", "vllm>=0.2.5", "interegular", "lark", "numba",
|
||||
"pydantic", "diskcache", "cloudpickle"]
|
||||
"pydantic", "diskcache", "cloudpickle", "pillow"]
|
||||
openai = ["openai>=1.0", "numpy"]
|
||||
anthropic = ["anthropic", "numpy"]
|
||||
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
import importlib
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import sglang
|
||||
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
|
||||
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||
from sglang.srt.utils import is_multimodal_model
|
||||
@@ -20,6 +23,32 @@ logger = logging.getLogger("model_runner")
|
||||
global_model_mode: List[str] = []
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def import_model_classes():
|
||||
model_arch_name_to_cls = {}
|
||||
for module_path in (Path(sglang.__file__).parent / "srt" / "models").glob("*.py"):
|
||||
module = importlib.import_module(f"sglang.srt.models.{module_path.stem}")
|
||||
if hasattr(module, "EntryClass"):
|
||||
model_arch_name_to_cls[module.EntryClass.__name__] = module.EntryClass
|
||||
return model_arch_name_to_cls
|
||||
|
||||
|
||||
def get_model_cls_by_arch_name(model_arch_names):
|
||||
model_arch_name_to_cls = import_model_classes()
|
||||
|
||||
model_class = None
|
||||
for arch in model_arch_names:
|
||||
if arch in model_arch_name_to_cls:
|
||||
model_class = model_arch_name_to_cls[arch]
|
||||
break
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported architectures: {arch}. "
|
||||
f"Supported list: {list(model_arch_name_to_cls.keys())}"
|
||||
)
|
||||
return model_class
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputMetadata:
|
||||
model_runner: "ModelRunner"
|
||||
@@ -237,34 +266,9 @@ class ModelRunner:
|
||||
|
||||
def load_model(self):
|
||||
"""See also vllm/model_executor/model_loader.py::get_model"""
|
||||
from sglang.srt.models.llama2 import LlamaForCausalLM
|
||||
from sglang.srt.models.llava import LlavaLlamaForCausalLM
|
||||
from sglang.srt.models.mixtral import MixtralForCausalLM
|
||||
from sglang.srt.models.qwen import QWenLMHeadModel
|
||||
|
||||
# Select model class
|
||||
architectures = getattr(self.model_config.hf_config, "architectures", [])
|
||||
|
||||
model_class = None
|
||||
for arch in architectures:
|
||||
if arch == "LlamaForCausalLM":
|
||||
model_class = LlamaForCausalLM
|
||||
break
|
||||
if arch == "MistralForCausalLM":
|
||||
model_class = LlamaForCausalLM
|
||||
break
|
||||
if arch == "LlavaLlamaForCausalLM":
|
||||
model_class = LlavaLlamaForCausalLM
|
||||
break
|
||||
if arch == "MixtralForCausalLM":
|
||||
model_class = MixtralForCausalLM
|
||||
break
|
||||
if arch == "QWenLMHeadModel":
|
||||
model_class = QWenLMHeadModel
|
||||
break
|
||||
if model_class is None:
|
||||
raise ValueError(f"Unsupported architectures: {architectures}")
|
||||
|
||||
model_class = get_model_cls_by_arch_name(architectures)
|
||||
logger.info(f"Rank {self.tp_rank}: load weight begin.")
|
||||
|
||||
# Load weights
|
||||
|
||||
@@ -318,3 +318,5 @@ class LlamaForCausalLM(nn.Module):
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
EntryClass = LlamaForCausalLM
|
||||
|
||||
@@ -330,3 +330,5 @@ def monkey_path_clip_vision_embed_forward():
|
||||
"forward",
|
||||
clip_vision_embed_forward,
|
||||
)
|
||||
|
||||
EntryClass = LlavaLlamaForCausalLM
|
||||
|
||||
@@ -376,3 +376,5 @@ class MixtralForCausalLM(nn.Module):
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
EntryClass = MixtralForCausalLM
|
||||
|
||||
@@ -258,3 +258,5 @@ class QWenLMHeadModel(nn.Module):
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
EntryClass = QWenLMHeadModel
|
||||
|
||||
Reference in New Issue
Block a user