Support server based rollout in Verlengine (#4848)

Co-authored-by: Jin Pan <jpan236@wisc.edu>
Co-authored-by: Chayenne <zhaochen20@outlook.com>
Co-authored-by: Jinn <47354855+jhinpan@users.noreply.github.com>
This commit is contained in:
tianlian yi
2025-04-13 01:07:52 +08:00
committed by GitHub
parent 3e4794aad8
commit bc92107b03
10 changed files with 720 additions and 29 deletions

View File

@@ -12,16 +12,18 @@
# limitations under the License.
# ==============================================================================
import os
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Literal, Optional, Tuple, Union
import torch
import torch.distributed as dist
from PIL.Image import Image
from torch.distributed.tensor import DeviceMesh, DTensor
from sglang.srt.entrypoints.http_server_engine import HttpServerEngineAdapter
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
from sglang.srt.patch_torch import monkey_patch_torch_reductions
from sglang.srt.server import Engine
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj
@@ -30,6 +32,7 @@ class VerlEngine:
self,
device_mesh_cpu: DeviceMesh,
nnodes: int = 1,
backend: Literal["engine", "server"] = "engine",
**kwargs,
):
monkey_patch_torch_reductions()
@@ -40,13 +43,25 @@ class VerlEngine:
node_rank = self._tp_rank // tp_size_per_node
first_rank_in_node = self._tp_rank % tp_size_per_node == 0
if first_rank_in_node:
os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0"
self._engine = Engine(
**kwargs, tp_size=self._tp_size, node_rank=node_rank, nnodes=nnodes
)
# Common engine keyword arguments
engine_kwargs = dict(
**kwargs, tp_size=self._tp_size, node_rank=node_rank, nnodes=nnodes
)
if backend == "engine":
if first_rank_in_node:
os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0"
self._engine = Engine(**engine_kwargs)
else:
self._engine = None
elif backend == "server":
if self._tp_rank == 0:
self._engine = HttpServerEngineAdapter(**engine_kwargs)
else:
self._engine = None
else:
self._engine = None
raise ValueError(f"Unsupported backend: {backend}")
dist.barrier(group=self._device_mesh_cpu.get_group())