Files
sglang/python/sglang/srt/lora/lora_registry.py
2025-08-03 00:02:23 -07:00

196 lines
7.7 KiB
Python

# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import asyncio
from collections import defaultdict
from dataclasses import dataclass, field, fields
from typing import Dict, List, Optional, Union
from uuid import uuid4
from sglang.srt.aio_rwlock import RWLock
from sglang.srt.utils import ConcurrentCounter
@dataclass(frozen=True)
class LoRARef:
"""
Reference record for a LoRA model.
This object guarantees a unique ``lora_id`` and may include ``lora_name`` and ``lora_path``. The ID
eliminates conflicts from reused LoRA names or paths and can be used to generate deterministic cache
keys (e.g., radix cache).
"""
lora_id: str = field(default_factory=lambda: uuid4().hex)
lora_name: Optional[str] = None
lora_path: Optional[str] = None
def __post_init__(self):
if self.lora_id is None:
raise ValueError("lora_id cannot be None")
def __str__(self) -> str:
parts = [
f"{f.name}={value}"
for f in fields(self)
if (value := getattr(self, f.name)) is not None
]
return f"{self.__class__.__name__}({', '.join(parts)})"
class LoRARegistry:
"""
The central registry to keep track of available LoRA adapters and ongoing LoRA requests.
The `LoRARegistry` resides in the tokenizer manager process and acts as the single source of truth for all
available LoRA adapters. It supports concurrent inference and dynamic adapter updates through a two-phase
update / eventual consistency model between the tokenizer manager process and the scheduler processes.
"""
def __init__(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
assert lora_paths is None or all(
isinstance(lora, LoRARef) for lora in lora_paths.values()
), (
"server_args.lora_paths should have been normalized to LoRARef objects during server initialization. "
"Please file an issue if you see this error."
)
# A read-write lock to ensure adapters loading / unloading operations are exclusive.
# Please note that the counter increment/decrement operations are not synchronized through this
# lock, as they are designed to be non-blocking and can be performed concurrently.
self._registry_lock = RWLock()
# A dictionary to hold LoRARef objects, mapping from LoRA name to LoRARef.
self._registry: Dict[str, LoRARef] = {}
# Counters for ongoing requests, mapping from LoRA ID to ConcurrentCounter.
self._counters: Dict[str, ConcurrentCounter] = {}
# Initialize the registry with provided LoRA paths, if present.
if lora_paths:
for lora_ref in lora_paths.values():
self._register_adapter(lora_ref)
async def register(self, lora_ref: LoRARef):
"""
Register a new LoRARef object in the registry.
Args:
lora_ref (LoRARef): The LoRARef object to register.
"""
async with self._registry_lock.writer_lock:
self._register_adapter(lora_ref)
async def unregister(self, lora_name: str) -> str:
"""
Unregister a LoRARef object from the registry and returns the removed LoRA ID.
Args:
lora_name (str): The name of the LoRA model to unregister.
"""
async with self._registry_lock.writer_lock:
lora_ref = self._registry.get(lora_name, None)
if lora_ref is None:
raise ValueError(
f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}"
)
del self._registry[lora_name]
del self._counters[lora_ref.lora_id]
return lora_ref.lora_id
async def acquire(self, lora_name: Union[str, List[str]]) -> Union[str, List[str]]:
"""
Queries registry for LoRA IDs based on LoRA names and start tracking the usage of the corresponding LoRA adapters
by incrementing its counter.
"""
def _lookup(name: str) -> str:
lora_ref = self._registry.get(name, None)
if lora_ref is None:
raise ValueError(
f"The following requested LoRA adapters are not loaded: {name}\n"
f"Loaded adapters: {self._registry.keys()}."
)
return lora_ref.lora_id
async with self._registry_lock.reader_lock:
if isinstance(lora_name, str):
lora_id = _lookup(lora_name)
await self._counters[lora_id].increment(notify_all=False)
return lora_id
elif isinstance(lora_name, list):
lora_ids = [_lookup(name) for name in lora_name]
# Increment the counters only after all IDs are looked up.
await asyncio.gather(
*[self._counters[id].increment(notify_all=False) for id in lora_ids]
)
return lora_ids
else:
raise TypeError(
"lora_name must be either a string or a list of strings."
)
async def release(self, lora_id: Union[str, List[str]]):
"""
Decrements the usage counter for a LoRA adapter, indicating that it is no longer in use.
"""
async with self._registry_lock.reader_lock:
if isinstance(lora_id, str):
await self._counters[lora_id].decrement()
elif isinstance(lora_id, list):
await asyncio.gather(
*[self._counters[id].decrement() for id in lora_id]
)
else:
raise TypeError("lora_id must be either a string or a list of strings.")
async def wait_for_unload(self, lora_id: str):
"""
Waits until the usage counter for a LoRA adapter reaches zero, indicating that it is no longer in use.
This is useful for ensuring that a LoRA adapter can be safely unloaded.
This method itself is not synchronized, which is safe because it should only be called during LoRA unloading,
which itself is guaranteed to be sequential.
"""
assert (
lora_id not in self._registry
), "wait_for_unload should only be called after the LoRA adapter has been unregistered. "
counter = self._counters.get(lora_id)
if counter:
# Wait until no requests are using this LoRA adapter.
await counter.wait_for_zero()
del self._counters[lora_id]
def _register_adapter(self, lora_ref: LoRARef):
"""
Internal helper method to register a LoRA adapter.
"""
if lora_ref.lora_name in self._registry:
raise ValueError(
f"LoRA with name {lora_ref.lora_name} already exists. Loaded LoRAs: {self._registry.keys()}"
)
self._registry[lora_ref.lora_name] = lora_ref
self._counters[lora_ref.lora_id] = ConcurrentCounter()
return lora_ref
@property
def num_registered_loras(self) -> int:
"""
Returns the total number of LoRA adapters currently registered.
"""
return len(self._registry)