Lazily import lora backends (#4225)
This commit is contained in:
@@ -1,23 +1,20 @@
|
|||||||
from .base_backend import BaseLoRABackend
|
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
||||||
from .flashinfer_backend import FlashInferLoRABackend
|
|
||||||
from .triton_backend import TritonLoRABackend
|
|
||||||
|
|
||||||
|
|
||||||
def get_backend_from_name(name: str) -> BaseLoRABackend:
|
def get_backend_from_name(name: str) -> BaseLoRABackend:
|
||||||
"""
|
"""
|
||||||
Get corresponding backend class from backend's name
|
Get corresponding backend class from backend's name
|
||||||
"""
|
"""
|
||||||
backend_mapping = {
|
if name == "triton":
|
||||||
"triton": TritonLoRABackend,
|
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
|
||||||
"flashinfer": FlashInferLoRABackend,
|
|
||||||
}
|
|
||||||
|
|
||||||
if name in backend_mapping:
|
return TritonLoRABackend
|
||||||
return backend_mapping[name]
|
elif name == "flashinfer":
|
||||||
|
from sglang.srt.lora.backend.flashinfer_backend import FlashInferLoRABackend
|
||||||
|
|
||||||
raise Exception(
|
return FlashInferLoRABackend
|
||||||
f"No supported lora backend called {name}. It should be one of {list(backend_mapping.keys())}"
|
else:
|
||||||
)
|
raise ValueError(f"Invalid backend: {name}")
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|||||||
Reference in New Issue
Block a user