Support dynamic LoRA loading / unloading in engine/server API (#7446)

This commit is contained in:
Lifu Huang
2025-06-27 21:00:27 -07:00
committed by GitHub
parent cfe2edac38
commit 49538d111b
14 changed files with 949 additions and 31 deletions

View File

@@ -20,7 +20,7 @@ import copy
import uuid
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union
from sglang.srt.multimodal.mm_utils import has_valid_data
@@ -1002,3 +1002,27 @@ class RpcReqInput:
class RpcReqOutput:
success: bool
message: str
@dataclass
class LoadLoRAAdapterReqInput:
# The name of the lora module to newly loaded.
lora_name: str
# The path of loading.
lora_path: str
@dataclass
class UnloadLoRAAdapterReqInput:
# The name of lora module to unload.
lora_name: str
@dataclass
class LoRAUpdateResult:
success: bool
error_message: Optional[str] = None
loaded_adapters: Dict[str, str] = field(default_factory=dict)
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult

View File

@@ -82,6 +82,8 @@ from sglang.srt.managers.io_struct import (
HealthCheckOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
LoadLoRAAdapterReqInput,
LoadLoRAAdapterReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
@@ -99,6 +101,8 @@ from sglang.srt.managers.io_struct import (
SlowDownReqOutput,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
UnloadLoRAAdapterReqInput,
UnloadLoRAAdapterReqOutput,
UpdateWeightFromDiskReqInput,
UpdateWeightFromDiskReqOutput,
UpdateWeightsFromDistributedReqInput,
@@ -519,6 +523,8 @@ class Scheduler(
(SetInternalStateReq, self.set_internal_state),
(RpcReqInput, self.handle_rpc_request),
(ExpertDistributionReq, self.expert_distribution_handle),
(LoadLoRAAdapterReqInput, self.load_lora_adapter),
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
]
)
@@ -2241,6 +2247,36 @@ class Scheduler(
logger.error(message)
return UpdateWeightFromDiskReqOutput(success, message, 0)
def load_lora_adapter(
self, recv_req: LoadLoRAAdapterReqInput
) -> LoadLoRAAdapterReqOutput:
"""In-place loading a new lora adapter from disk or huggingface."""
result = self.tp_worker.load_lora_adapter(recv_req)
if result.success:
flush_cache_success = self.flush_cache()
assert flush_cache_success, "Cache flush failed after loading lora adapter."
else:
logger.error(result.error_message)
return result
def unload_lora_adapter(
self, recv_req: UnloadLoRAAdapterReqInput
) -> UnloadLoRAAdapterReqOutput:
"""Unload the lora adapter."""
result = self.tp_worker.unload_lora_adapter(recv_req)
if result.success:
flush_cache_success = self.flush_cache()
assert (
flush_cache_success
), "Cache flush failed after unloading LoRA weights"
else:
logger.error(result.error_message)
return result
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
"""Initialize the online model parameter update group."""
success, message = self.tp_worker.init_weights_update_group(recv_req)

View File

@@ -83,6 +83,9 @@ from sglang.srt.managers.io_struct import (
HealthCheckOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
LoadLoRAAdapterReqInput,
LoadLoRAAdapterReqOutput,
LoRAUpdateResult,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
@@ -99,6 +102,8 @@ from sglang.srt.managers.io_struct import (
SlowDownReqOutput,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
UnloadLoRAAdapterReqInput,
UnloadLoRAAdapterReqOutput,
UpdateWeightFromDiskReqInput,
UpdateWeightFromDiskReqOutput,
UpdateWeightsFromDistributedReqInput,
@@ -311,6 +316,9 @@ class TokenizerManager:
self.expert_distribution_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.update_lora_adapter_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self._result_dispatcher = TypeBasedDispatcher(
[
@@ -377,6 +385,10 @@ class TokenizerManager:
ExpertDistributionReqOutput,
self.expert_distribution_communicator.handle_recv,
),
(
LoRAUpdateResult,
self.update_lora_adapter_communicator.handle_recv,
),
(HealthCheckOutput, lambda x: None),
]
)
@@ -960,6 +972,49 @@ class TokenizerManager:
result = (await self.update_weights_from_tensor_communicator(obj))[0]
return result.success, result.message
async def load_lora_adapter(
self,
obj: LoadLoRAAdapterReqInput,
_: Optional[fastapi.Request] = None,
) -> LoadLoRAAdapterReqOutput:
self.auto_create_handle_loop()
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
# with dp_size > 1.
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for dynamic lora loading"
logger.info(
"Start load Lora adapter. Lora name=%s, path=%s",
obj.lora_name,
obj.lora_path,
)
async with self.model_update_lock.writer_lock:
result = (await self.update_lora_adapter_communicator(obj))[0]
return result
async def unload_lora_adapter(
self,
obj: UnloadLoRAAdapterReqInput,
_: Optional[fastapi.Request] = None,
) -> UnloadLoRAAdapterReqOutput:
self.auto_create_handle_loop()
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
# with dp_size > 1.
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for dynamic lora loading"
logger.info(
"Start unload Lora adapter. Lora name=%s",
obj.lora_name,
)
async with self.model_update_lock.writer_lock:
result = (await self.update_lora_adapter_communicator(obj))[0]
return result
async def get_weights_by_name(
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
):

View File

@@ -30,6 +30,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput,
InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput,
UnloadLoRAAdapterReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
@@ -275,3 +277,13 @@ class TpModelWorker:
recv_req.name, recv_req.truncate_size
)
return parameter
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
result = self.model_runner.load_lora_adapter(
recv_req.lora_name, recv_req.lora_path
)
return result
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
result = self.model_runner.unload_lora_adapter(recv_req.lora_name)
return result

View File

@@ -26,6 +26,8 @@ import torch
from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput,
InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput,
UnloadLoRAAdapterReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
@@ -268,6 +270,12 @@ class TpModelWorkerClient:
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
return self.worker.get_weights_by_name(recv_req)
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
return self.worker.load_lora_adapter(recv_req)
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
return self.worker.unload_lora_adapter(recv_req)
def __delete__(self):
self.input_queue.put((None, None))
self.copy_queue.put((None, None, None))