Support overlapped lora updates (#8213)
This commit is contained in:
@@ -282,6 +282,11 @@ class TokenizerManager:
|
||||
None
|
||||
)
|
||||
|
||||
# Lock to serialize LoRA update operations.
|
||||
# Please note that, unlike `model_update_lock`, this does not block inference, allowing
|
||||
# LoRA updates and inference to overlap.
|
||||
self.lora_update_lock = asyncio.Lock()
|
||||
|
||||
# For pd disaggregtion
|
||||
self.disaggregation_mode = DisaggregationMode(
|
||||
self.server_args.disaggregation_mode
|
||||
@@ -537,7 +542,8 @@ class TokenizerManager:
|
||||
mm_inputs = None
|
||||
|
||||
if self.server_args.enable_lora and obj.lora_path:
|
||||
# Replace the user-friendly LoRA names in `lora_path` with their corresponding unique LoRA IDs.
|
||||
# Start tracking ongoing requests for LoRA adapters and replace the user-friendly LoRA names in
|
||||
# `lora_path` with their corresponding unique LoRA IDs, as required for internal processing.
|
||||
obj.lora_path = await self.lora_registry.acquire(obj.lora_path)
|
||||
|
||||
self._validate_one_request(obj, input_ids)
|
||||
@@ -747,6 +753,10 @@ class TokenizerManager:
|
||||
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
|
||||
logger.info(msg)
|
||||
|
||||
# Mark ongoing LoRA request as finished.
|
||||
if self.server_args.enable_lora and obj.lora_path:
|
||||
await self.lora_registry.release(obj.lora_path)
|
||||
|
||||
# Check if this was an abort/error created by scheduler
|
||||
if isinstance(out["meta_info"].get("finish_reason"), dict):
|
||||
finish_reason = out["meta_info"]["finish_reason"]
|
||||
@@ -1053,16 +1063,18 @@ class TokenizerManager:
|
||||
obj.lora_path,
|
||||
)
|
||||
|
||||
async with self.model_update_lock.writer_lock:
|
||||
async with self.lora_update_lock:
|
||||
# Generate new uniquely identifiable LoRARef object.
|
||||
new_adapter = LoRARef(
|
||||
lora_name=obj.lora_name,
|
||||
lora_path=obj.lora_path,
|
||||
)
|
||||
|
||||
# Register the new adapter in the registry.
|
||||
# Trigger the actual loading operation at the backend processes.
|
||||
obj.lora_id = new_adapter.lora_id
|
||||
result = (await self.update_lora_adapter_communicator(obj))[0]
|
||||
|
||||
# Register the LoRA adapter only after loading is successful.
|
||||
if result.success:
|
||||
await self.lora_registry.register(new_adapter)
|
||||
|
||||
@@ -1093,8 +1105,15 @@ class TokenizerManager:
|
||||
obj.lora_name,
|
||||
)
|
||||
|
||||
async with self.model_update_lock.writer_lock:
|
||||
obj.lora_id = await self.lora_registry.unregister(obj.lora_name)
|
||||
async with self.lora_update_lock:
|
||||
# Unregister the LoRA adapter from the registry to stop new requests for this adapter
|
||||
# from being started.
|
||||
lora_id = await self.lora_registry.unregister(obj.lora_name)
|
||||
obj.lora_id = lora_id
|
||||
|
||||
# Initiate the actual unloading operation at the backend processes only after all
|
||||
# ongoing requests using this LoRA adapter are finished.
|
||||
await self.lora_registry.wait_for_unload(lora_id)
|
||||
result = (await self.update_lora_adapter_communicator(obj))[0]
|
||||
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user