[PD] Add custom memory pool option to support Mooncake PD with NVLink (#7264)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
47
python/sglang/srt/disaggregation/mooncake/memory_pool.py
Normal file
47
python/sglang/srt/disaggregation/mooncake/memory_pool.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import os
|
||||
import threading
|
||||
from importlib import resources
|
||||
from typing import Dict, Final, Optional
|
||||
|
||||
import torch
|
||||
from torch.cuda.memory import CUDAPluggableAllocator
|
||||
|
||||
|
||||
# TODO(shangming): move this class into mooncake's package for more general use cases
|
||||
class MooncakeNVLinkAllocator:
|
||||
_instances: Dict[torch.device, CUDAPluggableAllocator] = {}
|
||||
_lock: Final = threading.Lock()
|
||||
|
||||
@classmethod
|
||||
def _get_so_path(cls) -> str:
|
||||
"""Dynamically locate hook.so in the mooncake package installation"""
|
||||
try:
|
||||
# Attempt to locate package resource
|
||||
with resources.path("mooncake", "hook.so") as so_path:
|
||||
if so_path.exists():
|
||||
return str(so_path)
|
||||
except (ImportError, FileNotFoundError, TypeError):
|
||||
pass
|
||||
|
||||
# Fallback strategy: check in package location via import metadata
|
||||
try:
|
||||
import mooncake
|
||||
|
||||
base_path = os.path.dirname(os.path.abspath(mooncake.__file__))
|
||||
so_path = os.path.join(base_path, "hook.so")
|
||||
if os.path.exists(so_path):
|
||||
return so_path
|
||||
except (ImportError, FileNotFoundError, TypeError):
|
||||
raise ImportError(
|
||||
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL require mooncake-transfer-engine >= 0.3.3.post2."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_allocator(cls, device: torch.device) -> CUDAPluggableAllocator:
|
||||
with cls._lock:
|
||||
if device not in cls._instances:
|
||||
so_path = cls._get_so_path()
|
||||
cls._instances[device] = CUDAPluggableAllocator(
|
||||
so_path, "mc_nvlink_malloc", "mc_nvlink_free"
|
||||
)
|
||||
return cls._instances[device]
|
||||
@@ -6,6 +6,7 @@ import random
|
||||
import threading
|
||||
import warnings
|
||||
from collections import deque
|
||||
from contextlib import nullcontext
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
@@ -84,24 +85,37 @@ class ReqToMetadataIdxAllocator:
|
||||
|
||||
|
||||
class MetadataBuffers:
|
||||
def __init__(self, size: int, max_top_logprobs_num: int = 128):
|
||||
# TODO: abort top_logprobs_num > 128 in PD
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
max_top_logprobs_num: int = 128,
|
||||
custom_mem_pool: torch.cuda.MemPool = None,
|
||||
):
|
||||
self.custom_mem_pool = custom_mem_pool
|
||||
device = "cuda" if self.custom_mem_pool else "cpu"
|
||||
|
||||
# We transfer the metadata of first output token to decode
|
||||
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
|
||||
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device="cpu")
|
||||
self.output_token_logprobs_val = torch.zeros(
|
||||
(size, 16), dtype=torch.float32, device="cpu"
|
||||
)
|
||||
self.output_token_logprobs_idx = torch.zeros(
|
||||
(size, 16), dtype=torch.int32, device="cpu"
|
||||
)
|
||||
self.output_top_logprobs_val = torch.zeros(
|
||||
(size, max_top_logprobs_num), dtype=torch.float32, device="cpu"
|
||||
)
|
||||
self.output_top_logprobs_idx = torch.zeros(
|
||||
(size, max_top_logprobs_num), dtype=torch.int32, device="cpu"
|
||||
)
|
||||
with (
|
||||
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
||||
if self.custom_mem_pool
|
||||
else nullcontext()
|
||||
):
|
||||
# TODO: abort top_logprobs_num > 128 in PD
|
||||
|
||||
# We transfer the metadata of first output token to decode
|
||||
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
|
||||
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
|
||||
self.output_token_logprobs_val = torch.zeros(
|
||||
(size, 16), dtype=torch.float32, device=device
|
||||
)
|
||||
self.output_token_logprobs_idx = torch.zeros(
|
||||
(size, 16), dtype=torch.int32, device=device
|
||||
)
|
||||
self.output_top_logprobs_val = torch.zeros(
|
||||
(size, max_top_logprobs_num), dtype=torch.float32, device=device
|
||||
)
|
||||
self.output_top_logprobs_idx = torch.zeros(
|
||||
(size, max_top_logprobs_num), dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
def get_buf_infos(self):
|
||||
ptrs = [
|
||||
|
||||
@@ -622,7 +622,10 @@ class Scheduler(
|
||||
self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
||||
buffer_size
|
||||
)
|
||||
self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
|
||||
self.disagg_metadata_buffers = MetadataBuffers(
|
||||
buffer_size,
|
||||
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
|
||||
)
|
||||
|
||||
# The decode requests polling kv cache
|
||||
self.disagg_decode_transfer_queue = DecodeTransferQueue(
|
||||
@@ -669,7 +672,10 @@ class Scheduler(
|
||||
self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
||||
buffer_size
|
||||
)
|
||||
self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
|
||||
self.disagg_metadata_buffers = MetadataBuffers(
|
||||
buffer_size,
|
||||
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
|
||||
)
|
||||
|
||||
self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
|
||||
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
|
||||
|
||||
@@ -26,6 +26,8 @@ KVCache actually holds the physical kv cache.
|
||||
|
||||
import abc
|
||||
import logging
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -34,7 +36,7 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.utils import debug_timing, is_cuda, next_power_of_2
|
||||
from sglang.srt.utils import debug_timing, get_bool_env_var, is_cuda, next_power_of_2
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -260,6 +262,22 @@ class MHATokenToKVPool(KVCache):
|
||||
|
||||
self.head_num = head_num
|
||||
self.head_dim = head_dim
|
||||
|
||||
# for disagg with nvlink
|
||||
self.enable_custom_mem_pool = get_bool_env_var(
|
||||
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
|
||||
)
|
||||
if self.enable_custom_mem_pool:
|
||||
from sglang.srt.disaggregation.mooncake.memory_pool import (
|
||||
MooncakeNVLinkAllocator,
|
||||
)
|
||||
|
||||
# TODO(shangming): abstract custom allocator class for more backends
|
||||
allocator = MooncakeNVLinkAllocator.get_allocator(self.device)
|
||||
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
|
||||
else:
|
||||
self.custom_mem_pool = None
|
||||
|
||||
self._create_buffers()
|
||||
|
||||
# used for chunked cpu-offloading
|
||||
@@ -275,24 +293,29 @@ class MHATokenToKVPool(KVCache):
|
||||
|
||||
def _create_buffers(self):
|
||||
with self.memory_saver_adapter.region():
|
||||
# [size, head_num, head_dim] for each layer
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.k_buffer = [
|
||||
torch.zeros(
|
||||
(self.size + self.page_size, self.head_num, self.head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.layer_num)
|
||||
]
|
||||
self.v_buffer = [
|
||||
torch.zeros(
|
||||
(self.size + self.page_size, self.head_num, self.head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.layer_num)
|
||||
]
|
||||
with (
|
||||
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
||||
if self.enable_custom_mem_pool
|
||||
else nullcontext()
|
||||
):
|
||||
# [size, head_num, head_dim] for each layer
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.k_buffer = [
|
||||
torch.zeros(
|
||||
(self.size + self.page_size, self.head_num, self.head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.layer_num)
|
||||
]
|
||||
self.v_buffer = [
|
||||
torch.zeros(
|
||||
(self.size + self.page_size, self.head_num, self.head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.layer_num)
|
||||
]
|
||||
|
||||
self.data_ptrs = torch.tensor(
|
||||
[x.data_ptr() for x in self.k_buffer + self.v_buffer],
|
||||
@@ -349,6 +372,9 @@ class MHATokenToKVPool(KVCache):
|
||||
]
|
||||
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
||||
|
||||
def maybe_get_custom_mem_pool(self):
|
||||
return self.custom_mem_pool
|
||||
|
||||
def get_cpu_copy(self, indices):
|
||||
torch.cuda.synchronize()
|
||||
kv_cache_cpu = []
|
||||
@@ -569,16 +595,36 @@ class MLATokenToKVPool(KVCache):
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
|
||||
# for disagg with nvlink
|
||||
self.enable_custom_mem_pool = get_bool_env_var(
|
||||
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
|
||||
)
|
||||
if self.enable_custom_mem_pool:
|
||||
from sglang.srt.disaggregation.mooncake.memory_pool import (
|
||||
MooncakeNVLinkAllocator,
|
||||
)
|
||||
|
||||
# TODO(shangming): abstract custom allocator class for more backends
|
||||
allocator = MooncakeNVLinkAllocator.get_allocator(self.device)
|
||||
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
|
||||
else:
|
||||
self.custom_mem_pool = None
|
||||
|
||||
with self.memory_saver_adapter.region():
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.kv_buffer = [
|
||||
torch.zeros(
|
||||
(size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=device,
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
with (
|
||||
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
||||
if self.custom_mem_pool
|
||||
else nullcontext()
|
||||
):
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.kv_buffer = [
|
||||
torch.zeros(
|
||||
(size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=device,
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
self.layer_transfer_counter = None
|
||||
|
||||
@@ -604,6 +650,9 @@ class MLATokenToKVPool(KVCache):
|
||||
]
|
||||
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
||||
|
||||
def maybe_get_custom_mem_pool(self):
|
||||
return self.custom_mem_pool
|
||||
|
||||
def get_key_buffer(self, layer_id: int):
|
||||
if self.layer_transfer_counter is not None:
|
||||
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
||||
|
||||
Reference in New Issue
Block a user