[Fix/Potential Bugs] Can not correctly import models in python/sglang/srt/models (#311)
This commit is contained in:
@@ -18,6 +18,9 @@ from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
|||||||
from vllm.model_executor.model_loader import _set_default_torch_dtype
|
from vllm.model_executor.model_loader import _set_default_torch_dtype
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
|
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import pkgutil
|
||||||
|
|
||||||
import sglang
|
import sglang
|
||||||
|
|
||||||
QUANTIONCONFIG_MAPPING = {"awq": AWQConfig, "gptq": GPTQConfig, "marlin": MarlinConfig}
|
QUANTIONCONFIG_MAPPING = {"awq": AWQConfig, "gptq": GPTQConfig, "marlin": MarlinConfig}
|
||||||
@@ -32,10 +35,11 @@ global_server_args_dict: dict = None
|
|||||||
@lru_cache()
|
@lru_cache()
|
||||||
def import_model_classes():
|
def import_model_classes():
|
||||||
model_arch_name_to_cls = {}
|
model_arch_name_to_cls = {}
|
||||||
for f in importlib.resources.files("sglang.srt.models").iterdir():
|
package_name = "sglang.srt.models"
|
||||||
if f.name.endswith(".py"):
|
package = importlib.import_module(package_name)
|
||||||
module_name = Path(f.name).with_suffix('')
|
for finder, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + '.'):
|
||||||
module = importlib.import_module(f"sglang.srt.models.{module_name}")
|
if not ispkg:
|
||||||
|
module = importlib.import_module(name)
|
||||||
if hasattr(module, "EntryClass"):
|
if hasattr(module, "EntryClass"):
|
||||||
model_arch_name_to_cls[module.EntryClass.__name__] = module.EntryClass
|
model_arch_name_to_cls[module.EntryClass.__name__] = module.EntryClass
|
||||||
return model_arch_name_to_cls
|
return model_arch_name_to_cls
|
||||||
|
|||||||
Reference in New Issue
Block a user