Support overlapped lora updates (#8213)
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import builtins
|
||||
import ctypes
|
||||
import dataclasses
|
||||
@@ -2862,3 +2863,89 @@ SUPPORTED_LORA_TARGET_MODULES = [
|
||||
]
|
||||
|
||||
LORA_TARGET_ALL_MODULES = "all"
|
||||
|
||||
|
||||
class ConcurrentCounter:
|
||||
"""
|
||||
An asynchronous counter for managing concurrent tasks that need
|
||||
coordinated increments, decrements, and waiting until the count reaches zero.
|
||||
|
||||
This class is useful for scenarios like tracking the number of in-flight tasks
|
||||
and waiting for them to complete.
|
||||
"""
|
||||
|
||||
def __init__(self, initial: int = 0):
|
||||
"""
|
||||
Initialize the counter with an optional initial value.
|
||||
|
||||
Args:
|
||||
initial (int): The initial value of the counter. Default is 0.
|
||||
"""
|
||||
self._count = initial
|
||||
self._condition = asyncio.Condition()
|
||||
|
||||
def value(self) -> int:
|
||||
"""
|
||||
Return the current value of the counter.
|
||||
|
||||
Note:
|
||||
This method is not synchronized. It may return a stale value
|
||||
if other coroutines are concurrently modifying the counter.
|
||||
|
||||
Returns:
|
||||
int: The current counter value.
|
||||
"""
|
||||
return self._count
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return an informative string representation of the counter."""
|
||||
return f"<ConcurrentCounter value={self.value()}>"
|
||||
|
||||
async def increment(self, n: int = 1, notify_all: bool = True):
|
||||
"""
|
||||
Atomically increment the counter by a given amount and notify all waiters.
|
||||
|
||||
Args:
|
||||
n (int): The amount to increment the counter by. Default is 1.
|
||||
notify_all (bool): Whether to notify all waiters after incrementing. Default is True.
|
||||
"""
|
||||
async with self._condition:
|
||||
self._count += n
|
||||
if notify_all:
|
||||
self._condition.notify_all()
|
||||
|
||||
async def decrement(self, n: int = 1, notify_all: bool = True):
|
||||
"""
|
||||
Atomically decrement the counter by a given amount and notify all waiters.
|
||||
|
||||
Args:
|
||||
n (int): The amount to decrement the counter by. Default is 1.
|
||||
notify_all (bool): Whether to notify all waiters after decrementing. Default is True.
|
||||
"""
|
||||
async with self._condition:
|
||||
self._count -= n
|
||||
if notify_all:
|
||||
self._condition.notify_all()
|
||||
|
||||
async def wait_for(self, condition: Callable[[int], bool]):
|
||||
"""
|
||||
Asynchronously wait until the counter satisfies a given condition.
|
||||
|
||||
This suspends the calling coroutine without blocking the thread, allowing
|
||||
other tasks to run while waiting. When the condition is met, the coroutine resumes.
|
||||
|
||||
Args:
|
||||
condition (Callable[[int], bool]): A function that takes the current counter value
|
||||
and returns True when the condition is satisfied.
|
||||
"""
|
||||
async with self._condition:
|
||||
await self._condition.wait_for(lambda: condition(self._count))
|
||||
|
||||
async def wait_for_zero(self):
|
||||
"""
|
||||
Asynchronously wait until the counter reaches zero.
|
||||
|
||||
This suspends the calling coroutine without blocking the thread, allowing
|
||||
other tasks to run while waiting. When the counter becomes zero, the coroutine resumes.
|
||||
"""
|
||||
self.wait_for(lambda count: count == 0)
|
||||
|
||||
Reference in New Issue
Block a user