84 lines
2.2 KiB
Python
84 lines
2.2 KiB
Python
import logging
|
|
import os
|
|
from dataclasses import dataclass
|
|
from multiprocessing import shared_memory
|
|
from pathlib import Path
|
|
from typing import List, Optional
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from sglang.srt.distributed.naive_distributed import get_naive_distributed
|
|
from sglang.srt.utils import check_cuda_result
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class HostSharedMemoryManager:
|
|
def __init__(self, base_name: str):
|
|
self._base_name = Path(base_name)
|
|
self._operation_index = 0
|
|
self._records: List[_Record] = []
|
|
|
|
def malloc(self, *, shape, dtype):
|
|
meta_tensor = torch.empty(size=shape, dtype=dtype, device="meta")
|
|
raw = self._malloc_raw(num_bytes=meta_tensor.nbytes)
|
|
return raw.view(dtype).view(*shape)
|
|
|
|
def _malloc_raw(self, *, num_bytes: int) -> torch.Tensor:
|
|
import cuda.bindings.runtime as cuda_rt
|
|
|
|
self._operation_index += 1
|
|
shm_name = f"{self._base_name}_op{self._operation_index}"
|
|
|
|
# TODO handle dispose
|
|
if get_naive_distributed().get_rank() == 0:
|
|
shm = shared_memory.SharedMemory(name=shm_name, create=True, size=num_bytes)
|
|
|
|
get_naive_distributed().barrier()
|
|
|
|
if get_naive_distributed().get_rank() != 0:
|
|
shm = shared_memory.SharedMemory(name=shm_name)
|
|
|
|
np_array = np.ndarray((num_bytes,), dtype=np.uint8, buffer=shm.buf)
|
|
tensor = torch.from_numpy(np_array)
|
|
|
|
check_cuda_result(
|
|
cuda_rt.cudaHostRegister(
|
|
tensor.data_ptr(), num_bytes, cuda_rt.cudaHostRegisterPortable
|
|
)
|
|
)
|
|
|
|
get_naive_distributed().barrier()
|
|
|
|
self._records.append(
|
|
_Record(
|
|
shm=shm,
|
|
np_array=np_array,
|
|
tensor=tensor,
|
|
)
|
|
)
|
|
return tensor
|
|
|
|
|
|
@dataclass
|
|
class _Record:
|
|
shm: shared_memory.SharedMemory
|
|
np_array: np.ndarray
|
|
tensor: torch.Tensor
|
|
|
|
|
|
# Can have multi instances if needed
|
|
_instance: Optional[HostSharedMemoryManager] = None
|
|
|
|
|
|
def get_host_shared_memory_manager():
|
|
assert _instance is not None
|
|
return _instance
|
|
|
|
|
|
def set_host_shared_memory_manager(instance: HostSharedMemoryManager):
|
|
global _instance
|
|
assert _instance is None
|
|
_instance = instance
|