diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 5ab2e3758..509c71531 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -699,14 +699,14 @@ class GroupCoordinator: ) # Serialize object to tensor and get the size as well - object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8).to( - device=self.device + object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8).cuda( + device=torch.cuda.current_device() ) size_tensor = torch.tensor( [object_tensor.numel()], dtype=torch.long, - device=self.device, + device=torch.cuda.current_device(), ) # Send object size @@ -731,7 +731,9 @@ class GroupCoordinator: src != self.rank_in_group ), "Invalid source rank. Source rank is the same as the current rank." - size_tensor = torch.empty(1, dtype=torch.long, device=self.device) + size_tensor = torch.empty( + 1, dtype=torch.long, device=torch.cuda.current_device() + ) # Receive object size rank_size = torch.distributed.recv( @@ -742,7 +744,7 @@ class GroupCoordinator: object_tensor = torch.empty( # type: ignore[call-overload] size_tensor.item(), # type: ignore[arg-type] dtype=torch.uint8, - device=self.device, + device=torch.cuda.current_device(), ) rank_object = torch.distributed.recv( diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index c79e296f6..748cb7322 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -975,7 +975,6 @@ class Scheduler( self.world_group.device_group, self.pp_rank * self.tp_size + dp_offset, (self.pp_rank + 1) * self.tp_size + dp_offset, - device=self.device, ) # send out proxy tensors to the next stage @@ -1024,7 +1023,6 @@ class Scheduler( self.world_group.device_group, (self.pp_rank - 1) * self.tp_size + dp_offset, self.pp_rank * self.tp_size + dp_offset, - device=self.device, ) else: recv_reqs = None @@ -1055,7 +1053,6 @@ class Scheduler( self.attn_tp_group.rank, self.attn_tp_cpu_group, src=self.attn_tp_group.ranks[0], - device=self.device, ) if self.tp_size != 1: control_reqs = broadcast_pyobj( @@ -1063,7 +1060,6 @@ class Scheduler( self.tp_group.rank, self.tp_cpu_group, src=self.tp_group.ranks[0], - device=self.device, ) recv_reqs = work_reqs + control_reqs elif self.tp_size != 1: @@ -1072,7 +1068,6 @@ class Scheduler( self.tp_group.rank, self.tp_cpu_group, src=self.tp_group.ranks[0], - device=self.device, ) return recv_reqs diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index daeed4faf..ff20ea01e 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -144,7 +144,6 @@ class TpModelWorker: self.tp_size * self.pp_rank + tp_rank, self.world_group.cpu_group, src=self.world_group.ranks[0], - device=self.device, )[0] set_random_seed(self.random_seed) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 37e06b8dc..ce159a4da 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1100,15 +1100,15 @@ def broadcast_pyobj( rank: int, dist_group: Optional[torch.distributed.ProcessGroup] = None, src: int = 0, - device: Optional[str] = None, + force_cpu_device: bool = True, ): """Broadcast inputs from src rank to all other ranks with torch.dist backend. The `rank` here refer to the source rank on global process group (regardless of dist_group argument). """ - - if device is None: - device = get_device() + device = torch.device( + "cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu" + ) if rank == src: if len(data) == 0: @@ -1148,38 +1148,44 @@ def point_to_point_pyobj( group: Optional[torch.distributed.ProcessGroup] = None, src: int = 0, dst: int = 1, - device: Optional[str] = None, ): """Send data from src to dst in group using DeviceToDevice communication.""" - if device is None: - device = get_device() + if rank == src: if len(data) == 0: - tensor_size = torch.tensor([0], dtype=torch.long, device=device) + tensor_size = torch.tensor( + [0], dtype=torch.long, device=torch.cuda.current_device() + ) dist.send(tensor_size, dst=dst, group=group) else: serialized_data = pickle.dumps(data) size = len(serialized_data) tensor_data = torch.ByteTensor( np.frombuffer(serialized_data, dtype=np.uint8) - ).to( - device=device - ) # Move to Device - tensor_size = torch.tensor([size], dtype=torch.long, device=device) + ).cuda( + device=torch.cuda.current_device() + ) # Move to GPU + tensor_size = torch.tensor( + [size], dtype=torch.long, device=torch.cuda.current_device() + ) dist.send(tensor_size, dst=dst, group=group) dist.send(tensor_data, dst=dst, group=group) return data elif rank == dst: - tensor_size = torch.tensor([0], dtype=torch.long, device=device) + tensor_size = torch.tensor( + [0], dtype=torch.long, device=torch.cuda.current_device() + ) dist.recv(tensor_size, src=src, group=group) size = tensor_size.item() if size == 0: return [] - tensor_data = torch.empty(size, dtype=torch.uint8, device=device) + tensor_data = torch.empty( + size, dtype=torch.uint8, device=torch.cuda.current_device() + ) dist.recv(tensor_data, src=src, group=group) serialized_data = bytes(