Support server based rollout in Verlengine (#4848)
Co-authored-by: Jin Pan <jpan236@wisc.edu> Co-authored-by: Chayenne <zhaochen20@outlook.com> Co-authored-by: Jinn <47354855+jhinpan@users.noreply.github.com>
This commit is contained in:
@@ -12,16 +12,18 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from PIL.Image import Image
|
||||
from torch.distributed.tensor import DeviceMesh, DTensor
|
||||
|
||||
from sglang.srt.entrypoints.http_server_engine import HttpServerEngineAdapter
|
||||
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
|
||||
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
||||
from sglang.srt.server import Engine
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj
|
||||
|
||||
|
||||
@@ -30,6 +32,7 @@ class VerlEngine:
|
||||
self,
|
||||
device_mesh_cpu: DeviceMesh,
|
||||
nnodes: int = 1,
|
||||
backend: Literal["engine", "server"] = "engine",
|
||||
**kwargs,
|
||||
):
|
||||
monkey_patch_torch_reductions()
|
||||
@@ -40,13 +43,25 @@ class VerlEngine:
|
||||
node_rank = self._tp_rank // tp_size_per_node
|
||||
first_rank_in_node = self._tp_rank % tp_size_per_node == 0
|
||||
|
||||
if first_rank_in_node:
|
||||
os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0"
|
||||
self._engine = Engine(
|
||||
**kwargs, tp_size=self._tp_size, node_rank=node_rank, nnodes=nnodes
|
||||
)
|
||||
# Common engine keyword arguments
|
||||
engine_kwargs = dict(
|
||||
**kwargs, tp_size=self._tp_size, node_rank=node_rank, nnodes=nnodes
|
||||
)
|
||||
|
||||
if backend == "engine":
|
||||
if first_rank_in_node:
|
||||
os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0"
|
||||
self._engine = Engine(**engine_kwargs)
|
||||
else:
|
||||
self._engine = None
|
||||
|
||||
elif backend == "server":
|
||||
if self._tp_rank == 0:
|
||||
self._engine = HttpServerEngineAdapter(**engine_kwargs)
|
||||
else:
|
||||
self._engine = None
|
||||
else:
|
||||
self._engine = None
|
||||
raise ValueError(f"Unsupported backend: {backend}")
|
||||
|
||||
dist.barrier(group=self._device_mesh_cpu.get_group())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user