Offload tensors by sharding on GPU (#9536)
This commit is contained in:
@@ -321,6 +321,7 @@ class _BaseParamOffloader(ABC):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def create(mode: str, **kwargs) -> "_BaseParamOffloader":
|
def create(mode: str, **kwargs) -> "_BaseParamOffloader":
|
||||||
return {
|
return {
|
||||||
|
"meta": _MetaParamOffloader,
|
||||||
"cpu": _CpuParamOffloader,
|
"cpu": _CpuParamOffloader,
|
||||||
"shm_cpu": _ShmCpuParamOffloader,
|
"shm_cpu": _ShmCpuParamOffloader,
|
||||||
"sharded_gpu": _ShardedGpuParamOffloader,
|
"sharded_gpu": _ShardedGpuParamOffloader,
|
||||||
@@ -341,6 +342,17 @@ class _BaseParamOffloader(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class _MetaParamOffloader(_BaseParamOffloader):
|
||||||
|
"""Usually used for debugging."""
|
||||||
|
|
||||||
|
def __init__(self, module, param_name):
|
||||||
|
super().__init__(module, param_name)
|
||||||
|
_move_param_to_meta(module, param_name)
|
||||||
|
|
||||||
|
def create_device_tensor(self):
|
||||||
|
return torch.empty_like(self._param.data, device="cuda")
|
||||||
|
|
||||||
|
|
||||||
class _CpuParamOffloader(_BaseParamOffloader):
|
class _CpuParamOffloader(_BaseParamOffloader):
|
||||||
def __init__(self, module, param_name):
|
def __init__(self, module, param_name):
|
||||||
super().__init__(module, param_name)
|
super().__init__(module, param_name)
|
||||||
@@ -431,3 +443,106 @@ def _empty_strided_like(x: torch.Tensor, device, pin_memory=False):
|
|||||||
device=device,
|
device=device,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------- ShardedGpu ------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
# TODO unify with ShmCpu mode
|
||||||
|
class _ShardedGpuParamOffloader(_BaseParamOffloader):
|
||||||
|
def __init__(self, module, param_name):
|
||||||
|
super().__init__(module, param_name)
|
||||||
|
self._rank = get_naive_distributed().get_rank()
|
||||||
|
self._world_size = get_naive_distributed().get_world_size()
|
||||||
|
|
||||||
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||||
|
|
||||||
|
assert get_tensor_model_parallel_world_size() == 1, "not yet support tp_size!=1"
|
||||||
|
assert (
|
||||||
|
self._param.data.is_contiguous()
|
||||||
|
), f"not yet support non-contiguous tensor {self._param.shape=} {self._param.stride()=}"
|
||||||
|
|
||||||
|
if self._rank == 0:
|
||||||
|
_move_param_to_cpu(self._param, pin_memory=True)
|
||||||
|
else:
|
||||||
|
_move_param_to_meta(self._module, self._param_name)
|
||||||
|
|
||||||
|
self.sharded_param_handles = None
|
||||||
|
|
||||||
|
def post_init(self):
|
||||||
|
# check again since it may be changed
|
||||||
|
assert (
|
||||||
|
self._param.data.is_contiguous()
|
||||||
|
), f"not yet support non-contiguous tensor {self._param.shape=} {self._param.stride()=}"
|
||||||
|
|
||||||
|
scatter_src = self._param.data
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[offloader] post_init {scatter_src.nbytes=} {scatter_src.dtype=} {scatter_src.shape=} {torch.cuda.memory_allocated()=}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._rank == 0:
|
||||||
|
scatter_src = scatter_src.to("cuda")
|
||||||
|
scatter_list = _even_chunk(scatter_src, self._world_size)
|
||||||
|
|
||||||
|
sharded_param = torch.empty(
|
||||||
|
scatter_list[0].shape, dtype=scatter_list[0].dtype, device="cuda"
|
||||||
|
)
|
||||||
|
self.sharded_param_handles = _create_shared_buffer_tensors(
|
||||||
|
local_tensor=sharded_param
|
||||||
|
)
|
||||||
|
|
||||||
|
get_naive_distributed().scatter(
|
||||||
|
sharded_param, scatter_list if self._rank == 0 else None
|
||||||
|
)
|
||||||
|
|
||||||
|
_move_param_to_meta(self._module, self._param_name)
|
||||||
|
|
||||||
|
def create_device_tensor(self):
|
||||||
|
output = _empty_strided_like(self._param, device="cuda")
|
||||||
|
output_chunks = output.chunk(self._world_size)
|
||||||
|
|
||||||
|
for index in range(self._world_size):
|
||||||
|
src_rank = (self._rank + index) % self._world_size
|
||||||
|
src_buf = self.sharded_param_handles[src_rank]
|
||||||
|
output_chunks[src_rank].copy_(src_buf)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def _even_chunk(x: torch.Tensor, chunks: int):
|
||||||
|
assert x.shape[0] % chunks == 0, f"{x.shape=} {chunks=}"
|
||||||
|
return list(x.chunk(chunks))
|
||||||
|
|
||||||
|
|
||||||
|
def _create_shared_buffer_tensors(local_tensor: torch.Tensor) -> List[torch.Tensor]:
|
||||||
|
self_rank = get_naive_distributed().get_rank()
|
||||||
|
world_size = get_naive_distributed().get_world_size()
|
||||||
|
|
||||||
|
object_list = get_naive_distributed().all_gather_object(
|
||||||
|
dict(
|
||||||
|
dup_serialized_local_tensor=[
|
||||||
|
(
|
||||||
|
None
|
||||||
|
if interesting_rank == self_rank
|
||||||
|
else MultiprocessingSerializer.serialize(local_tensor)
|
||||||
|
)
|
||||||
|
for interesting_rank in range(world_size)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
output_tensors = []
|
||||||
|
for output_rank in range(world_size):
|
||||||
|
remote_serialized_tensor = object_list[output_rank][
|
||||||
|
"dup_serialized_local_tensor"
|
||||||
|
][self_rank]
|
||||||
|
if output_rank == self_rank:
|
||||||
|
assert remote_serialized_tensor is None
|
||||||
|
output_tensors.append(local_tensor)
|
||||||
|
else:
|
||||||
|
output_tensors.append(
|
||||||
|
MultiprocessingSerializer.deserialize(remote_serialized_tensor)
|
||||||
|
)
|
||||||
|
|
||||||
|
return output_tensors
|
||||||
|
|||||||
Reference in New Issue
Block a user