Support dynamic LoRA loading / unloading in engine/server API (#7446)
This commit is contained in:
@@ -82,6 +82,8 @@ from sglang.srt.managers.io_struct import (
|
||||
HealthCheckOutput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
InitWeightsUpdateGroupReqOutput,
|
||||
LoadLoRAAdapterReqInput,
|
||||
LoadLoRAAdapterReqOutput,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
@@ -99,6 +101,8 @@ from sglang.srt.managers.io_struct import (
|
||||
SlowDownReqOutput,
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
UnloadLoRAAdapterReqInput,
|
||||
UnloadLoRAAdapterReqOutput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightFromDiskReqOutput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
@@ -519,6 +523,8 @@ class Scheduler(
|
||||
(SetInternalStateReq, self.set_internal_state),
|
||||
(RpcReqInput, self.handle_rpc_request),
|
||||
(ExpertDistributionReq, self.expert_distribution_handle),
|
||||
(LoadLoRAAdapterReqInput, self.load_lora_adapter),
|
||||
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -2241,6 +2247,36 @@ class Scheduler(
|
||||
logger.error(message)
|
||||
return UpdateWeightFromDiskReqOutput(success, message, 0)
|
||||
|
||||
def load_lora_adapter(
|
||||
self, recv_req: LoadLoRAAdapterReqInput
|
||||
) -> LoadLoRAAdapterReqOutput:
|
||||
"""In-place loading a new lora adapter from disk or huggingface."""
|
||||
|
||||
result = self.tp_worker.load_lora_adapter(recv_req)
|
||||
|
||||
if result.success:
|
||||
flush_cache_success = self.flush_cache()
|
||||
assert flush_cache_success, "Cache flush failed after loading lora adapter."
|
||||
else:
|
||||
logger.error(result.error_message)
|
||||
return result
|
||||
|
||||
def unload_lora_adapter(
|
||||
self, recv_req: UnloadLoRAAdapterReqInput
|
||||
) -> UnloadLoRAAdapterReqOutput:
|
||||
"""Unload the lora adapter."""
|
||||
|
||||
result = self.tp_worker.unload_lora_adapter(recv_req)
|
||||
|
||||
if result.success:
|
||||
flush_cache_success = self.flush_cache()
|
||||
assert (
|
||||
flush_cache_success
|
||||
), "Cache flush failed after unloading LoRA weights"
|
||||
else:
|
||||
logger.error(result.error_message)
|
||||
return result
|
||||
|
||||
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
||||
"""Initialize the online model parameter update group."""
|
||||
success, message = self.tp_worker.init_weights_update_group(recv_req)
|
||||
|
||||
Reference in New Issue
Block a user