Patch PyTorch's bug that cross-process tensor transfer will lead to wrong device (#4565)
This commit is contained in:
@@ -19,6 +19,7 @@ import torch.distributed as dist
|
||||
from torch.distributed.tensor import DeviceMesh, DTensor
|
||||
|
||||
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
|
||||
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
||||
from sglang.srt.server import Engine
|
||||
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj
|
||||
|
||||
@@ -30,6 +31,7 @@ class VerlEngine:
|
||||
nnodes: int = 1,
|
||||
**kwargs,
|
||||
):
|
||||
monkey_patch_torch_reductions()
|
||||
self._device_mesh_cpu = device_mesh_cpu
|
||||
self._tp_rank = device_mesh_cpu.get_local_rank()
|
||||
self._tp_size = device_mesh_cpu.size()
|
||||
|
||||
Reference in New Issue
Block a user