[RL] support update_weights_from_distributed with different group and multiple weights (#7292)
This commit is contained in:
@@ -418,12 +418,21 @@ class Engine(EngineBase):
|
||||
self.tokenizer_manager.init_weights_update_group(obj, None)
|
||||
)
|
||||
|
||||
def update_weights_from_distributed(self, name: str, dtype, shape):
|
||||
def update_weights_from_distributed(
|
||||
self,
|
||||
names: list[str],
|
||||
dtypes: list[str],
|
||||
shapes: list[list[int]],
|
||||
group_name: str = "weight_update_group",
|
||||
flush_cache: bool = True,
|
||||
):
|
||||
"""Update weights from distributed source."""
|
||||
obj = UpdateWeightsFromDistributedReqInput(
|
||||
name=name,
|
||||
dtype=dtype,
|
||||
shape=shape,
|
||||
names=names,
|
||||
dtypes=dtypes,
|
||||
shapes=shapes,
|
||||
group_name=group_name,
|
||||
flush_cache=flush_cache,
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(
|
||||
|
||||
Reference in New Issue
Block a user