Support dynamic LoRA loading / unloading in engine/server API (#7446)
This commit is contained in:
@@ -48,6 +48,14 @@ class EngineBase(ABC):
|
||||
"""Update model weights with in-memory tensor data."""
|
||||
pass
|
||||
|
||||
def load_lora_adapter(self, lora_name: str, lora_path: str):
|
||||
"""Load a new LoRA adapter without re-launching the engine."""
|
||||
pass
|
||||
|
||||
def unload_lora_adapter(self, lora_name: str):
|
||||
"""Unload a LoRA adapter without re-launching the engine."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def release_memory_occupation(self):
|
||||
"""Release GPU memory occupation temporarily."""
|
||||
|
||||
@@ -48,10 +48,12 @@ from sglang.srt.managers.io_struct import (
|
||||
GetWeightsByNameReqInput,
|
||||
ImageDataItem,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
LoadLoRAAdapterReqInput,
|
||||
ReleaseMemoryOccupationReqInput,
|
||||
ResumeMemoryOccupationReqInput,
|
||||
RpcReqInput,
|
||||
RpcReqOutput,
|
||||
UnloadLoRAAdapterReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
@@ -478,6 +480,29 @@ class Engine(EngineBase):
|
||||
self.tokenizer_manager.get_weights_by_name(obj, None)
|
||||
)
|
||||
|
||||
def load_lora_adapter(self, lora_name: str, lora_path: str):
|
||||
"""Load a new LoRA adapter without re-launching the engine."""
|
||||
|
||||
obj = LoadLoRAAdapterReqInput(
|
||||
lora_name=lora_name,
|
||||
lora_path=lora_path,
|
||||
)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(
|
||||
self.tokenizer_manager.load_lora_adapter(obj, None)
|
||||
)
|
||||
|
||||
def unload_lora_adapter(self, lora_name: str):
|
||||
"""Unload a LoRA adapter without re-launching the engine."""
|
||||
|
||||
obj = UnloadLoRAAdapterReqInput(lora_name=lora_name)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(
|
||||
self.tokenizer_manager.unload_lora_adapter(obj, None)
|
||||
)
|
||||
|
||||
def release_memory_occupation(self, tags: Optional[List[str]] = None):
|
||||
obj = ReleaseMemoryOccupationReqInput(tags=tags)
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
@@ -72,6 +72,7 @@ from sglang.srt.managers.io_struct import (
|
||||
GenerateReqInput,
|
||||
GetWeightsByNameReqInput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
LoadLoRAAdapterReqInput,
|
||||
OpenSessionReqInput,
|
||||
ParseFunctionCallReq,
|
||||
ProfileReqInput,
|
||||
@@ -80,6 +81,7 @@ from sglang.srt.managers.io_struct import (
|
||||
SeparateReasoningReqInput,
|
||||
SetInternalStateReq,
|
||||
SlowDownReqInput,
|
||||
UnloadLoRAAdapterReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
@@ -595,6 +597,40 @@ async def slow_down(obj: SlowDownReqInput, request: Request):
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
@app.api_route("/load_lora_adapter", methods=["POST"])
|
||||
async def load_lora_adapter(obj: LoadLoRAAdapterReqInput, request: Request):
|
||||
"""Load a new LoRA adapter without re-launching the server."""
|
||||
result = await _global_state.tokenizer_manager.load_lora_adapter(obj, request)
|
||||
|
||||
if result.success:
|
||||
return ORJSONResponse(
|
||||
result,
|
||||
status_code=HTTPStatus.OK,
|
||||
)
|
||||
else:
|
||||
return ORJSONResponse(
|
||||
result,
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
@app.api_route("/unload_lora_adapter", methods=["POST"])
|
||||
async def unload_lora_adapter(obj: UnloadLoRAAdapterReqInput, request: Request):
|
||||
"""Load a new LoRA adapter without re-launching the server."""
|
||||
result = await _global_state.tokenizer_manager.unload_lora_adapter(obj, request)
|
||||
|
||||
if result.success:
|
||||
return ORJSONResponse(
|
||||
result,
|
||||
status_code=HTTPStatus.OK,
|
||||
)
|
||||
else:
|
||||
return ORJSONResponse(
|
||||
result,
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
@app.api_route("/open_session", methods=["GET", "POST"])
|
||||
async def open_session(obj: OpenSessionReqInput, request: Request):
|
||||
"""Open a session, and return its unique session id."""
|
||||
|
||||
Reference in New Issue
Block a user