fix: fix broadcast_pyobj breaking VerlEngine (#5997)
This commit is contained in:
@@ -37,6 +37,7 @@ class VerlEngine:
|
|||||||
monkey_patch_torch_reductions()
|
monkey_patch_torch_reductions()
|
||||||
self._device_mesh_cpu = device_mesh_cpu
|
self._device_mesh_cpu = device_mesh_cpu
|
||||||
self._tp_rank = device_mesh_cpu.get_local_rank()
|
self._tp_rank = device_mesh_cpu.get_local_rank()
|
||||||
|
self._rank = device_mesh_cpu.get_rank()
|
||||||
self._tp_size = device_mesh_cpu.size()
|
self._tp_size = device_mesh_cpu.size()
|
||||||
tp_size_per_node = self._tp_size // nnodes
|
tp_size_per_node = self._tp_size // nnodes
|
||||||
node_rank = self._tp_rank // tp_size_per_node
|
node_rank = self._tp_rank // tp_size_per_node
|
||||||
@@ -114,7 +115,7 @@ class VerlEngine:
|
|||||||
# Most naive implementation, can extract tensor and send via gloo if too slow
|
# Most naive implementation, can extract tensor and send via gloo if too slow
|
||||||
[output] = broadcast_pyobj(
|
[output] = broadcast_pyobj(
|
||||||
data=[output],
|
data=[output],
|
||||||
rank=self._tp_rank,
|
rank=self._rank,
|
||||||
dist_group=self._device_mesh_cpu.get_group(),
|
dist_group=self._device_mesh_cpu.get_group(),
|
||||||
src=self._device_mesh_cpu.mesh[0].item(),
|
src=self._device_mesh_cpu.mesh[0].item(),
|
||||||
force_cpu_device=False,
|
force_cpu_device=False,
|
||||||
@@ -157,7 +158,7 @@ class VerlEngine:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self._tp_rank == 0:
|
if self._tp_rank == 0:
|
||||||
self._engine.tokenizer_manager.flush_cache()
|
self._engine.flush_cache()
|
||||||
|
|
||||||
def release_memory_occupation(self):
|
def release_memory_occupation(self):
|
||||||
if self._tp_rank == 0:
|
if self._tp_rank == 0:
|
||||||
|
|||||||
@@ -897,7 +897,10 @@ def broadcast_pyobj(
|
|||||||
src: int = 0,
|
src: int = 0,
|
||||||
force_cpu_device: bool = True,
|
force_cpu_device: bool = True,
|
||||||
):
|
):
|
||||||
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
|
"""Broadcast inputs from src rank to all other ranks with torch.dist backend.
|
||||||
|
The `rank` here refer to the source rank on global process group (regardless
|
||||||
|
of dist_group argument).
|
||||||
|
"""
|
||||||
device = torch.device(
|
device = torch.device(
|
||||||
"cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu"
|
"cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu"
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user