Support dynamic LoRA loading / unloading in engine/server API (#7446)

This commit is contained in:
Lifu Huang
2025-06-27 21:00:27 -07:00
committed by GitHub
parent cfe2edac38
commit 49538d111b
14 changed files with 949 additions and 31 deletions

View File

@@ -82,6 +82,8 @@ from sglang.srt.managers.io_struct import (
HealthCheckOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
LoadLoRAAdapterReqInput,
LoadLoRAAdapterReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
@@ -99,6 +101,8 @@ from sglang.srt.managers.io_struct import (
SlowDownReqOutput,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
UnloadLoRAAdapterReqInput,
UnloadLoRAAdapterReqOutput,
UpdateWeightFromDiskReqInput,
UpdateWeightFromDiskReqOutput,
UpdateWeightsFromDistributedReqInput,
@@ -519,6 +523,8 @@ class Scheduler(
(SetInternalStateReq, self.set_internal_state),
(RpcReqInput, self.handle_rpc_request),
(ExpertDistributionReq, self.expert_distribution_handle),
(LoadLoRAAdapterReqInput, self.load_lora_adapter),
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
]
)
@@ -2241,6 +2247,36 @@ class Scheduler(
logger.error(message)
return UpdateWeightFromDiskReqOutput(success, message, 0)
def load_lora_adapter(
self, recv_req: LoadLoRAAdapterReqInput
) -> LoadLoRAAdapterReqOutput:
"""In-place loading a new lora adapter from disk or huggingface."""
result = self.tp_worker.load_lora_adapter(recv_req)
if result.success:
flush_cache_success = self.flush_cache()
assert flush_cache_success, "Cache flush failed after loading lora adapter."
else:
logger.error(result.error_message)
return result
def unload_lora_adapter(
self, recv_req: UnloadLoRAAdapterReqInput
) -> UnloadLoRAAdapterReqOutput:
"""Unload the lora adapter."""
result = self.tp_worker.unload_lora_adapter(recv_req)
if result.success:
flush_cache_success = self.flush_cache()
assert (
flush_cache_success
), "Cache flush failed after unloading LoRA weights"
else:
logger.error(result.error_message)
return result
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
"""Initialize the online model parameter update group."""
success, message = self.tp_worker.init_weights_update_group(recv_req)