Dynamic model class loading (#101)
This commit is contained in:
@@ -20,7 +20,7 @@ dependencies = [
|
|||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
srt = ["aiohttp", "fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn",
|
srt = ["aiohttp", "fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn",
|
||||||
"zmq", "vllm>=0.2.5", "interegular", "lark", "numba",
|
"zmq", "vllm>=0.2.5", "interegular", "lark", "numba",
|
||||||
"pydantic", "diskcache", "cloudpickle"]
|
"pydantic", "diskcache", "cloudpickle", "pillow"]
|
||||||
openai = ["openai>=1.0", "numpy"]
|
openai = ["openai>=1.0", "numpy"]
|
||||||
anthropic = ["anthropic", "numpy"]
|
anthropic = ["anthropic", "numpy"]
|
||||||
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
|
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum, auto
|
from functools import lru_cache
|
||||||
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import sglang
|
||||||
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
|
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
|
||||||
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||||
from sglang.srt.utils import is_multimodal_model
|
from sglang.srt.utils import is_multimodal_model
|
||||||
@@ -20,6 +23,32 @@ logger = logging.getLogger("model_runner")
|
|||||||
global_model_mode: List[str] = []
|
global_model_mode: List[str] = []
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def import_model_classes():
|
||||||
|
model_arch_name_to_cls = {}
|
||||||
|
for module_path in (Path(sglang.__file__).parent / "srt" / "models").glob("*.py"):
|
||||||
|
module = importlib.import_module(f"sglang.srt.models.{module_path.stem}")
|
||||||
|
if hasattr(module, "EntryClass"):
|
||||||
|
model_arch_name_to_cls[module.EntryClass.__name__] = module.EntryClass
|
||||||
|
return model_arch_name_to_cls
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_cls_by_arch_name(model_arch_names):
|
||||||
|
model_arch_name_to_cls = import_model_classes()
|
||||||
|
|
||||||
|
model_class = None
|
||||||
|
for arch in model_arch_names:
|
||||||
|
if arch in model_arch_name_to_cls:
|
||||||
|
model_class = model_arch_name_to_cls[arch]
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported architectures: {arch}. "
|
||||||
|
f"Supported list: {list(model_arch_name_to_cls.keys())}"
|
||||||
|
)
|
||||||
|
return model_class
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InputMetadata:
|
class InputMetadata:
|
||||||
model_runner: "ModelRunner"
|
model_runner: "ModelRunner"
|
||||||
@@ -237,34 +266,9 @@ class ModelRunner:
|
|||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
"""See also vllm/model_executor/model_loader.py::get_model"""
|
"""See also vllm/model_executor/model_loader.py::get_model"""
|
||||||
from sglang.srt.models.llama2 import LlamaForCausalLM
|
|
||||||
from sglang.srt.models.llava import LlavaLlamaForCausalLM
|
|
||||||
from sglang.srt.models.mixtral import MixtralForCausalLM
|
|
||||||
from sglang.srt.models.qwen import QWenLMHeadModel
|
|
||||||
|
|
||||||
# Select model class
|
# Select model class
|
||||||
architectures = getattr(self.model_config.hf_config, "architectures", [])
|
architectures = getattr(self.model_config.hf_config, "architectures", [])
|
||||||
|
model_class = get_model_cls_by_arch_name(architectures)
|
||||||
model_class = None
|
|
||||||
for arch in architectures:
|
|
||||||
if arch == "LlamaForCausalLM":
|
|
||||||
model_class = LlamaForCausalLM
|
|
||||||
break
|
|
||||||
if arch == "MistralForCausalLM":
|
|
||||||
model_class = LlamaForCausalLM
|
|
||||||
break
|
|
||||||
if arch == "LlavaLlamaForCausalLM":
|
|
||||||
model_class = LlavaLlamaForCausalLM
|
|
||||||
break
|
|
||||||
if arch == "MixtralForCausalLM":
|
|
||||||
model_class = MixtralForCausalLM
|
|
||||||
break
|
|
||||||
if arch == "QWenLMHeadModel":
|
|
||||||
model_class = QWenLMHeadModel
|
|
||||||
break
|
|
||||||
if model_class is None:
|
|
||||||
raise ValueError(f"Unsupported architectures: {architectures}")
|
|
||||||
|
|
||||||
logger.info(f"Rank {self.tp_rank}: load weight begin.")
|
logger.info(f"Rank {self.tp_rank}: load weight begin.")
|
||||||
|
|
||||||
# Load weights
|
# Load weights
|
||||||
|
|||||||
@@ -318,3 +318,5 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
EntryClass = LlamaForCausalLM
|
||||||
|
|||||||
@@ -330,3 +330,5 @@ def monkey_path_clip_vision_embed_forward():
|
|||||||
"forward",
|
"forward",
|
||||||
clip_vision_embed_forward,
|
clip_vision_embed_forward,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
EntryClass = LlavaLlamaForCausalLM
|
||||||
|
|||||||
@@ -376,3 +376,5 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
EntryClass = MixtralForCausalLM
|
||||||
|
|||||||
@@ -258,3 +258,5 @@ class QWenLMHeadModel(nn.Module):
|
|||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
EntryClass = QWenLMHeadModel
|
||||||
|
|||||||
Reference in New Issue
Block a user