Support overlapped lora updates (#8213)
This commit is contained in:
@@ -14,10 +14,14 @@
|
||||
|
||||
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import Dict, List, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from sglang.srt.aio_rwlock import RWLock
|
||||
from sglang.srt.utils import ConcurrentCounter
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LoRARef:
|
||||
@@ -48,10 +52,11 @@ class LoRARef:
|
||||
|
||||
class LoRARegistry:
|
||||
"""
|
||||
The central registry to keep track of available LoRA adapters.
|
||||
The central registry to keep track of available LoRA adapters and ongoing LoRA requests.
|
||||
|
||||
TODO (lifuhuang): This registry is intended as the foundation for overlapped lora update. We decided
|
||||
to keep it in a separate PR to keep code review simple and to unblock the radix cache work.
|
||||
The `LoRARegistry` resides in the tokenizer manager process and acts as the single source of truth for all
|
||||
available LoRA adapters. It supports concurrent inference and dynamic adapter updates through a two-phase
|
||||
update / eventual consistency model between the tokenizer manager process and the scheduler processes.
|
||||
"""
|
||||
|
||||
def __init__(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
|
||||
@@ -62,8 +67,19 @@ class LoRARegistry:
|
||||
"Please file an issue if you see this error."
|
||||
)
|
||||
|
||||
# A read-write lock to ensure adapters loading / unloading operations are exclusive.
|
||||
# Please note that the counter increment/decrement operations are not synchronized through this
|
||||
# lock, as they are designed to be non-blocking and can be performed concurrently.
|
||||
self._registry_lock = RWLock()
|
||||
# A dictionary to hold LoRARef objects, mapping from LoRA name to LoRARef.
|
||||
self._registry: Dict[str, LoRARef] = dict(lora_paths or {})
|
||||
self._registry: Dict[str, LoRARef] = {}
|
||||
# Counters for ongoing requests, mapping from LoRA ID to ConcurrentCounter.
|
||||
self._counters: Dict[str, ConcurrentCounter] = {}
|
||||
|
||||
# Initialize the registry with provided LoRA paths, if present.
|
||||
if lora_paths:
|
||||
for lora_ref in lora_paths.values():
|
||||
self._register_adapter(lora_ref)
|
||||
|
||||
async def register(self, lora_ref: LoRARef):
|
||||
"""
|
||||
@@ -72,11 +88,8 @@ class LoRARegistry:
|
||||
Args:
|
||||
lora_ref (LoRARef): The LoRARef object to register.
|
||||
"""
|
||||
if lora_ref.lora_name in self._registry:
|
||||
raise ValueError(
|
||||
f"LoRA with name {lora_ref.lora_name} already exists. Loaded LoRAs: {self._registry.keys()}"
|
||||
)
|
||||
self._registry[lora_ref.lora_name] = lora_ref
|
||||
async with self._registry_lock.writer_lock:
|
||||
self._register_adapter(lora_ref)
|
||||
|
||||
async def unregister(self, lora_name: str) -> str:
|
||||
"""
|
||||
@@ -85,12 +98,14 @@ class LoRARegistry:
|
||||
Args:
|
||||
lora_name (str): The name of the LoRA model to unregister.
|
||||
"""
|
||||
lora_ref = self._registry.get(lora_name, None)
|
||||
if lora_ref is None:
|
||||
raise ValueError(
|
||||
f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}"
|
||||
)
|
||||
del self._registry[lora_name]
|
||||
async with self._registry_lock.writer_lock:
|
||||
lora_ref = self._registry.get(lora_name, None)
|
||||
if lora_ref is None:
|
||||
raise ValueError(
|
||||
f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}"
|
||||
)
|
||||
del self._registry[lora_name]
|
||||
del self._counters[lora_ref.lora_id]
|
||||
|
||||
return lora_ref.lora_id
|
||||
|
||||
@@ -98,27 +113,76 @@ class LoRARegistry:
|
||||
"""
|
||||
Queries registry for LoRA IDs based on LoRA names and start tracking the usage of the corresponding LoRA adapters
|
||||
by incrementing its counter.
|
||||
|
||||
TODO (lifuhuang): currently it only queries the registry and does not track the usage of LoRA adapters.
|
||||
"""
|
||||
|
||||
async def _acquire_single(name: str) -> str:
|
||||
def _lookup(name: str) -> str:
|
||||
lora_ref = self._registry.get(name, None)
|
||||
if lora_ref is None:
|
||||
raise ValueError(
|
||||
f"The following requested LoRA adapters are not loaded: {name}\n"
|
||||
f"Loaded adapters: {self._registry.keys()}."
|
||||
)
|
||||
# await self._counters[lora_ref.lora_id].increment()
|
||||
return lora_ref.lora_id
|
||||
|
||||
if isinstance(lora_name, str):
|
||||
lora_id = await _acquire_single(lora_name)
|
||||
return lora_id
|
||||
elif isinstance(lora_name, list):
|
||||
lora_ids = await asyncio.gather(
|
||||
*[_acquire_single(name) for name in lora_name]
|
||||
async with self._registry_lock.reader_lock:
|
||||
if isinstance(lora_name, str):
|
||||
lora_id = _lookup(lora_name)
|
||||
await self._counters[lora_id].increment(notify_all=False)
|
||||
return lora_id
|
||||
elif isinstance(lora_name, list):
|
||||
lora_ids = [_lookup(name) for name in lora_name]
|
||||
|
||||
# Increment the counters only after all IDs are looked up.
|
||||
await asyncio.gather(
|
||||
*[self._counters[id].increment(notify_all=False) for id in lora_ids]
|
||||
)
|
||||
return lora_ids
|
||||
else:
|
||||
raise TypeError(
|
||||
"lora_name must be either a string or a list of strings."
|
||||
)
|
||||
|
||||
async def release(self, lora_id: Union[str, List[str]]):
|
||||
"""
|
||||
Decrements the usage counter for a LoRA adapter, indicating that it is no longer in use.
|
||||
"""
|
||||
|
||||
async with self._registry_lock.reader_lock:
|
||||
if isinstance(lora_id, str):
|
||||
await self._counters[lora_id].decrement()
|
||||
elif isinstance(lora_id, list):
|
||||
await asyncio.gather(
|
||||
*[self._counters[id].decrement() for id in lora_id]
|
||||
)
|
||||
else:
|
||||
raise TypeError("lora_id must be either a string or a list of strings.")
|
||||
|
||||
async def wait_for_unload(self, lora_id: str):
|
||||
"""
|
||||
Waits until the usage counter for a LoRA adapter reaches zero, indicating that it is no longer in use.
|
||||
This is useful for ensuring that a LoRA adapter can be safely unloaded.
|
||||
|
||||
This method itself is not synchronized, which is safe because it should only be called during LoRA unloading,
|
||||
which itself is guaranteed to be sequential.
|
||||
"""
|
||||
assert (
|
||||
lora_id not in self._registry
|
||||
), "wait_for_unload should only be called after the LoRA adapter has been unregistered. "
|
||||
counter = self._counters.get(lora_id)
|
||||
if counter:
|
||||
# Wait until no requests are using this LoRA adapter.
|
||||
await counter.wait_for_zero()
|
||||
del self._counters[lora_id]
|
||||
|
||||
def _register_adapter(self, lora_ref: LoRARef):
|
||||
"""
|
||||
Internal helper method to register a LoRA adapter.
|
||||
"""
|
||||
|
||||
if lora_ref.lora_name in self._registry:
|
||||
raise ValueError(
|
||||
f"LoRA with name {lora_ref.lora_name} already exists. Loaded LoRAs: {self._registry.keys()}"
|
||||
)
|
||||
return lora_ids
|
||||
else:
|
||||
raise TypeError("lora_name must be either a string or a list of strings.")
|
||||
self._registry[lora_ref.lora_name] = lora_ref
|
||||
self._counters[lora_ref.lora_id] = ConcurrentCounter()
|
||||
return lora_ref
|
||||
|
||||
@@ -282,6 +282,11 @@ class TokenizerManager:
|
||||
None
|
||||
)
|
||||
|
||||
# Lock to serialize LoRA update operations.
|
||||
# Please note that, unlike `model_update_lock`, this does not block inference, allowing
|
||||
# LoRA updates and inference to overlap.
|
||||
self.lora_update_lock = asyncio.Lock()
|
||||
|
||||
# For pd disaggregtion
|
||||
self.disaggregation_mode = DisaggregationMode(
|
||||
self.server_args.disaggregation_mode
|
||||
@@ -537,7 +542,8 @@ class TokenizerManager:
|
||||
mm_inputs = None
|
||||
|
||||
if self.server_args.enable_lora and obj.lora_path:
|
||||
# Replace the user-friendly LoRA names in `lora_path` with their corresponding unique LoRA IDs.
|
||||
# Start tracking ongoing requests for LoRA adapters and replace the user-friendly LoRA names in
|
||||
# `lora_path` with their corresponding unique LoRA IDs, as required for internal processing.
|
||||
obj.lora_path = await self.lora_registry.acquire(obj.lora_path)
|
||||
|
||||
self._validate_one_request(obj, input_ids)
|
||||
@@ -747,6 +753,10 @@ class TokenizerManager:
|
||||
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
|
||||
logger.info(msg)
|
||||
|
||||
# Mark ongoing LoRA request as finished.
|
||||
if self.server_args.enable_lora and obj.lora_path:
|
||||
await self.lora_registry.release(obj.lora_path)
|
||||
|
||||
# Check if this was an abort/error created by scheduler
|
||||
if isinstance(out["meta_info"].get("finish_reason"), dict):
|
||||
finish_reason = out["meta_info"]["finish_reason"]
|
||||
@@ -1053,16 +1063,18 @@ class TokenizerManager:
|
||||
obj.lora_path,
|
||||
)
|
||||
|
||||
async with self.model_update_lock.writer_lock:
|
||||
async with self.lora_update_lock:
|
||||
# Generate new uniquely identifiable LoRARef object.
|
||||
new_adapter = LoRARef(
|
||||
lora_name=obj.lora_name,
|
||||
lora_path=obj.lora_path,
|
||||
)
|
||||
|
||||
# Register the new adapter in the registry.
|
||||
# Trigger the actual loading operation at the backend processes.
|
||||
obj.lora_id = new_adapter.lora_id
|
||||
result = (await self.update_lora_adapter_communicator(obj))[0]
|
||||
|
||||
# Register the LoRA adapter only after loading is successful.
|
||||
if result.success:
|
||||
await self.lora_registry.register(new_adapter)
|
||||
|
||||
@@ -1093,8 +1105,15 @@ class TokenizerManager:
|
||||
obj.lora_name,
|
||||
)
|
||||
|
||||
async with self.model_update_lock.writer_lock:
|
||||
obj.lora_id = await self.lora_registry.unregister(obj.lora_name)
|
||||
async with self.lora_update_lock:
|
||||
# Unregister the LoRA adapter from the registry to stop new requests for this adapter
|
||||
# from being started.
|
||||
lora_id = await self.lora_registry.unregister(obj.lora_name)
|
||||
obj.lora_id = lora_id
|
||||
|
||||
# Initiate the actual unloading operation at the backend processes only after all
|
||||
# ongoing requests using this LoRA adapter are finished.
|
||||
await self.lora_registry.wait_for_unload(lora_id)
|
||||
result = (await self.update_lora_adapter_communicator(obj))[0]
|
||||
|
||||
return result
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import builtins
|
||||
import ctypes
|
||||
import dataclasses
|
||||
@@ -2862,3 +2863,89 @@ SUPPORTED_LORA_TARGET_MODULES = [
|
||||
]
|
||||
|
||||
LORA_TARGET_ALL_MODULES = "all"
|
||||
|
||||
|
||||
class ConcurrentCounter:
|
||||
"""
|
||||
An asynchronous counter for managing concurrent tasks that need
|
||||
coordinated increments, decrements, and waiting until the count reaches zero.
|
||||
|
||||
This class is useful for scenarios like tracking the number of in-flight tasks
|
||||
and waiting for them to complete.
|
||||
"""
|
||||
|
||||
def __init__(self, initial: int = 0):
|
||||
"""
|
||||
Initialize the counter with an optional initial value.
|
||||
|
||||
Args:
|
||||
initial (int): The initial value of the counter. Default is 0.
|
||||
"""
|
||||
self._count = initial
|
||||
self._condition = asyncio.Condition()
|
||||
|
||||
def value(self) -> int:
|
||||
"""
|
||||
Return the current value of the counter.
|
||||
|
||||
Note:
|
||||
This method is not synchronized. It may return a stale value
|
||||
if other coroutines are concurrently modifying the counter.
|
||||
|
||||
Returns:
|
||||
int: The current counter value.
|
||||
"""
|
||||
return self._count
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return an informative string representation of the counter."""
|
||||
return f"<ConcurrentCounter value={self.value()}>"
|
||||
|
||||
async def increment(self, n: int = 1, notify_all: bool = True):
|
||||
"""
|
||||
Atomically increment the counter by a given amount and notify all waiters.
|
||||
|
||||
Args:
|
||||
n (int): The amount to increment the counter by. Default is 1.
|
||||
notify_all (bool): Whether to notify all waiters after incrementing. Default is True.
|
||||
"""
|
||||
async with self._condition:
|
||||
self._count += n
|
||||
if notify_all:
|
||||
self._condition.notify_all()
|
||||
|
||||
async def decrement(self, n: int = 1, notify_all: bool = True):
|
||||
"""
|
||||
Atomically decrement the counter by a given amount and notify all waiters.
|
||||
|
||||
Args:
|
||||
n (int): The amount to decrement the counter by. Default is 1.
|
||||
notify_all (bool): Whether to notify all waiters after decrementing. Default is True.
|
||||
"""
|
||||
async with self._condition:
|
||||
self._count -= n
|
||||
if notify_all:
|
||||
self._condition.notify_all()
|
||||
|
||||
async def wait_for(self, condition: Callable[[int], bool]):
|
||||
"""
|
||||
Asynchronously wait until the counter satisfies a given condition.
|
||||
|
||||
This suspends the calling coroutine without blocking the thread, allowing
|
||||
other tasks to run while waiting. When the condition is met, the coroutine resumes.
|
||||
|
||||
Args:
|
||||
condition (Callable[[int], bool]): A function that takes the current counter value
|
||||
and returns True when the condition is satisfied.
|
||||
"""
|
||||
async with self._condition:
|
||||
await self._condition.wait_for(lambda: condition(self._count))
|
||||
|
||||
async def wait_for_zero(self):
|
||||
"""
|
||||
Asynchronously wait until the counter reaches zero.
|
||||
|
||||
This suspends the calling coroutine without blocking the thread, allowing
|
||||
other tasks to run while waiting. When the counter becomes zero, the coroutine resumes.
|
||||
"""
|
||||
self.wait_for(lambda count: count == 0)
|
||||
|
||||
Reference in New Issue
Block a user