Support overlapped lora updates (#8213)
This commit is contained in:
@@ -14,10 +14,14 @@
|
|||||||
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass, field, fields
|
from dataclasses import dataclass, field, fields
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from sglang.srt.aio_rwlock import RWLock
|
||||||
|
from sglang.srt.utils import ConcurrentCounter
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class LoRARef:
|
class LoRARef:
|
||||||
@@ -48,10 +52,11 @@ class LoRARef:
|
|||||||
|
|
||||||
class LoRARegistry:
|
class LoRARegistry:
|
||||||
"""
|
"""
|
||||||
The central registry to keep track of available LoRA adapters.
|
The central registry to keep track of available LoRA adapters and ongoing LoRA requests.
|
||||||
|
|
||||||
TODO (lifuhuang): This registry is intended as the foundation for overlapped lora update. We decided
|
The `LoRARegistry` resides in the tokenizer manager process and acts as the single source of truth for all
|
||||||
to keep it in a separate PR to keep code review simple and to unblock the radix cache work.
|
available LoRA adapters. It supports concurrent inference and dynamic adapter updates through a two-phase
|
||||||
|
update / eventual consistency model between the tokenizer manager process and the scheduler processes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
|
def __init__(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
|
||||||
@@ -62,8 +67,19 @@ class LoRARegistry:
|
|||||||
"Please file an issue if you see this error."
|
"Please file an issue if you see this error."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# A read-write lock to ensure adapters loading / unloading operations are exclusive.
|
||||||
|
# Please note that the counter increment/decrement operations are not synchronized through this
|
||||||
|
# lock, as they are designed to be non-blocking and can be performed concurrently.
|
||||||
|
self._registry_lock = RWLock()
|
||||||
# A dictionary to hold LoRARef objects, mapping from LoRA name to LoRARef.
|
# A dictionary to hold LoRARef objects, mapping from LoRA name to LoRARef.
|
||||||
self._registry: Dict[str, LoRARef] = dict(lora_paths or {})
|
self._registry: Dict[str, LoRARef] = {}
|
||||||
|
# Counters for ongoing requests, mapping from LoRA ID to ConcurrentCounter.
|
||||||
|
self._counters: Dict[str, ConcurrentCounter] = {}
|
||||||
|
|
||||||
|
# Initialize the registry with provided LoRA paths, if present.
|
||||||
|
if lora_paths:
|
||||||
|
for lora_ref in lora_paths.values():
|
||||||
|
self._register_adapter(lora_ref)
|
||||||
|
|
||||||
async def register(self, lora_ref: LoRARef):
|
async def register(self, lora_ref: LoRARef):
|
||||||
"""
|
"""
|
||||||
@@ -72,11 +88,8 @@ class LoRARegistry:
|
|||||||
Args:
|
Args:
|
||||||
lora_ref (LoRARef): The LoRARef object to register.
|
lora_ref (LoRARef): The LoRARef object to register.
|
||||||
"""
|
"""
|
||||||
if lora_ref.lora_name in self._registry:
|
async with self._registry_lock.writer_lock:
|
||||||
raise ValueError(
|
self._register_adapter(lora_ref)
|
||||||
f"LoRA with name {lora_ref.lora_name} already exists. Loaded LoRAs: {self._registry.keys()}"
|
|
||||||
)
|
|
||||||
self._registry[lora_ref.lora_name] = lora_ref
|
|
||||||
|
|
||||||
async def unregister(self, lora_name: str) -> str:
|
async def unregister(self, lora_name: str) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -85,12 +98,14 @@ class LoRARegistry:
|
|||||||
Args:
|
Args:
|
||||||
lora_name (str): The name of the LoRA model to unregister.
|
lora_name (str): The name of the LoRA model to unregister.
|
||||||
"""
|
"""
|
||||||
lora_ref = self._registry.get(lora_name, None)
|
async with self._registry_lock.writer_lock:
|
||||||
if lora_ref is None:
|
lora_ref = self._registry.get(lora_name, None)
|
||||||
raise ValueError(
|
if lora_ref is None:
|
||||||
f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}"
|
raise ValueError(
|
||||||
)
|
f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}"
|
||||||
del self._registry[lora_name]
|
)
|
||||||
|
del self._registry[lora_name]
|
||||||
|
del self._counters[lora_ref.lora_id]
|
||||||
|
|
||||||
return lora_ref.lora_id
|
return lora_ref.lora_id
|
||||||
|
|
||||||
@@ -98,27 +113,76 @@ class LoRARegistry:
|
|||||||
"""
|
"""
|
||||||
Queries registry for LoRA IDs based on LoRA names and start tracking the usage of the corresponding LoRA adapters
|
Queries registry for LoRA IDs based on LoRA names and start tracking the usage of the corresponding LoRA adapters
|
||||||
by incrementing its counter.
|
by incrementing its counter.
|
||||||
|
|
||||||
TODO (lifuhuang): currently it only queries the registry and does not track the usage of LoRA adapters.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def _acquire_single(name: str) -> str:
|
def _lookup(name: str) -> str:
|
||||||
lora_ref = self._registry.get(name, None)
|
lora_ref = self._registry.get(name, None)
|
||||||
if lora_ref is None:
|
if lora_ref is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The following requested LoRA adapters are not loaded: {name}\n"
|
f"The following requested LoRA adapters are not loaded: {name}\n"
|
||||||
f"Loaded adapters: {self._registry.keys()}."
|
f"Loaded adapters: {self._registry.keys()}."
|
||||||
)
|
)
|
||||||
# await self._counters[lora_ref.lora_id].increment()
|
|
||||||
return lora_ref.lora_id
|
return lora_ref.lora_id
|
||||||
|
|
||||||
if isinstance(lora_name, str):
|
async with self._registry_lock.reader_lock:
|
||||||
lora_id = await _acquire_single(lora_name)
|
if isinstance(lora_name, str):
|
||||||
return lora_id
|
lora_id = _lookup(lora_name)
|
||||||
elif isinstance(lora_name, list):
|
await self._counters[lora_id].increment(notify_all=False)
|
||||||
lora_ids = await asyncio.gather(
|
return lora_id
|
||||||
*[_acquire_single(name) for name in lora_name]
|
elif isinstance(lora_name, list):
|
||||||
|
lora_ids = [_lookup(name) for name in lora_name]
|
||||||
|
|
||||||
|
# Increment the counters only after all IDs are looked up.
|
||||||
|
await asyncio.gather(
|
||||||
|
*[self._counters[id].increment(notify_all=False) for id in lora_ids]
|
||||||
|
)
|
||||||
|
return lora_ids
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
"lora_name must be either a string or a list of strings."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def release(self, lora_id: Union[str, List[str]]):
|
||||||
|
"""
|
||||||
|
Decrements the usage counter for a LoRA adapter, indicating that it is no longer in use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async with self._registry_lock.reader_lock:
|
||||||
|
if isinstance(lora_id, str):
|
||||||
|
await self._counters[lora_id].decrement()
|
||||||
|
elif isinstance(lora_id, list):
|
||||||
|
await asyncio.gather(
|
||||||
|
*[self._counters[id].decrement() for id in lora_id]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise TypeError("lora_id must be either a string or a list of strings.")
|
||||||
|
|
||||||
|
async def wait_for_unload(self, lora_id: str):
|
||||||
|
"""
|
||||||
|
Waits until the usage counter for a LoRA adapter reaches zero, indicating that it is no longer in use.
|
||||||
|
This is useful for ensuring that a LoRA adapter can be safely unloaded.
|
||||||
|
|
||||||
|
This method itself is not synchronized, which is safe because it should only be called during LoRA unloading,
|
||||||
|
which itself is guaranteed to be sequential.
|
||||||
|
"""
|
||||||
|
assert (
|
||||||
|
lora_id not in self._registry
|
||||||
|
), "wait_for_unload should only be called after the LoRA adapter has been unregistered. "
|
||||||
|
counter = self._counters.get(lora_id)
|
||||||
|
if counter:
|
||||||
|
# Wait until no requests are using this LoRA adapter.
|
||||||
|
await counter.wait_for_zero()
|
||||||
|
del self._counters[lora_id]
|
||||||
|
|
||||||
|
def _register_adapter(self, lora_ref: LoRARef):
|
||||||
|
"""
|
||||||
|
Internal helper method to register a LoRA adapter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if lora_ref.lora_name in self._registry:
|
||||||
|
raise ValueError(
|
||||||
|
f"LoRA with name {lora_ref.lora_name} already exists. Loaded LoRAs: {self._registry.keys()}"
|
||||||
)
|
)
|
||||||
return lora_ids
|
self._registry[lora_ref.lora_name] = lora_ref
|
||||||
else:
|
self._counters[lora_ref.lora_id] = ConcurrentCounter()
|
||||||
raise TypeError("lora_name must be either a string or a list of strings.")
|
return lora_ref
|
||||||
|
|||||||
@@ -282,6 +282,11 @@ class TokenizerManager:
|
|||||||
None
|
None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Lock to serialize LoRA update operations.
|
||||||
|
# Please note that, unlike `model_update_lock`, this does not block inference, allowing
|
||||||
|
# LoRA updates and inference to overlap.
|
||||||
|
self.lora_update_lock = asyncio.Lock()
|
||||||
|
|
||||||
# For pd disaggregtion
|
# For pd disaggregtion
|
||||||
self.disaggregation_mode = DisaggregationMode(
|
self.disaggregation_mode = DisaggregationMode(
|
||||||
self.server_args.disaggregation_mode
|
self.server_args.disaggregation_mode
|
||||||
@@ -537,7 +542,8 @@ class TokenizerManager:
|
|||||||
mm_inputs = None
|
mm_inputs = None
|
||||||
|
|
||||||
if self.server_args.enable_lora and obj.lora_path:
|
if self.server_args.enable_lora and obj.lora_path:
|
||||||
# Replace the user-friendly LoRA names in `lora_path` with their corresponding unique LoRA IDs.
|
# Start tracking ongoing requests for LoRA adapters and replace the user-friendly LoRA names in
|
||||||
|
# `lora_path` with their corresponding unique LoRA IDs, as required for internal processing.
|
||||||
obj.lora_path = await self.lora_registry.acquire(obj.lora_path)
|
obj.lora_path = await self.lora_registry.acquire(obj.lora_path)
|
||||||
|
|
||||||
self._validate_one_request(obj, input_ids)
|
self._validate_one_request(obj, input_ids)
|
||||||
@@ -747,6 +753,10 @@ class TokenizerManager:
|
|||||||
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
|
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
|
||||||
logger.info(msg)
|
logger.info(msg)
|
||||||
|
|
||||||
|
# Mark ongoing LoRA request as finished.
|
||||||
|
if self.server_args.enable_lora and obj.lora_path:
|
||||||
|
await self.lora_registry.release(obj.lora_path)
|
||||||
|
|
||||||
# Check if this was an abort/error created by scheduler
|
# Check if this was an abort/error created by scheduler
|
||||||
if isinstance(out["meta_info"].get("finish_reason"), dict):
|
if isinstance(out["meta_info"].get("finish_reason"), dict):
|
||||||
finish_reason = out["meta_info"]["finish_reason"]
|
finish_reason = out["meta_info"]["finish_reason"]
|
||||||
@@ -1053,16 +1063,18 @@ class TokenizerManager:
|
|||||||
obj.lora_path,
|
obj.lora_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
async with self.model_update_lock.writer_lock:
|
async with self.lora_update_lock:
|
||||||
# Generate new uniquely identifiable LoRARef object.
|
# Generate new uniquely identifiable LoRARef object.
|
||||||
new_adapter = LoRARef(
|
new_adapter = LoRARef(
|
||||||
lora_name=obj.lora_name,
|
lora_name=obj.lora_name,
|
||||||
lora_path=obj.lora_path,
|
lora_path=obj.lora_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Register the new adapter in the registry.
|
# Trigger the actual loading operation at the backend processes.
|
||||||
obj.lora_id = new_adapter.lora_id
|
obj.lora_id = new_adapter.lora_id
|
||||||
result = (await self.update_lora_adapter_communicator(obj))[0]
|
result = (await self.update_lora_adapter_communicator(obj))[0]
|
||||||
|
|
||||||
|
# Register the LoRA adapter only after loading is successful.
|
||||||
if result.success:
|
if result.success:
|
||||||
await self.lora_registry.register(new_adapter)
|
await self.lora_registry.register(new_adapter)
|
||||||
|
|
||||||
@@ -1093,8 +1105,15 @@ class TokenizerManager:
|
|||||||
obj.lora_name,
|
obj.lora_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
async with self.model_update_lock.writer_lock:
|
async with self.lora_update_lock:
|
||||||
obj.lora_id = await self.lora_registry.unregister(obj.lora_name)
|
# Unregister the LoRA adapter from the registry to stop new requests for this adapter
|
||||||
|
# from being started.
|
||||||
|
lora_id = await self.lora_registry.unregister(obj.lora_name)
|
||||||
|
obj.lora_id = lora_id
|
||||||
|
|
||||||
|
# Initiate the actual unloading operation at the backend processes only after all
|
||||||
|
# ongoing requests using this LoRA adapter are finished.
|
||||||
|
await self.lora_registry.wait_for_unload(lora_id)
|
||||||
result = (await self.update_lora_adapter_communicator(obj))[0]
|
result = (await self.update_lora_adapter_communicator(obj))[0]
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import builtins
|
import builtins
|
||||||
import ctypes
|
import ctypes
|
||||||
import dataclasses
|
import dataclasses
|
||||||
@@ -2862,3 +2863,89 @@ SUPPORTED_LORA_TARGET_MODULES = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
LORA_TARGET_ALL_MODULES = "all"
|
LORA_TARGET_ALL_MODULES = "all"
|
||||||
|
|
||||||
|
|
||||||
|
class ConcurrentCounter:
|
||||||
|
"""
|
||||||
|
An asynchronous counter for managing concurrent tasks that need
|
||||||
|
coordinated increments, decrements, and waiting until the count reaches zero.
|
||||||
|
|
||||||
|
This class is useful for scenarios like tracking the number of in-flight tasks
|
||||||
|
and waiting for them to complete.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, initial: int = 0):
|
||||||
|
"""
|
||||||
|
Initialize the counter with an optional initial value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
initial (int): The initial value of the counter. Default is 0.
|
||||||
|
"""
|
||||||
|
self._count = initial
|
||||||
|
self._condition = asyncio.Condition()
|
||||||
|
|
||||||
|
def value(self) -> int:
|
||||||
|
"""
|
||||||
|
Return the current value of the counter.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This method is not synchronized. It may return a stale value
|
||||||
|
if other coroutines are concurrently modifying the counter.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The current counter value.
|
||||||
|
"""
|
||||||
|
return self._count
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
"""Return an informative string representation of the counter."""
|
||||||
|
return f"<ConcurrentCounter value={self.value()}>"
|
||||||
|
|
||||||
|
async def increment(self, n: int = 1, notify_all: bool = True):
|
||||||
|
"""
|
||||||
|
Atomically increment the counter by a given amount and notify all waiters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n (int): The amount to increment the counter by. Default is 1.
|
||||||
|
notify_all (bool): Whether to notify all waiters after incrementing. Default is True.
|
||||||
|
"""
|
||||||
|
async with self._condition:
|
||||||
|
self._count += n
|
||||||
|
if notify_all:
|
||||||
|
self._condition.notify_all()
|
||||||
|
|
||||||
|
async def decrement(self, n: int = 1, notify_all: bool = True):
|
||||||
|
"""
|
||||||
|
Atomically decrement the counter by a given amount and notify all waiters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n (int): The amount to decrement the counter by. Default is 1.
|
||||||
|
notify_all (bool): Whether to notify all waiters after decrementing. Default is True.
|
||||||
|
"""
|
||||||
|
async with self._condition:
|
||||||
|
self._count -= n
|
||||||
|
if notify_all:
|
||||||
|
self._condition.notify_all()
|
||||||
|
|
||||||
|
async def wait_for(self, condition: Callable[[int], bool]):
|
||||||
|
"""
|
||||||
|
Asynchronously wait until the counter satisfies a given condition.
|
||||||
|
|
||||||
|
This suspends the calling coroutine without blocking the thread, allowing
|
||||||
|
other tasks to run while waiting. When the condition is met, the coroutine resumes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
condition (Callable[[int], bool]): A function that takes the current counter value
|
||||||
|
and returns True when the condition is satisfied.
|
||||||
|
"""
|
||||||
|
async with self._condition:
|
||||||
|
await self._condition.wait_for(lambda: condition(self._count))
|
||||||
|
|
||||||
|
async def wait_for_zero(self):
|
||||||
|
"""
|
||||||
|
Asynchronously wait until the counter reaches zero.
|
||||||
|
|
||||||
|
This suspends the calling coroutine without blocking the thread, allowing
|
||||||
|
other tasks to run while waiting. When the counter becomes zero, the coroutine resumes.
|
||||||
|
"""
|
||||||
|
self.wait_for(lambda count: count == 0)
|
||||||
|
|||||||
@@ -231,8 +231,7 @@ class TestBenchServing(CustomTestCase):
|
|||||||
f"median_ttft_ms: {res['median_ttft_ms']:.2f} ms\n"
|
f"median_ttft_ms: {res['median_ttft_ms']:.2f} ms\n"
|
||||||
)
|
)
|
||||||
self.assertLess(res["median_e2e_latency_ms"], 4000)
|
self.assertLess(res["median_e2e_latency_ms"], 4000)
|
||||||
# TODO (lifuhuang): This will be fixed by the overlapped LoRA update in a separate PR.
|
self.assertLess(res["median_ttft_ms"], 80)
|
||||||
self.assertLess(res["median_ttft_ms"], 1600)
|
|
||||||
|
|
||||||
def _run_lora_latency_test(self, enable_background_task: bool):
|
def _run_lora_latency_test(self, enable_background_task: bool):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user