Patch PyTorch's bug that cross-process tensor transfer will lead to wrong device (#4565)

This commit is contained in:
fzyzcjy
2025-03-27 15:22:33 +08:00
committed by GitHub
parent 6f5cc5eb05
commit 92bb49a7f9
5 changed files with 211 additions and 2 deletions

View File

@@ -19,6 +19,7 @@ import torch.distributed as dist
from torch.distributed.tensor import DeviceMesh, DTensor
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
from sglang.srt.patch_torch import monkey_patch_torch_reductions
from sglang.srt.server import Engine
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj
@@ -30,6 +31,7 @@ class VerlEngine:
nnodes: int = 1,
**kwargs,
):
monkey_patch_torch_reductions()
self._device_mesh_cpu = device_mesh_cpu
self._tp_rank = device_mesh_cpu.get_local_rank()
self._tp_size = device_mesh_cpu.size()