Overlapped weight offload (#8034)
This commit is contained in:
112
python/sglang/srt/distributed/naive_distributed.py
Normal file
112
python/sglang/srt/distributed/naive_distributed.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
import base64
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.utils import MultiprocessingSerializer
|
||||||
|
|
||||||
|
|
||||||
|
class NaiveDistributed:
|
||||||
|
def __init__(self, rank: int, world_size: int, rendezvous: str):
|
||||||
|
self._rank = rank
|
||||||
|
self._world_size = world_size
|
||||||
|
self._operation_index = 0
|
||||||
|
self._directory = Path(rendezvous)
|
||||||
|
self._directory.mkdir(parents=True, exist_ok=True)
|
||||||
|
assert 0 <= rank < world_size
|
||||||
|
|
||||||
|
# both barrier to be safe, and as a sanity check
|
||||||
|
self.barrier()
|
||||||
|
|
||||||
|
def get_rank(self):
|
||||||
|
return self._rank
|
||||||
|
|
||||||
|
def get_world_size(self):
|
||||||
|
return self._world_size
|
||||||
|
|
||||||
|
def scatter(
|
||||||
|
self, tensor: torch.Tensor, scatter_list: List[torch.Tensor], src: int = 0
|
||||||
|
):
|
||||||
|
if self._rank == src:
|
||||||
|
assert len(scatter_list) == self._world_size
|
||||||
|
else:
|
||||||
|
assert scatter_list is None
|
||||||
|
|
||||||
|
gathered_objects = self.all_gather_object(
|
||||||
|
dict(
|
||||||
|
serialized_scatter_list=[
|
||||||
|
(
|
||||||
|
None
|
||||||
|
if item_rank == src
|
||||||
|
else MultiprocessingSerializer.serialize(item)
|
||||||
|
)
|
||||||
|
for item_rank, item in enumerate(scatter_list)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if self._rank == src
|
||||||
|
else dict()
|
||||||
|
)
|
||||||
|
|
||||||
|
remote_serialized_tensor = gathered_objects[src]["serialized_scatter_list"][
|
||||||
|
self._rank
|
||||||
|
]
|
||||||
|
if self._rank == src:
|
||||||
|
assert remote_serialized_tensor is None
|
||||||
|
remote_tensor = scatter_list[self._rank]
|
||||||
|
else:
|
||||||
|
remote_tensor = MultiprocessingSerializer.deserialize(
|
||||||
|
remote_serialized_tensor
|
||||||
|
)
|
||||||
|
tensor.copy_(remote_tensor)
|
||||||
|
|
||||||
|
# avoid src tensor be deleted too early
|
||||||
|
self.barrier()
|
||||||
|
|
||||||
|
def all_gather_object(self, obj: Any) -> List[Any]:
|
||||||
|
self._operation_index += 1
|
||||||
|
|
||||||
|
text_postfix = "\n"
|
||||||
|
|
||||||
|
def _get_path(interesting_rank: int):
|
||||||
|
return (
|
||||||
|
self._directory
|
||||||
|
/ f"rank{interesting_rank}_op{self._operation_index}.txt"
|
||||||
|
)
|
||||||
|
|
||||||
|
_get_path(self._rank).write_text(
|
||||||
|
base64.b64encode(pickle.dumps(obj)).decode("utf-8") + text_postfix
|
||||||
|
)
|
||||||
|
|
||||||
|
def _read_one(interesting_rank: int):
|
||||||
|
p = _get_path(interesting_rank)
|
||||||
|
while True:
|
||||||
|
if p.exists() and (text := p.read_text()).endswith(text_postfix):
|
||||||
|
return pickle.loads(base64.b64decode(text[: -len(text_postfix)]))
|
||||||
|
time.sleep(0.001)
|
||||||
|
|
||||||
|
return [
|
||||||
|
_read_one(interesting_rank) for interesting_rank in range(self._world_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
def barrier(self):
|
||||||
|
actual_objs = self.all_gather_object(self._rank)
|
||||||
|
assert actual_objs == list(range(self._world_size)), f"{actual_objs=}"
|
||||||
|
|
||||||
|
|
||||||
|
# Can have multi instances if needed
|
||||||
|
_instance: Optional[NaiveDistributed] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_naive_distributed():
|
||||||
|
assert _instance is not None
|
||||||
|
return _instance
|
||||||
|
|
||||||
|
|
||||||
|
def set_naive_distributed(instance: NaiveDistributed):
|
||||||
|
global _instance
|
||||||
|
assert _instance is None
|
||||||
|
_instance = instance
|
||||||
@@ -23,8 +23,10 @@ import dataclasses
|
|||||||
import logging
|
import logging
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
import signal
|
import signal
|
||||||
import threading
|
import threading
|
||||||
|
import time
|
||||||
from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union
|
from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import zmq
|
import zmq
|
||||||
@@ -654,6 +656,11 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|||||||
# flashinfer uses this environment variable for various kernels from MoE to quant kernels
|
# flashinfer uses this environment variable for various kernels from MoE to quant kernels
|
||||||
os.environ["TRTLLM_ENABLE_PDL"] = "1"
|
os.environ["TRTLLM_ENABLE_PDL"] = "1"
|
||||||
|
|
||||||
|
# Can also be passed as argument
|
||||||
|
os.environ["SGLANG_RUN_ID"] = (
|
||||||
|
f"sglang-run-{time.time()}-{random.randint(0, 100000000)}"
|
||||||
|
)
|
||||||
|
|
||||||
# Set prometheus env vars
|
# Set prometheus env vars
|
||||||
if server_args.enable_metrics:
|
if server_args.enable_metrics:
|
||||||
set_prometheus_multiproc_dir()
|
set_prometheus_multiproc_dir()
|
||||||
|
|||||||
83
python/sglang/srt/host_shared_memory.py
Normal file
83
python/sglang/srt/host_shared_memory.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from multiprocessing import shared_memory
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.distributed.naive_distributed import get_naive_distributed
|
||||||
|
from sglang.srt.utils import check_cuda_result
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class HostSharedMemoryManager:
|
||||||
|
def __init__(self, base_name: str):
|
||||||
|
self._base_name = Path(base_name)
|
||||||
|
self._operation_index = 0
|
||||||
|
self._records: List[_Record] = []
|
||||||
|
|
||||||
|
def malloc(self, *, shape, dtype):
|
||||||
|
meta_tensor = torch.empty(size=shape, dtype=dtype, device="meta")
|
||||||
|
raw = self._malloc_raw(num_bytes=meta_tensor.nbytes)
|
||||||
|
return raw.view(dtype).view(*shape)
|
||||||
|
|
||||||
|
def _malloc_raw(self, *, num_bytes: int) -> torch.Tensor:
|
||||||
|
import cuda.bindings.runtime as cuda_rt
|
||||||
|
|
||||||
|
self._operation_index += 1
|
||||||
|
shm_name = f"{self._base_name}_op{self._operation_index}"
|
||||||
|
|
||||||
|
# TODO handle dispose
|
||||||
|
if get_naive_distributed().get_rank() == 0:
|
||||||
|
shm = shared_memory.SharedMemory(name=shm_name, create=True, size=num_bytes)
|
||||||
|
|
||||||
|
get_naive_distributed().barrier()
|
||||||
|
|
||||||
|
if get_naive_distributed().get_rank() != 0:
|
||||||
|
shm = shared_memory.SharedMemory(name=shm_name)
|
||||||
|
|
||||||
|
np_array = np.ndarray((num_bytes,), dtype=np.uint8, buffer=shm.buf)
|
||||||
|
tensor = torch.from_numpy(np_array)
|
||||||
|
|
||||||
|
check_cuda_result(
|
||||||
|
cuda_rt.cudaHostRegister(
|
||||||
|
tensor.data_ptr(), num_bytes, cuda_rt.cudaHostRegisterPortable
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
get_naive_distributed().barrier()
|
||||||
|
|
||||||
|
self._records.append(
|
||||||
|
_Record(
|
||||||
|
shm=shm,
|
||||||
|
np_array=np_array,
|
||||||
|
tensor=tensor,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _Record:
|
||||||
|
shm: shared_memory.SharedMemory
|
||||||
|
np_array: np.ndarray
|
||||||
|
tensor: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
# Can have multi instances if needed
|
||||||
|
_instance: Optional[HostSharedMemoryManager] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_host_shared_memory_manager():
|
||||||
|
assert _instance is not None
|
||||||
|
return _instance
|
||||||
|
|
||||||
|
|
||||||
|
def set_host_shared_memory_manager(instance: HostSharedMemoryManager):
|
||||||
|
global _instance
|
||||||
|
assert _instance is None
|
||||||
|
_instance = instance
|
||||||
@@ -92,6 +92,7 @@ class TpModelWorker:
|
|||||||
pp_rank=pp_rank,
|
pp_rank=pp_rank,
|
||||||
pp_size=server_args.pp_size,
|
pp_size=server_args.pp_size,
|
||||||
nccl_port=nccl_port,
|
nccl_port=nccl_port,
|
||||||
|
dp_rank=dp_rank,
|
||||||
server_args=server_args,
|
server_args=server_args,
|
||||||
is_draft_worker=is_draft_worker,
|
is_draft_worker=is_draft_worker,
|
||||||
req_to_token_pool=req_to_token_pool,
|
req_to_token_pool=req_to_token_pool,
|
||||||
|
|||||||
@@ -172,6 +172,7 @@ class ModelRunner:
|
|||||||
pp_size: int,
|
pp_size: int,
|
||||||
nccl_port: int,
|
nccl_port: int,
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
|
dp_rank: Optional[int] = None,
|
||||||
is_draft_worker: bool = False,
|
is_draft_worker: bool = False,
|
||||||
req_to_token_pool: Optional[ReqToTokenPool] = None,
|
req_to_token_pool: Optional[ReqToTokenPool] = None,
|
||||||
token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
|
token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
|
||||||
@@ -234,7 +235,7 @@ class ModelRunner:
|
|||||||
min_per_gpu_memory = self.init_torch_distributed()
|
min_per_gpu_memory = self.init_torch_distributed()
|
||||||
|
|
||||||
# CPU offload
|
# CPU offload
|
||||||
set_offloader(create_offloader_from_server_args(server_args))
|
set_offloader(create_offloader_from_server_args(server_args, dp_rank=dp_rank))
|
||||||
|
|
||||||
# Update deep gemm configure
|
# Update deep gemm configure
|
||||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
||||||
|
|||||||
@@ -1996,6 +1996,23 @@ class DeepseekV2Model(nn.Module):
|
|||||||
pp_rank=self.pp_group.rank_in_group,
|
pp_rank=self.pp_group.rank_in_group,
|
||||||
pp_size=self.pp_group.world_size,
|
pp_size=self.pp_group.world_size,
|
||||||
prefix=add_prefix("layers", prefix),
|
prefix=add_prefix("layers", prefix),
|
||||||
|
offloader_kwargs=dict(
|
||||||
|
submodule_accessor=lambda layer: (
|
||||||
|
layer.mlp.experts
|
||||||
|
if isinstance(layer.mlp, DeepseekV2MoE)
|
||||||
|
else layer.mlp
|
||||||
|
),
|
||||||
|
whitelist_param_names_creator=lambda module: (
|
||||||
|
[
|
||||||
|
"w13_weight",
|
||||||
|
"w2_weight",
|
||||||
|
"w13_blockscale_swizzled",
|
||||||
|
"w2_blockscale_swizzled",
|
||||||
|
]
|
||||||
|
if isinstance(module, FusedMoE)
|
||||||
|
else []
|
||||||
|
),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
if self.pp_group.is_last_rank:
|
if self.pp_group.is_last_rank:
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|||||||
@@ -1,12 +1,24 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import Callable, Generator, List, Optional
|
from typing import Callable, Generator, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.func import functional_call
|
from torch.func import functional_call
|
||||||
|
|
||||||
|
from sglang.srt.distributed.naive_distributed import (
|
||||||
|
NaiveDistributed,
|
||||||
|
get_naive_distributed,
|
||||||
|
set_naive_distributed,
|
||||||
|
)
|
||||||
|
from sglang.srt.host_shared_memory import (
|
||||||
|
HostSharedMemoryManager,
|
||||||
|
get_host_shared_memory_manager,
|
||||||
|
set_host_shared_memory_manager,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.parameter import ModelWeightParameter
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import is_pin_memory_available
|
from sglang.srt.utils import MultiprocessingSerializer, is_pin_memory_available
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -45,11 +57,23 @@ def set_offloader(instance: BaseOffloader):
|
|||||||
_instance = instance
|
_instance = instance
|
||||||
|
|
||||||
|
|
||||||
def create_offloader_from_server_args(server_args: ServerArgs):
|
def create_offloader_from_server_args(server_args: ServerArgs, dp_rank: int):
|
||||||
if server_args.cpu_offload_gb > 0:
|
if server_args.cpu_offload_gb > 0:
|
||||||
return OffloaderV1(
|
return OffloaderV1(
|
||||||
cpu_offload_max_bytes=int(server_args.cpu_offload_gb * 1024**3)
|
cpu_offload_max_bytes=int(server_args.cpu_offload_gb * 1024**3)
|
||||||
)
|
)
|
||||||
|
if server_args.offload_group_size > 0:
|
||||||
|
assert (
|
||||||
|
server_args.cpu_offload_gb == 0
|
||||||
|
), "V2 offload does not support cpu_offload_gb yet"
|
||||||
|
return OffloaderV2(
|
||||||
|
group_size=server_args.offload_group_size,
|
||||||
|
num_in_group=server_args.offload_num_in_group,
|
||||||
|
prefetch_step=server_args.offload_prefetch_step,
|
||||||
|
mode=server_args.offload_mode,
|
||||||
|
dp_rank=dp_rank,
|
||||||
|
dp_size=server_args.dp_size,
|
||||||
|
)
|
||||||
return NoopOffloader()
|
return NoopOffloader()
|
||||||
|
|
||||||
|
|
||||||
@@ -120,3 +144,290 @@ class OffloaderV1(BaseOffloader):
|
|||||||
module.forward = forward
|
module.forward = forward
|
||||||
|
|
||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
class OffloaderV2(BaseOffloader):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
group_size: int,
|
||||||
|
num_in_group: int,
|
||||||
|
prefetch_step: int,
|
||||||
|
mode: str,
|
||||||
|
dp_rank: int,
|
||||||
|
dp_size: int,
|
||||||
|
):
|
||||||
|
self.group_size = group_size
|
||||||
|
self.num_in_group = num_in_group
|
||||||
|
self.prefetch_step = prefetch_step
|
||||||
|
self.mode = mode
|
||||||
|
|
||||||
|
run_id = os.environ["SGLANG_RUN_ID"]
|
||||||
|
|
||||||
|
# Temporarily init inside Offloader, can move if other modules also need this
|
||||||
|
if self.mode in {"sharded_gpu", "shm_cpu"}:
|
||||||
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||||
|
|
||||||
|
assert (
|
||||||
|
get_tensor_model_parallel_world_size() == 1
|
||||||
|
), "not yet support tp_size!=1"
|
||||||
|
set_naive_distributed(
|
||||||
|
NaiveDistributed(
|
||||||
|
rank=dp_rank,
|
||||||
|
world_size=dp_size,
|
||||||
|
rendezvous=f"/tmp/{run_id}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if self.mode in {"shm_cpu"}:
|
||||||
|
set_host_shared_memory_manager(
|
||||||
|
HostSharedMemoryManager(
|
||||||
|
base_name=run_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.offloaders = []
|
||||||
|
|
||||||
|
def wrap_modules(
|
||||||
|
self,
|
||||||
|
all_modules_generator: Generator[torch.nn.Module, None, None],
|
||||||
|
submodule_accessor: Optional[_SubmoduleAccessor] = None,
|
||||||
|
whitelist_param_names_creator: Optional[_WhitelistParamNamesCreator] = None,
|
||||||
|
):
|
||||||
|
assert len(self.offloaders) == 0, "should only call wrap_modules once"
|
||||||
|
|
||||||
|
alt_stream = torch.cuda.Stream()
|
||||||
|
|
||||||
|
all_modules = []
|
||||||
|
offload_submodules = []
|
||||||
|
for module_index, module in enumerate(all_modules_generator):
|
||||||
|
all_modules.append(module)
|
||||||
|
if module_index % self.group_size >= self.group_size - self.num_in_group:
|
||||||
|
submodule = submodule_accessor(module)
|
||||||
|
whitelist_param_names = whitelist_param_names_creator(submodule)
|
||||||
|
logger.info(
|
||||||
|
f"[offloader] offload {module_index=} submodule={type(submodule)} params={whitelist_param_names} memory_allocated={torch.cuda.memory_allocated()}"
|
||||||
|
)
|
||||||
|
offload_submodules.append(submodule)
|
||||||
|
self.offloaders.append(
|
||||||
|
_ModuleOffloader(
|
||||||
|
mode=self.mode,
|
||||||
|
module=submodule,
|
||||||
|
alt_stream=alt_stream,
|
||||||
|
whitelist_param_names=whitelist_param_names,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for index, module in enumerate(offload_submodules):
|
||||||
|
_hook_module_forward_for_offloader(
|
||||||
|
index=index,
|
||||||
|
module=module,
|
||||||
|
offloaders=self.offloaders,
|
||||||
|
prefetch_step=self.prefetch_step,
|
||||||
|
)
|
||||||
|
|
||||||
|
return all_modules
|
||||||
|
|
||||||
|
def post_init(self):
|
||||||
|
for offloader in self.offloaders:
|
||||||
|
offloader.post_init()
|
||||||
|
|
||||||
|
for i in range(self.prefetch_step):
|
||||||
|
self.offloaders[i].start_onload()
|
||||||
|
|
||||||
|
|
||||||
|
def _hook_module_forward_for_offloader(index, module, offloaders, prefetch_step):
|
||||||
|
def _on_forward_end():
|
||||||
|
offloaders[(index + prefetch_step) % len(offloaders)].start_onload()
|
||||||
|
offloaders[index].offload()
|
||||||
|
|
||||||
|
_hook_module_forward_raw(
|
||||||
|
module,
|
||||||
|
on_forward_end=_on_forward_end,
|
||||||
|
get_parameter_and_buffer_dicts=lambda: offloaders[
|
||||||
|
index
|
||||||
|
].wait_and_get_device_tensors(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _hook_module_forward_raw(module, on_forward_end, get_parameter_and_buffer_dicts):
|
||||||
|
original_forward = module.forward
|
||||||
|
|
||||||
|
def forward(*args, **kwargs):
|
||||||
|
module.forward = original_forward
|
||||||
|
output = functional_call(
|
||||||
|
module, get_parameter_and_buffer_dicts(), args=args, kwargs=kwargs
|
||||||
|
)
|
||||||
|
on_forward_end()
|
||||||
|
module.forward = forward
|
||||||
|
return output
|
||||||
|
|
||||||
|
module.forward = forward
|
||||||
|
|
||||||
|
|
||||||
|
class _ModuleOffloader(ABC):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mode: str,
|
||||||
|
module: torch.nn.Module,
|
||||||
|
alt_stream: torch.cuda.Stream,
|
||||||
|
whitelist_param_names: List[str],
|
||||||
|
):
|
||||||
|
self.mode = mode
|
||||||
|
self.module = module
|
||||||
|
self.device = next(module.parameters()).device
|
||||||
|
self.alt_stream = alt_stream
|
||||||
|
|
||||||
|
assert self.device != torch.device(
|
||||||
|
"cpu"
|
||||||
|
), "not handled device=cpu case yet (should skip this tensor)"
|
||||||
|
|
||||||
|
self._device_tensors = None
|
||||||
|
self._load_event = None
|
||||||
|
|
||||||
|
param_dict = dict(self.module.named_parameters())
|
||||||
|
assert all(
|
||||||
|
name in param_dict for name in whitelist_param_names
|
||||||
|
), f"{whitelist_param_names=} {list(param_dict.keys())=}"
|
||||||
|
|
||||||
|
self._param_offloaders = {
|
||||||
|
name: _BaseParamOffloader.create(mode, module=module, param_name=name)
|
||||||
|
for name in whitelist_param_names
|
||||||
|
}
|
||||||
|
|
||||||
|
def post_init(self):
|
||||||
|
for name, param_offloader in self._param_offloaders.items():
|
||||||
|
param_offloader.post_init()
|
||||||
|
|
||||||
|
def start_onload(self):
|
||||||
|
self.alt_stream.wait_stream(torch.cuda.current_stream())
|
||||||
|
with torch.cuda.stream(self.alt_stream):
|
||||||
|
self._device_tensors = self._create_device_tensors()
|
||||||
|
self._load_event = torch.cuda.Event()
|
||||||
|
self._load_event.record()
|
||||||
|
|
||||||
|
def offload(self):
|
||||||
|
self._device_tensors = None
|
||||||
|
self._load_event = None
|
||||||
|
|
||||||
|
def wait_and_get_device_tensors(self):
|
||||||
|
assert self._device_tensors is not None
|
||||||
|
self._load_event.wait()
|
||||||
|
return self._device_tensors
|
||||||
|
|
||||||
|
def _create_device_tensors(self):
|
||||||
|
return {k: v.create_device_tensor() for k, v in self._param_offloaders.items()}
|
||||||
|
|
||||||
|
|
||||||
|
class _BaseParamOffloader(ABC):
|
||||||
|
@staticmethod
|
||||||
|
def create(mode: str, **kwargs) -> "_BaseParamOffloader":
|
||||||
|
return {
|
||||||
|
"cpu": _CpuParamOffloader,
|
||||||
|
"shm_cpu": _ShmCpuParamOffloader,
|
||||||
|
"sharded_gpu": _ShardedGpuParamOffloader,
|
||||||
|
}[mode](**kwargs)
|
||||||
|
|
||||||
|
def __init__(self, module, param_name):
|
||||||
|
self._module = module
|
||||||
|
self._param_name = param_name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _param(self):
|
||||||
|
return getattr(self._module, self._param_name)
|
||||||
|
|
||||||
|
def post_init(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def create_device_tensor(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class _CpuParamOffloader(_BaseParamOffloader):
|
||||||
|
def __init__(self, module, param_name):
|
||||||
|
super().__init__(module, param_name)
|
||||||
|
_move_param_to_cpu(self._param, pin_memory=True)
|
||||||
|
|
||||||
|
def create_device_tensor(self):
|
||||||
|
return self._param.to("cuda", non_blocking=True)
|
||||||
|
|
||||||
|
|
||||||
|
class _ShmCpuParamOffloader(_BaseParamOffloader):
|
||||||
|
def __init__(self, module, param_name):
|
||||||
|
super().__init__(module, param_name)
|
||||||
|
self._rank = get_naive_distributed().get_rank()
|
||||||
|
self._world_size = get_naive_distributed().get_world_size()
|
||||||
|
|
||||||
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||||
|
|
||||||
|
assert get_tensor_model_parallel_world_size() == 1, "not yet support tp_size!=1"
|
||||||
|
assert (
|
||||||
|
self._param.data.is_contiguous()
|
||||||
|
), f"not yet support non-contiguous tensor {self._param.shape=} {self._param.stride()=}"
|
||||||
|
|
||||||
|
self.shm_cpu_data = get_host_shared_memory_manager().malloc(
|
||||||
|
shape=self._param.shape, dtype=self._param.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._rank == 0:
|
||||||
|
self.shm_cpu_data.copy_(self._param.data.to("cpu"))
|
||||||
|
self._param.data = self.shm_cpu_data
|
||||||
|
else:
|
||||||
|
_move_param_to_meta(self._module, self._param_name)
|
||||||
|
get_naive_distributed().barrier()
|
||||||
|
|
||||||
|
def post_init(self):
|
||||||
|
if self._rank == 0:
|
||||||
|
assert (
|
||||||
|
self.shm_cpu_data.data_ptr() == self._param.data.data_ptr()
|
||||||
|
), f"{self.shm_cpu_data.data_ptr()=} {self._param.data.data_ptr()=} {self.shm_cpu_data=} {self._param.data=}"
|
||||||
|
|
||||||
|
_move_param_to_meta(self._module, self._param_name)
|
||||||
|
|
||||||
|
def create_device_tensor(self):
|
||||||
|
return self.shm_cpu_data.to("cuda", non_blocking=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _move_param_to_cpu(param, pin_memory: bool):
|
||||||
|
cpu_data = _empty_strided_like(
|
||||||
|
param.data,
|
||||||
|
device="cpu",
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
)
|
||||||
|
cpu_data.copy_(param.data)
|
||||||
|
param.data = cpu_data
|
||||||
|
|
||||||
|
|
||||||
|
def _move_param_to_meta(module, param_name):
|
||||||
|
old_param = getattr(module, param_name)
|
||||||
|
old_param_type = type(old_param)
|
||||||
|
|
||||||
|
new_data = old_param.data.to("meta")
|
||||||
|
|
||||||
|
if old_param_type == ModelWeightParameter:
|
||||||
|
# manually checked how `w13_weight` and `w2_weight` are constructed
|
||||||
|
new_param = ModelWeightParameter(
|
||||||
|
data=new_data,
|
||||||
|
**{
|
||||||
|
k: getattr(old_param, k)
|
||||||
|
for k in ["input_dim", "output_dim", "weight_loader"]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
elif old_param_type == torch.nn.Parameter:
|
||||||
|
new_param = torch.nn.Parameter(
|
||||||
|
data=new_data,
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown {old_param_type=} {old_param=}")
|
||||||
|
|
||||||
|
setattr(module, param_name, new_param)
|
||||||
|
|
||||||
|
|
||||||
|
def _empty_strided_like(x: torch.Tensor, device, pin_memory=False):
|
||||||
|
return torch.empty_strided(
|
||||||
|
size=x.size(),
|
||||||
|
stride=x.stride(),
|
||||||
|
dtype=x.dtype,
|
||||||
|
layout=x.layout,
|
||||||
|
device=device,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
)
|
||||||
|
|||||||
@@ -85,7 +85,6 @@ class ServerArgs:
|
|||||||
max_prefill_tokens: int = 16384
|
max_prefill_tokens: int = 16384
|
||||||
schedule_policy: str = "fcfs"
|
schedule_policy: str = "fcfs"
|
||||||
schedule_conservativeness: float = 1.0
|
schedule_conservativeness: float = 1.0
|
||||||
cpu_offload_gb: int = 0
|
|
||||||
page_size: Optional[int] = None
|
page_size: Optional[int] = None
|
||||||
hybrid_kvcache_ratio: Optional[float] = None
|
hybrid_kvcache_ratio: Optional[float] = None
|
||||||
swa_full_tokens_ratio: float = 0.8
|
swa_full_tokens_ratio: float = 0.8
|
||||||
@@ -226,6 +225,13 @@ class ServerArgs:
|
|||||||
ds_heavy_channel_type: str = "qk"
|
ds_heavy_channel_type: str = "qk"
|
||||||
ds_sparse_decode_threshold: int = 4096
|
ds_sparse_decode_threshold: int = 4096
|
||||||
|
|
||||||
|
# Offloading
|
||||||
|
cpu_offload_gb: int = 0
|
||||||
|
offload_group_size: int = -1
|
||||||
|
offload_num_in_group: int = 1
|
||||||
|
offload_prefetch_step: int = 1
|
||||||
|
offload_mode: str = "cpu"
|
||||||
|
|
||||||
# Optimization/debug options
|
# Optimization/debug options
|
||||||
disable_radix_cache: bool = False
|
disable_radix_cache: bool = False
|
||||||
cuda_graph_max_bs: Optional[int] = None
|
cuda_graph_max_bs: Optional[int] = None
|
||||||
@@ -976,12 +982,6 @@ class ServerArgs:
|
|||||||
default=ServerArgs.schedule_conservativeness,
|
default=ServerArgs.schedule_conservativeness,
|
||||||
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
|
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--cpu-offload-gb",
|
|
||||||
type=int,
|
|
||||||
default=ServerArgs.cpu_offload_gb,
|
|
||||||
help="How many GBs of RAM to reserve for CPU offloading.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--page-size",
|
"--page-size",
|
||||||
type=int,
|
type=int,
|
||||||
@@ -1683,6 +1683,38 @@ class ServerArgs:
|
|||||||
help="The type of heavy channels in double sparsity attention",
|
help="The type of heavy channels in double sparsity attention",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Offloading
|
||||||
|
parser.add_argument(
|
||||||
|
"--cpu-offload-gb",
|
||||||
|
type=int,
|
||||||
|
default=ServerArgs.cpu_offload_gb,
|
||||||
|
help="How many GBs of RAM to reserve for CPU offloading.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--offload-group-size",
|
||||||
|
type=int,
|
||||||
|
default=ServerArgs.offload_group_size,
|
||||||
|
help="Number of layers per group in offloading.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--offload-num-in-group",
|
||||||
|
type=int,
|
||||||
|
default=ServerArgs.offload_num_in_group,
|
||||||
|
help="Number of layers to be offloaded within a group.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--offload-prefetch-step",
|
||||||
|
type=int,
|
||||||
|
default=ServerArgs.offload_prefetch_step,
|
||||||
|
help="Steps to prefetch in offloading.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--offload-mode",
|
||||||
|
type=str,
|
||||||
|
default=ServerArgs.offload_mode,
|
||||||
|
help="Mode of offloading.",
|
||||||
|
)
|
||||||
|
|
||||||
# Optimization/debug options
|
# Optimization/debug options
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--disable-radix-cache",
|
"--disable-radix-cache",
|
||||||
|
|||||||
@@ -2954,3 +2954,13 @@ class ConcurrentCounter:
|
|||||||
@lru_cache(maxsize=1)
|
@lru_cache(maxsize=1)
|
||||||
def is_triton_kernels_available() -> bool:
|
def is_triton_kernels_available() -> bool:
|
||||||
return importlib.util.find_spec("triton_kernels") is not None
|
return importlib.util.find_spec("triton_kernels") is not None
|
||||||
|
|
||||||
|
|
||||||
|
def check_cuda_result(raw_output):
|
||||||
|
import cuda.bindings.runtime as cuda_rt
|
||||||
|
|
||||||
|
err, *results = raw_output
|
||||||
|
if err != cuda_rt.cudaError_t.cudaSuccess:
|
||||||
|
raise Exception(f"CUDA error: {err}")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|||||||
Reference in New Issue
Block a user