From 93f75778beb9a0387c7f3e0f974b39addeea8e94 Mon Sep 17 00:00:00 2001 From: penguin_wwy <940375606@qq.com> Date: Fri, 19 Sep 2025 00:31:56 +0800 Subject: [PATCH] [RL] Add destroy process group api (#9979) --- python/sglang/srt/entrypoints/engine.py | 14 +++++++++++++ python/sglang/srt/entrypoints/http_server.py | 15 +++++++++++++ python/sglang/srt/managers/io_struct.py | 11 ++++++++++ python/sglang/srt/managers/scheduler.py | 2 ++ .../scheduler_update_weights_mixin.py | 7 +++++++ .../managers/tokenizer_communicator_mixin.py | 21 +++++++++++++++++++ python/sglang/srt/managers/tp_worker.py | 7 +++++++ .../srt/managers/tp_worker_overlap_thread.py | 5 +++++ .../sglang/srt/model_executor/model_runner.py | 13 ++++++++++++ .../test_update_weights_from_distributed.py | 14 +++++++++++++ 10 files changed, 109 insertions(+) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index f86e9a751..012ff5ab7 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -47,6 +47,7 @@ from sglang.srt.managers.data_parallel_controller import ( ) from sglang.srt.managers.detokenizer_manager import run_detokenizer_process from sglang.srt.managers.io_struct import ( + DestroyWeightsUpdateGroupReqInput, EmbeddingReqInput, GenerateReqInput, GetWeightsByNameReqInput, @@ -433,6 +434,19 @@ class Engine(EngineBase): self.tokenizer_manager.init_weights_update_group(obj, None) ) + def destroy_weights_update_group( + self, + group_name: str, + ): + """Destroy parameter update group.""" + obj = DestroyWeightsUpdateGroupReqInput( + group_name=group_name, + ) + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.tokenizer_manager.destroy_weights_update_group(obj, None) + ) + def update_weights_from_distributed( self, names: list[str], diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 3be69159c..ea0d9799b 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -70,6 +70,7 @@ from sglang.srt.managers.io_struct import ( AbortReq, CloseSessionReqInput, ConfigureLoggingReq, + DestroyWeightsUpdateGroupReqInput, EmbeddingReqInput, GenerateReqInput, GetWeightsByNameReqInput, @@ -729,6 +730,20 @@ async def init_weights_update_group( return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) +@app.post("/destroy_weights_update_group") +async def destroy_weights_update_group( + obj: DestroyWeightsUpdateGroupReqInput, request: Request +): + """Destroy the parameter update group.""" + success, message = ( + await _global_state.tokenizer_manager.destroy_weights_update_group(obj, request) + ) + content = {"success": success, "message": message} + return ORJSONResponse( + content, status_code=200 if success else HTTPStatus.BAD_REQUEST + ) + + @app.post("/update_weights_from_tensor") async def update_weights_from_tensor( obj: UpdateWeightsFromTensorReqInput, request: Request diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index c479f6d54..86cfcf945 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -1094,6 +1094,17 @@ class InitWeightsUpdateGroupReqOutput: message: str +@dataclass +class DestroyWeightsUpdateGroupReqInput: + group_name: str = "weight_update_group" + + +@dataclass +class DestroyWeightsUpdateGroupReqOutput: + success: bool + message: str + + @dataclass class UpdateWeightVersionReqInput: # The new weight version diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index a246534cb..a4f47819d 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -72,6 +72,7 @@ from sglang.srt.managers.io_struct import ( ClearHiCacheReqInput, ClearHiCacheReqOutput, CloseSessionReqInput, + DestroyWeightsUpdateGroupReqInput, ExpertDistributionReq, ExpertDistributionReqOutput, FlushCacheReqInput, @@ -566,6 +567,7 @@ class Scheduler( (CloseSessionReqInput, self.close_session), (UpdateWeightFromDiskReqInput, self.update_weights_from_disk), (InitWeightsUpdateGroupReqInput, self.init_weights_update_group), + (DestroyWeightsUpdateGroupReqInput, self.destroy_weights_update_group), ( InitWeightsSendGroupForRemoteInstanceReqInput, self.init_weights_send_group_for_remote_instance, diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py index fdae2142c..fdb7acd64 100644 --- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py +++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py @@ -5,6 +5,8 @@ import torch from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS from sglang.srt.managers.io_struct import ( + DestroyWeightsUpdateGroupReqInput, + DestroyWeightsUpdateGroupReqOutput, GetWeightsByNameReqInput, GetWeightsByNameReqOutput, InitWeightsUpdateGroupReqInput, @@ -41,6 +43,11 @@ class SchedulerUpdateWeightsMixin: success, message = self.tp_worker.init_weights_update_group(recv_req) return InitWeightsUpdateGroupReqOutput(success, message) + def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput): + """Destroy the online model parameter update group.""" + success, message = self.tp_worker.destroy_weights_update_group(recv_req) + return DestroyWeightsUpdateGroupReqOutput(success, message) + def update_weights_from_distributed( self, recv_req: UpdateWeightsFromDistributedReqInput, diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py index 8970d5ad5..1c541914c 100644 --- a/python/sglang/srt/managers/tokenizer_communicator_mixin.py +++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py @@ -24,6 +24,8 @@ import zmq from sglang.srt.managers.io_struct import ( ClearHiCacheReqInput, ClearHiCacheReqOutput, + DestroyWeightsUpdateGroupReqInput, + DestroyWeightsUpdateGroupReqOutput, ExpertDistributionReq, ExpertDistributionReqOutput, FlushCacheReqInput, @@ -149,6 +151,9 @@ class TokenizerCommunicatorMixin: self.init_weights_update_group_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) + self.destroy_weights_update_group_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) self.update_weights_from_distributed_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) @@ -207,6 +212,10 @@ class TokenizerCommunicatorMixin: InitWeightsUpdateGroupReqOutput, self.init_weights_update_group_communicator.handle_recv, ), + ( + DestroyWeightsUpdateGroupReqOutput, + self.destroy_weights_update_group_communicator.handle_recv, + ), ( UpdateWeightsFromDistributedReqOutput, self.update_weights_from_distributed_communicator.handle_recv, @@ -345,6 +354,18 @@ class TokenizerCommunicatorMixin: result = (await self.init_weights_update_group_communicator(obj))[0] return result.success, result.message + async def destroy_weights_update_group( + self, + obj: DestroyWeightsUpdateGroupReqInput, + request: Optional[fastapi.Request] = None, + ) -> Tuple[bool, str]: + self.auto_create_handle_loop() + assert ( + self.server_args.dp_size == 1 + ), "dp_size must be 1 for destroy parameter update group" + result = (await self.destroy_weights_update_group_communicator(obj))[0] + return result.success, result.message + async def update_weights_from_distributed( self: TokenizerManager, obj: UpdateWeightsFromDistributedReqInput, diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 7bed87592..98bc9a16f 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -29,6 +29,7 @@ from sglang.srt.hf_transformers_utils import ( ) from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.io_struct import ( + DestroyWeightsUpdateGroupReqInput, GetWeightsByNameReqInput, InitWeightsSendGroupForRemoteInstanceReqInput, InitWeightsUpdateGroupReqInput, @@ -304,6 +305,12 @@ class TpModelWorker: ) return success, message + def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput): + success, message = self.model_runner.destroy_weights_update_group( + recv_req.group_name, + ) + return success, message + def init_weights_send_group_for_remote_instance( self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput ): diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index e34399a41..d0b5e586d 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -25,6 +25,7 @@ import psutil import torch from sglang.srt.managers.io_struct import ( + DestroyWeightsUpdateGroupReqInput, GetWeightsByNameReqInput, InitWeightsSendGroupForRemoteInstanceReqInput, InitWeightsUpdateGroupReqInput, @@ -278,6 +279,10 @@ class TpModelWorkerClient: success, message = self.worker.init_weights_update_group(recv_req) return success, message + def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput): + success, message = self.worker.destroy_weights_update_group(recv_req) + return success, message + def init_weights_send_group_for_remote_instance( self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput ): diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 7835c3fa1..8bfea5613 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1025,6 +1025,19 @@ class ModelRunner: logger.error(message) return False, message + def destroy_weights_update_group(self, group_name): + try: + if group_name in self._model_update_group: + pg = self._model_update_group.pop(group_name) + torch.distributed.destroy_process_group(pg) + return True, "Succeeded to destroy custom process group." + else: + return False, "The group to be destroyed does not exist." + except Exception as e: + message = f"Failed to destroy custom process group: {e}." + logger.error(message) + return False, message + def update_weights_from_distributed(self, names, dtypes, shapes, group_name): """ Update specific parameter in the model weights online diff --git a/test/srt/rl/test_update_weights_from_distributed.py b/test/srt/rl/test_update_weights_from_distributed.py index a3b938c38..37782c397 100644 --- a/test/srt/rl/test_update_weights_from_distributed.py +++ b/test/srt/rl/test_update_weights_from_distributed.py @@ -344,6 +344,20 @@ def init_process_sgl( ) param_queue.put((f"sgl_dp_{rank}_base_params", base_params)) + if backend == "Engine": + success, _ = engine.destroy_weights_update_group( + group_name="test_parameter_update_group", + ) + assert success is True + else: + response = requests.post( + f"{url}/destroy_weights_update_group", + json={ + "group_name": "test_parameter_update_group", + }, + ) + assert response.status_code == 200 + # Shutdown the engine or terminate the server process. if backend == "Engine": engine.shutdown()