Support dynamic LoRA loading / unloading in engine/server API (#7446)

This commit is contained in:
Lifu Huang
2025-06-27 21:00:27 -07:00
committed by GitHub
parent cfe2edac38
commit 49538d111b
14 changed files with 949 additions and 31 deletions

View File

@@ -48,6 +48,14 @@ class EngineBase(ABC):
"""Update model weights with in-memory tensor data."""
pass
def load_lora_adapter(self, lora_name: str, lora_path: str):
"""Load a new LoRA adapter without re-launching the engine."""
pass
def unload_lora_adapter(self, lora_name: str):
"""Unload a LoRA adapter without re-launching the engine."""
pass
@abstractmethod
def release_memory_occupation(self):
"""Release GPU memory occupation temporarily."""

View File

@@ -48,10 +48,12 @@ from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput,
ImageDataItem,
InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput,
ReleaseMemoryOccupationReqInput,
ResumeMemoryOccupationReqInput,
RpcReqInput,
RpcReqOutput,
UnloadLoRAAdapterReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
@@ -478,6 +480,29 @@ class Engine(EngineBase):
self.tokenizer_manager.get_weights_by_name(obj, None)
)
def load_lora_adapter(self, lora_name: str, lora_path: str):
"""Load a new LoRA adapter without re-launching the engine."""
obj = LoadLoRAAdapterReqInput(
lora_name=lora_name,
lora_path=lora_path,
)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
self.tokenizer_manager.load_lora_adapter(obj, None)
)
def unload_lora_adapter(self, lora_name: str):
"""Unload a LoRA adapter without re-launching the engine."""
obj = UnloadLoRAAdapterReqInput(lora_name=lora_name)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
self.tokenizer_manager.unload_lora_adapter(obj, None)
)
def release_memory_occupation(self, tags: Optional[List[str]] = None):
obj = ReleaseMemoryOccupationReqInput(tags=tags)
loop = asyncio.get_event_loop()

View File

@@ -72,6 +72,7 @@ from sglang.srt.managers.io_struct import (
GenerateReqInput,
GetWeightsByNameReqInput,
InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput,
OpenSessionReqInput,
ParseFunctionCallReq,
ProfileReqInput,
@@ -80,6 +81,7 @@ from sglang.srt.managers.io_struct import (
SeparateReasoningReqInput,
SetInternalStateReq,
SlowDownReqInput,
UnloadLoRAAdapterReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
@@ -595,6 +597,40 @@ async def slow_down(obj: SlowDownReqInput, request: Request):
return _create_error_response(e)
@app.api_route("/load_lora_adapter", methods=["POST"])
async def load_lora_adapter(obj: LoadLoRAAdapterReqInput, request: Request):
"""Load a new LoRA adapter without re-launching the server."""
result = await _global_state.tokenizer_manager.load_lora_adapter(obj, request)
if result.success:
return ORJSONResponse(
result,
status_code=HTTPStatus.OK,
)
else:
return ORJSONResponse(
result,
status_code=HTTPStatus.BAD_REQUEST,
)
@app.api_route("/unload_lora_adapter", methods=["POST"])
async def unload_lora_adapter(obj: UnloadLoRAAdapterReqInput, request: Request):
"""Load a new LoRA adapter without re-launching the server."""
result = await _global_state.tokenizer_manager.unload_lora_adapter(obj, request)
if result.success:
return ORJSONResponse(
result,
status_code=HTTPStatus.OK,
)
else:
return ORJSONResponse(
result,
status_code=HTTPStatus.BAD_REQUEST,
)
@app.api_route("/open_session", methods=["GET", "POST"])
async def open_session(obj: OpenSessionReqInput, request: Request):
"""Open a session, and return its unique session id."""