[RL] support update_weights_from_distributed with different group and multiple weights (#7292)

This commit is contained in:
Zilin Zhu
2025-07-03 10:29:11 +08:00
committed by GitHub
parent 09e699bba4
commit 0626f678de
6 changed files with 73 additions and 38 deletions

View File

@@ -294,22 +294,27 @@ def init_process_sgl(
update_parameters.remove("lm_head.weight")
# Get weights from the training engine and update the inference engine.
for parameter_name in update_parameters:
if backend == "Engine":
engine.update_weights_from_distributed(
parameter_name,
dtype=torch.bfloat16,
shape=state_dict_key_to_shape[parameter_name],
)
else:
requests.post(
f"{url}/update_weights_from_distributed",
json={
"name": parameter_name,
"dtype": "bfloat16",
"shape": state_dict_key_to_shape[parameter_name],
},
)
names = [parameter_name for parameter_name in update_parameters]
dtypes = [torch.bfloat16 if backend == "Engine" else "bfloat16"] * len(names)
shapes = [state_dict_key_to_shape[parameter_name] for parameter_name in names]
if backend == "Engine":
engine.update_weights_from_distributed(
names,
dtypes=dtypes,
shapes=shapes,
group_name="test_parameter_update_group",
)
else:
requests.post(
f"{url}/update_weights_from_distributed",
json={
"names": names,
"dtypes": dtypes,
"shapes": shapes,
"group_name": "test_parameter_update_group",
},
)
torch.cuda.synchronize()
time_end_update = time.perf_counter()