[RL] Add destroy process group api (#9979)
This commit is contained in:
@@ -47,6 +47,7 @@ from sglang.srt.managers.data_parallel_controller import (
|
||||
)
|
||||
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
|
||||
from sglang.srt.managers.io_struct import (
|
||||
DestroyWeightsUpdateGroupReqInput,
|
||||
EmbeddingReqInput,
|
||||
GenerateReqInput,
|
||||
GetWeightsByNameReqInput,
|
||||
@@ -433,6 +434,19 @@ class Engine(EngineBase):
|
||||
self.tokenizer_manager.init_weights_update_group(obj, None)
|
||||
)
|
||||
|
||||
def destroy_weights_update_group(
|
||||
self,
|
||||
group_name: str,
|
||||
):
|
||||
"""Destroy parameter update group."""
|
||||
obj = DestroyWeightsUpdateGroupReqInput(
|
||||
group_name=group_name,
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(
|
||||
self.tokenizer_manager.destroy_weights_update_group(obj, None)
|
||||
)
|
||||
|
||||
def update_weights_from_distributed(
|
||||
self,
|
||||
names: list[str],
|
||||
|
||||
@@ -70,6 +70,7 @@ from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
CloseSessionReqInput,
|
||||
ConfigureLoggingReq,
|
||||
DestroyWeightsUpdateGroupReqInput,
|
||||
EmbeddingReqInput,
|
||||
GenerateReqInput,
|
||||
GetWeightsByNameReqInput,
|
||||
@@ -729,6 +730,20 @@ async def init_weights_update_group(
|
||||
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
|
||||
@app.post("/destroy_weights_update_group")
|
||||
async def destroy_weights_update_group(
|
||||
obj: DestroyWeightsUpdateGroupReqInput, request: Request
|
||||
):
|
||||
"""Destroy the parameter update group."""
|
||||
success, message = (
|
||||
await _global_state.tokenizer_manager.destroy_weights_update_group(obj, request)
|
||||
)
|
||||
content = {"success": success, "message": message}
|
||||
return ORJSONResponse(
|
||||
content, status_code=200 if success else HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
|
||||
|
||||
@app.post("/update_weights_from_tensor")
|
||||
async def update_weights_from_tensor(
|
||||
obj: UpdateWeightsFromTensorReqInput, request: Request
|
||||
|
||||
Reference in New Issue
Block a user