From 31b9f19e5483202a59223d303663640597aec04c Mon Sep 17 00:00:00 2001 From: Zilin Zhu Date: Sat, 18 Oct 2025 14:26:19 +0800 Subject: [PATCH] [RL] support weight update with DP attention (#11669) --- .../managers/tokenizer_communicator_mixin.py | 29 ++++++++++++------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py index 91179cb53..0a3baf5de 100644 --- a/python/sglang/srt/managers/tokenizer_communicator_mixin.py +++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py @@ -146,6 +146,13 @@ class _Communicator(Generic[T]): if len(self._result_values) == self._fan_out: self._result_event.set() + @staticmethod + def merge_results(results): + all_success = all([r.success for r in results]) + all_message = [r.message for r in results] + all_message = " | ".join(all_message) + return all_success, all_message + class TokenizerCommunicatorMixin: """Mixin class for TokenizerManager to handle communication with the scheduler.""" @@ -358,10 +365,11 @@ class TokenizerCommunicatorMixin: ) -> Tuple[bool, str]: self.auto_create_handle_loop() assert ( - self.server_args.dp_size == 1 - ), "dp_size must be 1 for init parameter update group" - result = (await self.init_weights_update_group_communicator(obj))[0] - return result.success, result.message + self.server_args.dp_size == 1 or self.server_args.enable_dp_attention + ), "dp_size must be 1 or dp attention must be enabled for update weights from distributed" + + results = await self.init_weights_update_group_communicator(obj) + return _Communicator.merge_results(results) async def destroy_weights_update_group( self, @@ -370,10 +378,11 @@ class TokenizerCommunicatorMixin: ) -> Tuple[bool, str]: self.auto_create_handle_loop() assert ( - self.server_args.dp_size == 1 - ), "dp_size must be 1 for destroy parameter update group" - result = (await self.destroy_weights_update_group_communicator(obj))[0] - return result.success, result.message + self.server_args.dp_size == 1 or self.server_args.enable_dp_attention + ), "dp_size must be 1 or dp attention must be enabled for destroy parameter update group" + + results = await self.destroy_weights_update_group_communicator(obj) + return _Communicator.merge_results(results) async def update_weights_from_distributed( self: TokenizerManager, @@ -391,8 +400,8 @@ class TokenizerCommunicatorMixin: # This means that weight sync # cannot run while requests are in progress. async with self.model_update_lock.writer_lock: - result = (await self.update_weights_from_distributed_communicator(obj))[0] - return result.success, result.message + results = await self.update_weights_from_distributed_communicator(obj) + return _Communicator.merge_results(results) async def init_weights_send_group_for_remote_instance( self,