[RL] support update_weights_from_distributed with different group and multiple weights (#7292)
This commit is contained in:
@@ -752,9 +752,13 @@ class UpdateWeightFromDiskReqOutput:
|
||||
|
||||
@dataclass
|
||||
class UpdateWeightsFromDistributedReqInput:
|
||||
name: str
|
||||
dtype: str
|
||||
shape: List[int]
|
||||
names: List[str]
|
||||
dtypes: List[str]
|
||||
shapes: List[List[int]]
|
||||
# The group name
|
||||
group_name: str = "weight_update_group"
|
||||
# Whether to flush the cache after updating weights
|
||||
flush_cache: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -2303,8 +2303,9 @@ class Scheduler(
|
||||
"""Update the online model parameter."""
|
||||
success, message = self.tp_worker.update_weights_from_distributed(recv_req)
|
||||
if success:
|
||||
flush_cache_success = self.flush_cache()
|
||||
assert flush_cache_success, "Cache flush failed after updating weights"
|
||||
if recv_req.flush_cache:
|
||||
flush_cache_success = self.flush_cache()
|
||||
assert flush_cache_success, "Cache flush failed after updating weights"
|
||||
else:
|
||||
logger.error(message)
|
||||
return UpdateWeightsFromDistributedReqOutput(success, message)
|
||||
|
||||
@@ -259,7 +259,7 @@ class TpModelWorker:
|
||||
self, recv_req: UpdateWeightsFromDistributedReqInput
|
||||
):
|
||||
success, message = self.model_runner.update_weights_from_distributed(
|
||||
recv_req.name, recv_req.dtype, recv_req.shape
|
||||
recv_req.names, recv_req.dtypes, recv_req.shapes, recv_req.group_name
|
||||
)
|
||||
return success, message
|
||||
|
||||
|
||||
Reference in New Issue
Block a user