From 92e2d74fd0426afb98621465d6574ad2a823e842 Mon Sep 17 00:00:00 2001 From: Qubitium <417764+Qubitium@users.noreply.github.com> Date: Wed, 13 Mar 2024 13:02:48 +0800 Subject: [PATCH] Fix env (docker) compat due to __file__ usage (#288) --- python/sglang/srt/managers/router/model_runner.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/router/model_runner.py index b63cc6b9d..ac98a85f0 100644 --- a/python/sglang/srt/managers/router/model_runner.py +++ b/python/sglang/srt/managers/router/model_runner.py @@ -4,6 +4,7 @@ import inspect from dataclasses import dataclass from functools import lru_cache from pathlib import Path +import importlib.resources import numpy as np import torch @@ -31,10 +32,12 @@ global_server_args_dict: dict = None @lru_cache() def import_model_classes(): model_arch_name_to_cls = {} - for module_path in (Path(sglang.__file__).parent / "srt" / "models").glob("*.py"): - module = importlib.import_module(f"sglang.srt.models.{module_path.stem}") - if hasattr(module, "EntryClass"): - model_arch_name_to_cls[module.EntryClass.__name__] = module.EntryClass + for f in importlib.resources.files("sglang.srt.models").iterdir(): + if f.name.endswith(".py"): + module_name = Path(f.name).with_suffix('') + module = importlib.import_module(f"sglang.srt.models.{module_name}") + if hasattr(module, "EntryClass"): + model_arch_name_to_cls[module.EntryClass.__name__] = module.EntryClass return model_arch_name_to_cls