[PD] Add custom memory pool option to support Mooncake PD with NVLink (#7264)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
@@ -26,6 +26,8 @@ KVCache actually holds the physical kv cache.
|
||||
|
||||
import abc
|
||||
import logging
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -34,7 +36,7 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.utils import debug_timing, is_cuda, next_power_of_2
|
||||
from sglang.srt.utils import debug_timing, get_bool_env_var, is_cuda, next_power_of_2
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -260,6 +262,22 @@ class MHATokenToKVPool(KVCache):
|
||||
|
||||
self.head_num = head_num
|
||||
self.head_dim = head_dim
|
||||
|
||||
# for disagg with nvlink
|
||||
self.enable_custom_mem_pool = get_bool_env_var(
|
||||
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
|
||||
)
|
||||
if self.enable_custom_mem_pool:
|
||||
from sglang.srt.disaggregation.mooncake.memory_pool import (
|
||||
MooncakeNVLinkAllocator,
|
||||
)
|
||||
|
||||
# TODO(shangming): abstract custom allocator class for more backends
|
||||
allocator = MooncakeNVLinkAllocator.get_allocator(self.device)
|
||||
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
|
||||
else:
|
||||
self.custom_mem_pool = None
|
||||
|
||||
self._create_buffers()
|
||||
|
||||
# used for chunked cpu-offloading
|
||||
@@ -275,24 +293,29 @@ class MHATokenToKVPool(KVCache):
|
||||
|
||||
def _create_buffers(self):
|
||||
with self.memory_saver_adapter.region():
|
||||
# [size, head_num, head_dim] for each layer
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.k_buffer = [
|
||||
torch.zeros(
|
||||
(self.size + self.page_size, self.head_num, self.head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.layer_num)
|
||||
]
|
||||
self.v_buffer = [
|
||||
torch.zeros(
|
||||
(self.size + self.page_size, self.head_num, self.head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.layer_num)
|
||||
]
|
||||
with (
|
||||
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
||||
if self.enable_custom_mem_pool
|
||||
else nullcontext()
|
||||
):
|
||||
# [size, head_num, head_dim] for each layer
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.k_buffer = [
|
||||
torch.zeros(
|
||||
(self.size + self.page_size, self.head_num, self.head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.layer_num)
|
||||
]
|
||||
self.v_buffer = [
|
||||
torch.zeros(
|
||||
(self.size + self.page_size, self.head_num, self.head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.layer_num)
|
||||
]
|
||||
|
||||
self.data_ptrs = torch.tensor(
|
||||
[x.data_ptr() for x in self.k_buffer + self.v_buffer],
|
||||
@@ -349,6 +372,9 @@ class MHATokenToKVPool(KVCache):
|
||||
]
|
||||
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
||||
|
||||
def maybe_get_custom_mem_pool(self):
|
||||
return self.custom_mem_pool
|
||||
|
||||
def get_cpu_copy(self, indices):
|
||||
torch.cuda.synchronize()
|
||||
kv_cache_cpu = []
|
||||
@@ -569,16 +595,36 @@ class MLATokenToKVPool(KVCache):
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
|
||||
# for disagg with nvlink
|
||||
self.enable_custom_mem_pool = get_bool_env_var(
|
||||
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
|
||||
)
|
||||
if self.enable_custom_mem_pool:
|
||||
from sglang.srt.disaggregation.mooncake.memory_pool import (
|
||||
MooncakeNVLinkAllocator,
|
||||
)
|
||||
|
||||
# TODO(shangming): abstract custom allocator class for more backends
|
||||
allocator = MooncakeNVLinkAllocator.get_allocator(self.device)
|
||||
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
|
||||
else:
|
||||
self.custom_mem_pool = None
|
||||
|
||||
with self.memory_saver_adapter.region():
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.kv_buffer = [
|
||||
torch.zeros(
|
||||
(size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=device,
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
with (
|
||||
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
||||
if self.custom_mem_pool
|
||||
else nullcontext()
|
||||
):
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.kv_buffer = [
|
||||
torch.zeros(
|
||||
(size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=device,
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
self.layer_transfer_counter = None
|
||||
|
||||
@@ -604,6 +650,9 @@ class MLATokenToKVPool(KVCache):
|
||||
]
|
||||
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
||||
|
||||
def maybe_get_custom_mem_pool(self):
|
||||
return self.custom_mem_pool
|
||||
|
||||
def get_key_buffer(self, layer_id: int):
|
||||
if self.layer_transfer_counter is not None:
|
||||
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
||||
|
||||
Reference in New Issue
Block a user