Fix race condition in async lora unload (#9084)

This commit is contained in:
Lifu Huang
2025-08-11 22:59:29 -07:00
committed by GitHub
parent 4093d460ce
commit 5ded39cab2
4 changed files with 35 additions and 19 deletions

View File

@@ -14,7 +14,6 @@
import asyncio
from collections import defaultdict
from dataclasses import dataclass, field, fields
from typing import Dict, List, Optional, Union
from uuid import uuid4
@@ -106,7 +105,6 @@ class LoRARegistry:
f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}"
)
del self._registry[lora_name]
del self._counters[lora_ref.lora_id]
return lora_ref.lora_id
@@ -117,6 +115,9 @@ class LoRARegistry:
"""
def _lookup(name: str) -> str:
if name is None:
return None
lora_ref = self._registry.get(name, None)
if lora_ref is None:
raise ValueError(
@@ -135,7 +136,11 @@ class LoRARegistry:
# Increment the counters only after all IDs are looked up.
await asyncio.gather(
*[self._counters[id].increment(notify_all=False) for id in lora_ids]
*[
self._counters[id].increment(notify_all=False)
for id in lora_ids
if id is not None
]
)
return lora_ids
else:
@@ -153,7 +158,11 @@ class LoRARegistry:
await self._counters[lora_id].decrement()
elif isinstance(lora_id, list):
await asyncio.gather(
*[self._counters[id].decrement() for id in lora_id]
*[
self._counters[id].decrement()
for id in lora_id
if id is not None
]
)
else:
raise TypeError("lora_id must be either a string or a list of strings.")
@@ -169,11 +178,13 @@ class LoRARegistry:
assert (
lora_id not in self._registry
), "wait_for_unload should only be called after the LoRA adapter has been unregistered. "
counter = self._counters.get(lora_id)
if counter:
# Wait until no requests are using this LoRA adapter.
await counter.wait_for_zero()
del self._counters[lora_id]
assert (
lora_id in self._counters
), "The LoRA ID should still have a counter if it has been registered before."
# Wait until no requests are using this LoRA adapter.
await self._counters[lora_id].wait_for_zero()
del self._counters[lora_id]
def _register_adapter(self, lora_ref: LoRARef):
"""