Fix race condition in async lora unload (#9084)

This commit is contained in:
Lifu Huang
2025-08-11 22:59:29 -07:00
committed by GitHub
parent 4093d460ce
commit 5ded39cab2
4 changed files with 35 additions and 19 deletions

View File

@@ -485,6 +485,10 @@ class TokenizerManager:
await self.is_pause_cond.wait_for(lambda: not self.is_pause)
async with self.model_update_lock.reader_lock:
if self.server_args.enable_lora and obj.lora_path:
# Look up the LoRA ID from the registry and start tracking ongoing LoRA requests.
obj.lora_id = await self.lora_registry.acquire(obj.lora_path)
if obj.is_single:
tokenized_obj = await self._tokenize_one_request(obj)
state = self._send_one_request(obj, tokenized_obj, created_time)
@@ -552,11 +556,6 @@ class TokenizerManager:
else:
mm_inputs = None
if self.server_args.enable_lora and obj.lora_path:
# Start tracking ongoing requests for LoRA adapters and replace the user-friendly LoRA names in
# `lora_path` with their corresponding unique LoRA IDs, as required for internal processing.
obj.lora_id = await self.lora_registry.acquire(obj.lora_path)
self._validate_one_request(obj, input_ids)
return self._create_tokenized_object(
obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
@@ -774,10 +773,6 @@ class TokenizerManager:
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
logger.info(msg)
# Mark ongoing LoRA request as finished.
if self.server_args.enable_lora and obj.lora_path:
await self.lora_registry.release(obj.lora_id)
# Check if this was an abort/error created by scheduler
if isinstance(out["meta_info"].get("finish_reason"), dict):
finish_reason = out["meta_info"]["finish_reason"]
@@ -796,6 +791,11 @@ class TokenizerManager:
# Delete the key to prevent resending abort request to the scheduler and
# to ensure aborted request state is cleaned up.
del self.rid_to_state[state.obj.rid]
# Mark ongoing LoRA request as finished.
if self.server_args.enable_lora and state.obj.lora_path:
await self.lora_registry.release(state.obj.lora_id)
raise fastapi.HTTPException(
status_code=finish_reason["status_code"],
detail=finish_reason["message"],
@@ -1599,6 +1599,10 @@ class TokenizerManager:
meta_info["e2e_latency"] = state.finished_time - state.created_time
del self.rid_to_state[rid]
# Mark ongoing LoRA request as finished.
if self.server_args.enable_lora and state.obj.lora_path:
asyncio.create_task(self.lora_registry.release(state.obj.lora_id))
state.out_list.append(out_dict)
state.event.set()