Support pinning adapter via server args. (#9249)
This commit is contained in:
@@ -55,7 +55,7 @@ class LoRAManager:
|
||||
tp_rank: int = 0,
|
||||
max_lora_rank: Optional[int] = None,
|
||||
target_modules: Optional[Iterable[str]] = None,
|
||||
lora_paths: Optional[Dict[str, LoRARef]] = None,
|
||||
lora_paths: Optional[List[LoRARef]] = None,
|
||||
):
|
||||
self.base_model: torch.nn.Module = base_model
|
||||
self.base_hf_config: AutoConfig = base_hf_config
|
||||
@@ -370,7 +370,7 @@ class LoRAManager:
|
||||
self,
|
||||
max_lora_rank: Optional[int] = None,
|
||||
target_modules: Optional[Iterable[str]] = None,
|
||||
lora_paths: Optional[Dict[str, LoRARef]] = None,
|
||||
lora_paths: Optional[List[LoRARef]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the internal (mutable) state of the LoRAManager.
|
||||
@@ -392,7 +392,7 @@ class LoRAManager:
|
||||
self.init_memory_pool()
|
||||
self.update_lora_info()
|
||||
|
||||
def init_lora_adapters(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
|
||||
def init_lora_adapters(self, lora_paths: Optional[List[LoRARef]] = None):
|
||||
# Configs of all active LoRA adapters, indexed by LoRA ID.
|
||||
self.configs: Dict[str, LoRAConfig] = {}
|
||||
|
||||
@@ -406,7 +406,7 @@ class LoRAManager:
|
||||
self.num_pinned_loras: int = 0
|
||||
|
||||
if lora_paths:
|
||||
for lora_ref in lora_paths.values():
|
||||
for lora_ref in lora_paths:
|
||||
result = self.load_lora_adapter(lora_ref)
|
||||
if not result.success:
|
||||
raise RuntimeError(
|
||||
|
||||
@@ -59,9 +59,9 @@ class LoRARegistry:
|
||||
update / eventual consistency model between the tokenizer manager process and the scheduler processes.
|
||||
"""
|
||||
|
||||
def __init__(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
|
||||
def __init__(self, lora_paths: Optional[List[LoRARef]] = None):
|
||||
assert lora_paths is None or all(
|
||||
isinstance(lora, LoRARef) for lora in lora_paths.values()
|
||||
isinstance(lora, LoRARef) for lora in lora_paths
|
||||
), (
|
||||
"server_args.lora_paths should have been normalized to LoRARef objects during server initialization. "
|
||||
"Please file an issue if you see this error."
|
||||
@@ -78,7 +78,7 @@ class LoRARegistry:
|
||||
|
||||
# Initialize the registry with provided LoRA paths, if present.
|
||||
if lora_paths:
|
||||
for lora_ref in lora_paths.values():
|
||||
for lora_ref in lora_paths:
|
||||
self._register_adapter(lora_ref)
|
||||
|
||||
async def register(self, lora_ref: LoRARef):
|
||||
|
||||
Reference in New Issue
Block a user