From 0626f678de6d58e38378f05511f6200cdc67e70c Mon Sep 17 00:00:00 2001 From: Zilin Zhu Date: Thu, 3 Jul 2025 10:29:11 +0800 Subject: [PATCH] [RL] support update_weights_from_distributed with different group and multiple weights (#7292) --- python/sglang/srt/entrypoints/engine.py | 17 ++++++-- python/sglang/srt/managers/io_struct.py | 10 +++-- python/sglang/srt/managers/scheduler.py | 5 ++- python/sglang/srt/managers/tp_worker.py | 2 +- .../sglang/srt/model_executor/model_runner.py | 40 +++++++++++++------ .../test_update_weights_from_distributed.py | 37 +++++++++-------- 6 files changed, 73 insertions(+), 38 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 4712aea2c..f56c40133 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -418,12 +418,21 @@ class Engine(EngineBase): self.tokenizer_manager.init_weights_update_group(obj, None) ) - def update_weights_from_distributed(self, name: str, dtype, shape): + def update_weights_from_distributed( + self, + names: list[str], + dtypes: list[str], + shapes: list[list[int]], + group_name: str = "weight_update_group", + flush_cache: bool = True, + ): """Update weights from distributed source.""" obj = UpdateWeightsFromDistributedReqInput( - name=name, - dtype=dtype, - shape=shape, + names=names, + dtypes=dtypes, + shapes=shapes, + group_name=group_name, + flush_cache=flush_cache, ) loop = asyncio.get_event_loop() return loop.run_until_complete( diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 27b14a2ce..6c5b6e196 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -752,9 +752,13 @@ class UpdateWeightFromDiskReqOutput: @dataclass class UpdateWeightsFromDistributedReqInput: - name: str - dtype: str - shape: List[int] + names: List[str] + dtypes: List[str] + shapes: List[List[int]] + # The group name + group_name: str = "weight_update_group" + # Whether to flush the cache after updating weights + flush_cache: bool = True @dataclass diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 630a9dd2d..8a82c1fea 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -2303,8 +2303,9 @@ class Scheduler( """Update the online model parameter.""" success, message = self.tp_worker.update_weights_from_distributed(recv_req) if success: - flush_cache_success = self.flush_cache() - assert flush_cache_success, "Cache flush failed after updating weights" + if recv_req.flush_cache: + flush_cache_success = self.flush_cache() + assert flush_cache_success, "Cache flush failed after updating weights" else: logger.error(message) return UpdateWeightsFromDistributedReqOutput(success, message) diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index afd9541aa..a0a33741d 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -259,7 +259,7 @@ class TpModelWorker: self, recv_req: UpdateWeightsFromDistributedReqInput ): success, message = self.model_runner.update_weights_from_distributed( - recv_req.name, recv_req.dtype, recv_req.shape + recv_req.names, recv_req.dtypes, recv_req.shapes, recv_req.group_name ) return success, message diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 0bd442d0e..cf22c2e58 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -225,6 +225,7 @@ class ModelRunner: self.support_pp = ( "pp_proxy_tensors" in inspect.signature(self.model.forward).parameters ) + self._model_update_group = {} def initialize(self, min_per_gpu_memory: float): server_args = self.server_args @@ -744,7 +745,7 @@ class ModelRunner: ) try: - self._model_update_group = init_custom_process_group( + self._model_update_group[group_name] = init_custom_process_group( backend=backend, init_method=f"tcp://{master_address}:{master_port}", world_size=world_size, @@ -757,7 +758,7 @@ class ModelRunner: logger.error(message) return False, message - def update_weights_from_distributed(self, name, dtype, shape): + def update_weights_from_distributed(self, names, dtypes, shapes, group_name): """ Update specific parameter in the model weights online through `_model_update_group` process group. @@ -767,19 +768,34 @@ class ModelRunner: dtype: the data type of the parameter to be updated. shape: the shape of the parameter to be updated. """ - target_dtype = ( - dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype) + + assert group_name in self._model_update_group, ( + f"Group {group_name} not in {list(self._model_update_group.keys())}. " + "Please call `init_weights_update_group` first." ) - assert ( - self._model_update_group is not None - ), "model update group must be initialized" - try: - weights = torch.empty(shape, dtype=target_dtype, device=self.device) - torch.distributed.broadcast(weights, src=0, group=self._model_update_group) - self.model.load_weights([(name, weights)]) - return True, f"Succeeded to update parameter {name} online." + weights = [] + handles = [] + for name, dtype, shape in zip(names, dtypes, shapes): + target_dtype = ( + dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype) + ) + weight = torch.empty(shape, dtype=target_dtype, device=self.device) + handles.append( + torch.distributed.broadcast( + weight, + src=0, + group=self._model_update_group[group_name], + async_op=True, + ) + ) + weights.append((name, weight)) + for handle in handles: + handle.wait() + + self.model.load_weights(weights) + return True, f"Succeeded to update parameter online." except Exception as e: error_msg = ( diff --git a/test/srt/test_update_weights_from_distributed.py b/test/srt/test_update_weights_from_distributed.py index cfcc5d951..a3b938c38 100644 --- a/test/srt/test_update_weights_from_distributed.py +++ b/test/srt/test_update_weights_from_distributed.py @@ -294,22 +294,27 @@ def init_process_sgl( update_parameters.remove("lm_head.weight") # Get weights from the training engine and update the inference engine. - for parameter_name in update_parameters: - if backend == "Engine": - engine.update_weights_from_distributed( - parameter_name, - dtype=torch.bfloat16, - shape=state_dict_key_to_shape[parameter_name], - ) - else: - requests.post( - f"{url}/update_weights_from_distributed", - json={ - "name": parameter_name, - "dtype": "bfloat16", - "shape": state_dict_key_to_shape[parameter_name], - }, - ) + names = [parameter_name for parameter_name in update_parameters] + dtypes = [torch.bfloat16 if backend == "Engine" else "bfloat16"] * len(names) + shapes = [state_dict_key_to_shape[parameter_name] for parameter_name in names] + + if backend == "Engine": + engine.update_weights_from_distributed( + names, + dtypes=dtypes, + shapes=shapes, + group_name="test_parameter_update_group", + ) + else: + requests.post( + f"{url}/update_weights_from_distributed", + json={ + "names": names, + "dtypes": dtypes, + "shapes": shapes, + "group_name": "test_parameter_update_group", + }, + ) torch.cuda.synchronize() time_end_update = time.perf_counter()