[RL] support update_weights_from_distributed with different group and multiple weights (#7292)
This commit is contained in:
@@ -294,22 +294,27 @@ def init_process_sgl(
|
||||
update_parameters.remove("lm_head.weight")
|
||||
|
||||
# Get weights from the training engine and update the inference engine.
|
||||
for parameter_name in update_parameters:
|
||||
if backend == "Engine":
|
||||
engine.update_weights_from_distributed(
|
||||
parameter_name,
|
||||
dtype=torch.bfloat16,
|
||||
shape=state_dict_key_to_shape[parameter_name],
|
||||
)
|
||||
else:
|
||||
requests.post(
|
||||
f"{url}/update_weights_from_distributed",
|
||||
json={
|
||||
"name": parameter_name,
|
||||
"dtype": "bfloat16",
|
||||
"shape": state_dict_key_to_shape[parameter_name],
|
||||
},
|
||||
)
|
||||
names = [parameter_name for parameter_name in update_parameters]
|
||||
dtypes = [torch.bfloat16 if backend == "Engine" else "bfloat16"] * len(names)
|
||||
shapes = [state_dict_key_to_shape[parameter_name] for parameter_name in names]
|
||||
|
||||
if backend == "Engine":
|
||||
engine.update_weights_from_distributed(
|
||||
names,
|
||||
dtypes=dtypes,
|
||||
shapes=shapes,
|
||||
group_name="test_parameter_update_group",
|
||||
)
|
||||
else:
|
||||
requests.post(
|
||||
f"{url}/update_weights_from_distributed",
|
||||
json={
|
||||
"names": names,
|
||||
"dtypes": dtypes,
|
||||
"shapes": shapes,
|
||||
"group_name": "test_parameter_update_group",
|
||||
},
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
time_end_update = time.perf_counter()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user