Support overlapped lora updates (#8213)
This commit is contained in:
@@ -14,10 +14,14 @@
|
||||
|
||||
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import Dict, List, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from sglang.srt.aio_rwlock import RWLock
|
||||
from sglang.srt.utils import ConcurrentCounter
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LoRARef:
|
||||
@@ -48,10 +52,11 @@ class LoRARef:
|
||||
|
||||
class LoRARegistry:
|
||||
"""
|
||||
The central registry to keep track of available LoRA adapters.
|
||||
The central registry to keep track of available LoRA adapters and ongoing LoRA requests.
|
||||
|
||||
TODO (lifuhuang): This registry is intended as the foundation for overlapped lora update. We decided
|
||||
to keep it in a separate PR to keep code review simple and to unblock the radix cache work.
|
||||
The `LoRARegistry` resides in the tokenizer manager process and acts as the single source of truth for all
|
||||
available LoRA adapters. It supports concurrent inference and dynamic adapter updates through a two-phase
|
||||
update / eventual consistency model between the tokenizer manager process and the scheduler processes.
|
||||
"""
|
||||
|
||||
def __init__(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
|
||||
@@ -62,8 +67,19 @@ class LoRARegistry:
|
||||
"Please file an issue if you see this error."
|
||||
)
|
||||
|
||||
# A read-write lock to ensure adapters loading / unloading operations are exclusive.
|
||||
# Please note that the counter increment/decrement operations are not synchronized through this
|
||||
# lock, as they are designed to be non-blocking and can be performed concurrently.
|
||||
self._registry_lock = RWLock()
|
||||
# A dictionary to hold LoRARef objects, mapping from LoRA name to LoRARef.
|
||||
self._registry: Dict[str, LoRARef] = dict(lora_paths or {})
|
||||
self._registry: Dict[str, LoRARef] = {}
|
||||
# Counters for ongoing requests, mapping from LoRA ID to ConcurrentCounter.
|
||||
self._counters: Dict[str, ConcurrentCounter] = {}
|
||||
|
||||
# Initialize the registry with provided LoRA paths, if present.
|
||||
if lora_paths:
|
||||
for lora_ref in lora_paths.values():
|
||||
self._register_adapter(lora_ref)
|
||||
|
||||
async def register(self, lora_ref: LoRARef):
|
||||
"""
|
||||
@@ -72,11 +88,8 @@ class LoRARegistry:
|
||||
Args:
|
||||
lora_ref (LoRARef): The LoRARef object to register.
|
||||
"""
|
||||
if lora_ref.lora_name in self._registry:
|
||||
raise ValueError(
|
||||
f"LoRA with name {lora_ref.lora_name} already exists. Loaded LoRAs: {self._registry.keys()}"
|
||||
)
|
||||
self._registry[lora_ref.lora_name] = lora_ref
|
||||
async with self._registry_lock.writer_lock:
|
||||
self._register_adapter(lora_ref)
|
||||
|
||||
async def unregister(self, lora_name: str) -> str:
|
||||
"""
|
||||
@@ -85,12 +98,14 @@ class LoRARegistry:
|
||||
Args:
|
||||
lora_name (str): The name of the LoRA model to unregister.
|
||||
"""
|
||||
lora_ref = self._registry.get(lora_name, None)
|
||||
if lora_ref is None:
|
||||
raise ValueError(
|
||||
f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}"
|
||||
)
|
||||
del self._registry[lora_name]
|
||||
async with self._registry_lock.writer_lock:
|
||||
lora_ref = self._registry.get(lora_name, None)
|
||||
if lora_ref is None:
|
||||
raise ValueError(
|
||||
f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}"
|
||||
)
|
||||
del self._registry[lora_name]
|
||||
del self._counters[lora_ref.lora_id]
|
||||
|
||||
return lora_ref.lora_id
|
||||
|
||||
@@ -98,27 +113,76 @@ class LoRARegistry:
|
||||
"""
|
||||
Queries registry for LoRA IDs based on LoRA names and start tracking the usage of the corresponding LoRA adapters
|
||||
by incrementing its counter.
|
||||
|
||||
TODO (lifuhuang): currently it only queries the registry and does not track the usage of LoRA adapters.
|
||||
"""
|
||||
|
||||
async def _acquire_single(name: str) -> str:
|
||||
def _lookup(name: str) -> str:
|
||||
lora_ref = self._registry.get(name, None)
|
||||
if lora_ref is None:
|
||||
raise ValueError(
|
||||
f"The following requested LoRA adapters are not loaded: {name}\n"
|
||||
f"Loaded adapters: {self._registry.keys()}."
|
||||
)
|
||||
# await self._counters[lora_ref.lora_id].increment()
|
||||
return lora_ref.lora_id
|
||||
|
||||
if isinstance(lora_name, str):
|
||||
lora_id = await _acquire_single(lora_name)
|
||||
return lora_id
|
||||
elif isinstance(lora_name, list):
|
||||
lora_ids = await asyncio.gather(
|
||||
*[_acquire_single(name) for name in lora_name]
|
||||
async with self._registry_lock.reader_lock:
|
||||
if isinstance(lora_name, str):
|
||||
lora_id = _lookup(lora_name)
|
||||
await self._counters[lora_id].increment(notify_all=False)
|
||||
return lora_id
|
||||
elif isinstance(lora_name, list):
|
||||
lora_ids = [_lookup(name) for name in lora_name]
|
||||
|
||||
# Increment the counters only after all IDs are looked up.
|
||||
await asyncio.gather(
|
||||
*[self._counters[id].increment(notify_all=False) for id in lora_ids]
|
||||
)
|
||||
return lora_ids
|
||||
else:
|
||||
raise TypeError(
|
||||
"lora_name must be either a string or a list of strings."
|
||||
)
|
||||
|
||||
async def release(self, lora_id: Union[str, List[str]]):
|
||||
"""
|
||||
Decrements the usage counter for a LoRA adapter, indicating that it is no longer in use.
|
||||
"""
|
||||
|
||||
async with self._registry_lock.reader_lock:
|
||||
if isinstance(lora_id, str):
|
||||
await self._counters[lora_id].decrement()
|
||||
elif isinstance(lora_id, list):
|
||||
await asyncio.gather(
|
||||
*[self._counters[id].decrement() for id in lora_id]
|
||||
)
|
||||
else:
|
||||
raise TypeError("lora_id must be either a string or a list of strings.")
|
||||
|
||||
async def wait_for_unload(self, lora_id: str):
|
||||
"""
|
||||
Waits until the usage counter for a LoRA adapter reaches zero, indicating that it is no longer in use.
|
||||
This is useful for ensuring that a LoRA adapter can be safely unloaded.
|
||||
|
||||
This method itself is not synchronized, which is safe because it should only be called during LoRA unloading,
|
||||
which itself is guaranteed to be sequential.
|
||||
"""
|
||||
assert (
|
||||
lora_id not in self._registry
|
||||
), "wait_for_unload should only be called after the LoRA adapter has been unregistered. "
|
||||
counter = self._counters.get(lora_id)
|
||||
if counter:
|
||||
# Wait until no requests are using this LoRA adapter.
|
||||
await counter.wait_for_zero()
|
||||
del self._counters[lora_id]
|
||||
|
||||
def _register_adapter(self, lora_ref: LoRARef):
|
||||
"""
|
||||
Internal helper method to register a LoRA adapter.
|
||||
"""
|
||||
|
||||
if lora_ref.lora_name in self._registry:
|
||||
raise ValueError(
|
||||
f"LoRA with name {lora_ref.lora_name} already exists. Loaded LoRAs: {self._registry.keys()}"
|
||||
)
|
||||
return lora_ids
|
||||
else:
|
||||
raise TypeError("lora_name must be either a string or a list of strings.")
|
||||
self._registry[lora_ref.lora_name] = lora_ref
|
||||
self._counters[lora_ref.lora_id] = ConcurrentCounter()
|
||||
return lora_ref
|
||||
|
||||
Reference in New Issue
Block a user