Support loading weights from remote instance (#8215)
Signed-off-by: Anqi Shen <amy.saq@antgroup.com> Co-authored-by: Chayenne <74843776+zhaochenyang20@users.noreply.github.com>
This commit is contained in:
@@ -73,6 +73,7 @@ from sglang.srt.managers.io_struct import (
|
||||
EmbeddingReqInput,
|
||||
GenerateReqInput,
|
||||
GetWeightsByNameReqInput,
|
||||
InitWeightsSendGroupForRemoteInstanceReqInput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
LoadLoRAAdapterReqInput,
|
||||
OpenSessionReqInput,
|
||||
@@ -80,6 +81,7 @@ from sglang.srt.managers.io_struct import (
|
||||
ProfileReqInput,
|
||||
ReleaseMemoryOccupationReqInput,
|
||||
ResumeMemoryOccupationReqInput,
|
||||
SendWeightsToRemoteInstanceReqInput,
|
||||
SeparateReasoningReqInput,
|
||||
SetInternalStateReq,
|
||||
SlowDownReqInput,
|
||||
@@ -670,6 +672,38 @@ async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: R
|
||||
)
|
||||
|
||||
|
||||
@app.post("/init_weights_send_group_for_remote_instance")
|
||||
async def init_weights_send_group_for_remote_instance(
|
||||
obj: InitWeightsSendGroupForRemoteInstanceReqInput, request: Request
|
||||
):
|
||||
success, message = (
|
||||
await _global_state.tokenizer_manager.init_weights_send_group_for_remote_instance(
|
||||
obj, request
|
||||
)
|
||||
)
|
||||
content = {"success": success, "message": message}
|
||||
if success:
|
||||
return ORJSONResponse(content, status_code=200)
|
||||
else:
|
||||
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
|
||||
@app.post("/send_weights_to_remote_instance")
|
||||
async def send_weights_to_remote_instance(
|
||||
obj: SendWeightsToRemoteInstanceReqInput, request: Request
|
||||
):
|
||||
success, message = (
|
||||
await _global_state.tokenizer_manager.send_weights_to_remote_instance(
|
||||
obj, request
|
||||
)
|
||||
)
|
||||
content = {"success": success, "message": message}
|
||||
if success:
|
||||
return ORJSONResponse(content, status_code=200)
|
||||
else:
|
||||
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
|
||||
@app.post("/init_weights_update_group")
|
||||
async def init_weights_update_group(
|
||||
obj: InitWeightsUpdateGroupReqInput, request: Request
|
||||
|
||||
Reference in New Issue
Block a user