Support dynamic LoRA loading / unloading in engine/server API (#7446)
This commit is contained in:
@@ -20,7 +20,7 @@ import copy
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union
|
||||
|
||||
from sglang.srt.multimodal.mm_utils import has_valid_data
|
||||
|
||||
@@ -1002,3 +1002,27 @@ class RpcReqInput:
|
||||
class RpcReqOutput:
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadLoRAAdapterReqInput:
|
||||
# The name of the lora module to newly loaded.
|
||||
lora_name: str
|
||||
# The path of loading.
|
||||
lora_path: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnloadLoRAAdapterReqInput:
|
||||
# The name of lora module to unload.
|
||||
lora_name: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAUpdateResult:
|
||||
success: bool
|
||||
error_message: Optional[str] = None
|
||||
loaded_adapters: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
|
||||
|
||||
@@ -82,6 +82,8 @@ from sglang.srt.managers.io_struct import (
|
||||
HealthCheckOutput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
InitWeightsUpdateGroupReqOutput,
|
||||
LoadLoRAAdapterReqInput,
|
||||
LoadLoRAAdapterReqOutput,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
@@ -99,6 +101,8 @@ from sglang.srt.managers.io_struct import (
|
||||
SlowDownReqOutput,
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
UnloadLoRAAdapterReqInput,
|
||||
UnloadLoRAAdapterReqOutput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightFromDiskReqOutput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
@@ -519,6 +523,8 @@ class Scheduler(
|
||||
(SetInternalStateReq, self.set_internal_state),
|
||||
(RpcReqInput, self.handle_rpc_request),
|
||||
(ExpertDistributionReq, self.expert_distribution_handle),
|
||||
(LoadLoRAAdapterReqInput, self.load_lora_adapter),
|
||||
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -2241,6 +2247,36 @@ class Scheduler(
|
||||
logger.error(message)
|
||||
return UpdateWeightFromDiskReqOutput(success, message, 0)
|
||||
|
||||
def load_lora_adapter(
|
||||
self, recv_req: LoadLoRAAdapterReqInput
|
||||
) -> LoadLoRAAdapterReqOutput:
|
||||
"""In-place loading a new lora adapter from disk or huggingface."""
|
||||
|
||||
result = self.tp_worker.load_lora_adapter(recv_req)
|
||||
|
||||
if result.success:
|
||||
flush_cache_success = self.flush_cache()
|
||||
assert flush_cache_success, "Cache flush failed after loading lora adapter."
|
||||
else:
|
||||
logger.error(result.error_message)
|
||||
return result
|
||||
|
||||
def unload_lora_adapter(
|
||||
self, recv_req: UnloadLoRAAdapterReqInput
|
||||
) -> UnloadLoRAAdapterReqOutput:
|
||||
"""Unload the lora adapter."""
|
||||
|
||||
result = self.tp_worker.unload_lora_adapter(recv_req)
|
||||
|
||||
if result.success:
|
||||
flush_cache_success = self.flush_cache()
|
||||
assert (
|
||||
flush_cache_success
|
||||
), "Cache flush failed after unloading LoRA weights"
|
||||
else:
|
||||
logger.error(result.error_message)
|
||||
return result
|
||||
|
||||
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
||||
"""Initialize the online model parameter update group."""
|
||||
success, message = self.tp_worker.init_weights_update_group(recv_req)
|
||||
|
||||
@@ -83,6 +83,9 @@ from sglang.srt.managers.io_struct import (
|
||||
HealthCheckOutput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
InitWeightsUpdateGroupReqOutput,
|
||||
LoadLoRAAdapterReqInput,
|
||||
LoadLoRAAdapterReqOutput,
|
||||
LoRAUpdateResult,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
@@ -99,6 +102,8 @@ from sglang.srt.managers.io_struct import (
|
||||
SlowDownReqOutput,
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
UnloadLoRAAdapterReqInput,
|
||||
UnloadLoRAAdapterReqOutput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightFromDiskReqOutput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
@@ -311,6 +316,9 @@ class TokenizerManager:
|
||||
self.expert_distribution_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.update_lora_adapter_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
|
||||
self._result_dispatcher = TypeBasedDispatcher(
|
||||
[
|
||||
@@ -377,6 +385,10 @@ class TokenizerManager:
|
||||
ExpertDistributionReqOutput,
|
||||
self.expert_distribution_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
LoRAUpdateResult,
|
||||
self.update_lora_adapter_communicator.handle_recv,
|
||||
),
|
||||
(HealthCheckOutput, lambda x: None),
|
||||
]
|
||||
)
|
||||
@@ -960,6 +972,49 @@ class TokenizerManager:
|
||||
result = (await self.update_weights_from_tensor_communicator(obj))[0]
|
||||
return result.success, result.message
|
||||
|
||||
async def load_lora_adapter(
|
||||
self,
|
||||
obj: LoadLoRAAdapterReqInput,
|
||||
_: Optional[fastapi.Request] = None,
|
||||
) -> LoadLoRAAdapterReqOutput:
|
||||
self.auto_create_handle_loop()
|
||||
|
||||
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
||||
# with dp_size > 1.
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be 1 for dynamic lora loading"
|
||||
logger.info(
|
||||
"Start load Lora adapter. Lora name=%s, path=%s",
|
||||
obj.lora_name,
|
||||
obj.lora_path,
|
||||
)
|
||||
|
||||
async with self.model_update_lock.writer_lock:
|
||||
result = (await self.update_lora_adapter_communicator(obj))[0]
|
||||
return result
|
||||
|
||||
async def unload_lora_adapter(
|
||||
self,
|
||||
obj: UnloadLoRAAdapterReqInput,
|
||||
_: Optional[fastapi.Request] = None,
|
||||
) -> UnloadLoRAAdapterReqOutput:
|
||||
self.auto_create_handle_loop()
|
||||
|
||||
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
||||
# with dp_size > 1.
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be 1 for dynamic lora loading"
|
||||
logger.info(
|
||||
"Start unload Lora adapter. Lora name=%s",
|
||||
obj.lora_name,
|
||||
)
|
||||
|
||||
async with self.model_update_lock.writer_lock:
|
||||
result = (await self.update_lora_adapter_communicator(obj))[0]
|
||||
return result
|
||||
|
||||
async def get_weights_by_name(
|
||||
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
||||
):
|
||||
|
||||
@@ -30,6 +30,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.io_struct import (
|
||||
GetWeightsByNameReqInput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
LoadLoRAAdapterReqInput,
|
||||
UnloadLoRAAdapterReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
@@ -275,3 +277,13 @@ class TpModelWorker:
|
||||
recv_req.name, recv_req.truncate_size
|
||||
)
|
||||
return parameter
|
||||
|
||||
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
|
||||
result = self.model_runner.load_lora_adapter(
|
||||
recv_req.lora_name, recv_req.lora_path
|
||||
)
|
||||
return result
|
||||
|
||||
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
|
||||
result = self.model_runner.unload_lora_adapter(recv_req.lora_name)
|
||||
return result
|
||||
|
||||
@@ -26,6 +26,8 @@ import torch
|
||||
from sglang.srt.managers.io_struct import (
|
||||
GetWeightsByNameReqInput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
LoadLoRAAdapterReqInput,
|
||||
UnloadLoRAAdapterReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
@@ -268,6 +270,12 @@ class TpModelWorkerClient:
|
||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
||||
return self.worker.get_weights_by_name(recv_req)
|
||||
|
||||
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
|
||||
return self.worker.load_lora_adapter(recv_req)
|
||||
|
||||
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
|
||||
return self.worker.unload_lora_adapter(recv_req)
|
||||
|
||||
def __delete__(self):
|
||||
self.input_queue.put((None, None))
|
||||
self.copy_queue.put((None, None, None))
|
||||
|
||||
Reference in New Issue
Block a user