Support pinning adapter via server args. (#9249)

This commit is contained in:
Lifu Huang
2025-08-20 16:25:01 -07:00
committed by GitHub
parent 24eaebeb4b
commit b0980af89f
8 changed files with 162 additions and 55 deletions

View File

@@ -153,7 +153,9 @@ class ServerArgs:
enable_lora: Optional[bool] = None
max_lora_rank: Optional[int] = None
lora_target_modules: Optional[Union[set[str], List[str]]] = None
lora_paths: Optional[Union[dict[str, str], dict[str, LoRARef], List[str]]] = None
lora_paths: Optional[
Union[dict[str, str], List[dict[str, str]], List[str], List[LoRARef]]
] = None
max_loaded_loras: Optional[int] = None
max_loras_per_batch: int = 8
lora_backend: str = "triton"
@@ -1319,7 +1321,7 @@ class ServerArgs:
nargs="*",
default=None,
action=LoRAPathAction,
help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}.",
help='The list of LoRA adapters to load. Each adapter must be specified in one of the following formats: <PATH> | <NAME>=<PATH> | JSON with schema {"lora_name":str,"lora_path":str,"pinned":bool}',
)
parser.add_argument(
"--max-loras-per-batch",
@@ -2086,28 +2088,42 @@ class ServerArgs:
)
if self.enable_lora:
# Normalize lora_paths to a dictionary if it is a list.
# TODO (lifuhuang): support specifying pinned adapters in server_args.
if isinstance(self.lora_paths, list):
lora_paths = self.lora_paths
self.lora_paths = {}
self.lora_paths = []
for lora_path in lora_paths:
if "=" in lora_path:
name, path = lora_path.split("=", 1)
self.lora_paths[name] = LoRARef(
lora_name=name, lora_path=path, pinned=False
if isinstance(lora_path, str):
if "=" in lora_path:
name, path = lora_path.split("=", 1)
lora_ref = LoRARef(
lora_name=name, lora_path=path, pinned=False
)
else:
lora_ref = LoRARef(
lora_name=lora_path, lora_path=lora_path, pinned=False
)
elif isinstance(lora_path, dict):
assert (
"lora_name" in lora_path and "lora_path" in lora_path
), f"When providing LoRA paths as a list of dict, each dict should contain 'lora_name' and 'lora_path' keys. Got: {lora_path}"
lora_ref = LoRARef(
lora_name=lora_path["lora_name"],
lora_path=lora_path["lora_path"],
pinned=lora_path.get("pinned", False),
)
else:
self.lora_paths[lora_path] = LoRARef(
lora_name=lora_path, lora_path=lora_path, pinned=False
raise ValueError(
f"Invalid type for item in --lora-paths list: {type(lora_path)}. "
"Expected a string or a dictionary."
)
self.lora_paths.append(lora_ref)
elif isinstance(self.lora_paths, dict):
self.lora_paths = {
k: LoRARef(lora_name=k, lora_path=v, pinned=False)
self.lora_paths = [
LoRARef(lora_name=k, lora_path=v, pinned=False)
for k, v in self.lora_paths.items()
}
]
elif self.lora_paths is None:
self.lora_paths = {}
self.lora_paths = []
else:
raise ValueError(
f"Invalid type for --lora-paths: {type(self.lora_paths)}. "
@@ -2134,9 +2150,7 @@ class ServerArgs:
"max_loaded_loras should be greater than or equal to max_loras_per_batch. "
f"max_loaded_loras={self.max_loaded_loras}, max_loras_per_batch={self.max_loras_per_batch}"
)
assert (
not self.lora_paths or len(self.lora_paths) <= self.max_loaded_loras
), (
assert len(self.lora_paths) <= self.max_loaded_loras, (
"The number of LoRA paths should not exceed max_loaded_loras. "
f"max_loaded_loras={self.max_loaded_loras}, lora_paths={len(self.lora_paths)}"
)
@@ -2357,13 +2371,22 @@ class PortArgs:
class LoRAPathAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, {})
for lora_path in values:
if "=" in lora_path:
name, path = lora_path.split("=", 1)
getattr(namespace, self.dest)[name] = path
else:
getattr(namespace, self.dest)[lora_path] = lora_path
lora_paths = []
if values:
assert isinstance(values, list), "Expected a list of LoRA paths."
for lora_path in values:
lora_path = lora_path.strip()
if lora_path.startswith("{") and lora_path.endswith("}"):
obj = json.loads(lora_path)
assert "lora_path" in obj and "lora_name" in obj, (
f"{repr(lora_path)} looks like a JSON str, "
"but it does not contain 'lora_name' and 'lora_path' keys."
)
lora_paths.append(obj)
else:
lora_paths.append(lora_path)
setattr(namespace, self.dest, lora_paths)
class DeprecatedAction(argparse.Action):