Support Flatten Tensor Update Weights to speed up MOE Update Weights by 20% (#8079)

This commit is contained in:
Stefan He
2025-08-10 16:08:59 -07:00
committed by GitHub
parent 0418b9d4ea
commit 8ecf6b9d24
4 changed files with 210 additions and 3 deletions

View File

@@ -451,15 +451,20 @@ class Engine(EngineBase):
):
"""Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be false
to avoid duplicated cache cleaning operation."""
obj = UpdateWeightsFromTensorReqInput(
serialized_named_tensors=[
if load_format == "flattened_bucket":
serialized_named_tensors = named_tensors
else:
serialized_named_tensors = [
MultiprocessingSerializer.serialize(named_tensors)
for _ in range(self.server_args.tp_size)
],
]
obj = UpdateWeightsFromTensorReqInput(
serialized_named_tensors=serialized_named_tensors,
load_format=load_format,
flush_cache=flush_cache,
)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
self.tokenizer_manager.update_weights_from_tensor(obj, None)
)