Fix race condition in async lora unload (#9084)
This commit is contained in:
@@ -14,7 +14,6 @@
|
||||
|
||||
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import Dict, List, Optional, Union
|
||||
from uuid import uuid4
|
||||
@@ -106,7 +105,6 @@ class LoRARegistry:
|
||||
f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}"
|
||||
)
|
||||
del self._registry[lora_name]
|
||||
del self._counters[lora_ref.lora_id]
|
||||
|
||||
return lora_ref.lora_id
|
||||
|
||||
@@ -117,6 +115,9 @@ class LoRARegistry:
|
||||
"""
|
||||
|
||||
def _lookup(name: str) -> str:
|
||||
if name is None:
|
||||
return None
|
||||
|
||||
lora_ref = self._registry.get(name, None)
|
||||
if lora_ref is None:
|
||||
raise ValueError(
|
||||
@@ -135,7 +136,11 @@ class LoRARegistry:
|
||||
|
||||
# Increment the counters only after all IDs are looked up.
|
||||
await asyncio.gather(
|
||||
*[self._counters[id].increment(notify_all=False) for id in lora_ids]
|
||||
*[
|
||||
self._counters[id].increment(notify_all=False)
|
||||
for id in lora_ids
|
||||
if id is not None
|
||||
]
|
||||
)
|
||||
return lora_ids
|
||||
else:
|
||||
@@ -153,7 +158,11 @@ class LoRARegistry:
|
||||
await self._counters[lora_id].decrement()
|
||||
elif isinstance(lora_id, list):
|
||||
await asyncio.gather(
|
||||
*[self._counters[id].decrement() for id in lora_id]
|
||||
*[
|
||||
self._counters[id].decrement()
|
||||
for id in lora_id
|
||||
if id is not None
|
||||
]
|
||||
)
|
||||
else:
|
||||
raise TypeError("lora_id must be either a string or a list of strings.")
|
||||
@@ -169,11 +178,13 @@ class LoRARegistry:
|
||||
assert (
|
||||
lora_id not in self._registry
|
||||
), "wait_for_unload should only be called after the LoRA adapter has been unregistered. "
|
||||
counter = self._counters.get(lora_id)
|
||||
if counter:
|
||||
# Wait until no requests are using this LoRA adapter.
|
||||
await counter.wait_for_zero()
|
||||
del self._counters[lora_id]
|
||||
assert (
|
||||
lora_id in self._counters
|
||||
), "The LoRA ID should still have a counter if it has been registered before."
|
||||
|
||||
# Wait until no requests are using this LoRA adapter.
|
||||
await self._counters[lora_id].wait_for_zero()
|
||||
del self._counters[lora_id]
|
||||
|
||||
def _register_adapter(self, lora_ref: LoRARef):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user