[RL] support update_weights_from_distributed with different group and multiple weights (#7292)

This commit is contained in:
Zilin Zhu
2025-07-03 10:29:11 +08:00
committed by GitHub
parent 09e699bba4
commit 0626f678de
6 changed files with 73 additions and 38 deletions

View File

@@ -418,12 +418,21 @@ class Engine(EngineBase):
self.tokenizer_manager.init_weights_update_group(obj, None)
)
def update_weights_from_distributed(self, name: str, dtype, shape):
def update_weights_from_distributed(
self,
names: list[str],
dtypes: list[str],
shapes: list[list[int]],
group_name: str = "weight_update_group",
flush_cache: bool = True,
):
"""Update weights from distributed source."""
obj = UpdateWeightsFromDistributedReqInput(
name=name,
dtype=dtype,
shape=shape,
names=names,
dtypes=dtypes,
shapes=shapes,
group_name=group_name,
flush_cache=flush_cache,
)
loop = asyncio.get_event_loop()
return loop.run_until_complete(