Overlapped weight offload (#8034)
This commit is contained in:
83
python/sglang/srt/host_shared_memory.py
Normal file
83
python/sglang/srt/host_shared_memory.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing import shared_memory
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from sglang.srt.distributed.naive_distributed import get_naive_distributed
|
||||
from sglang.srt.utils import check_cuda_result
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HostSharedMemoryManager:
|
||||
def __init__(self, base_name: str):
|
||||
self._base_name = Path(base_name)
|
||||
self._operation_index = 0
|
||||
self._records: List[_Record] = []
|
||||
|
||||
def malloc(self, *, shape, dtype):
|
||||
meta_tensor = torch.empty(size=shape, dtype=dtype, device="meta")
|
||||
raw = self._malloc_raw(num_bytes=meta_tensor.nbytes)
|
||||
return raw.view(dtype).view(*shape)
|
||||
|
||||
def _malloc_raw(self, *, num_bytes: int) -> torch.Tensor:
|
||||
import cuda.bindings.runtime as cuda_rt
|
||||
|
||||
self._operation_index += 1
|
||||
shm_name = f"{self._base_name}_op{self._operation_index}"
|
||||
|
||||
# TODO handle dispose
|
||||
if get_naive_distributed().get_rank() == 0:
|
||||
shm = shared_memory.SharedMemory(name=shm_name, create=True, size=num_bytes)
|
||||
|
||||
get_naive_distributed().barrier()
|
||||
|
||||
if get_naive_distributed().get_rank() != 0:
|
||||
shm = shared_memory.SharedMemory(name=shm_name)
|
||||
|
||||
np_array = np.ndarray((num_bytes,), dtype=np.uint8, buffer=shm.buf)
|
||||
tensor = torch.from_numpy(np_array)
|
||||
|
||||
check_cuda_result(
|
||||
cuda_rt.cudaHostRegister(
|
||||
tensor.data_ptr(), num_bytes, cuda_rt.cudaHostRegisterPortable
|
||||
)
|
||||
)
|
||||
|
||||
get_naive_distributed().barrier()
|
||||
|
||||
self._records.append(
|
||||
_Record(
|
||||
shm=shm,
|
||||
np_array=np_array,
|
||||
tensor=tensor,
|
||||
)
|
||||
)
|
||||
return tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Record:
|
||||
shm: shared_memory.SharedMemory
|
||||
np_array: np.ndarray
|
||||
tensor: torch.Tensor
|
||||
|
||||
|
||||
# Can have multi instances if needed
|
||||
_instance: Optional[HostSharedMemoryManager] = None
|
||||
|
||||
|
||||
def get_host_shared_memory_manager():
|
||||
assert _instance is not None
|
||||
return _instance
|
||||
|
||||
|
||||
def set_host_shared_memory_manager(instance: HostSharedMemoryManager):
|
||||
global _instance
|
||||
assert _instance is None
|
||||
_instance = instance
|
||||
Reference in New Issue
Block a user