Support pinning adapter via server args. (#9249)

This commit is contained in:
Lifu Huang
2025-08-20 16:25:01 -07:00
committed by GitHub
parent 24eaebeb4b
commit b0980af89f
8 changed files with 162 additions and 55 deletions

View File

@@ -55,7 +55,7 @@ class LoRAManager:
tp_rank: int = 0,
max_lora_rank: Optional[int] = None,
target_modules: Optional[Iterable[str]] = None,
lora_paths: Optional[Dict[str, LoRARef]] = None,
lora_paths: Optional[List[LoRARef]] = None,
):
self.base_model: torch.nn.Module = base_model
self.base_hf_config: AutoConfig = base_hf_config
@@ -370,7 +370,7 @@ class LoRAManager:
self,
max_lora_rank: Optional[int] = None,
target_modules: Optional[Iterable[str]] = None,
lora_paths: Optional[Dict[str, LoRARef]] = None,
lora_paths: Optional[List[LoRARef]] = None,
):
"""
Initialize the internal (mutable) state of the LoRAManager.
@@ -392,7 +392,7 @@ class LoRAManager:
self.init_memory_pool()
self.update_lora_info()
def init_lora_adapters(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
def init_lora_adapters(self, lora_paths: Optional[List[LoRARef]] = None):
# Configs of all active LoRA adapters, indexed by LoRA ID.
self.configs: Dict[str, LoRAConfig] = {}
@@ -406,7 +406,7 @@ class LoRAManager:
self.num_pinned_loras: int = 0
if lora_paths:
for lora_ref in lora_paths.values():
for lora_ref in lora_paths:
result = self.load_lora_adapter(lora_ref)
if not result.success:
raise RuntimeError(

View File

@@ -59,9 +59,9 @@ class LoRARegistry:
update / eventual consistency model between the tokenizer manager process and the scheduler processes.
"""
def __init__(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
def __init__(self, lora_paths: Optional[List[LoRARef]] = None):
assert lora_paths is None or all(
isinstance(lora, LoRARef) for lora in lora_paths.values()
isinstance(lora, LoRARef) for lora in lora_paths
), (
"server_args.lora_paths should have been normalized to LoRARef objects during server initialization. "
"Please file an issue if you see this error."
@@ -78,7 +78,7 @@ class LoRARegistry:
# Initialize the registry with provided LoRA paths, if present.
if lora_paths:
for lora_ref in lora_paths.values():
for lora_ref in lora_paths:
self._register_adapter(lora_ref)
async def register(self, lora_ref: LoRARef):