diff --git a/python/sglang/srt/lora/backend/__init__.py b/python/sglang/srt/lora/backend/__init__.py index 07fe11d23..7b76f90e5 100644 --- a/python/sglang/srt/lora/backend/__init__.py +++ b/python/sglang/srt/lora/backend/__init__.py @@ -1,23 +1,20 @@ -from .base_backend import BaseLoRABackend -from .flashinfer_backend import FlashInferLoRABackend -from .triton_backend import TritonLoRABackend +from sglang.srt.lora.backend.base_backend import BaseLoRABackend def get_backend_from_name(name: str) -> BaseLoRABackend: """ Get corresponding backend class from backend's name """ - backend_mapping = { - "triton": TritonLoRABackend, - "flashinfer": FlashInferLoRABackend, - } + if name == "triton": + from sglang.srt.lora.backend.triton_backend import TritonLoRABackend - if name in backend_mapping: - return backend_mapping[name] + return TritonLoRABackend + elif name == "flashinfer": + from sglang.srt.lora.backend.flashinfer_backend import FlashInferLoRABackend - raise Exception( - f"No supported lora backend called {name}. It should be one of {list(backend_mapping.keys())}" - ) + return FlashInferLoRABackend + else: + raise ValueError(f"Invalid backend: {name}") __all__ = [