[RL] Add destroy process group api (#9979)
This commit is contained in:
@@ -47,6 +47,7 @@ from sglang.srt.managers.data_parallel_controller import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
|
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
|
DestroyWeightsUpdateGroupReqInput,
|
||||||
EmbeddingReqInput,
|
EmbeddingReqInput,
|
||||||
GenerateReqInput,
|
GenerateReqInput,
|
||||||
GetWeightsByNameReqInput,
|
GetWeightsByNameReqInput,
|
||||||
@@ -433,6 +434,19 @@ class Engine(EngineBase):
|
|||||||
self.tokenizer_manager.init_weights_update_group(obj, None)
|
self.tokenizer_manager.init_weights_update_group(obj, None)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def destroy_weights_update_group(
|
||||||
|
self,
|
||||||
|
group_name: str,
|
||||||
|
):
|
||||||
|
"""Destroy parameter update group."""
|
||||||
|
obj = DestroyWeightsUpdateGroupReqInput(
|
||||||
|
group_name=group_name,
|
||||||
|
)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return loop.run_until_complete(
|
||||||
|
self.tokenizer_manager.destroy_weights_update_group(obj, None)
|
||||||
|
)
|
||||||
|
|
||||||
def update_weights_from_distributed(
|
def update_weights_from_distributed(
|
||||||
self,
|
self,
|
||||||
names: list[str],
|
names: list[str],
|
||||||
|
|||||||
@@ -70,6 +70,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
AbortReq,
|
AbortReq,
|
||||||
CloseSessionReqInput,
|
CloseSessionReqInput,
|
||||||
ConfigureLoggingReq,
|
ConfigureLoggingReq,
|
||||||
|
DestroyWeightsUpdateGroupReqInput,
|
||||||
EmbeddingReqInput,
|
EmbeddingReqInput,
|
||||||
GenerateReqInput,
|
GenerateReqInput,
|
||||||
GetWeightsByNameReqInput,
|
GetWeightsByNameReqInput,
|
||||||
@@ -729,6 +730,20 @@ async def init_weights_update_group(
|
|||||||
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/destroy_weights_update_group")
|
||||||
|
async def destroy_weights_update_group(
|
||||||
|
obj: DestroyWeightsUpdateGroupReqInput, request: Request
|
||||||
|
):
|
||||||
|
"""Destroy the parameter update group."""
|
||||||
|
success, message = (
|
||||||
|
await _global_state.tokenizer_manager.destroy_weights_update_group(obj, request)
|
||||||
|
)
|
||||||
|
content = {"success": success, "message": message}
|
||||||
|
return ORJSONResponse(
|
||||||
|
content, status_code=200 if success else HTTPStatus.BAD_REQUEST
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/update_weights_from_tensor")
|
@app.post("/update_weights_from_tensor")
|
||||||
async def update_weights_from_tensor(
|
async def update_weights_from_tensor(
|
||||||
obj: UpdateWeightsFromTensorReqInput, request: Request
|
obj: UpdateWeightsFromTensorReqInput, request: Request
|
||||||
|
|||||||
@@ -1094,6 +1094,17 @@ class InitWeightsUpdateGroupReqOutput:
|
|||||||
message: str
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DestroyWeightsUpdateGroupReqInput:
|
||||||
|
group_name: str = "weight_update_group"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DestroyWeightsUpdateGroupReqOutput:
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UpdateWeightVersionReqInput:
|
class UpdateWeightVersionReqInput:
|
||||||
# The new weight version
|
# The new weight version
|
||||||
|
|||||||
@@ -72,6 +72,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
ClearHiCacheReqInput,
|
ClearHiCacheReqInput,
|
||||||
ClearHiCacheReqOutput,
|
ClearHiCacheReqOutput,
|
||||||
CloseSessionReqInput,
|
CloseSessionReqInput,
|
||||||
|
DestroyWeightsUpdateGroupReqInput,
|
||||||
ExpertDistributionReq,
|
ExpertDistributionReq,
|
||||||
ExpertDistributionReqOutput,
|
ExpertDistributionReqOutput,
|
||||||
FlushCacheReqInput,
|
FlushCacheReqInput,
|
||||||
@@ -566,6 +567,7 @@ class Scheduler(
|
|||||||
(CloseSessionReqInput, self.close_session),
|
(CloseSessionReqInput, self.close_session),
|
||||||
(UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
|
(UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
|
||||||
(InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
|
(InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
|
||||||
|
(DestroyWeightsUpdateGroupReqInput, self.destroy_weights_update_group),
|
||||||
(
|
(
|
||||||
InitWeightsSendGroupForRemoteInstanceReqInput,
|
InitWeightsSendGroupForRemoteInstanceReqInput,
|
||||||
self.init_weights_send_group_for_remote_instance,
|
self.init_weights_send_group_for_remote_instance,
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ import torch
|
|||||||
|
|
||||||
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
|
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
|
DestroyWeightsUpdateGroupReqInput,
|
||||||
|
DestroyWeightsUpdateGroupReqOutput,
|
||||||
GetWeightsByNameReqInput,
|
GetWeightsByNameReqInput,
|
||||||
GetWeightsByNameReqOutput,
|
GetWeightsByNameReqOutput,
|
||||||
InitWeightsUpdateGroupReqInput,
|
InitWeightsUpdateGroupReqInput,
|
||||||
@@ -41,6 +43,11 @@ class SchedulerUpdateWeightsMixin:
|
|||||||
success, message = self.tp_worker.init_weights_update_group(recv_req)
|
success, message = self.tp_worker.init_weights_update_group(recv_req)
|
||||||
return InitWeightsUpdateGroupReqOutput(success, message)
|
return InitWeightsUpdateGroupReqOutput(success, message)
|
||||||
|
|
||||||
|
def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
|
||||||
|
"""Destroy the online model parameter update group."""
|
||||||
|
success, message = self.tp_worker.destroy_weights_update_group(recv_req)
|
||||||
|
return DestroyWeightsUpdateGroupReqOutput(success, message)
|
||||||
|
|
||||||
def update_weights_from_distributed(
|
def update_weights_from_distributed(
|
||||||
self,
|
self,
|
||||||
recv_req: UpdateWeightsFromDistributedReqInput,
|
recv_req: UpdateWeightsFromDistributedReqInput,
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ import zmq
|
|||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
ClearHiCacheReqInput,
|
ClearHiCacheReqInput,
|
||||||
ClearHiCacheReqOutput,
|
ClearHiCacheReqOutput,
|
||||||
|
DestroyWeightsUpdateGroupReqInput,
|
||||||
|
DestroyWeightsUpdateGroupReqOutput,
|
||||||
ExpertDistributionReq,
|
ExpertDistributionReq,
|
||||||
ExpertDistributionReqOutput,
|
ExpertDistributionReqOutput,
|
||||||
FlushCacheReqInput,
|
FlushCacheReqInput,
|
||||||
@@ -149,6 +151,9 @@ class TokenizerCommunicatorMixin:
|
|||||||
self.init_weights_update_group_communicator = _Communicator(
|
self.init_weights_update_group_communicator = _Communicator(
|
||||||
self.send_to_scheduler, server_args.dp_size
|
self.send_to_scheduler, server_args.dp_size
|
||||||
)
|
)
|
||||||
|
self.destroy_weights_update_group_communicator = _Communicator(
|
||||||
|
self.send_to_scheduler, server_args.dp_size
|
||||||
|
)
|
||||||
self.update_weights_from_distributed_communicator = _Communicator(
|
self.update_weights_from_distributed_communicator = _Communicator(
|
||||||
self.send_to_scheduler, server_args.dp_size
|
self.send_to_scheduler, server_args.dp_size
|
||||||
)
|
)
|
||||||
@@ -207,6 +212,10 @@ class TokenizerCommunicatorMixin:
|
|||||||
InitWeightsUpdateGroupReqOutput,
|
InitWeightsUpdateGroupReqOutput,
|
||||||
self.init_weights_update_group_communicator.handle_recv,
|
self.init_weights_update_group_communicator.handle_recv,
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
DestroyWeightsUpdateGroupReqOutput,
|
||||||
|
self.destroy_weights_update_group_communicator.handle_recv,
|
||||||
|
),
|
||||||
(
|
(
|
||||||
UpdateWeightsFromDistributedReqOutput,
|
UpdateWeightsFromDistributedReqOutput,
|
||||||
self.update_weights_from_distributed_communicator.handle_recv,
|
self.update_weights_from_distributed_communicator.handle_recv,
|
||||||
@@ -345,6 +354,18 @@ class TokenizerCommunicatorMixin:
|
|||||||
result = (await self.init_weights_update_group_communicator(obj))[0]
|
result = (await self.init_weights_update_group_communicator(obj))[0]
|
||||||
return result.success, result.message
|
return result.success, result.message
|
||||||
|
|
||||||
|
async def destroy_weights_update_group(
|
||||||
|
self,
|
||||||
|
obj: DestroyWeightsUpdateGroupReqInput,
|
||||||
|
request: Optional[fastapi.Request] = None,
|
||||||
|
) -> Tuple[bool, str]:
|
||||||
|
self.auto_create_handle_loop()
|
||||||
|
assert (
|
||||||
|
self.server_args.dp_size == 1
|
||||||
|
), "dp_size must be 1 for destroy parameter update group"
|
||||||
|
result = (await self.destroy_weights_update_group_communicator(obj))[0]
|
||||||
|
return result.success, result.message
|
||||||
|
|
||||||
async def update_weights_from_distributed(
|
async def update_weights_from_distributed(
|
||||||
self: TokenizerManager,
|
self: TokenizerManager,
|
||||||
obj: UpdateWeightsFromDistributedReqInput,
|
obj: UpdateWeightsFromDistributedReqInput,
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from sglang.srt.hf_transformers_utils import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
|
DestroyWeightsUpdateGroupReqInput,
|
||||||
GetWeightsByNameReqInput,
|
GetWeightsByNameReqInput,
|
||||||
InitWeightsSendGroupForRemoteInstanceReqInput,
|
InitWeightsSendGroupForRemoteInstanceReqInput,
|
||||||
InitWeightsUpdateGroupReqInput,
|
InitWeightsUpdateGroupReqInput,
|
||||||
@@ -304,6 +305,12 @@ class TpModelWorker:
|
|||||||
)
|
)
|
||||||
return success, message
|
return success, message
|
||||||
|
|
||||||
|
def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
|
||||||
|
success, message = self.model_runner.destroy_weights_update_group(
|
||||||
|
recv_req.group_name,
|
||||||
|
)
|
||||||
|
return success, message
|
||||||
|
|
||||||
def init_weights_send_group_for_remote_instance(
|
def init_weights_send_group_for_remote_instance(
|
||||||
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
|
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ import psutil
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
|
DestroyWeightsUpdateGroupReqInput,
|
||||||
GetWeightsByNameReqInput,
|
GetWeightsByNameReqInput,
|
||||||
InitWeightsSendGroupForRemoteInstanceReqInput,
|
InitWeightsSendGroupForRemoteInstanceReqInput,
|
||||||
InitWeightsUpdateGroupReqInput,
|
InitWeightsUpdateGroupReqInput,
|
||||||
@@ -278,6 +279,10 @@ class TpModelWorkerClient:
|
|||||||
success, message = self.worker.init_weights_update_group(recv_req)
|
success, message = self.worker.init_weights_update_group(recv_req)
|
||||||
return success, message
|
return success, message
|
||||||
|
|
||||||
|
def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
|
||||||
|
success, message = self.worker.destroy_weights_update_group(recv_req)
|
||||||
|
return success, message
|
||||||
|
|
||||||
def init_weights_send_group_for_remote_instance(
|
def init_weights_send_group_for_remote_instance(
|
||||||
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
|
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -1025,6 +1025,19 @@ class ModelRunner:
|
|||||||
logger.error(message)
|
logger.error(message)
|
||||||
return False, message
|
return False, message
|
||||||
|
|
||||||
|
def destroy_weights_update_group(self, group_name):
|
||||||
|
try:
|
||||||
|
if group_name in self._model_update_group:
|
||||||
|
pg = self._model_update_group.pop(group_name)
|
||||||
|
torch.distributed.destroy_process_group(pg)
|
||||||
|
return True, "Succeeded to destroy custom process group."
|
||||||
|
else:
|
||||||
|
return False, "The group to be destroyed does not exist."
|
||||||
|
except Exception as e:
|
||||||
|
message = f"Failed to destroy custom process group: {e}."
|
||||||
|
logger.error(message)
|
||||||
|
return False, message
|
||||||
|
|
||||||
def update_weights_from_distributed(self, names, dtypes, shapes, group_name):
|
def update_weights_from_distributed(self, names, dtypes, shapes, group_name):
|
||||||
"""
|
"""
|
||||||
Update specific parameter in the model weights online
|
Update specific parameter in the model weights online
|
||||||
|
|||||||
@@ -344,6 +344,20 @@ def init_process_sgl(
|
|||||||
)
|
)
|
||||||
param_queue.put((f"sgl_dp_{rank}_base_params", base_params))
|
param_queue.put((f"sgl_dp_{rank}_base_params", base_params))
|
||||||
|
|
||||||
|
if backend == "Engine":
|
||||||
|
success, _ = engine.destroy_weights_update_group(
|
||||||
|
group_name="test_parameter_update_group",
|
||||||
|
)
|
||||||
|
assert success is True
|
||||||
|
else:
|
||||||
|
response = requests.post(
|
||||||
|
f"{url}/destroy_weights_update_group",
|
||||||
|
json={
|
||||||
|
"group_name": "test_parameter_update_group",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
# Shutdown the engine or terminate the server process.
|
# Shutdown the engine or terminate the server process.
|
||||||
if backend == "Engine":
|
if backend == "Engine":
|
||||||
engine.shutdown()
|
engine.shutdown()
|
||||||
|
|||||||
Reference in New Issue
Block a user