From 71a7f1d86fbd361ad145c9220318b4ae3a2d4998 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Mon, 25 Aug 2025 15:02:49 +0800 Subject: [PATCH] Offload tensors by sharding on GPU (#9536) --- python/sglang/srt/offloader.py | 115 +++++++++++++++++++++++++++++++++ 1 file changed, 115 insertions(+) diff --git a/python/sglang/srt/offloader.py b/python/sglang/srt/offloader.py index b7b06cf71..aea7d7f23 100644 --- a/python/sglang/srt/offloader.py +++ b/python/sglang/srt/offloader.py @@ -321,6 +321,7 @@ class _BaseParamOffloader(ABC): @staticmethod def create(mode: str, **kwargs) -> "_BaseParamOffloader": return { + "meta": _MetaParamOffloader, "cpu": _CpuParamOffloader, "shm_cpu": _ShmCpuParamOffloader, "sharded_gpu": _ShardedGpuParamOffloader, @@ -341,6 +342,17 @@ class _BaseParamOffloader(ABC): raise NotImplementedError +class _MetaParamOffloader(_BaseParamOffloader): + """Usually used for debugging.""" + + def __init__(self, module, param_name): + super().__init__(module, param_name) + _move_param_to_meta(module, param_name) + + def create_device_tensor(self): + return torch.empty_like(self._param.data, device="cuda") + + class _CpuParamOffloader(_BaseParamOffloader): def __init__(self, module, param_name): super().__init__(module, param_name) @@ -431,3 +443,106 @@ def _empty_strided_like(x: torch.Tensor, device, pin_memory=False): device=device, pin_memory=pin_memory, ) + + +# ----------------------------------------- ShardedGpu ------------------------------------------------------ + + +# TODO unify with ShmCpu mode +class _ShardedGpuParamOffloader(_BaseParamOffloader): + def __init__(self, module, param_name): + super().__init__(module, param_name) + self._rank = get_naive_distributed().get_rank() + self._world_size = get_naive_distributed().get_world_size() + + from sglang.srt.distributed import get_tensor_model_parallel_world_size + + assert get_tensor_model_parallel_world_size() == 1, "not yet support tp_size!=1" + assert ( + self._param.data.is_contiguous() + ), f"not yet support non-contiguous tensor {self._param.shape=} {self._param.stride()=}" + + if self._rank == 0: + _move_param_to_cpu(self._param, pin_memory=True) + else: + _move_param_to_meta(self._module, self._param_name) + + self.sharded_param_handles = None + + def post_init(self): + # check again since it may be changed + assert ( + self._param.data.is_contiguous() + ), f"not yet support non-contiguous tensor {self._param.shape=} {self._param.stride()=}" + + scatter_src = self._param.data + + logger.info( + f"[offloader] post_init {scatter_src.nbytes=} {scatter_src.dtype=} {scatter_src.shape=} {torch.cuda.memory_allocated()=}" + ) + + if self._rank == 0: + scatter_src = scatter_src.to("cuda") + scatter_list = _even_chunk(scatter_src, self._world_size) + + sharded_param = torch.empty( + scatter_list[0].shape, dtype=scatter_list[0].dtype, device="cuda" + ) + self.sharded_param_handles = _create_shared_buffer_tensors( + local_tensor=sharded_param + ) + + get_naive_distributed().scatter( + sharded_param, scatter_list if self._rank == 0 else None + ) + + _move_param_to_meta(self._module, self._param_name) + + def create_device_tensor(self): + output = _empty_strided_like(self._param, device="cuda") + output_chunks = output.chunk(self._world_size) + + for index in range(self._world_size): + src_rank = (self._rank + index) % self._world_size + src_buf = self.sharded_param_handles[src_rank] + output_chunks[src_rank].copy_(src_buf) + + return output + + +def _even_chunk(x: torch.Tensor, chunks: int): + assert x.shape[0] % chunks == 0, f"{x.shape=} {chunks=}" + return list(x.chunk(chunks)) + + +def _create_shared_buffer_tensors(local_tensor: torch.Tensor) -> List[torch.Tensor]: + self_rank = get_naive_distributed().get_rank() + world_size = get_naive_distributed().get_world_size() + + object_list = get_naive_distributed().all_gather_object( + dict( + dup_serialized_local_tensor=[ + ( + None + if interesting_rank == self_rank + else MultiprocessingSerializer.serialize(local_tensor) + ) + for interesting_rank in range(world_size) + ] + ) + ) + + output_tensors = [] + for output_rank in range(world_size): + remote_serialized_tensor = object_list[output_rank][ + "dup_serialized_local_tensor" + ][self_rank] + if output_rank == self_rank: + assert remote_serialized_tensor is None + output_tensors.append(local_tensor) + else: + output_tensors.append( + MultiprocessingSerializer.deserialize(remote_serialized_tensor) + ) + + return output_tensors