[RL] support weight update with DP attention (#11669)
This commit is contained in:
@@ -146,6 +146,13 @@ class _Communicator(Generic[T]):
|
|||||||
if len(self._result_values) == self._fan_out:
|
if len(self._result_values) == self._fan_out:
|
||||||
self._result_event.set()
|
self._result_event.set()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def merge_results(results):
|
||||||
|
all_success = all([r.success for r in results])
|
||||||
|
all_message = [r.message for r in results]
|
||||||
|
all_message = " | ".join(all_message)
|
||||||
|
return all_success, all_message
|
||||||
|
|
||||||
|
|
||||||
class TokenizerCommunicatorMixin:
|
class TokenizerCommunicatorMixin:
|
||||||
"""Mixin class for TokenizerManager to handle communication with the scheduler."""
|
"""Mixin class for TokenizerManager to handle communication with the scheduler."""
|
||||||
@@ -358,10 +365,11 @@ class TokenizerCommunicatorMixin:
|
|||||||
) -> Tuple[bool, str]:
|
) -> Tuple[bool, str]:
|
||||||
self.auto_create_handle_loop()
|
self.auto_create_handle_loop()
|
||||||
assert (
|
assert (
|
||||||
self.server_args.dp_size == 1
|
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
|
||||||
), "dp_size must be 1 for init parameter update group"
|
), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
|
||||||
result = (await self.init_weights_update_group_communicator(obj))[0]
|
|
||||||
return result.success, result.message
|
results = await self.init_weights_update_group_communicator(obj)
|
||||||
|
return _Communicator.merge_results(results)
|
||||||
|
|
||||||
async def destroy_weights_update_group(
|
async def destroy_weights_update_group(
|
||||||
self,
|
self,
|
||||||
@@ -370,10 +378,11 @@ class TokenizerCommunicatorMixin:
|
|||||||
) -> Tuple[bool, str]:
|
) -> Tuple[bool, str]:
|
||||||
self.auto_create_handle_loop()
|
self.auto_create_handle_loop()
|
||||||
assert (
|
assert (
|
||||||
self.server_args.dp_size == 1
|
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
|
||||||
), "dp_size must be 1 for destroy parameter update group"
|
), "dp_size must be 1 or dp attention must be enabled for destroy parameter update group"
|
||||||
result = (await self.destroy_weights_update_group_communicator(obj))[0]
|
|
||||||
return result.success, result.message
|
results = await self.destroy_weights_update_group_communicator(obj)
|
||||||
|
return _Communicator.merge_results(results)
|
||||||
|
|
||||||
async def update_weights_from_distributed(
|
async def update_weights_from_distributed(
|
||||||
self: TokenizerManager,
|
self: TokenizerManager,
|
||||||
@@ -391,8 +400,8 @@ class TokenizerCommunicatorMixin:
|
|||||||
# This means that weight sync
|
# This means that weight sync
|
||||||
# cannot run while requests are in progress.
|
# cannot run while requests are in progress.
|
||||||
async with self.model_update_lock.writer_lock:
|
async with self.model_update_lock.writer_lock:
|
||||||
result = (await self.update_weights_from_distributed_communicator(obj))[0]
|
results = await self.update_weights_from_distributed_communicator(obj)
|
||||||
return result.success, result.message
|
return _Communicator.merge_results(results)
|
||||||
|
|
||||||
async def init_weights_send_group_for_remote_instance(
|
async def init_weights_send_group_for_remote_instance(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Reference in New Issue
Block a user