diff --git a/python/sglang/srt/entrypoints/verl_engine.py b/python/sglang/srt/entrypoints/verl_engine.py index 11cb4a246..ab1ce8e16 100644 --- a/python/sglang/srt/entrypoints/verl_engine.py +++ b/python/sglang/srt/entrypoints/verl_engine.py @@ -37,6 +37,7 @@ class VerlEngine: monkey_patch_torch_reductions() self._device_mesh_cpu = device_mesh_cpu self._tp_rank = device_mesh_cpu.get_local_rank() + self._rank = device_mesh_cpu.get_rank() self._tp_size = device_mesh_cpu.size() tp_size_per_node = self._tp_size // nnodes node_rank = self._tp_rank // tp_size_per_node @@ -114,7 +115,7 @@ class VerlEngine: # Most naive implementation, can extract tensor and send via gloo if too slow [output] = broadcast_pyobj( data=[output], - rank=self._tp_rank, + rank=self._rank, dist_group=self._device_mesh_cpu.get_group(), src=self._device_mesh_cpu.mesh[0].item(), force_cpu_device=False, @@ -157,7 +158,7 @@ class VerlEngine: ) if self._tp_rank == 0: - self._engine.tokenizer_manager.flush_cache() + self._engine.flush_cache() def release_memory_occupation(self): if self._tp_rank == 0: diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 3c68e6057..f581ffd55 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -897,7 +897,10 @@ def broadcast_pyobj( src: int = 0, force_cpu_device: bool = True, ): - """Broadcast inputs from rank=0 to all other ranks with torch.dist backend.""" + """Broadcast inputs from src rank to all other ranks with torch.dist backend. + The `rank` here refer to the source rank on global process group (regardless + of dist_group argument). + """ device = torch.device( "cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu" )