[RL] Add destroy process group api (#9979)

This commit is contained in:
penguin_wwy
2025-09-19 00:31:56 +08:00
committed by GitHub
parent 4039c626e2
commit 93f75778be
10 changed files with 109 additions and 0 deletions

View File

@@ -47,6 +47,7 @@ from sglang.srt.managers.data_parallel_controller import (
)
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
from sglang.srt.managers.io_struct import (
DestroyWeightsUpdateGroupReqInput,
EmbeddingReqInput,
GenerateReqInput,
GetWeightsByNameReqInput,
@@ -433,6 +434,19 @@ class Engine(EngineBase):
self.tokenizer_manager.init_weights_update_group(obj, None)
)
def destroy_weights_update_group(
self,
group_name: str,
):
"""Destroy parameter update group."""
obj = DestroyWeightsUpdateGroupReqInput(
group_name=group_name,
)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
self.tokenizer_manager.destroy_weights_update_group(obj, None)
)
def update_weights_from_distributed(
self,
names: list[str],

View File

@@ -70,6 +70,7 @@ from sglang.srt.managers.io_struct import (
AbortReq,
CloseSessionReqInput,
ConfigureLoggingReq,
DestroyWeightsUpdateGroupReqInput,
EmbeddingReqInput,
GenerateReqInput,
GetWeightsByNameReqInput,
@@ -729,6 +730,20 @@ async def init_weights_update_group(
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
@app.post("/destroy_weights_update_group")
async def destroy_weights_update_group(
obj: DestroyWeightsUpdateGroupReqInput, request: Request
):
"""Destroy the parameter update group."""
success, message = (
await _global_state.tokenizer_manager.destroy_weights_update_group(obj, request)
)
content = {"success": success, "message": message}
return ORJSONResponse(
content, status_code=200 if success else HTTPStatus.BAD_REQUEST
)
@app.post("/update_weights_from_tensor")
async def update_weights_from_tensor(
obj: UpdateWeightsFromTensorReqInput, request: Request