[PD] Add custom memory pool option to support Mooncake PD with NVLink (#7264)

Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
shangmingc
2025-06-18 08:21:37 +08:00
committed by GitHub
parent ceaa85c9e6
commit c26d7349d3
4 changed files with 163 additions and 47 deletions

View File

@@ -26,6 +26,8 @@ KVCache actually holds the physical kv cache.
import abc
import logging
import os
from contextlib import nullcontext
from typing import List, Optional, Tuple, Union
import numpy as np
@@ -34,7 +36,7 @@ import triton
import triton.language as tl
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import debug_timing, is_cuda, next_power_of_2
from sglang.srt.utils import debug_timing, get_bool_env_var, is_cuda, next_power_of_2
logger = logging.getLogger(__name__)
@@ -260,6 +262,22 @@ class MHATokenToKVPool(KVCache):
self.head_num = head_num
self.head_dim = head_dim
# for disagg with nvlink
self.enable_custom_mem_pool = get_bool_env_var(
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
)
if self.enable_custom_mem_pool:
from sglang.srt.disaggregation.mooncake.memory_pool import (
MooncakeNVLinkAllocator,
)
# TODO(shangming): abstract custom allocator class for more backends
allocator = MooncakeNVLinkAllocator.get_allocator(self.device)
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
else:
self.custom_mem_pool = None
self._create_buffers()
# used for chunked cpu-offloading
@@ -275,24 +293,29 @@ class MHATokenToKVPool(KVCache):
def _create_buffers(self):
with self.memory_saver_adapter.region():
# [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.k_buffer = [
torch.zeros(
(self.size + self.page_size, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
self.v_buffer = [
torch.zeros(
(self.size + self.page_size, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
with (
torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.enable_custom_mem_pool
else nullcontext()
):
# [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.k_buffer = [
torch.zeros(
(self.size + self.page_size, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
self.v_buffer = [
torch.zeros(
(self.size + self.page_size, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
self.data_ptrs = torch.tensor(
[x.data_ptr() for x in self.k_buffer + self.v_buffer],
@@ -349,6 +372,9 @@ class MHATokenToKVPool(KVCache):
]
return kv_data_ptrs, kv_data_lens, kv_item_lens
def maybe_get_custom_mem_pool(self):
return self.custom_mem_pool
def get_cpu_copy(self, indices):
torch.cuda.synchronize()
kv_cache_cpu = []
@@ -569,16 +595,36 @@ class MLATokenToKVPool(KVCache):
self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
# for disagg with nvlink
self.enable_custom_mem_pool = get_bool_env_var(
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
)
if self.enable_custom_mem_pool:
from sglang.srt.disaggregation.mooncake.memory_pool import (
MooncakeNVLinkAllocator,
)
# TODO(shangming): abstract custom allocator class for more backends
allocator = MooncakeNVLinkAllocator.get_allocator(self.device)
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
else:
self.custom_mem_pool = None
with self.memory_saver_adapter.region():
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.kv_buffer = [
torch.zeros(
(size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
dtype=self.store_dtype,
device=device,
)
for _ in range(layer_num)
]
with (
torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.custom_mem_pool
else nullcontext()
):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.kv_buffer = [
torch.zeros(
(size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
dtype=self.store_dtype,
device=device,
)
for _ in range(layer_num)
]
self.layer_transfer_counter = None
@@ -604,6 +650,9 @@ class MLATokenToKVPool(KVCache):
]
return kv_data_ptrs, kv_data_lens, kv_item_lens
def maybe_get_custom_mem_pool(self):
return self.custom_mem_pool
def get_key_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)