Support Flatten Tensor Update Weights to speed up MOE Update Weights by 20% (#8079)
This commit is contained in:
@@ -121,6 +121,10 @@ from sglang.srt.utils import (
|
||||
set_cpu_offload_max_bytes,
|
||||
set_cuda_arch,
|
||||
)
|
||||
from sglang.srt.weight_sync.tensor_bucket import (
|
||||
FlattenedTensorBucket,
|
||||
FlattenedTensorMetadata,
|
||||
)
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_npu = is_npu()
|
||||
@@ -896,6 +900,12 @@ class ModelRunner:
|
||||
load_format: Optional[str] = None,
|
||||
):
|
||||
monkey_patch_torch_reductions()
|
||||
if load_format == "flattened_bucket":
|
||||
# Handle flattened bucket format
|
||||
return self._update_weights_from_flattened_bucket(
|
||||
flattened_tensor_bucket_dict=named_tensors
|
||||
)
|
||||
|
||||
# We need to get device after patch otherwise the device would be wrong
|
||||
infered_device = torch.cuda.current_device()
|
||||
|
||||
@@ -914,6 +924,38 @@ class ModelRunner:
|
||||
raise NotImplementedError(f"Unknown load_format={load_format}")
|
||||
return True, "Success"
|
||||
|
||||
def _update_weights_from_flattened_bucket(
|
||||
self,
|
||||
flattened_tensor_bucket_dict,
|
||||
):
|
||||
"""Handle flattened bucket format for weight updates"""
|
||||
flattened_tensor = flattened_tensor_bucket_dict["flattened_tensor"]
|
||||
metadata = flattened_tensor_bucket_dict["metadata"]
|
||||
|
||||
# Convert metadata dict to our format
|
||||
converted_metadata = []
|
||||
for meta in metadata:
|
||||
converted_meta = FlattenedTensorMetadata(
|
||||
name=meta.name,
|
||||
shape=meta.shape,
|
||||
dtype=meta.dtype,
|
||||
start_idx=meta.start_idx,
|
||||
end_idx=meta.end_idx,
|
||||
numel=meta.numel,
|
||||
)
|
||||
converted_metadata.append(converted_meta)
|
||||
|
||||
# Create bucket and reconstruct tensors
|
||||
bucket = FlattenedTensorBucket(
|
||||
flattened_tensor=flattened_tensor, metadata=converted_metadata
|
||||
)
|
||||
reconstructed_tensors = bucket.reconstruct_tensors()
|
||||
|
||||
# Load the reconstructed tensors using the standard method
|
||||
self.model.load_weights(reconstructed_tensors)
|
||||
|
||||
return True, "Success"
|
||||
|
||||
def get_weights_by_name(
|
||||
self, name: str, truncate_size: int = 100
|
||||
) -> Optional[torch.Tensor]:
|
||||
|
||||
Reference in New Issue
Block a user