[RL] Add destroy process group api (#9979)
This commit is contained in:
@@ -47,6 +47,7 @@ from sglang.srt.managers.data_parallel_controller import (
|
||||
)
|
||||
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
|
||||
from sglang.srt.managers.io_struct import (
|
||||
DestroyWeightsUpdateGroupReqInput,
|
||||
EmbeddingReqInput,
|
||||
GenerateReqInput,
|
||||
GetWeightsByNameReqInput,
|
||||
@@ -433,6 +434,19 @@ class Engine(EngineBase):
|
||||
self.tokenizer_manager.init_weights_update_group(obj, None)
|
||||
)
|
||||
|
||||
def destroy_weights_update_group(
|
||||
self,
|
||||
group_name: str,
|
||||
):
|
||||
"""Destroy parameter update group."""
|
||||
obj = DestroyWeightsUpdateGroupReqInput(
|
||||
group_name=group_name,
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(
|
||||
self.tokenizer_manager.destroy_weights_update_group(obj, None)
|
||||
)
|
||||
|
||||
def update_weights_from_distributed(
|
||||
self,
|
||||
names: list[str],
|
||||
|
||||
@@ -70,6 +70,7 @@ from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
CloseSessionReqInput,
|
||||
ConfigureLoggingReq,
|
||||
DestroyWeightsUpdateGroupReqInput,
|
||||
EmbeddingReqInput,
|
||||
GenerateReqInput,
|
||||
GetWeightsByNameReqInput,
|
||||
@@ -729,6 +730,20 @@ async def init_weights_update_group(
|
||||
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
|
||||
@app.post("/destroy_weights_update_group")
|
||||
async def destroy_weights_update_group(
|
||||
obj: DestroyWeightsUpdateGroupReqInput, request: Request
|
||||
):
|
||||
"""Destroy the parameter update group."""
|
||||
success, message = (
|
||||
await _global_state.tokenizer_manager.destroy_weights_update_group(obj, request)
|
||||
)
|
||||
content = {"success": success, "message": message}
|
||||
return ORJSONResponse(
|
||||
content, status_code=200 if success else HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
|
||||
|
||||
@app.post("/update_weights_from_tensor")
|
||||
async def update_weights_from_tensor(
|
||||
obj: UpdateWeightsFromTensorReqInput, request: Request
|
||||
|
||||
@@ -1094,6 +1094,17 @@ class InitWeightsUpdateGroupReqOutput:
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class DestroyWeightsUpdateGroupReqInput:
|
||||
group_name: str = "weight_update_group"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DestroyWeightsUpdateGroupReqOutput:
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateWeightVersionReqInput:
|
||||
# The new weight version
|
||||
|
||||
@@ -72,6 +72,7 @@ from sglang.srt.managers.io_struct import (
|
||||
ClearHiCacheReqInput,
|
||||
ClearHiCacheReqOutput,
|
||||
CloseSessionReqInput,
|
||||
DestroyWeightsUpdateGroupReqInput,
|
||||
ExpertDistributionReq,
|
||||
ExpertDistributionReqOutput,
|
||||
FlushCacheReqInput,
|
||||
@@ -566,6 +567,7 @@ class Scheduler(
|
||||
(CloseSessionReqInput, self.close_session),
|
||||
(UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
|
||||
(InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
|
||||
(DestroyWeightsUpdateGroupReqInput, self.destroy_weights_update_group),
|
||||
(
|
||||
InitWeightsSendGroupForRemoteInstanceReqInput,
|
||||
self.init_weights_send_group_for_remote_instance,
|
||||
|
||||
@@ -5,6 +5,8 @@ import torch
|
||||
|
||||
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
|
||||
from sglang.srt.managers.io_struct import (
|
||||
DestroyWeightsUpdateGroupReqInput,
|
||||
DestroyWeightsUpdateGroupReqOutput,
|
||||
GetWeightsByNameReqInput,
|
||||
GetWeightsByNameReqOutput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
@@ -41,6 +43,11 @@ class SchedulerUpdateWeightsMixin:
|
||||
success, message = self.tp_worker.init_weights_update_group(recv_req)
|
||||
return InitWeightsUpdateGroupReqOutput(success, message)
|
||||
|
||||
def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
|
||||
"""Destroy the online model parameter update group."""
|
||||
success, message = self.tp_worker.destroy_weights_update_group(recv_req)
|
||||
return DestroyWeightsUpdateGroupReqOutput(success, message)
|
||||
|
||||
def update_weights_from_distributed(
|
||||
self,
|
||||
recv_req: UpdateWeightsFromDistributedReqInput,
|
||||
|
||||
@@ -24,6 +24,8 @@ import zmq
|
||||
from sglang.srt.managers.io_struct import (
|
||||
ClearHiCacheReqInput,
|
||||
ClearHiCacheReqOutput,
|
||||
DestroyWeightsUpdateGroupReqInput,
|
||||
DestroyWeightsUpdateGroupReqOutput,
|
||||
ExpertDistributionReq,
|
||||
ExpertDistributionReqOutput,
|
||||
FlushCacheReqInput,
|
||||
@@ -149,6 +151,9 @@ class TokenizerCommunicatorMixin:
|
||||
self.init_weights_update_group_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.destroy_weights_update_group_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.update_weights_from_distributed_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
@@ -207,6 +212,10 @@ class TokenizerCommunicatorMixin:
|
||||
InitWeightsUpdateGroupReqOutput,
|
||||
self.init_weights_update_group_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
DestroyWeightsUpdateGroupReqOutput,
|
||||
self.destroy_weights_update_group_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
UpdateWeightsFromDistributedReqOutput,
|
||||
self.update_weights_from_distributed_communicator.handle_recv,
|
||||
@@ -345,6 +354,18 @@ class TokenizerCommunicatorMixin:
|
||||
result = (await self.init_weights_update_group_communicator(obj))[0]
|
||||
return result.success, result.message
|
||||
|
||||
async def destroy_weights_update_group(
|
||||
self,
|
||||
obj: DestroyWeightsUpdateGroupReqInput,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
) -> Tuple[bool, str]:
|
||||
self.auto_create_handle_loop()
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be 1 for destroy parameter update group"
|
||||
result = (await self.destroy_weights_update_group_communicator(obj))[0]
|
||||
return result.success, result.message
|
||||
|
||||
async def update_weights_from_distributed(
|
||||
self: TokenizerManager,
|
||||
obj: UpdateWeightsFromDistributedReqInput,
|
||||
|
||||
@@ -29,6 +29,7 @@ from sglang.srt.hf_transformers_utils import (
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.io_struct import (
|
||||
DestroyWeightsUpdateGroupReqInput,
|
||||
GetWeightsByNameReqInput,
|
||||
InitWeightsSendGroupForRemoteInstanceReqInput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
@@ -304,6 +305,12 @@ class TpModelWorker:
|
||||
)
|
||||
return success, message
|
||||
|
||||
def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
|
||||
success, message = self.model_runner.destroy_weights_update_group(
|
||||
recv_req.group_name,
|
||||
)
|
||||
return success, message
|
||||
|
||||
def init_weights_send_group_for_remote_instance(
|
||||
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
|
||||
):
|
||||
|
||||
@@ -25,6 +25,7 @@ import psutil
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.io_struct import (
|
||||
DestroyWeightsUpdateGroupReqInput,
|
||||
GetWeightsByNameReqInput,
|
||||
InitWeightsSendGroupForRemoteInstanceReqInput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
@@ -278,6 +279,10 @@ class TpModelWorkerClient:
|
||||
success, message = self.worker.init_weights_update_group(recv_req)
|
||||
return success, message
|
||||
|
||||
def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
|
||||
success, message = self.worker.destroy_weights_update_group(recv_req)
|
||||
return success, message
|
||||
|
||||
def init_weights_send_group_for_remote_instance(
|
||||
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
|
||||
):
|
||||
|
||||
@@ -1025,6 +1025,19 @@ class ModelRunner:
|
||||
logger.error(message)
|
||||
return False, message
|
||||
|
||||
def destroy_weights_update_group(self, group_name):
|
||||
try:
|
||||
if group_name in self._model_update_group:
|
||||
pg = self._model_update_group.pop(group_name)
|
||||
torch.distributed.destroy_process_group(pg)
|
||||
return True, "Succeeded to destroy custom process group."
|
||||
else:
|
||||
return False, "The group to be destroyed does not exist."
|
||||
except Exception as e:
|
||||
message = f"Failed to destroy custom process group: {e}."
|
||||
logger.error(message)
|
||||
return False, message
|
||||
|
||||
def update_weights_from_distributed(self, names, dtypes, shapes, group_name):
|
||||
"""
|
||||
Update specific parameter in the model weights online
|
||||
|
||||
Reference in New Issue
Block a user