diff --git a/python/sglang/srt/distributed/naive_distributed.py b/python/sglang/srt/distributed/naive_distributed.py new file mode 100644 index 000000000..61165d90c --- /dev/null +++ b/python/sglang/srt/distributed/naive_distributed.py @@ -0,0 +1,112 @@ +import base64 +import os +import pickle +import time +from pathlib import Path +from typing import Any, List, Optional + +import torch + +from sglang.srt.utils import MultiprocessingSerializer + + +class NaiveDistributed: + def __init__(self, rank: int, world_size: int, rendezvous: str): + self._rank = rank + self._world_size = world_size + self._operation_index = 0 + self._directory = Path(rendezvous) + self._directory.mkdir(parents=True, exist_ok=True) + assert 0 <= rank < world_size + + # both barrier to be safe, and as a sanity check + self.barrier() + + def get_rank(self): + return self._rank + + def get_world_size(self): + return self._world_size + + def scatter( + self, tensor: torch.Tensor, scatter_list: List[torch.Tensor], src: int = 0 + ): + if self._rank == src: + assert len(scatter_list) == self._world_size + else: + assert scatter_list is None + + gathered_objects = self.all_gather_object( + dict( + serialized_scatter_list=[ + ( + None + if item_rank == src + else MultiprocessingSerializer.serialize(item) + ) + for item_rank, item in enumerate(scatter_list) + ] + ) + if self._rank == src + else dict() + ) + + remote_serialized_tensor = gathered_objects[src]["serialized_scatter_list"][ + self._rank + ] + if self._rank == src: + assert remote_serialized_tensor is None + remote_tensor = scatter_list[self._rank] + else: + remote_tensor = MultiprocessingSerializer.deserialize( + remote_serialized_tensor + ) + tensor.copy_(remote_tensor) + + # avoid src tensor be deleted too early + self.barrier() + + def all_gather_object(self, obj: Any) -> List[Any]: + self._operation_index += 1 + + text_postfix = "\n" + + def _get_path(interesting_rank: int): + return ( + self._directory + / f"rank{interesting_rank}_op{self._operation_index}.txt" + ) + + _get_path(self._rank).write_text( + base64.b64encode(pickle.dumps(obj)).decode("utf-8") + text_postfix + ) + + def _read_one(interesting_rank: int): + p = _get_path(interesting_rank) + while True: + if p.exists() and (text := p.read_text()).endswith(text_postfix): + return pickle.loads(base64.b64decode(text[: -len(text_postfix)])) + time.sleep(0.001) + + return [ + _read_one(interesting_rank) for interesting_rank in range(self._world_size) + ] + + def barrier(self): + actual_objs = self.all_gather_object(self._rank) + assert actual_objs == list(range(self._world_size)), f"{actual_objs=}" + + +# Can have multi instances if needed +_instance: Optional[NaiveDistributed] = None + + +def get_naive_distributed(): + assert _instance is not None + return _instance + + +def set_naive_distributed(instance: NaiveDistributed): + global _instance + assert _instance is None + _instance = instance diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 90c167432..f1e858c94 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -23,8 +23,10 @@ import dataclasses import logging import multiprocessing as mp import os +import random import signal import threading +import time from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union import zmq @@ -654,6 +656,11 @@ def _set_envs_and_config(server_args: ServerArgs): # flashinfer uses this environment variable for various kernels from MoE to quant kernels os.environ["TRTLLM_ENABLE_PDL"] = "1" + # Can also be passed as argument + os.environ["SGLANG_RUN_ID"] = ( + f"sglang-run-{time.time()}-{random.randint(0, 100000000)}" + ) + # Set prometheus env vars if server_args.enable_metrics: set_prometheus_multiproc_dir() diff --git a/python/sglang/srt/host_shared_memory.py b/python/sglang/srt/host_shared_memory.py new file mode 100644 index 000000000..c599527f9 --- /dev/null +++ b/python/sglang/srt/host_shared_memory.py @@ -0,0 +1,83 @@ +import logging +import os +from dataclasses import dataclass +from multiprocessing import shared_memory +from pathlib import Path +from typing import List, Optional + +import numpy as np +import torch + +from sglang.srt.distributed.naive_distributed import get_naive_distributed +from sglang.srt.utils import check_cuda_result + +logger = logging.getLogger(__name__) + + +class HostSharedMemoryManager: + def __init__(self, base_name: str): + self._base_name = Path(base_name) + self._operation_index = 0 + self._records: List[_Record] = [] + + def malloc(self, *, shape, dtype): + meta_tensor = torch.empty(size=shape, dtype=dtype, device="meta") + raw = self._malloc_raw(num_bytes=meta_tensor.nbytes) + return raw.view(dtype).view(*shape) + + def _malloc_raw(self, *, num_bytes: int) -> torch.Tensor: + import cuda.bindings.runtime as cuda_rt + + self._operation_index += 1 + shm_name = f"{self._base_name}_op{self._operation_index}" + + # TODO handle dispose + if get_naive_distributed().get_rank() == 0: + shm = shared_memory.SharedMemory(name=shm_name, create=True, size=num_bytes) + + get_naive_distributed().barrier() + + if get_naive_distributed().get_rank() != 0: + shm = shared_memory.SharedMemory(name=shm_name) + + np_array = np.ndarray((num_bytes,), dtype=np.uint8, buffer=shm.buf) + tensor = torch.from_numpy(np_array) + + check_cuda_result( + cuda_rt.cudaHostRegister( + tensor.data_ptr(), num_bytes, cuda_rt.cudaHostRegisterPortable + ) + ) + + get_naive_distributed().barrier() + + self._records.append( + _Record( + shm=shm, + np_array=np_array, + tensor=tensor, + ) + ) + return tensor + + +@dataclass +class _Record: + shm: shared_memory.SharedMemory + np_array: np.ndarray + tensor: torch.Tensor + + +# Can have multi instances if needed +_instance: Optional[HostSharedMemoryManager] = None + + +def get_host_shared_memory_manager(): + assert _instance is not None + return _instance + + +def set_host_shared_memory_manager(instance: HostSharedMemoryManager): + global _instance + assert _instance is None + _instance = instance diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 77dac1ea6..968be171d 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -92,6 +92,7 @@ class TpModelWorker: pp_rank=pp_rank, pp_size=server_args.pp_size, nccl_port=nccl_port, + dp_rank=dp_rank, server_args=server_args, is_draft_worker=is_draft_worker, req_to_token_pool=req_to_token_pool, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index acfeaee3d..293dba061 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -172,6 +172,7 @@ class ModelRunner: pp_size: int, nccl_port: int, server_args: ServerArgs, + dp_rank: Optional[int] = None, is_draft_worker: bool = False, req_to_token_pool: Optional[ReqToTokenPool] = None, token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None, @@ -234,7 +235,7 @@ class ModelRunner: min_per_gpu_memory = self.init_torch_distributed() # CPU offload - set_offloader(create_offloader_from_server_args(server_args)) + set_offloader(create_offloader_from_server_args(server_args, dp_rank=dp_rank)) # Update deep gemm configure if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM: diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 95b962fa3..bf22528f0 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1996,6 +1996,23 @@ class DeepseekV2Model(nn.Module): pp_rank=self.pp_group.rank_in_group, pp_size=self.pp_group.world_size, prefix=add_prefix("layers", prefix), + offloader_kwargs=dict( + submodule_accessor=lambda layer: ( + layer.mlp.experts + if isinstance(layer.mlp, DeepseekV2MoE) + else layer.mlp + ), + whitelist_param_names_creator=lambda module: ( + [ + "w13_weight", + "w2_weight", + "w13_blockscale_swizzled", + "w2_blockscale_swizzled", + ] + if isinstance(module, FusedMoE) + else [] + ), + ), ) if self.pp_group.is_last_rank: self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/python/sglang/srt/offloader.py b/python/sglang/srt/offloader.py index f7bf4082b..b7b06cf71 100644 --- a/python/sglang/srt/offloader.py +++ b/python/sglang/srt/offloader.py @@ -1,12 +1,24 @@ import logging +import os from abc import ABC from typing import Callable, Generator, List, Optional import torch from torch.func import functional_call +from sglang.srt.distributed.naive_distributed import ( + NaiveDistributed, + get_naive_distributed, + set_naive_distributed, +) +from sglang.srt.host_shared_memory import ( + HostSharedMemoryManager, + get_host_shared_memory_manager, + set_host_shared_memory_manager, +) +from sglang.srt.layers.parameter import ModelWeightParameter from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import is_pin_memory_available +from sglang.srt.utils import MultiprocessingSerializer, is_pin_memory_available logger = logging.getLogger(__name__) @@ -45,11 +57,23 @@ def set_offloader(instance: BaseOffloader): _instance = instance -def create_offloader_from_server_args(server_args: ServerArgs): +def create_offloader_from_server_args(server_args: ServerArgs, dp_rank: int): if server_args.cpu_offload_gb > 0: return OffloaderV1( cpu_offload_max_bytes=int(server_args.cpu_offload_gb * 1024**3) ) + if server_args.offload_group_size > 0: + assert ( + server_args.cpu_offload_gb == 0 + ), "V2 offload does not support cpu_offload_gb yet" + return OffloaderV2( + group_size=server_args.offload_group_size, + num_in_group=server_args.offload_num_in_group, + prefetch_step=server_args.offload_prefetch_step, + mode=server_args.offload_mode, + dp_rank=dp_rank, + dp_size=server_args.dp_size, + ) return NoopOffloader() @@ -120,3 +144,290 @@ class OffloaderV1(BaseOffloader): module.forward = forward return module + + +class OffloaderV2(BaseOffloader): + def __init__( + self, + group_size: int, + num_in_group: int, + prefetch_step: int, + mode: str, + dp_rank: int, + dp_size: int, + ): + self.group_size = group_size + self.num_in_group = num_in_group + self.prefetch_step = prefetch_step + self.mode = mode + + run_id = os.environ["SGLANG_RUN_ID"] + + # Temporarily init inside Offloader, can move if other modules also need this + if self.mode in {"sharded_gpu", "shm_cpu"}: + from sglang.srt.distributed import get_tensor_model_parallel_world_size + + assert ( + get_tensor_model_parallel_world_size() == 1 + ), "not yet support tp_size!=1" + set_naive_distributed( + NaiveDistributed( + rank=dp_rank, + world_size=dp_size, + rendezvous=f"/tmp/{run_id}", + ) + ) + if self.mode in {"shm_cpu"}: + set_host_shared_memory_manager( + HostSharedMemoryManager( + base_name=run_id, + ) + ) + + self.offloaders = [] + + def wrap_modules( + self, + all_modules_generator: Generator[torch.nn.Module, None, None], + submodule_accessor: Optional[_SubmoduleAccessor] = None, + whitelist_param_names_creator: Optional[_WhitelistParamNamesCreator] = None, + ): + assert len(self.offloaders) == 0, "should only call wrap_modules once" + + alt_stream = torch.cuda.Stream() + + all_modules = [] + offload_submodules = [] + for module_index, module in enumerate(all_modules_generator): + all_modules.append(module) + if module_index % self.group_size >= self.group_size - self.num_in_group: + submodule = submodule_accessor(module) + whitelist_param_names = whitelist_param_names_creator(submodule) + logger.info( + f"[offloader] offload {module_index=} submodule={type(submodule)} params={whitelist_param_names} memory_allocated={torch.cuda.memory_allocated()}" + ) + offload_submodules.append(submodule) + self.offloaders.append( + _ModuleOffloader( + mode=self.mode, + module=submodule, + alt_stream=alt_stream, + whitelist_param_names=whitelist_param_names, + ) + ) + + for index, module in enumerate(offload_submodules): + _hook_module_forward_for_offloader( + index=index, + module=module, + offloaders=self.offloaders, + prefetch_step=self.prefetch_step, + ) + + return all_modules + + def post_init(self): + for offloader in self.offloaders: + offloader.post_init() + + for i in range(self.prefetch_step): + self.offloaders[i].start_onload() + + +def _hook_module_forward_for_offloader(index, module, offloaders, prefetch_step): + def _on_forward_end(): + offloaders[(index + prefetch_step) % len(offloaders)].start_onload() + offloaders[index].offload() + + _hook_module_forward_raw( + module, + on_forward_end=_on_forward_end, + get_parameter_and_buffer_dicts=lambda: offloaders[ + index + ].wait_and_get_device_tensors(), + ) + + +def _hook_module_forward_raw(module, on_forward_end, get_parameter_and_buffer_dicts): + original_forward = module.forward + + def forward(*args, **kwargs): + module.forward = original_forward + output = functional_call( + module, get_parameter_and_buffer_dicts(), args=args, kwargs=kwargs + ) + on_forward_end() + module.forward = forward + return output + + module.forward = forward + + +class _ModuleOffloader(ABC): + def __init__( + self, + mode: str, + module: torch.nn.Module, + alt_stream: torch.cuda.Stream, + whitelist_param_names: List[str], + ): + self.mode = mode + self.module = module + self.device = next(module.parameters()).device + self.alt_stream = alt_stream + + assert self.device != torch.device( + "cpu" + ), "not handled device=cpu case yet (should skip this tensor)" + + self._device_tensors = None + self._load_event = None + + param_dict = dict(self.module.named_parameters()) + assert all( + name in param_dict for name in whitelist_param_names + ), f"{whitelist_param_names=} {list(param_dict.keys())=}" + + self._param_offloaders = { + name: _BaseParamOffloader.create(mode, module=module, param_name=name) + for name in whitelist_param_names + } + + def post_init(self): + for name, param_offloader in self._param_offloaders.items(): + param_offloader.post_init() + + def start_onload(self): + self.alt_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self.alt_stream): + self._device_tensors = self._create_device_tensors() + self._load_event = torch.cuda.Event() + self._load_event.record() + + def offload(self): + self._device_tensors = None + self._load_event = None + + def wait_and_get_device_tensors(self): + assert self._device_tensors is not None + self._load_event.wait() + return self._device_tensors + + def _create_device_tensors(self): + return {k: v.create_device_tensor() for k, v in self._param_offloaders.items()} + + +class _BaseParamOffloader(ABC): + @staticmethod + def create(mode: str, **kwargs) -> "_BaseParamOffloader": + return { + "cpu": _CpuParamOffloader, + "shm_cpu": _ShmCpuParamOffloader, + "sharded_gpu": _ShardedGpuParamOffloader, + }[mode](**kwargs) + + def __init__(self, module, param_name): + self._module = module + self._param_name = param_name + + @property + def _param(self): + return getattr(self._module, self._param_name) + + def post_init(self): + pass + + def create_device_tensor(self): + raise NotImplementedError + + +class _CpuParamOffloader(_BaseParamOffloader): + def __init__(self, module, param_name): + super().__init__(module, param_name) + _move_param_to_cpu(self._param, pin_memory=True) + + def create_device_tensor(self): + return self._param.to("cuda", non_blocking=True) + + +class _ShmCpuParamOffloader(_BaseParamOffloader): + def __init__(self, module, param_name): + super().__init__(module, param_name) + self._rank = get_naive_distributed().get_rank() + self._world_size = get_naive_distributed().get_world_size() + + from sglang.srt.distributed import get_tensor_model_parallel_world_size + + assert get_tensor_model_parallel_world_size() == 1, "not yet support tp_size!=1" + assert ( + self._param.data.is_contiguous() + ), f"not yet support non-contiguous tensor {self._param.shape=} {self._param.stride()=}" + + self.shm_cpu_data = get_host_shared_memory_manager().malloc( + shape=self._param.shape, dtype=self._param.dtype + ) + + if self._rank == 0: + self.shm_cpu_data.copy_(self._param.data.to("cpu")) + self._param.data = self.shm_cpu_data + else: + _move_param_to_meta(self._module, self._param_name) + get_naive_distributed().barrier() + + def post_init(self): + if self._rank == 0: + assert ( + self.shm_cpu_data.data_ptr() == self._param.data.data_ptr() + ), f"{self.shm_cpu_data.data_ptr()=} {self._param.data.data_ptr()=} {self.shm_cpu_data=} {self._param.data=}" + + _move_param_to_meta(self._module, self._param_name) + + def create_device_tensor(self): + return self.shm_cpu_data.to("cuda", non_blocking=True) + + +def _move_param_to_cpu(param, pin_memory: bool): + cpu_data = _empty_strided_like( + param.data, + device="cpu", + pin_memory=pin_memory, + ) + cpu_data.copy_(param.data) + param.data = cpu_data + + +def _move_param_to_meta(module, param_name): + old_param = getattr(module, param_name) + old_param_type = type(old_param) + + new_data = old_param.data.to("meta") + + if old_param_type == ModelWeightParameter: + # manually checked how `w13_weight` and `w2_weight` are constructed + new_param = ModelWeightParameter( + data=new_data, + **{ + k: getattr(old_param, k) + for k in ["input_dim", "output_dim", "weight_loader"] + }, + ) + elif old_param_type == torch.nn.Parameter: + new_param = torch.nn.Parameter( + data=new_data, + requires_grad=False, + ) + else: + raise ValueError(f"Unknown {old_param_type=} {old_param=}") + + setattr(module, param_name, new_param) + + +def _empty_strided_like(x: torch.Tensor, device, pin_memory=False): + return torch.empty_strided( + size=x.size(), + stride=x.stride(), + dtype=x.dtype, + layout=x.layout, + device=device, + pin_memory=pin_memory, + ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index d32227390..a2e532096 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -85,7 +85,6 @@ class ServerArgs: max_prefill_tokens: int = 16384 schedule_policy: str = "fcfs" schedule_conservativeness: float = 1.0 - cpu_offload_gb: int = 0 page_size: Optional[int] = None hybrid_kvcache_ratio: Optional[float] = None swa_full_tokens_ratio: float = 0.8 @@ -226,6 +225,13 @@ class ServerArgs: ds_heavy_channel_type: str = "qk" ds_sparse_decode_threshold: int = 4096 + # Offloading + cpu_offload_gb: int = 0 + offload_group_size: int = -1 + offload_num_in_group: int = 1 + offload_prefetch_step: int = 1 + offload_mode: str = "cpu" + # Optimization/debug options disable_radix_cache: bool = False cuda_graph_max_bs: Optional[int] = None @@ -976,12 +982,6 @@ class ServerArgs: default=ServerArgs.schedule_conservativeness, help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.", ) - parser.add_argument( - "--cpu-offload-gb", - type=int, - default=ServerArgs.cpu_offload_gb, - help="How many GBs of RAM to reserve for CPU offloading.", - ) parser.add_argument( "--page-size", type=int, @@ -1683,6 +1683,38 @@ class ServerArgs: help="The type of heavy channels in double sparsity attention", ) + # Offloading + parser.add_argument( + "--cpu-offload-gb", + type=int, + default=ServerArgs.cpu_offload_gb, + help="How many GBs of RAM to reserve for CPU offloading.", + ) + parser.add_argument( + "--offload-group-size", + type=int, + default=ServerArgs.offload_group_size, + help="Number of layers per group in offloading.", + ) + parser.add_argument( + "--offload-num-in-group", + type=int, + default=ServerArgs.offload_num_in_group, + help="Number of layers to be offloaded within a group.", + ) + parser.add_argument( + "--offload-prefetch-step", + type=int, + default=ServerArgs.offload_prefetch_step, + help="Steps to prefetch in offloading.", + ) + parser.add_argument( + "--offload-mode", + type=str, + default=ServerArgs.offload_mode, + help="Mode of offloading.", + ) + # Optimization/debug options parser.add_argument( "--disable-radix-cache", diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 6979be0d4..cb5e4cd1e 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -2954,3 +2954,13 @@ class ConcurrentCounter: @lru_cache(maxsize=1) def is_triton_kernels_available() -> bool: return importlib.util.find_spec("triton_kernels") is not None + + +def check_cuda_result(raw_output): + import cuda.bindings.runtime as cuda_rt + + err, *results = raw_output + if err != cuda_rt.cudaError_t.cudaSuccess: + raise Exception(f"CUDA error: {err}") + + return results