Support overlapped lora updates (#8213)

This commit is contained in:
Lifu Huang
2025-07-27 13:00:44 -07:00
committed by GitHub
parent 95217a9b4d
commit df90645525
4 changed files with 204 additions and 35 deletions

View File

@@ -15,6 +15,7 @@
from __future__ import annotations
import asyncio
import builtins
import ctypes
import dataclasses
@@ -2862,3 +2863,89 @@ SUPPORTED_LORA_TARGET_MODULES = [
]
LORA_TARGET_ALL_MODULES = "all"
class ConcurrentCounter:
"""
An asynchronous counter for managing concurrent tasks that need
coordinated increments, decrements, and waiting until the count reaches zero.
This class is useful for scenarios like tracking the number of in-flight tasks
and waiting for them to complete.
"""
def __init__(self, initial: int = 0):
"""
Initialize the counter with an optional initial value.
Args:
initial (int): The initial value of the counter. Default is 0.
"""
self._count = initial
self._condition = asyncio.Condition()
def value(self) -> int:
"""
Return the current value of the counter.
Note:
This method is not synchronized. It may return a stale value
if other coroutines are concurrently modifying the counter.
Returns:
int: The current counter value.
"""
return self._count
def __repr__(self) -> str:
"""Return an informative string representation of the counter."""
return f"<ConcurrentCounter value={self.value()}>"
async def increment(self, n: int = 1, notify_all: bool = True):
"""
Atomically increment the counter by a given amount and notify all waiters.
Args:
n (int): The amount to increment the counter by. Default is 1.
notify_all (bool): Whether to notify all waiters after incrementing. Default is True.
"""
async with self._condition:
self._count += n
if notify_all:
self._condition.notify_all()
async def decrement(self, n: int = 1, notify_all: bool = True):
"""
Atomically decrement the counter by a given amount and notify all waiters.
Args:
n (int): The amount to decrement the counter by. Default is 1.
notify_all (bool): Whether to notify all waiters after decrementing. Default is True.
"""
async with self._condition:
self._count -= n
if notify_all:
self._condition.notify_all()
async def wait_for(self, condition: Callable[[int], bool]):
"""
Asynchronously wait until the counter satisfies a given condition.
This suspends the calling coroutine without blocking the thread, allowing
other tasks to run while waiting. When the condition is met, the coroutine resumes.
Args:
condition (Callable[[int], bool]): A function that takes the current counter value
and returns True when the condition is satisfied.
"""
async with self._condition:
await self._condition.wait_for(lambda: condition(self._count))
async def wait_for_zero(self):
"""
Asynchronously wait until the counter reaches zero.
This suspends the calling coroutine without blocking the thread, allowing
other tasks to run while waiting. When the counter becomes zero, the coroutine resumes.
"""
self.wait_for(lambda count: count == 0)