Support Flatten Tensor Update Weights to speed up MOE Update Weights by 20% (#8079)
This commit is contained in:
@@ -451,15 +451,20 @@ class Engine(EngineBase):
|
||||
):
|
||||
"""Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be false
|
||||
to avoid duplicated cache cleaning operation."""
|
||||
obj = UpdateWeightsFromTensorReqInput(
|
||||
serialized_named_tensors=[
|
||||
if load_format == "flattened_bucket":
|
||||
serialized_named_tensors = named_tensors
|
||||
else:
|
||||
serialized_named_tensors = [
|
||||
MultiprocessingSerializer.serialize(named_tensors)
|
||||
for _ in range(self.server_args.tp_size)
|
||||
],
|
||||
]
|
||||
obj = UpdateWeightsFromTensorReqInput(
|
||||
serialized_named_tensors=serialized_named_tensors,
|
||||
load_format=load_format,
|
||||
flush_cache=flush_cache,
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
return loop.run_until_complete(
|
||||
self.tokenizer_manager.update_weights_from_tensor(obj, None)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user