Support overlapped lora updates (#8213)

This commit is contained in:
Lifu Huang
2025-07-27 13:00:44 -07:00
committed by GitHub
parent 95217a9b4d
commit df90645525
4 changed files with 204 additions and 35 deletions

View File

@@ -282,6 +282,11 @@ class TokenizerManager:
None
)
# Lock to serialize LoRA update operations.
# Please note that, unlike `model_update_lock`, this does not block inference, allowing
# LoRA updates and inference to overlap.
self.lora_update_lock = asyncio.Lock()
# For pd disaggregtion
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
@@ -537,7 +542,8 @@ class TokenizerManager:
mm_inputs = None
if self.server_args.enable_lora and obj.lora_path:
# Replace the user-friendly LoRA names in `lora_path` with their corresponding unique LoRA IDs.
# Start tracking ongoing requests for LoRA adapters and replace the user-friendly LoRA names in
# `lora_path` with their corresponding unique LoRA IDs, as required for internal processing.
obj.lora_path = await self.lora_registry.acquire(obj.lora_path)
self._validate_one_request(obj, input_ids)
@@ -747,6 +753,10 @@ class TokenizerManager:
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
logger.info(msg)
# Mark ongoing LoRA request as finished.
if self.server_args.enable_lora and obj.lora_path:
await self.lora_registry.release(obj.lora_path)
# Check if this was an abort/error created by scheduler
if isinstance(out["meta_info"].get("finish_reason"), dict):
finish_reason = out["meta_info"]["finish_reason"]
@@ -1053,16 +1063,18 @@ class TokenizerManager:
obj.lora_path,
)
async with self.model_update_lock.writer_lock:
async with self.lora_update_lock:
# Generate new uniquely identifiable LoRARef object.
new_adapter = LoRARef(
lora_name=obj.lora_name,
lora_path=obj.lora_path,
)
# Register the new adapter in the registry.
# Trigger the actual loading operation at the backend processes.
obj.lora_id = new_adapter.lora_id
result = (await self.update_lora_adapter_communicator(obj))[0]
# Register the LoRA adapter only after loading is successful.
if result.success:
await self.lora_registry.register(new_adapter)
@@ -1093,8 +1105,15 @@ class TokenizerManager:
obj.lora_name,
)
async with self.model_update_lock.writer_lock:
obj.lora_id = await self.lora_registry.unregister(obj.lora_name)
async with self.lora_update_lock:
# Unregister the LoRA adapter from the registry to stop new requests for this adapter
# from being started.
lora_id = await self.lora_registry.unregister(obj.lora_name)
obj.lora_id = lora_id
# Initiate the actual unloading operation at the backend processes only after all
# ongoing requests using this LoRA adapter are finished.
await self.lora_registry.wait_for_unload(lora_id)
result = (await self.update_lora_adapter_communicator(obj))[0]
return result