diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 509c71531..5ab2e3758 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -699,14 +699,14 @@ class GroupCoordinator: ) # Serialize object to tensor and get the size as well - object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8).cuda( - device=torch.cuda.current_device() + object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8).to( + device=self.device ) size_tensor = torch.tensor( [object_tensor.numel()], dtype=torch.long, - device=torch.cuda.current_device(), + device=self.device, ) # Send object size @@ -731,9 +731,7 @@ class GroupCoordinator: src != self.rank_in_group ), "Invalid source rank. Source rank is the same as the current rank." - size_tensor = torch.empty( - 1, dtype=torch.long, device=torch.cuda.current_device() - ) + size_tensor = torch.empty(1, dtype=torch.long, device=self.device) # Receive object size rank_size = torch.distributed.recv( @@ -744,7 +742,7 @@ class GroupCoordinator: object_tensor = torch.empty( # type: ignore[call-overload] size_tensor.item(), # type: ignore[arg-type] dtype=torch.uint8, - device=torch.cuda.current_device(), + device=self.device, ) rank_object = torch.distributed.recv( diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index afb4b870d..9a1654343 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -962,6 +962,7 @@ class Scheduler( self.world_group.device_group, self.pp_rank * self.tp_size + dp_offset, (self.pp_rank + 1) * self.tp_size + dp_offset, + device=self.device, ) # send out proxy tensors to the next stage @@ -1010,6 +1011,7 @@ class Scheduler( self.world_group.device_group, (self.pp_rank - 1) * self.tp_size + dp_offset, self.pp_rank * self.tp_size + dp_offset, + device=self.device, ) else: recv_reqs = None @@ -1040,6 +1042,7 @@ class Scheduler( self.attn_tp_group.rank, self.attn_tp_cpu_group, src=self.attn_tp_group.ranks[0], + device=self.device, ) if self.tp_size != 1: control_reqs = broadcast_pyobj( @@ -1047,6 +1050,7 @@ class Scheduler( self.tp_group.rank, self.tp_cpu_group, src=self.tp_group.ranks[0], + device=self.device, ) recv_reqs = work_reqs + control_reqs elif self.tp_size != 1: @@ -1055,6 +1059,7 @@ class Scheduler( self.tp_group.rank, self.tp_cpu_group, src=self.tp_group.ranks[0], + device=self.device, ) return recv_reqs diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index ff20ea01e..daeed4faf 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -144,6 +144,7 @@ class TpModelWorker: self.tp_size * self.pp_rank + tp_rank, self.world_group.cpu_group, src=self.world_group.ranks[0], + device=self.device, )[0] set_random_seed(self.random_seed) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 377fa90c8..d055aab5b 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1094,15 +1094,15 @@ def broadcast_pyobj( rank: int, dist_group: Optional[torch.distributed.ProcessGroup] = None, src: int = 0, - force_cpu_device: bool = True, + device: Optional[str] = None, ): """Broadcast inputs from src rank to all other ranks with torch.dist backend. The `rank` here refer to the source rank on global process group (regardless of dist_group argument). """ - device = torch.device( - "cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu" - ) + + if device is None: + device = get_device() if rank == src: if len(data) == 0: @@ -1142,44 +1142,38 @@ def point_to_point_pyobj( group: Optional[torch.distributed.ProcessGroup] = None, src: int = 0, dst: int = 1, + device: Optional[str] = None, ): """Send data from src to dst in group using DeviceToDevice communication.""" - + if device is None: + device = get_device() if rank == src: if len(data) == 0: - tensor_size = torch.tensor( - [0], dtype=torch.long, device=torch.cuda.current_device() - ) + tensor_size = torch.tensor([0], dtype=torch.long, device=device) dist.send(tensor_size, dst=dst, group=group) else: serialized_data = pickle.dumps(data) size = len(serialized_data) tensor_data = torch.ByteTensor( np.frombuffer(serialized_data, dtype=np.uint8) - ).cuda( - device=torch.cuda.current_device() - ) # Move to GPU - tensor_size = torch.tensor( - [size], dtype=torch.long, device=torch.cuda.current_device() - ) + ).to( + device=device + ) # Move to Device + tensor_size = torch.tensor([size], dtype=torch.long, device=device) dist.send(tensor_size, dst=dst, group=group) dist.send(tensor_data, dst=dst, group=group) return data elif rank == dst: - tensor_size = torch.tensor( - [0], dtype=torch.long, device=torch.cuda.current_device() - ) + tensor_size = torch.tensor([0], dtype=torch.long, device=device) dist.recv(tensor_size, src=src, group=group) size = tensor_size.item() if size == 0: return [] - tensor_data = torch.empty( - size, dtype=torch.uint8, device=torch.cuda.current_device() - ) + tensor_data = torch.empty(size, dtype=torch.uint8, device=device) dist.recv(tensor_data, src=src, group=group) serialized_data = bytes(