From 8ecf6b9d2480c3f600826c7d8fef6a16ed603c3f Mon Sep 17 00:00:00 2001 From: Stefan He Date: Sun, 10 Aug 2025 16:08:59 -0700 Subject: [PATCH] Support Flatten Tensor Update Weights to speed up MOE Update Weights by 20% (#8079) --- python/sglang/srt/entrypoints/engine.py | 11 +- .../sglang/srt/model_executor/model_runner.py | 42 +++++++ .../sglang/srt/weight_sync/tensor_bucket.py | 106 ++++++++++++++++++ .../srt/rl/test_update_weights_from_tensor.py | 54 +++++++++ 4 files changed, 210 insertions(+), 3 deletions(-) create mode 100644 python/sglang/srt/weight_sync/tensor_bucket.py diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index bde60ddfc..e40e156e2 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -451,15 +451,20 @@ class Engine(EngineBase): ): """Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be false to avoid duplicated cache cleaning operation.""" - obj = UpdateWeightsFromTensorReqInput( - serialized_named_tensors=[ + if load_format == "flattened_bucket": + serialized_named_tensors = named_tensors + else: + serialized_named_tensors = [ MultiprocessingSerializer.serialize(named_tensors) for _ in range(self.server_args.tp_size) - ], + ] + obj = UpdateWeightsFromTensorReqInput( + serialized_named_tensors=serialized_named_tensors, load_format=load_format, flush_cache=flush_cache, ) loop = asyncio.get_event_loop() + return loop.run_until_complete( self.tokenizer_manager.update_weights_from_tensor(obj, None) ) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 2bb2676a8..ee83c2d9c 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -121,6 +121,10 @@ from sglang.srt.utils import ( set_cpu_offload_max_bytes, set_cuda_arch, ) +from sglang.srt.weight_sync.tensor_bucket import ( + FlattenedTensorBucket, + FlattenedTensorMetadata, +) _is_hip = is_hip() _is_npu = is_npu() @@ -896,6 +900,12 @@ class ModelRunner: load_format: Optional[str] = None, ): monkey_patch_torch_reductions() + if load_format == "flattened_bucket": + # Handle flattened bucket format + return self._update_weights_from_flattened_bucket( + flattened_tensor_bucket_dict=named_tensors + ) + # We need to get device after patch otherwise the device would be wrong infered_device = torch.cuda.current_device() @@ -914,6 +924,38 @@ class ModelRunner: raise NotImplementedError(f"Unknown load_format={load_format}") return True, "Success" + def _update_weights_from_flattened_bucket( + self, + flattened_tensor_bucket_dict, + ): + """Handle flattened bucket format for weight updates""" + flattened_tensor = flattened_tensor_bucket_dict["flattened_tensor"] + metadata = flattened_tensor_bucket_dict["metadata"] + + # Convert metadata dict to our format + converted_metadata = [] + for meta in metadata: + converted_meta = FlattenedTensorMetadata( + name=meta.name, + shape=meta.shape, + dtype=meta.dtype, + start_idx=meta.start_idx, + end_idx=meta.end_idx, + numel=meta.numel, + ) + converted_metadata.append(converted_meta) + + # Create bucket and reconstruct tensors + bucket = FlattenedTensorBucket( + flattened_tensor=flattened_tensor, metadata=converted_metadata + ) + reconstructed_tensors = bucket.reconstruct_tensors() + + # Load the reconstructed tensors using the standard method + self.model.load_weights(reconstructed_tensors) + + return True, "Success" + def get_weights_by_name( self, name: str, truncate_size: int = 100 ) -> Optional[torch.Tensor]: diff --git a/python/sglang/srt/weight_sync/tensor_bucket.py b/python/sglang/srt/weight_sync/tensor_bucket.py new file mode 100644 index 000000000..44273713f --- /dev/null +++ b/python/sglang/srt/weight_sync/tensor_bucket.py @@ -0,0 +1,106 @@ +from dataclasses import dataclass +from typing import List, Tuple + +import torch + + +@dataclass +class FlattenedTensorMetadata: + """Metadata for a tensor in a flattened bucket""" + + name: str + shape: torch.Size + dtype: torch.dtype + start_idx: int + end_idx: int + numel: int + + +class FlattenedTensorBucket: + """ + A bucket that flattens multiple tensors into a single tensor for efficient processing + while preserving all metadata needed for reconstruction. + """ + + def __init__( + self, + named_tensors: List[Tuple[str, torch.Tensor]] = None, + flattened_tensor: torch.Tensor = None, + metadata: List[FlattenedTensorMetadata] = None, + ): + """ + Initialize a tensor bucket from a list of named tensors OR from pre-flattened data. + Args: + named_tensors: List of (name, tensor) tuples (for creating new bucket) + flattened_tensor: Pre-flattened tensor (for reconstruction) + metadata: Pre-computed metadata (for reconstruction) + """ + if named_tensors is not None: + # Create bucket from named tensors + self.metadata: List[FlattenedTensorMetadata] = [None] * len(named_tensors) + self.flattened_tensor: torch.Tensor = None + + if not named_tensors: + raise ValueError("Cannot create empty tensor bucket") + + # Collect metadata and flatten tensors + current_idx = 0 + flattened_tensors: List[torch.Tensor] = [None] * len(named_tensors) + + for i, (name, tensor) in enumerate(named_tensors): + flattened = tensor.flatten() + flattened_tensors[i] = flattened + + # Store metadata + + numel = flattened.numel() + metadata_obj = FlattenedTensorMetadata( + name=name, + shape=tensor.shape, + dtype=tensor.dtype, + start_idx=current_idx, + end_idx=current_idx + numel, + numel=numel, + ) + self.metadata[i] = metadata_obj + current_idx += numel + + # Concatenate all flattened tensors + self.flattened_tensor = torch.cat(flattened_tensors, dim=0) + else: + # Initialize from pre-flattened data + if flattened_tensor is None or metadata is None: + raise ValueError( + "Must provide either named_tensors or both flattened_tensor and metadata" + ) + self.flattened_tensor = flattened_tensor + self.metadata = metadata + + def get_flattened_tensor(self) -> torch.Tensor: + """Get the flattened tensor containing all bucket tensors""" + return self.flattened_tensor + + def get_metadata(self) -> List[FlattenedTensorMetadata]: + """Get metadata for all tensors in the bucket""" + return self.metadata + + def reconstruct_tensors(self) -> List[Tuple[str, torch.Tensor]]: + """ + Reconstruct original tensors from flattened tensor with optimized performance. + Uses memory-efficient operations to minimize allocations and copies. + """ + # preallocate the result list + reconstructed = [None] * len(self.metadata) + + for i, meta in enumerate(self.metadata): + tensor = self.flattened_tensor[meta.start_idx : meta.end_idx].reshape( + meta.shape + ) + + # batch dtype conversion (if needed) + if tensor.dtype != meta.dtype: + tensor = tensor.to(meta.dtype) + + reconstructed[i] = (meta.name, tensor) + + return reconstructed diff --git a/test/srt/rl/test_update_weights_from_tensor.py b/test/srt/rl/test_update_weights_from_tensor.py index a1ca7f4b0..0dc947b60 100644 --- a/test/srt/rl/test_update_weights_from_tensor.py +++ b/test/srt/rl/test_update_weights_from_tensor.py @@ -5,6 +5,7 @@ import unittest import torch import sglang as sgl +from sglang.srt.weight_sync.tensor_bucket import FlattenedTensorBucket from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase @@ -112,6 +113,59 @@ class TestUpdateWeightsFromTensor(CustomTestCase): engine.shutdown() + def test_update_weights_from_tensor_load_format_flattened_bucket(self): + """Test updating weights using flattened_bucket format""" + engine = sgl.Engine(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST) + + # Create a small set of parameters for testing + param_names = [f"model.layers.{i}.mlp.up_proj.weight" for i in range(6, 10)] + + # Check original values + _check_param(engine, param_names[0], [0.0087, -0.0214, -0.0004, 0.0039, 0.0110]) + + # Create new tensors with different values + new_tensors = [] + for _, name in enumerate(param_names): + # Create tensors with different values for each parameter + value = 2.0 # Different value for each parameter + new_tensor = torch.full((16384, 2048), value, device="cuda") + new_tensors.append((name, new_tensor)) + + # Create a flattened bucket + flattened_bucket = FlattenedTensorBucket(named_tensors=new_tensors) + + # Extract the flattened tensor and metadata in the format expected by model_runner + flattened_tensor = flattened_bucket.get_flattened_tensor() + metadata = flattened_bucket.get_metadata() + + # Create the dict format expected by _update_weights_from_flattened_bucket + bucket_dict = {"flattened_tensor": flattened_tensor, "metadata": metadata} + + # Serialize the bucket data + from sglang.srt.utils import MultiprocessingSerializer + + serialized_bucket = MultiprocessingSerializer.serialize( + bucket_dict, output_str=True + ) + + # Create a list where each rank contains the same serialized data + # This simulates the distributed environment where each rank has the same data + serialized_bucket_list = [serialized_bucket] + + # Update weights using flattened_bucket format + time_start = time.perf_counter() + engine.update_weights_from_tensor( + named_tensors=serialized_bucket_list, load_format="flattened_bucket" + ) + update_time = time.perf_counter() - time_start + print(f"Flattened bucket update time: {update_time:.03f}") + + # Verify the weights were updated correctly + for i, param_name in enumerate(param_names): + _check_param(engine, param_name, [2.0] * 5) + + engine.shutdown() + def _check_param(engine, param_name, expect_values): actual_values = torch.tensor(engine.get_weights_by_name(param_name))[0, :5]