Fix race condition in async lora unload (#9084)

This commit is contained in:
Lifu Huang
2025-08-11 22:59:29 -07:00
committed by GitHub
parent 4093d460ce
commit 5ded39cab2
4 changed files with 35 additions and 19 deletions

View File

@@ -14,7 +14,6 @@
import asyncio import asyncio
from collections import defaultdict
from dataclasses import dataclass, field, fields from dataclasses import dataclass, field, fields
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from uuid import uuid4 from uuid import uuid4
@@ -106,7 +105,6 @@ class LoRARegistry:
f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}" f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}"
) )
del self._registry[lora_name] del self._registry[lora_name]
del self._counters[lora_ref.lora_id]
return lora_ref.lora_id return lora_ref.lora_id
@@ -117,6 +115,9 @@ class LoRARegistry:
""" """
def _lookup(name: str) -> str: def _lookup(name: str) -> str:
if name is None:
return None
lora_ref = self._registry.get(name, None) lora_ref = self._registry.get(name, None)
if lora_ref is None: if lora_ref is None:
raise ValueError( raise ValueError(
@@ -135,7 +136,11 @@ class LoRARegistry:
# Increment the counters only after all IDs are looked up. # Increment the counters only after all IDs are looked up.
await asyncio.gather( await asyncio.gather(
*[self._counters[id].increment(notify_all=False) for id in lora_ids] *[
self._counters[id].increment(notify_all=False)
for id in lora_ids
if id is not None
]
) )
return lora_ids return lora_ids
else: else:
@@ -153,7 +158,11 @@ class LoRARegistry:
await self._counters[lora_id].decrement() await self._counters[lora_id].decrement()
elif isinstance(lora_id, list): elif isinstance(lora_id, list):
await asyncio.gather( await asyncio.gather(
*[self._counters[id].decrement() for id in lora_id] *[
self._counters[id].decrement()
for id in lora_id
if id is not None
]
) )
else: else:
raise TypeError("lora_id must be either a string or a list of strings.") raise TypeError("lora_id must be either a string or a list of strings.")
@@ -169,10 +178,12 @@ class LoRARegistry:
assert ( assert (
lora_id not in self._registry lora_id not in self._registry
), "wait_for_unload should only be called after the LoRA adapter has been unregistered. " ), "wait_for_unload should only be called after the LoRA adapter has been unregistered. "
counter = self._counters.get(lora_id) assert (
if counter: lora_id in self._counters
), "The LoRA ID should still have a counter if it has been registered before."
# Wait until no requests are using this LoRA adapter. # Wait until no requests are using this LoRA adapter.
await counter.wait_for_zero() await self._counters[lora_id].wait_for_zero()
del self._counters[lora_id] del self._counters[lora_id]
def _register_adapter(self, lora_ref: LoRARef): def _register_adapter(self, lora_ref: LoRARef):

View File

@@ -455,6 +455,7 @@ class GenerateReqInput:
log_metrics=self.log_metrics, log_metrics=self.log_metrics,
modalities=self.modalities[i] if self.modalities else None, modalities=self.modalities[i] if self.modalities else None,
lora_path=self.lora_path[i] if self.lora_path is not None else None, lora_path=self.lora_path[i] if self.lora_path is not None else None,
lora_id=self.lora_id[i] if self.lora_id is not None else None,
custom_logit_processor=( custom_logit_processor=(
self.custom_logit_processor[i] self.custom_logit_processor[i]
if self.custom_logit_processor is not None if self.custom_logit_processor is not None

View File

@@ -485,6 +485,10 @@ class TokenizerManager:
await self.is_pause_cond.wait_for(lambda: not self.is_pause) await self.is_pause_cond.wait_for(lambda: not self.is_pause)
async with self.model_update_lock.reader_lock: async with self.model_update_lock.reader_lock:
if self.server_args.enable_lora and obj.lora_path:
# Look up the LoRA ID from the registry and start tracking ongoing LoRA requests.
obj.lora_id = await self.lora_registry.acquire(obj.lora_path)
if obj.is_single: if obj.is_single:
tokenized_obj = await self._tokenize_one_request(obj) tokenized_obj = await self._tokenize_one_request(obj)
state = self._send_one_request(obj, tokenized_obj, created_time) state = self._send_one_request(obj, tokenized_obj, created_time)
@@ -552,11 +556,6 @@ class TokenizerManager:
else: else:
mm_inputs = None mm_inputs = None
if self.server_args.enable_lora and obj.lora_path:
# Start tracking ongoing requests for LoRA adapters and replace the user-friendly LoRA names in
# `lora_path` with their corresponding unique LoRA IDs, as required for internal processing.
obj.lora_id = await self.lora_registry.acquire(obj.lora_path)
self._validate_one_request(obj, input_ids) self._validate_one_request(obj, input_ids)
return self._create_tokenized_object( return self._create_tokenized_object(
obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
@@ -774,10 +773,6 @@ class TokenizerManager:
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}" msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
logger.info(msg) logger.info(msg)
# Mark ongoing LoRA request as finished.
if self.server_args.enable_lora and obj.lora_path:
await self.lora_registry.release(obj.lora_id)
# Check if this was an abort/error created by scheduler # Check if this was an abort/error created by scheduler
if isinstance(out["meta_info"].get("finish_reason"), dict): if isinstance(out["meta_info"].get("finish_reason"), dict):
finish_reason = out["meta_info"]["finish_reason"] finish_reason = out["meta_info"]["finish_reason"]
@@ -796,6 +791,11 @@ class TokenizerManager:
# Delete the key to prevent resending abort request to the scheduler and # Delete the key to prevent resending abort request to the scheduler and
# to ensure aborted request state is cleaned up. # to ensure aborted request state is cleaned up.
del self.rid_to_state[state.obj.rid] del self.rid_to_state[state.obj.rid]
# Mark ongoing LoRA request as finished.
if self.server_args.enable_lora and state.obj.lora_path:
await self.lora_registry.release(state.obj.lora_id)
raise fastapi.HTTPException( raise fastapi.HTTPException(
status_code=finish_reason["status_code"], status_code=finish_reason["status_code"],
detail=finish_reason["message"], detail=finish_reason["message"],
@@ -1599,6 +1599,10 @@ class TokenizerManager:
meta_info["e2e_latency"] = state.finished_time - state.created_time meta_info["e2e_latency"] = state.finished_time - state.created_time
del self.rid_to_state[rid] del self.rid_to_state[rid]
# Mark ongoing LoRA request as finished.
if self.server_args.enable_lora and state.obj.lora_path:
asyncio.create_task(self.lora_registry.release(state.obj.lora_id))
state.out_list.append(out_dict) state.out_list.append(out_dict)
state.event.set() state.event.set()

View File

@@ -2960,7 +2960,7 @@ class ConcurrentCounter:
This suspends the calling coroutine without blocking the thread, allowing This suspends the calling coroutine without blocking the thread, allowing
other tasks to run while waiting. When the counter becomes zero, the coroutine resumes. other tasks to run while waiting. When the counter becomes zero, the coroutine resumes.
""" """
self.wait_for(lambda count: count == 0) await self.wait_for(lambda count: count == 0)
@lru_cache(maxsize=1) @lru_cache(maxsize=1)