Support overlapped lora updates (#8213)

This commit is contained in:
Lifu Huang
2025-07-27 13:00:44 -07:00
committed by GitHub
parent 95217a9b4d
commit df90645525
4 changed files with 204 additions and 35 deletions

View File

@@ -14,10 +14,14 @@
import asyncio
from collections import defaultdict
from dataclasses import dataclass, field, fields
from typing import Dict, List, Optional, Union
from uuid import uuid4
from sglang.srt.aio_rwlock import RWLock
from sglang.srt.utils import ConcurrentCounter
@dataclass(frozen=True)
class LoRARef:
@@ -48,10 +52,11 @@ class LoRARef:
class LoRARegistry:
"""
The central registry to keep track of available LoRA adapters.
The central registry to keep track of available LoRA adapters and ongoing LoRA requests.
TODO (lifuhuang): This registry is intended as the foundation for overlapped lora update. We decided
to keep it in a separate PR to keep code review simple and to unblock the radix cache work.
The `LoRARegistry` resides in the tokenizer manager process and acts as the single source of truth for all
available LoRA adapters. It supports concurrent inference and dynamic adapter updates through a two-phase
update / eventual consistency model between the tokenizer manager process and the scheduler processes.
"""
def __init__(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
@@ -62,8 +67,19 @@ class LoRARegistry:
"Please file an issue if you see this error."
)
# A read-write lock to ensure adapters loading / unloading operations are exclusive.
# Please note that the counter increment/decrement operations are not synchronized through this
# lock, as they are designed to be non-blocking and can be performed concurrently.
self._registry_lock = RWLock()
# A dictionary to hold LoRARef objects, mapping from LoRA name to LoRARef.
self._registry: Dict[str, LoRARef] = dict(lora_paths or {})
self._registry: Dict[str, LoRARef] = {}
# Counters for ongoing requests, mapping from LoRA ID to ConcurrentCounter.
self._counters: Dict[str, ConcurrentCounter] = {}
# Initialize the registry with provided LoRA paths, if present.
if lora_paths:
for lora_ref in lora_paths.values():
self._register_adapter(lora_ref)
async def register(self, lora_ref: LoRARef):
"""
@@ -72,11 +88,8 @@ class LoRARegistry:
Args:
lora_ref (LoRARef): The LoRARef object to register.
"""
if lora_ref.lora_name in self._registry:
raise ValueError(
f"LoRA with name {lora_ref.lora_name} already exists. Loaded LoRAs: {self._registry.keys()}"
)
self._registry[lora_ref.lora_name] = lora_ref
async with self._registry_lock.writer_lock:
self._register_adapter(lora_ref)
async def unregister(self, lora_name: str) -> str:
"""
@@ -85,12 +98,14 @@ class LoRARegistry:
Args:
lora_name (str): The name of the LoRA model to unregister.
"""
lora_ref = self._registry.get(lora_name, None)
if lora_ref is None:
raise ValueError(
f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}"
)
del self._registry[lora_name]
async with self._registry_lock.writer_lock:
lora_ref = self._registry.get(lora_name, None)
if lora_ref is None:
raise ValueError(
f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}"
)
del self._registry[lora_name]
del self._counters[lora_ref.lora_id]
return lora_ref.lora_id
@@ -98,27 +113,76 @@ class LoRARegistry:
"""
Queries registry for LoRA IDs based on LoRA names and start tracking the usage of the corresponding LoRA adapters
by incrementing its counter.
TODO (lifuhuang): currently it only queries the registry and does not track the usage of LoRA adapters.
"""
async def _acquire_single(name: str) -> str:
def _lookup(name: str) -> str:
lora_ref = self._registry.get(name, None)
if lora_ref is None:
raise ValueError(
f"The following requested LoRA adapters are not loaded: {name}\n"
f"Loaded adapters: {self._registry.keys()}."
)
# await self._counters[lora_ref.lora_id].increment()
return lora_ref.lora_id
if isinstance(lora_name, str):
lora_id = await _acquire_single(lora_name)
return lora_id
elif isinstance(lora_name, list):
lora_ids = await asyncio.gather(
*[_acquire_single(name) for name in lora_name]
async with self._registry_lock.reader_lock:
if isinstance(lora_name, str):
lora_id = _lookup(lora_name)
await self._counters[lora_id].increment(notify_all=False)
return lora_id
elif isinstance(lora_name, list):
lora_ids = [_lookup(name) for name in lora_name]
# Increment the counters only after all IDs are looked up.
await asyncio.gather(
*[self._counters[id].increment(notify_all=False) for id in lora_ids]
)
return lora_ids
else:
raise TypeError(
"lora_name must be either a string or a list of strings."
)
async def release(self, lora_id: Union[str, List[str]]):
"""
Decrements the usage counter for a LoRA adapter, indicating that it is no longer in use.
"""
async with self._registry_lock.reader_lock:
if isinstance(lora_id, str):
await self._counters[lora_id].decrement()
elif isinstance(lora_id, list):
await asyncio.gather(
*[self._counters[id].decrement() for id in lora_id]
)
else:
raise TypeError("lora_id must be either a string or a list of strings.")
async def wait_for_unload(self, lora_id: str):
"""
Waits until the usage counter for a LoRA adapter reaches zero, indicating that it is no longer in use.
This is useful for ensuring that a LoRA adapter can be safely unloaded.
This method itself is not synchronized, which is safe because it should only be called during LoRA unloading,
which itself is guaranteed to be sequential.
"""
assert (
lora_id not in self._registry
), "wait_for_unload should only be called after the LoRA adapter has been unregistered. "
counter = self._counters.get(lora_id)
if counter:
# Wait until no requests are using this LoRA adapter.
await counter.wait_for_zero()
del self._counters[lora_id]
def _register_adapter(self, lora_ref: LoRARef):
"""
Internal helper method to register a LoRA adapter.
"""
if lora_ref.lora_name in self._registry:
raise ValueError(
f"LoRA with name {lora_ref.lora_name} already exists. Loaded LoRAs: {self._registry.keys()}"
)
return lora_ids
else:
raise TypeError("lora_name must be either a string or a list of strings.")
self._registry[lora_ref.lora_name] = lora_ref
self._counters[lora_ref.lora_id] = ConcurrentCounter()
return lora_ref