Fix race condition in async lora unload (#9084)
This commit is contained in:
@@ -14,7 +14,6 @@
|
|||||||
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections import defaultdict
|
|
||||||
from dataclasses import dataclass, field, fields
|
from dataclasses import dataclass, field, fields
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
@@ -106,7 +105,6 @@ class LoRARegistry:
|
|||||||
f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}"
|
f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}"
|
||||||
)
|
)
|
||||||
del self._registry[lora_name]
|
del self._registry[lora_name]
|
||||||
del self._counters[lora_ref.lora_id]
|
|
||||||
|
|
||||||
return lora_ref.lora_id
|
return lora_ref.lora_id
|
||||||
|
|
||||||
@@ -117,6 +115,9 @@ class LoRARegistry:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def _lookup(name: str) -> str:
|
def _lookup(name: str) -> str:
|
||||||
|
if name is None:
|
||||||
|
return None
|
||||||
|
|
||||||
lora_ref = self._registry.get(name, None)
|
lora_ref = self._registry.get(name, None)
|
||||||
if lora_ref is None:
|
if lora_ref is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -135,7 +136,11 @@ class LoRARegistry:
|
|||||||
|
|
||||||
# Increment the counters only after all IDs are looked up.
|
# Increment the counters only after all IDs are looked up.
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*[self._counters[id].increment(notify_all=False) for id in lora_ids]
|
*[
|
||||||
|
self._counters[id].increment(notify_all=False)
|
||||||
|
for id in lora_ids
|
||||||
|
if id is not None
|
||||||
|
]
|
||||||
)
|
)
|
||||||
return lora_ids
|
return lora_ids
|
||||||
else:
|
else:
|
||||||
@@ -153,7 +158,11 @@ class LoRARegistry:
|
|||||||
await self._counters[lora_id].decrement()
|
await self._counters[lora_id].decrement()
|
||||||
elif isinstance(lora_id, list):
|
elif isinstance(lora_id, list):
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*[self._counters[id].decrement() for id in lora_id]
|
*[
|
||||||
|
self._counters[id].decrement()
|
||||||
|
for id in lora_id
|
||||||
|
if id is not None
|
||||||
|
]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise TypeError("lora_id must be either a string or a list of strings.")
|
raise TypeError("lora_id must be either a string or a list of strings.")
|
||||||
@@ -169,10 +178,12 @@ class LoRARegistry:
|
|||||||
assert (
|
assert (
|
||||||
lora_id not in self._registry
|
lora_id not in self._registry
|
||||||
), "wait_for_unload should only be called after the LoRA adapter has been unregistered. "
|
), "wait_for_unload should only be called after the LoRA adapter has been unregistered. "
|
||||||
counter = self._counters.get(lora_id)
|
assert (
|
||||||
if counter:
|
lora_id in self._counters
|
||||||
|
), "The LoRA ID should still have a counter if it has been registered before."
|
||||||
|
|
||||||
# Wait until no requests are using this LoRA adapter.
|
# Wait until no requests are using this LoRA adapter.
|
||||||
await counter.wait_for_zero()
|
await self._counters[lora_id].wait_for_zero()
|
||||||
del self._counters[lora_id]
|
del self._counters[lora_id]
|
||||||
|
|
||||||
def _register_adapter(self, lora_ref: LoRARef):
|
def _register_adapter(self, lora_ref: LoRARef):
|
||||||
|
|||||||
@@ -455,6 +455,7 @@ class GenerateReqInput:
|
|||||||
log_metrics=self.log_metrics,
|
log_metrics=self.log_metrics,
|
||||||
modalities=self.modalities[i] if self.modalities else None,
|
modalities=self.modalities[i] if self.modalities else None,
|
||||||
lora_path=self.lora_path[i] if self.lora_path is not None else None,
|
lora_path=self.lora_path[i] if self.lora_path is not None else None,
|
||||||
|
lora_id=self.lora_id[i] if self.lora_id is not None else None,
|
||||||
custom_logit_processor=(
|
custom_logit_processor=(
|
||||||
self.custom_logit_processor[i]
|
self.custom_logit_processor[i]
|
||||||
if self.custom_logit_processor is not None
|
if self.custom_logit_processor is not None
|
||||||
|
|||||||
@@ -485,6 +485,10 @@ class TokenizerManager:
|
|||||||
await self.is_pause_cond.wait_for(lambda: not self.is_pause)
|
await self.is_pause_cond.wait_for(lambda: not self.is_pause)
|
||||||
|
|
||||||
async with self.model_update_lock.reader_lock:
|
async with self.model_update_lock.reader_lock:
|
||||||
|
if self.server_args.enable_lora and obj.lora_path:
|
||||||
|
# Look up the LoRA ID from the registry and start tracking ongoing LoRA requests.
|
||||||
|
obj.lora_id = await self.lora_registry.acquire(obj.lora_path)
|
||||||
|
|
||||||
if obj.is_single:
|
if obj.is_single:
|
||||||
tokenized_obj = await self._tokenize_one_request(obj)
|
tokenized_obj = await self._tokenize_one_request(obj)
|
||||||
state = self._send_one_request(obj, tokenized_obj, created_time)
|
state = self._send_one_request(obj, tokenized_obj, created_time)
|
||||||
@@ -552,11 +556,6 @@ class TokenizerManager:
|
|||||||
else:
|
else:
|
||||||
mm_inputs = None
|
mm_inputs = None
|
||||||
|
|
||||||
if self.server_args.enable_lora and obj.lora_path:
|
|
||||||
# Start tracking ongoing requests for LoRA adapters and replace the user-friendly LoRA names in
|
|
||||||
# `lora_path` with their corresponding unique LoRA IDs, as required for internal processing.
|
|
||||||
obj.lora_id = await self.lora_registry.acquire(obj.lora_path)
|
|
||||||
|
|
||||||
self._validate_one_request(obj, input_ids)
|
self._validate_one_request(obj, input_ids)
|
||||||
return self._create_tokenized_object(
|
return self._create_tokenized_object(
|
||||||
obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
|
obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
|
||||||
@@ -774,10 +773,6 @@ class TokenizerManager:
|
|||||||
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
|
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
|
||||||
logger.info(msg)
|
logger.info(msg)
|
||||||
|
|
||||||
# Mark ongoing LoRA request as finished.
|
|
||||||
if self.server_args.enable_lora and obj.lora_path:
|
|
||||||
await self.lora_registry.release(obj.lora_id)
|
|
||||||
|
|
||||||
# Check if this was an abort/error created by scheduler
|
# Check if this was an abort/error created by scheduler
|
||||||
if isinstance(out["meta_info"].get("finish_reason"), dict):
|
if isinstance(out["meta_info"].get("finish_reason"), dict):
|
||||||
finish_reason = out["meta_info"]["finish_reason"]
|
finish_reason = out["meta_info"]["finish_reason"]
|
||||||
@@ -796,6 +791,11 @@ class TokenizerManager:
|
|||||||
# Delete the key to prevent resending abort request to the scheduler and
|
# Delete the key to prevent resending abort request to the scheduler and
|
||||||
# to ensure aborted request state is cleaned up.
|
# to ensure aborted request state is cleaned up.
|
||||||
del self.rid_to_state[state.obj.rid]
|
del self.rid_to_state[state.obj.rid]
|
||||||
|
|
||||||
|
# Mark ongoing LoRA request as finished.
|
||||||
|
if self.server_args.enable_lora and state.obj.lora_path:
|
||||||
|
await self.lora_registry.release(state.obj.lora_id)
|
||||||
|
|
||||||
raise fastapi.HTTPException(
|
raise fastapi.HTTPException(
|
||||||
status_code=finish_reason["status_code"],
|
status_code=finish_reason["status_code"],
|
||||||
detail=finish_reason["message"],
|
detail=finish_reason["message"],
|
||||||
@@ -1599,6 +1599,10 @@ class TokenizerManager:
|
|||||||
meta_info["e2e_latency"] = state.finished_time - state.created_time
|
meta_info["e2e_latency"] = state.finished_time - state.created_time
|
||||||
del self.rid_to_state[rid]
|
del self.rid_to_state[rid]
|
||||||
|
|
||||||
|
# Mark ongoing LoRA request as finished.
|
||||||
|
if self.server_args.enable_lora and state.obj.lora_path:
|
||||||
|
asyncio.create_task(self.lora_registry.release(state.obj.lora_id))
|
||||||
|
|
||||||
state.out_list.append(out_dict)
|
state.out_list.append(out_dict)
|
||||||
state.event.set()
|
state.event.set()
|
||||||
|
|
||||||
|
|||||||
@@ -2960,7 +2960,7 @@ class ConcurrentCounter:
|
|||||||
This suspends the calling coroutine without blocking the thread, allowing
|
This suspends the calling coroutine without blocking the thread, allowing
|
||||||
other tasks to run while waiting. When the counter becomes zero, the coroutine resumes.
|
other tasks to run while waiting. When the counter becomes zero, the coroutine resumes.
|
||||||
"""
|
"""
|
||||||
self.wait_for(lambda count: count == 0)
|
await self.wait_for(lambda count: count == 0)
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
@lru_cache(maxsize=1)
|
||||||
|
|||||||
Reference in New Issue
Block a user