diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/router/model_runner.py index 6f52db83d..afbe03abe 100644 --- a/python/sglang/srt/managers/router/model_runner.py +++ b/python/sglang/srt/managers/router/model_runner.py @@ -18,6 +18,9 @@ from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.model_loader import _set_default_torch_dtype from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel +import importlib +import pkgutil + import sglang QUANTIONCONFIG_MAPPING = {"awq": AWQConfig, "gptq": GPTQConfig, "marlin": MarlinConfig} @@ -32,10 +35,11 @@ global_server_args_dict: dict = None @lru_cache() def import_model_classes(): model_arch_name_to_cls = {} - for f in importlib.resources.files("sglang.srt.models").iterdir(): - if f.name.endswith(".py"): - module_name = Path(f.name).with_suffix('') - module = importlib.import_module(f"sglang.srt.models.{module_name}") + package_name = "sglang.srt.models" + package = importlib.import_module(package_name) + for finder, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + '.'): + if not ispkg: + module = importlib.import_module(name) if hasattr(module, "EntryClass"): model_arch_name_to_cls[module.EntryClass.__name__] = module.EntryClass return model_arch_name_to_cls