[RL] support update_weights_from_distributed with different group and multiple weights (#7292)
This commit is contained in:
@@ -418,12 +418,21 @@ class Engine(EngineBase):
|
|||||||
self.tokenizer_manager.init_weights_update_group(obj, None)
|
self.tokenizer_manager.init_weights_update_group(obj, None)
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_weights_from_distributed(self, name: str, dtype, shape):
|
def update_weights_from_distributed(
|
||||||
|
self,
|
||||||
|
names: list[str],
|
||||||
|
dtypes: list[str],
|
||||||
|
shapes: list[list[int]],
|
||||||
|
group_name: str = "weight_update_group",
|
||||||
|
flush_cache: bool = True,
|
||||||
|
):
|
||||||
"""Update weights from distributed source."""
|
"""Update weights from distributed source."""
|
||||||
obj = UpdateWeightsFromDistributedReqInput(
|
obj = UpdateWeightsFromDistributedReqInput(
|
||||||
name=name,
|
names=names,
|
||||||
dtype=dtype,
|
dtypes=dtypes,
|
||||||
shape=shape,
|
shapes=shapes,
|
||||||
|
group_name=group_name,
|
||||||
|
flush_cache=flush_cache,
|
||||||
)
|
)
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
return loop.run_until_complete(
|
return loop.run_until_complete(
|
||||||
|
|||||||
@@ -752,9 +752,13 @@ class UpdateWeightFromDiskReqOutput:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UpdateWeightsFromDistributedReqInput:
|
class UpdateWeightsFromDistributedReqInput:
|
||||||
name: str
|
names: List[str]
|
||||||
dtype: str
|
dtypes: List[str]
|
||||||
shape: List[int]
|
shapes: List[List[int]]
|
||||||
|
# The group name
|
||||||
|
group_name: str = "weight_update_group"
|
||||||
|
# Whether to flush the cache after updating weights
|
||||||
|
flush_cache: bool = True
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -2303,6 +2303,7 @@ class Scheduler(
|
|||||||
"""Update the online model parameter."""
|
"""Update the online model parameter."""
|
||||||
success, message = self.tp_worker.update_weights_from_distributed(recv_req)
|
success, message = self.tp_worker.update_weights_from_distributed(recv_req)
|
||||||
if success:
|
if success:
|
||||||
|
if recv_req.flush_cache:
|
||||||
flush_cache_success = self.flush_cache()
|
flush_cache_success = self.flush_cache()
|
||||||
assert flush_cache_success, "Cache flush failed after updating weights"
|
assert flush_cache_success, "Cache flush failed after updating weights"
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -259,7 +259,7 @@ class TpModelWorker:
|
|||||||
self, recv_req: UpdateWeightsFromDistributedReqInput
|
self, recv_req: UpdateWeightsFromDistributedReqInput
|
||||||
):
|
):
|
||||||
success, message = self.model_runner.update_weights_from_distributed(
|
success, message = self.model_runner.update_weights_from_distributed(
|
||||||
recv_req.name, recv_req.dtype, recv_req.shape
|
recv_req.names, recv_req.dtypes, recv_req.shapes, recv_req.group_name
|
||||||
)
|
)
|
||||||
return success, message
|
return success, message
|
||||||
|
|
||||||
|
|||||||
@@ -225,6 +225,7 @@ class ModelRunner:
|
|||||||
self.support_pp = (
|
self.support_pp = (
|
||||||
"pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
|
"pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
|
||||||
)
|
)
|
||||||
|
self._model_update_group = {}
|
||||||
|
|
||||||
def initialize(self, min_per_gpu_memory: float):
|
def initialize(self, min_per_gpu_memory: float):
|
||||||
server_args = self.server_args
|
server_args = self.server_args
|
||||||
@@ -744,7 +745,7 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._model_update_group = init_custom_process_group(
|
self._model_update_group[group_name] = init_custom_process_group(
|
||||||
backend=backend,
|
backend=backend,
|
||||||
init_method=f"tcp://{master_address}:{master_port}",
|
init_method=f"tcp://{master_address}:{master_port}",
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
@@ -757,7 +758,7 @@ class ModelRunner:
|
|||||||
logger.error(message)
|
logger.error(message)
|
||||||
return False, message
|
return False, message
|
||||||
|
|
||||||
def update_weights_from_distributed(self, name, dtype, shape):
|
def update_weights_from_distributed(self, names, dtypes, shapes, group_name):
|
||||||
"""
|
"""
|
||||||
Update specific parameter in the model weights online
|
Update specific parameter in the model weights online
|
||||||
through `_model_update_group` process group.
|
through `_model_update_group` process group.
|
||||||
@@ -767,19 +768,34 @@ class ModelRunner:
|
|||||||
dtype: the data type of the parameter to be updated.
|
dtype: the data type of the parameter to be updated.
|
||||||
shape: the shape of the parameter to be updated.
|
shape: the shape of the parameter to be updated.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
assert group_name in self._model_update_group, (
|
||||||
|
f"Group {group_name} not in {list(self._model_update_group.keys())}. "
|
||||||
|
"Please call `init_weights_update_group` first."
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
weights = []
|
||||||
|
handles = []
|
||||||
|
for name, dtype, shape in zip(names, dtypes, shapes):
|
||||||
target_dtype = (
|
target_dtype = (
|
||||||
dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
|
dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
|
||||||
)
|
)
|
||||||
|
weight = torch.empty(shape, dtype=target_dtype, device=self.device)
|
||||||
|
handles.append(
|
||||||
|
torch.distributed.broadcast(
|
||||||
|
weight,
|
||||||
|
src=0,
|
||||||
|
group=self._model_update_group[group_name],
|
||||||
|
async_op=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
weights.append((name, weight))
|
||||||
|
for handle in handles:
|
||||||
|
handle.wait()
|
||||||
|
|
||||||
assert (
|
self.model.load_weights(weights)
|
||||||
self._model_update_group is not None
|
return True, f"Succeeded to update parameter online."
|
||||||
), "model update group must be initialized"
|
|
||||||
|
|
||||||
try:
|
|
||||||
weights = torch.empty(shape, dtype=target_dtype, device=self.device)
|
|
||||||
torch.distributed.broadcast(weights, src=0, group=self._model_update_group)
|
|
||||||
self.model.load_weights([(name, weights)])
|
|
||||||
return True, f"Succeeded to update parameter {name} online."
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = (
|
error_msg = (
|
||||||
|
|||||||
@@ -294,20 +294,25 @@ def init_process_sgl(
|
|||||||
update_parameters.remove("lm_head.weight")
|
update_parameters.remove("lm_head.weight")
|
||||||
|
|
||||||
# Get weights from the training engine and update the inference engine.
|
# Get weights from the training engine and update the inference engine.
|
||||||
for parameter_name in update_parameters:
|
names = [parameter_name for parameter_name in update_parameters]
|
||||||
|
dtypes = [torch.bfloat16 if backend == "Engine" else "bfloat16"] * len(names)
|
||||||
|
shapes = [state_dict_key_to_shape[parameter_name] for parameter_name in names]
|
||||||
|
|
||||||
if backend == "Engine":
|
if backend == "Engine":
|
||||||
engine.update_weights_from_distributed(
|
engine.update_weights_from_distributed(
|
||||||
parameter_name,
|
names,
|
||||||
dtype=torch.bfloat16,
|
dtypes=dtypes,
|
||||||
shape=state_dict_key_to_shape[parameter_name],
|
shapes=shapes,
|
||||||
|
group_name="test_parameter_update_group",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
requests.post(
|
requests.post(
|
||||||
f"{url}/update_weights_from_distributed",
|
f"{url}/update_weights_from_distributed",
|
||||||
json={
|
json={
|
||||||
"name": parameter_name,
|
"names": names,
|
||||||
"dtype": "bfloat16",
|
"dtypes": dtypes,
|
||||||
"shape": state_dict_key_to_shape[parameter_name],
|
"shapes": shapes,
|
||||||
|
"group_name": "test_parameter_update_group",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|||||||
Reference in New Issue
Block a user