Support Flatten Tensor Update Weights to speed up MOE Update Weights by 20% (#8079)
This commit is contained in:
@@ -451,15 +451,20 @@ class Engine(EngineBase):
|
||||
):
|
||||
"""Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be false
|
||||
to avoid duplicated cache cleaning operation."""
|
||||
obj = UpdateWeightsFromTensorReqInput(
|
||||
serialized_named_tensors=[
|
||||
if load_format == "flattened_bucket":
|
||||
serialized_named_tensors = named_tensors
|
||||
else:
|
||||
serialized_named_tensors = [
|
||||
MultiprocessingSerializer.serialize(named_tensors)
|
||||
for _ in range(self.server_args.tp_size)
|
||||
],
|
||||
]
|
||||
obj = UpdateWeightsFromTensorReqInput(
|
||||
serialized_named_tensors=serialized_named_tensors,
|
||||
load_format=load_format,
|
||||
flush_cache=flush_cache,
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
return loop.run_until_complete(
|
||||
self.tokenizer_manager.update_weights_from_tensor(obj, None)
|
||||
)
|
||||
|
||||
@@ -121,6 +121,10 @@ from sglang.srt.utils import (
|
||||
set_cpu_offload_max_bytes,
|
||||
set_cuda_arch,
|
||||
)
|
||||
from sglang.srt.weight_sync.tensor_bucket import (
|
||||
FlattenedTensorBucket,
|
||||
FlattenedTensorMetadata,
|
||||
)
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_npu = is_npu()
|
||||
@@ -896,6 +900,12 @@ class ModelRunner:
|
||||
load_format: Optional[str] = None,
|
||||
):
|
||||
monkey_patch_torch_reductions()
|
||||
if load_format == "flattened_bucket":
|
||||
# Handle flattened bucket format
|
||||
return self._update_weights_from_flattened_bucket(
|
||||
flattened_tensor_bucket_dict=named_tensors
|
||||
)
|
||||
|
||||
# We need to get device after patch otherwise the device would be wrong
|
||||
infered_device = torch.cuda.current_device()
|
||||
|
||||
@@ -914,6 +924,38 @@ class ModelRunner:
|
||||
raise NotImplementedError(f"Unknown load_format={load_format}")
|
||||
return True, "Success"
|
||||
|
||||
def _update_weights_from_flattened_bucket(
|
||||
self,
|
||||
flattened_tensor_bucket_dict,
|
||||
):
|
||||
"""Handle flattened bucket format for weight updates"""
|
||||
flattened_tensor = flattened_tensor_bucket_dict["flattened_tensor"]
|
||||
metadata = flattened_tensor_bucket_dict["metadata"]
|
||||
|
||||
# Convert metadata dict to our format
|
||||
converted_metadata = []
|
||||
for meta in metadata:
|
||||
converted_meta = FlattenedTensorMetadata(
|
||||
name=meta.name,
|
||||
shape=meta.shape,
|
||||
dtype=meta.dtype,
|
||||
start_idx=meta.start_idx,
|
||||
end_idx=meta.end_idx,
|
||||
numel=meta.numel,
|
||||
)
|
||||
converted_metadata.append(converted_meta)
|
||||
|
||||
# Create bucket and reconstruct tensors
|
||||
bucket = FlattenedTensorBucket(
|
||||
flattened_tensor=flattened_tensor, metadata=converted_metadata
|
||||
)
|
||||
reconstructed_tensors = bucket.reconstruct_tensors()
|
||||
|
||||
# Load the reconstructed tensors using the standard method
|
||||
self.model.load_weights(reconstructed_tensors)
|
||||
|
||||
return True, "Success"
|
||||
|
||||
def get_weights_by_name(
|
||||
self, name: str, truncate_size: int = 100
|
||||
) -> Optional[torch.Tensor]:
|
||||
|
||||
106
python/sglang/srt/weight_sync/tensor_bucket.py
Normal file
106
python/sglang/srt/weight_sync/tensor_bucket.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlattenedTensorMetadata:
|
||||
"""Metadata for a tensor in a flattened bucket"""
|
||||
|
||||
name: str
|
||||
shape: torch.Size
|
||||
dtype: torch.dtype
|
||||
start_idx: int
|
||||
end_idx: int
|
||||
numel: int
|
||||
|
||||
|
||||
class FlattenedTensorBucket:
|
||||
"""
|
||||
A bucket that flattens multiple tensors into a single tensor for efficient processing
|
||||
while preserving all metadata needed for reconstruction.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
named_tensors: List[Tuple[str, torch.Tensor]] = None,
|
||||
flattened_tensor: torch.Tensor = None,
|
||||
metadata: List[FlattenedTensorMetadata] = None,
|
||||
):
|
||||
"""
|
||||
Initialize a tensor bucket from a list of named tensors OR from pre-flattened data.
|
||||
Args:
|
||||
named_tensors: List of (name, tensor) tuples (for creating new bucket)
|
||||
flattened_tensor: Pre-flattened tensor (for reconstruction)
|
||||
metadata: Pre-computed metadata (for reconstruction)
|
||||
"""
|
||||
if named_tensors is not None:
|
||||
# Create bucket from named tensors
|
||||
self.metadata: List[FlattenedTensorMetadata] = [None] * len(named_tensors)
|
||||
self.flattened_tensor: torch.Tensor = None
|
||||
|
||||
if not named_tensors:
|
||||
raise ValueError("Cannot create empty tensor bucket")
|
||||
|
||||
# Collect metadata and flatten tensors
|
||||
current_idx = 0
|
||||
flattened_tensors: List[torch.Tensor] = [None] * len(named_tensors)
|
||||
|
||||
for i, (name, tensor) in enumerate(named_tensors):
|
||||
flattened = tensor.flatten()
|
||||
flattened_tensors[i] = flattened
|
||||
|
||||
# Store metadata
|
||||
|
||||
numel = flattened.numel()
|
||||
metadata_obj = FlattenedTensorMetadata(
|
||||
name=name,
|
||||
shape=tensor.shape,
|
||||
dtype=tensor.dtype,
|
||||
start_idx=current_idx,
|
||||
end_idx=current_idx + numel,
|
||||
numel=numel,
|
||||
)
|
||||
self.metadata[i] = metadata_obj
|
||||
current_idx += numel
|
||||
|
||||
# Concatenate all flattened tensors
|
||||
self.flattened_tensor = torch.cat(flattened_tensors, dim=0)
|
||||
else:
|
||||
# Initialize from pre-flattened data
|
||||
if flattened_tensor is None or metadata is None:
|
||||
raise ValueError(
|
||||
"Must provide either named_tensors or both flattened_tensor and metadata"
|
||||
)
|
||||
self.flattened_tensor = flattened_tensor
|
||||
self.metadata = metadata
|
||||
|
||||
def get_flattened_tensor(self) -> torch.Tensor:
|
||||
"""Get the flattened tensor containing all bucket tensors"""
|
||||
return self.flattened_tensor
|
||||
|
||||
def get_metadata(self) -> List[FlattenedTensorMetadata]:
|
||||
"""Get metadata for all tensors in the bucket"""
|
||||
return self.metadata
|
||||
|
||||
def reconstruct_tensors(self) -> List[Tuple[str, torch.Tensor]]:
|
||||
"""
|
||||
Reconstruct original tensors from flattened tensor with optimized performance.
|
||||
Uses memory-efficient operations to minimize allocations and copies.
|
||||
"""
|
||||
# preallocate the result list
|
||||
reconstructed = [None] * len(self.metadata)
|
||||
|
||||
for i, meta in enumerate(self.metadata):
|
||||
tensor = self.flattened_tensor[meta.start_idx : meta.end_idx].reshape(
|
||||
meta.shape
|
||||
)
|
||||
|
||||
# batch dtype conversion (if needed)
|
||||
if tensor.dtype != meta.dtype:
|
||||
tensor = tensor.to(meta.dtype)
|
||||
|
||||
reconstructed[i] = (meta.name, tensor)
|
||||
|
||||
return reconstructed
|
||||
@@ -5,6 +5,7 @@ import unittest
|
||||
import torch
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.srt.weight_sync.tensor_bucket import FlattenedTensorBucket
|
||||
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase
|
||||
|
||||
|
||||
@@ -112,6 +113,59 @@ class TestUpdateWeightsFromTensor(CustomTestCase):
|
||||
|
||||
engine.shutdown()
|
||||
|
||||
def test_update_weights_from_tensor_load_format_flattened_bucket(self):
|
||||
"""Test updating weights using flattened_bucket format"""
|
||||
engine = sgl.Engine(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||
|
||||
# Create a small set of parameters for testing
|
||||
param_names = [f"model.layers.{i}.mlp.up_proj.weight" for i in range(6, 10)]
|
||||
|
||||
# Check original values
|
||||
_check_param(engine, param_names[0], [0.0087, -0.0214, -0.0004, 0.0039, 0.0110])
|
||||
|
||||
# Create new tensors with different values
|
||||
new_tensors = []
|
||||
for _, name in enumerate(param_names):
|
||||
# Create tensors with different values for each parameter
|
||||
value = 2.0 # Different value for each parameter
|
||||
new_tensor = torch.full((16384, 2048), value, device="cuda")
|
||||
new_tensors.append((name, new_tensor))
|
||||
|
||||
# Create a flattened bucket
|
||||
flattened_bucket = FlattenedTensorBucket(named_tensors=new_tensors)
|
||||
|
||||
# Extract the flattened tensor and metadata in the format expected by model_runner
|
||||
flattened_tensor = flattened_bucket.get_flattened_tensor()
|
||||
metadata = flattened_bucket.get_metadata()
|
||||
|
||||
# Create the dict format expected by _update_weights_from_flattened_bucket
|
||||
bucket_dict = {"flattened_tensor": flattened_tensor, "metadata": metadata}
|
||||
|
||||
# Serialize the bucket data
|
||||
from sglang.srt.utils import MultiprocessingSerializer
|
||||
|
||||
serialized_bucket = MultiprocessingSerializer.serialize(
|
||||
bucket_dict, output_str=True
|
||||
)
|
||||
|
||||
# Create a list where each rank contains the same serialized data
|
||||
# This simulates the distributed environment where each rank has the same data
|
||||
serialized_bucket_list = [serialized_bucket]
|
||||
|
||||
# Update weights using flattened_bucket format
|
||||
time_start = time.perf_counter()
|
||||
engine.update_weights_from_tensor(
|
||||
named_tensors=serialized_bucket_list, load_format="flattened_bucket"
|
||||
)
|
||||
update_time = time.perf_counter() - time_start
|
||||
print(f"Flattened bucket update time: {update_time:.03f}")
|
||||
|
||||
# Verify the weights were updated correctly
|
||||
for i, param_name in enumerate(param_names):
|
||||
_check_param(engine, param_name, [2.0] * 5)
|
||||
|
||||
engine.shutdown()
|
||||
|
||||
|
||||
def _check_param(engine, param_name, expect_values):
|
||||
actual_values = torch.tensor(engine.get_weights_by_name(param_name))[0, :5]
|
||||
|
||||
Reference in New Issue
Block a user