Support Flatten Tensor Update Weights to speed up MOE Update Weights by 20% (#8079)

This commit is contained in:
Stefan He
2025-08-10 16:08:59 -07:00
committed by GitHub
parent 0418b9d4ea
commit 8ecf6b9d24
4 changed files with 210 additions and 3 deletions

View File

@@ -121,6 +121,10 @@ from sglang.srt.utils import (
set_cpu_offload_max_bytes,
set_cuda_arch,
)
from sglang.srt.weight_sync.tensor_bucket import (
FlattenedTensorBucket,
FlattenedTensorMetadata,
)
_is_hip = is_hip()
_is_npu = is_npu()
@@ -896,6 +900,12 @@ class ModelRunner:
load_format: Optional[str] = None,
):
monkey_patch_torch_reductions()
if load_format == "flattened_bucket":
# Handle flattened bucket format
return self._update_weights_from_flattened_bucket(
flattened_tensor_bucket_dict=named_tensors
)
# We need to get device after patch otherwise the device would be wrong
infered_device = torch.cuda.current_device()
@@ -914,6 +924,38 @@ class ModelRunner:
raise NotImplementedError(f"Unknown load_format={load_format}")
return True, "Success"
def _update_weights_from_flattened_bucket(
self,
flattened_tensor_bucket_dict,
):
"""Handle flattened bucket format for weight updates"""
flattened_tensor = flattened_tensor_bucket_dict["flattened_tensor"]
metadata = flattened_tensor_bucket_dict["metadata"]
# Convert metadata dict to our format
converted_metadata = []
for meta in metadata:
converted_meta = FlattenedTensorMetadata(
name=meta.name,
shape=meta.shape,
dtype=meta.dtype,
start_idx=meta.start_idx,
end_idx=meta.end_idx,
numel=meta.numel,
)
converted_metadata.append(converted_meta)
# Create bucket and reconstruct tensors
bucket = FlattenedTensorBucket(
flattened_tensor=flattened_tensor, metadata=converted_metadata
)
reconstructed_tensors = bucket.reconstruct_tensors()
# Load the reconstructed tensors using the standard method
self.model.load_weights(reconstructed_tensors)
return True, "Success"
def get_weights_by_name(
self, name: str, truncate_size: int = 100
) -> Optional[torch.Tensor]: