Support pinning adapter via server args. (#9249)
This commit is contained in:
@@ -55,7 +55,7 @@ class LoRAManager:
|
||||
tp_rank: int = 0,
|
||||
max_lora_rank: Optional[int] = None,
|
||||
target_modules: Optional[Iterable[str]] = None,
|
||||
lora_paths: Optional[Dict[str, LoRARef]] = None,
|
||||
lora_paths: Optional[List[LoRARef]] = None,
|
||||
):
|
||||
self.base_model: torch.nn.Module = base_model
|
||||
self.base_hf_config: AutoConfig = base_hf_config
|
||||
@@ -370,7 +370,7 @@ class LoRAManager:
|
||||
self,
|
||||
max_lora_rank: Optional[int] = None,
|
||||
target_modules: Optional[Iterable[str]] = None,
|
||||
lora_paths: Optional[Dict[str, LoRARef]] = None,
|
||||
lora_paths: Optional[List[LoRARef]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the internal (mutable) state of the LoRAManager.
|
||||
@@ -392,7 +392,7 @@ class LoRAManager:
|
||||
self.init_memory_pool()
|
||||
self.update_lora_info()
|
||||
|
||||
def init_lora_adapters(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
|
||||
def init_lora_adapters(self, lora_paths: Optional[List[LoRARef]] = None):
|
||||
# Configs of all active LoRA adapters, indexed by LoRA ID.
|
||||
self.configs: Dict[str, LoRAConfig] = {}
|
||||
|
||||
@@ -406,7 +406,7 @@ class LoRAManager:
|
||||
self.num_pinned_loras: int = 0
|
||||
|
||||
if lora_paths:
|
||||
for lora_ref in lora_paths.values():
|
||||
for lora_ref in lora_paths:
|
||||
result = self.load_lora_adapter(lora_ref)
|
||||
if not result.success:
|
||||
raise RuntimeError(
|
||||
|
||||
@@ -59,9 +59,9 @@ class LoRARegistry:
|
||||
update / eventual consistency model between the tokenizer manager process and the scheduler processes.
|
||||
"""
|
||||
|
||||
def __init__(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
|
||||
def __init__(self, lora_paths: Optional[List[LoRARef]] = None):
|
||||
assert lora_paths is None or all(
|
||||
isinstance(lora, LoRARef) for lora in lora_paths.values()
|
||||
isinstance(lora, LoRARef) for lora in lora_paths
|
||||
), (
|
||||
"server_args.lora_paths should have been normalized to LoRARef objects during server initialization. "
|
||||
"Please file an issue if you see this error."
|
||||
@@ -78,7 +78,7 @@ class LoRARegistry:
|
||||
|
||||
# Initialize the registry with provided LoRA paths, if present.
|
||||
if lora_paths:
|
||||
for lora_ref in lora_paths.values():
|
||||
for lora_ref in lora_paths:
|
||||
self._register_adapter(lora_ref)
|
||||
|
||||
async def register(self, lora_ref: LoRARef):
|
||||
|
||||
@@ -298,7 +298,7 @@ class TokenizerManager:
|
||||
# The registry dynamically updates as adapters are loaded / unloaded during runtime. It
|
||||
# serves as the source of truth for available adapters and maps user-friendly LoRA names
|
||||
# to internally used unique LoRA IDs.
|
||||
self.lora_registry = LoRARegistry(self.server_args.lora_paths or {})
|
||||
self.lora_registry = LoRARegistry(self.server_args.lora_paths)
|
||||
# Lock to serialize LoRA update operations.
|
||||
# Please note that, unlike `model_update_lock`, this does not block inference, allowing
|
||||
# LoRA updates and inference to overlap.
|
||||
|
||||
@@ -153,7 +153,9 @@ class ServerArgs:
|
||||
enable_lora: Optional[bool] = None
|
||||
max_lora_rank: Optional[int] = None
|
||||
lora_target_modules: Optional[Union[set[str], List[str]]] = None
|
||||
lora_paths: Optional[Union[dict[str, str], dict[str, LoRARef], List[str]]] = None
|
||||
lora_paths: Optional[
|
||||
Union[dict[str, str], List[dict[str, str]], List[str], List[LoRARef]]
|
||||
] = None
|
||||
max_loaded_loras: Optional[int] = None
|
||||
max_loras_per_batch: int = 8
|
||||
lora_backend: str = "triton"
|
||||
@@ -1319,7 +1321,7 @@ class ServerArgs:
|
||||
nargs="*",
|
||||
default=None,
|
||||
action=LoRAPathAction,
|
||||
help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}.",
|
||||
help='The list of LoRA adapters to load. Each adapter must be specified in one of the following formats: <PATH> | <NAME>=<PATH> | JSON with schema {"lora_name":str,"lora_path":str,"pinned":bool}',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-loras-per-batch",
|
||||
@@ -2086,28 +2088,42 @@ class ServerArgs:
|
||||
)
|
||||
|
||||
if self.enable_lora:
|
||||
# Normalize lora_paths to a dictionary if it is a list.
|
||||
# TODO (lifuhuang): support specifying pinned adapters in server_args.
|
||||
if isinstance(self.lora_paths, list):
|
||||
lora_paths = self.lora_paths
|
||||
self.lora_paths = {}
|
||||
self.lora_paths = []
|
||||
for lora_path in lora_paths:
|
||||
if "=" in lora_path:
|
||||
name, path = lora_path.split("=", 1)
|
||||
self.lora_paths[name] = LoRARef(
|
||||
lora_name=name, lora_path=path, pinned=False
|
||||
if isinstance(lora_path, str):
|
||||
if "=" in lora_path:
|
||||
name, path = lora_path.split("=", 1)
|
||||
lora_ref = LoRARef(
|
||||
lora_name=name, lora_path=path, pinned=False
|
||||
)
|
||||
else:
|
||||
lora_ref = LoRARef(
|
||||
lora_name=lora_path, lora_path=lora_path, pinned=False
|
||||
)
|
||||
elif isinstance(lora_path, dict):
|
||||
assert (
|
||||
"lora_name" in lora_path and "lora_path" in lora_path
|
||||
), f"When providing LoRA paths as a list of dict, each dict should contain 'lora_name' and 'lora_path' keys. Got: {lora_path}"
|
||||
lora_ref = LoRARef(
|
||||
lora_name=lora_path["lora_name"],
|
||||
lora_path=lora_path["lora_path"],
|
||||
pinned=lora_path.get("pinned", False),
|
||||
)
|
||||
else:
|
||||
self.lora_paths[lora_path] = LoRARef(
|
||||
lora_name=lora_path, lora_path=lora_path, pinned=False
|
||||
raise ValueError(
|
||||
f"Invalid type for item in --lora-paths list: {type(lora_path)}. "
|
||||
"Expected a string or a dictionary."
|
||||
)
|
||||
self.lora_paths.append(lora_ref)
|
||||
elif isinstance(self.lora_paths, dict):
|
||||
self.lora_paths = {
|
||||
k: LoRARef(lora_name=k, lora_path=v, pinned=False)
|
||||
self.lora_paths = [
|
||||
LoRARef(lora_name=k, lora_path=v, pinned=False)
|
||||
for k, v in self.lora_paths.items()
|
||||
}
|
||||
]
|
||||
elif self.lora_paths is None:
|
||||
self.lora_paths = {}
|
||||
self.lora_paths = []
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid type for --lora-paths: {type(self.lora_paths)}. "
|
||||
@@ -2134,9 +2150,7 @@ class ServerArgs:
|
||||
"max_loaded_loras should be greater than or equal to max_loras_per_batch. "
|
||||
f"max_loaded_loras={self.max_loaded_loras}, max_loras_per_batch={self.max_loras_per_batch}"
|
||||
)
|
||||
assert (
|
||||
not self.lora_paths or len(self.lora_paths) <= self.max_loaded_loras
|
||||
), (
|
||||
assert len(self.lora_paths) <= self.max_loaded_loras, (
|
||||
"The number of LoRA paths should not exceed max_loaded_loras. "
|
||||
f"max_loaded_loras={self.max_loaded_loras}, lora_paths={len(self.lora_paths)}"
|
||||
)
|
||||
@@ -2357,13 +2371,22 @@ class PortArgs:
|
||||
|
||||
class LoRAPathAction(argparse.Action):
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
setattr(namespace, self.dest, {})
|
||||
for lora_path in values:
|
||||
if "=" in lora_path:
|
||||
name, path = lora_path.split("=", 1)
|
||||
getattr(namespace, self.dest)[name] = path
|
||||
else:
|
||||
getattr(namespace, self.dest)[lora_path] = lora_path
|
||||
lora_paths = []
|
||||
if values:
|
||||
assert isinstance(values, list), "Expected a list of LoRA paths."
|
||||
for lora_path in values:
|
||||
lora_path = lora_path.strip()
|
||||
if lora_path.startswith("{") and lora_path.endswith("}"):
|
||||
obj = json.loads(lora_path)
|
||||
assert "lora_path" in obj and "lora_name" in obj, (
|
||||
f"{repr(lora_path)} looks like a JSON str, "
|
||||
"but it does not contain 'lora_name' and 'lora_path' keys."
|
||||
)
|
||||
lora_paths.append(obj)
|
||||
else:
|
||||
lora_paths.append(lora_path)
|
||||
|
||||
setattr(namespace, self.dest, lora_paths)
|
||||
|
||||
|
||||
class DeprecatedAction(argparse.Action):
|
||||
|
||||
@@ -491,7 +491,7 @@ class SRTRunner:
|
||||
tp_size: int = 1,
|
||||
model_impl: str = "auto",
|
||||
port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
|
||||
lora_paths: List[str] = None,
|
||||
lora_paths: Optional[Union[List[str], List[dict[str, str]]]] = None,
|
||||
max_loras_per_batch: int = 4,
|
||||
attention_backend: Optional[str] = None,
|
||||
prefill_attention_backend: Optional[str] = None,
|
||||
|
||||
Reference in New Issue
Block a user