Offload tensors by sharding on GPU (#9536)
This commit is contained in:
@@ -321,6 +321,7 @@ class _BaseParamOffloader(ABC):
|
||||
@staticmethod
|
||||
def create(mode: str, **kwargs) -> "_BaseParamOffloader":
|
||||
return {
|
||||
"meta": _MetaParamOffloader,
|
||||
"cpu": _CpuParamOffloader,
|
||||
"shm_cpu": _ShmCpuParamOffloader,
|
||||
"sharded_gpu": _ShardedGpuParamOffloader,
|
||||
@@ -341,6 +342,17 @@ class _BaseParamOffloader(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _MetaParamOffloader(_BaseParamOffloader):
|
||||
"""Usually used for debugging."""
|
||||
|
||||
def __init__(self, module, param_name):
|
||||
super().__init__(module, param_name)
|
||||
_move_param_to_meta(module, param_name)
|
||||
|
||||
def create_device_tensor(self):
|
||||
return torch.empty_like(self._param.data, device="cuda")
|
||||
|
||||
|
||||
class _CpuParamOffloader(_BaseParamOffloader):
|
||||
def __init__(self, module, param_name):
|
||||
super().__init__(module, param_name)
|
||||
@@ -431,3 +443,106 @@ def _empty_strided_like(x: torch.Tensor, device, pin_memory=False):
|
||||
device=device,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------- ShardedGpu ------------------------------------------------------
|
||||
|
||||
|
||||
# TODO unify with ShmCpu mode
|
||||
class _ShardedGpuParamOffloader(_BaseParamOffloader):
|
||||
def __init__(self, module, param_name):
|
||||
super().__init__(module, param_name)
|
||||
self._rank = get_naive_distributed().get_rank()
|
||||
self._world_size = get_naive_distributed().get_world_size()
|
||||
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||
|
||||
assert get_tensor_model_parallel_world_size() == 1, "not yet support tp_size!=1"
|
||||
assert (
|
||||
self._param.data.is_contiguous()
|
||||
), f"not yet support non-contiguous tensor {self._param.shape=} {self._param.stride()=}"
|
||||
|
||||
if self._rank == 0:
|
||||
_move_param_to_cpu(self._param, pin_memory=True)
|
||||
else:
|
||||
_move_param_to_meta(self._module, self._param_name)
|
||||
|
||||
self.sharded_param_handles = None
|
||||
|
||||
def post_init(self):
|
||||
# check again since it may be changed
|
||||
assert (
|
||||
self._param.data.is_contiguous()
|
||||
), f"not yet support non-contiguous tensor {self._param.shape=} {self._param.stride()=}"
|
||||
|
||||
scatter_src = self._param.data
|
||||
|
||||
logger.info(
|
||||
f"[offloader] post_init {scatter_src.nbytes=} {scatter_src.dtype=} {scatter_src.shape=} {torch.cuda.memory_allocated()=}"
|
||||
)
|
||||
|
||||
if self._rank == 0:
|
||||
scatter_src = scatter_src.to("cuda")
|
||||
scatter_list = _even_chunk(scatter_src, self._world_size)
|
||||
|
||||
sharded_param = torch.empty(
|
||||
scatter_list[0].shape, dtype=scatter_list[0].dtype, device="cuda"
|
||||
)
|
||||
self.sharded_param_handles = _create_shared_buffer_tensors(
|
||||
local_tensor=sharded_param
|
||||
)
|
||||
|
||||
get_naive_distributed().scatter(
|
||||
sharded_param, scatter_list if self._rank == 0 else None
|
||||
)
|
||||
|
||||
_move_param_to_meta(self._module, self._param_name)
|
||||
|
||||
def create_device_tensor(self):
|
||||
output = _empty_strided_like(self._param, device="cuda")
|
||||
output_chunks = output.chunk(self._world_size)
|
||||
|
||||
for index in range(self._world_size):
|
||||
src_rank = (self._rank + index) % self._world_size
|
||||
src_buf = self.sharded_param_handles[src_rank]
|
||||
output_chunks[src_rank].copy_(src_buf)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _even_chunk(x: torch.Tensor, chunks: int):
|
||||
assert x.shape[0] % chunks == 0, f"{x.shape=} {chunks=}"
|
||||
return list(x.chunk(chunks))
|
||||
|
||||
|
||||
def _create_shared_buffer_tensors(local_tensor: torch.Tensor) -> List[torch.Tensor]:
|
||||
self_rank = get_naive_distributed().get_rank()
|
||||
world_size = get_naive_distributed().get_world_size()
|
||||
|
||||
object_list = get_naive_distributed().all_gather_object(
|
||||
dict(
|
||||
dup_serialized_local_tensor=[
|
||||
(
|
||||
None
|
||||
if interesting_rank == self_rank
|
||||
else MultiprocessingSerializer.serialize(local_tensor)
|
||||
)
|
||||
for interesting_rank in range(world_size)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
output_tensors = []
|
||||
for output_rank in range(world_size):
|
||||
remote_serialized_tensor = object_list[output_rank][
|
||||
"dup_serialized_local_tensor"
|
||||
][self_rank]
|
||||
if output_rank == self_rank:
|
||||
assert remote_serialized_tensor is None
|
||||
output_tensors.append(local_tensor)
|
||||
else:
|
||||
output_tensors.append(
|
||||
MultiprocessingSerializer.deserialize(remote_serialized_tensor)
|
||||
)
|
||||
|
||||
return output_tensors
|
||||
|
||||
Reference in New Issue
Block a user