[RL] support update_weights_from_distributed with different group and multiple weights (#7292)
This commit is contained in:
@@ -418,12 +418,21 @@ class Engine(EngineBase):
|
||||
self.tokenizer_manager.init_weights_update_group(obj, None)
|
||||
)
|
||||
|
||||
def update_weights_from_distributed(self, name: str, dtype, shape):
|
||||
def update_weights_from_distributed(
|
||||
self,
|
||||
names: list[str],
|
||||
dtypes: list[str],
|
||||
shapes: list[list[int]],
|
||||
group_name: str = "weight_update_group",
|
||||
flush_cache: bool = True,
|
||||
):
|
||||
"""Update weights from distributed source."""
|
||||
obj = UpdateWeightsFromDistributedReqInput(
|
||||
name=name,
|
||||
dtype=dtype,
|
||||
shape=shape,
|
||||
names=names,
|
||||
dtypes=dtypes,
|
||||
shapes=shapes,
|
||||
group_name=group_name,
|
||||
flush_cache=flush_cache,
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(
|
||||
|
||||
@@ -752,9 +752,13 @@ class UpdateWeightFromDiskReqOutput:
|
||||
|
||||
@dataclass
|
||||
class UpdateWeightsFromDistributedReqInput:
|
||||
name: str
|
||||
dtype: str
|
||||
shape: List[int]
|
||||
names: List[str]
|
||||
dtypes: List[str]
|
||||
shapes: List[List[int]]
|
||||
# The group name
|
||||
group_name: str = "weight_update_group"
|
||||
# Whether to flush the cache after updating weights
|
||||
flush_cache: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -2303,8 +2303,9 @@ class Scheduler(
|
||||
"""Update the online model parameter."""
|
||||
success, message = self.tp_worker.update_weights_from_distributed(recv_req)
|
||||
if success:
|
||||
flush_cache_success = self.flush_cache()
|
||||
assert flush_cache_success, "Cache flush failed after updating weights"
|
||||
if recv_req.flush_cache:
|
||||
flush_cache_success = self.flush_cache()
|
||||
assert flush_cache_success, "Cache flush failed after updating weights"
|
||||
else:
|
||||
logger.error(message)
|
||||
return UpdateWeightsFromDistributedReqOutput(success, message)
|
||||
|
||||
@@ -259,7 +259,7 @@ class TpModelWorker:
|
||||
self, recv_req: UpdateWeightsFromDistributedReqInput
|
||||
):
|
||||
success, message = self.model_runner.update_weights_from_distributed(
|
||||
recv_req.name, recv_req.dtype, recv_req.shape
|
||||
recv_req.names, recv_req.dtypes, recv_req.shapes, recv_req.group_name
|
||||
)
|
||||
return success, message
|
||||
|
||||
|
||||
@@ -225,6 +225,7 @@ class ModelRunner:
|
||||
self.support_pp = (
|
||||
"pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
|
||||
)
|
||||
self._model_update_group = {}
|
||||
|
||||
def initialize(self, min_per_gpu_memory: float):
|
||||
server_args = self.server_args
|
||||
@@ -744,7 +745,7 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
try:
|
||||
self._model_update_group = init_custom_process_group(
|
||||
self._model_update_group[group_name] = init_custom_process_group(
|
||||
backend=backend,
|
||||
init_method=f"tcp://{master_address}:{master_port}",
|
||||
world_size=world_size,
|
||||
@@ -757,7 +758,7 @@ class ModelRunner:
|
||||
logger.error(message)
|
||||
return False, message
|
||||
|
||||
def update_weights_from_distributed(self, name, dtype, shape):
|
||||
def update_weights_from_distributed(self, names, dtypes, shapes, group_name):
|
||||
"""
|
||||
Update specific parameter in the model weights online
|
||||
through `_model_update_group` process group.
|
||||
@@ -767,19 +768,34 @@ class ModelRunner:
|
||||
dtype: the data type of the parameter to be updated.
|
||||
shape: the shape of the parameter to be updated.
|
||||
"""
|
||||
target_dtype = (
|
||||
dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
|
||||
|
||||
assert group_name in self._model_update_group, (
|
||||
f"Group {group_name} not in {list(self._model_update_group.keys())}. "
|
||||
"Please call `init_weights_update_group` first."
|
||||
)
|
||||
|
||||
assert (
|
||||
self._model_update_group is not None
|
||||
), "model update group must be initialized"
|
||||
|
||||
try:
|
||||
weights = torch.empty(shape, dtype=target_dtype, device=self.device)
|
||||
torch.distributed.broadcast(weights, src=0, group=self._model_update_group)
|
||||
self.model.load_weights([(name, weights)])
|
||||
return True, f"Succeeded to update parameter {name} online."
|
||||
weights = []
|
||||
handles = []
|
||||
for name, dtype, shape in zip(names, dtypes, shapes):
|
||||
target_dtype = (
|
||||
dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
|
||||
)
|
||||
weight = torch.empty(shape, dtype=target_dtype, device=self.device)
|
||||
handles.append(
|
||||
torch.distributed.broadcast(
|
||||
weight,
|
||||
src=0,
|
||||
group=self._model_update_group[group_name],
|
||||
async_op=True,
|
||||
)
|
||||
)
|
||||
weights.append((name, weight))
|
||||
for handle in handles:
|
||||
handle.wait()
|
||||
|
||||
self.model.load_weights(weights)
|
||||
return True, f"Succeeded to update parameter online."
|
||||
|
||||
except Exception as e:
|
||||
error_msg = (
|
||||
|
||||
Reference in New Issue
Block a user