[RL] support update_weights_from_distributed with different group and multiple weights (#7292)

This commit is contained in:
Zilin Zhu
2025-07-03 10:29:11 +08:00
committed by GitHub
parent 09e699bba4
commit 0626f678de
6 changed files with 73 additions and 38 deletions

View File

@@ -752,9 +752,13 @@ class UpdateWeightFromDiskReqOutput:
@dataclass
class UpdateWeightsFromDistributedReqInput:
name: str
dtype: str
shape: List[int]
names: List[str]
dtypes: List[str]
shapes: List[List[int]]
# The group name
group_name: str = "weight_update_group"
# Whether to flush the cache after updating weights
flush_cache: bool = True
@dataclass

View File

@@ -2303,8 +2303,9 @@ class Scheduler(
"""Update the online model parameter."""
success, message = self.tp_worker.update_weights_from_distributed(recv_req)
if success:
flush_cache_success = self.flush_cache()
assert flush_cache_success, "Cache flush failed after updating weights"
if recv_req.flush_cache:
flush_cache_success = self.flush_cache()
assert flush_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
return UpdateWeightsFromDistributedReqOutput(success, message)

View File

@@ -259,7 +259,7 @@ class TpModelWorker:
self, recv_req: UpdateWeightsFromDistributedReqInput
):
success, message = self.model_runner.update_weights_from_distributed(
recv_req.name, recv_req.dtype, recv_req.shape
recv_req.names, recv_req.dtypes, recv_req.shapes, recv_req.group_name
)
return success, message