Support pinning adapter via server args. (#9249)
This commit is contained in:
@@ -153,7 +153,9 @@ class ServerArgs:
|
||||
enable_lora: Optional[bool] = None
|
||||
max_lora_rank: Optional[int] = None
|
||||
lora_target_modules: Optional[Union[set[str], List[str]]] = None
|
||||
lora_paths: Optional[Union[dict[str, str], dict[str, LoRARef], List[str]]] = None
|
||||
lora_paths: Optional[
|
||||
Union[dict[str, str], List[dict[str, str]], List[str], List[LoRARef]]
|
||||
] = None
|
||||
max_loaded_loras: Optional[int] = None
|
||||
max_loras_per_batch: int = 8
|
||||
lora_backend: str = "triton"
|
||||
@@ -1319,7 +1321,7 @@ class ServerArgs:
|
||||
nargs="*",
|
||||
default=None,
|
||||
action=LoRAPathAction,
|
||||
help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}.",
|
||||
help='The list of LoRA adapters to load. Each adapter must be specified in one of the following formats: <PATH> | <NAME>=<PATH> | JSON with schema {"lora_name":str,"lora_path":str,"pinned":bool}',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-loras-per-batch",
|
||||
@@ -2086,28 +2088,42 @@ class ServerArgs:
|
||||
)
|
||||
|
||||
if self.enable_lora:
|
||||
# Normalize lora_paths to a dictionary if it is a list.
|
||||
# TODO (lifuhuang): support specifying pinned adapters in server_args.
|
||||
if isinstance(self.lora_paths, list):
|
||||
lora_paths = self.lora_paths
|
||||
self.lora_paths = {}
|
||||
self.lora_paths = []
|
||||
for lora_path in lora_paths:
|
||||
if "=" in lora_path:
|
||||
name, path = lora_path.split("=", 1)
|
||||
self.lora_paths[name] = LoRARef(
|
||||
lora_name=name, lora_path=path, pinned=False
|
||||
if isinstance(lora_path, str):
|
||||
if "=" in lora_path:
|
||||
name, path = lora_path.split("=", 1)
|
||||
lora_ref = LoRARef(
|
||||
lora_name=name, lora_path=path, pinned=False
|
||||
)
|
||||
else:
|
||||
lora_ref = LoRARef(
|
||||
lora_name=lora_path, lora_path=lora_path, pinned=False
|
||||
)
|
||||
elif isinstance(lora_path, dict):
|
||||
assert (
|
||||
"lora_name" in lora_path and "lora_path" in lora_path
|
||||
), f"When providing LoRA paths as a list of dict, each dict should contain 'lora_name' and 'lora_path' keys. Got: {lora_path}"
|
||||
lora_ref = LoRARef(
|
||||
lora_name=lora_path["lora_name"],
|
||||
lora_path=lora_path["lora_path"],
|
||||
pinned=lora_path.get("pinned", False),
|
||||
)
|
||||
else:
|
||||
self.lora_paths[lora_path] = LoRARef(
|
||||
lora_name=lora_path, lora_path=lora_path, pinned=False
|
||||
raise ValueError(
|
||||
f"Invalid type for item in --lora-paths list: {type(lora_path)}. "
|
||||
"Expected a string or a dictionary."
|
||||
)
|
||||
self.lora_paths.append(lora_ref)
|
||||
elif isinstance(self.lora_paths, dict):
|
||||
self.lora_paths = {
|
||||
k: LoRARef(lora_name=k, lora_path=v, pinned=False)
|
||||
self.lora_paths = [
|
||||
LoRARef(lora_name=k, lora_path=v, pinned=False)
|
||||
for k, v in self.lora_paths.items()
|
||||
}
|
||||
]
|
||||
elif self.lora_paths is None:
|
||||
self.lora_paths = {}
|
||||
self.lora_paths = []
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid type for --lora-paths: {type(self.lora_paths)}. "
|
||||
@@ -2134,9 +2150,7 @@ class ServerArgs:
|
||||
"max_loaded_loras should be greater than or equal to max_loras_per_batch. "
|
||||
f"max_loaded_loras={self.max_loaded_loras}, max_loras_per_batch={self.max_loras_per_batch}"
|
||||
)
|
||||
assert (
|
||||
not self.lora_paths or len(self.lora_paths) <= self.max_loaded_loras
|
||||
), (
|
||||
assert len(self.lora_paths) <= self.max_loaded_loras, (
|
||||
"The number of LoRA paths should not exceed max_loaded_loras. "
|
||||
f"max_loaded_loras={self.max_loaded_loras}, lora_paths={len(self.lora_paths)}"
|
||||
)
|
||||
@@ -2357,13 +2371,22 @@ class PortArgs:
|
||||
|
||||
class LoRAPathAction(argparse.Action):
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
setattr(namespace, self.dest, {})
|
||||
for lora_path in values:
|
||||
if "=" in lora_path:
|
||||
name, path = lora_path.split("=", 1)
|
||||
getattr(namespace, self.dest)[name] = path
|
||||
else:
|
||||
getattr(namespace, self.dest)[lora_path] = lora_path
|
||||
lora_paths = []
|
||||
if values:
|
||||
assert isinstance(values, list), "Expected a list of LoRA paths."
|
||||
for lora_path in values:
|
||||
lora_path = lora_path.strip()
|
||||
if lora_path.startswith("{") and lora_path.endswith("}"):
|
||||
obj = json.loads(lora_path)
|
||||
assert "lora_path" in obj and "lora_name" in obj, (
|
||||
f"{repr(lora_path)} looks like a JSON str, "
|
||||
"but it does not contain 'lora_name' and 'lora_path' keys."
|
||||
)
|
||||
lora_paths.append(obj)
|
||||
else:
|
||||
lora_paths.append(lora_path)
|
||||
|
||||
setattr(namespace, self.dest, lora_paths)
|
||||
|
||||
|
||||
class DeprecatedAction(argparse.Action):
|
||||
|
||||
Reference in New Issue
Block a user