[PD] Add PD support for hybrid model (Qwen3-Next, DeepSeek V3.2 Exp) (#10912)
Signed-off-by: Shangming Cai <csmthu@gmail.com> Co-authored-by: hzh0425 <hzh0425@apache.org> Co-authored-by: ZeldaHuang <hzm414167@alibaba-inc.com>
This commit is contained in:
@@ -20,6 +20,10 @@ class KVArgs:
|
||||
aux_data_ptrs: List[int]
|
||||
aux_data_lens: List[int]
|
||||
aux_item_lens: List[int]
|
||||
state_data_ptrs: List[int]
|
||||
state_data_lens: List[int]
|
||||
state_item_lens: List[int]
|
||||
state_type: str # "none", "mamba", "swa"
|
||||
ib_device: str
|
||||
ib_traffic_class: str
|
||||
gpu_id: int
|
||||
@@ -76,9 +80,13 @@ class BaseKVSender(ABC):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def send(self, kv_indices: npt.NDArray[np.int32]):
|
||||
def send(
|
||||
self,
|
||||
kv_indices: npt.NDArray[np.int32],
|
||||
state_indices: Optional[List[int]] = None,
|
||||
):
|
||||
"""
|
||||
Send the kv cache at the given kv indices to the decoder server
|
||||
Send the kv cache at the given kv indices and the extra cache/state at the given indices to the decoder server
|
||||
"""
|
||||
...
|
||||
|
||||
@@ -108,9 +116,14 @@ class BaseKVReceiver(ABC):
|
||||
): ...
|
||||
|
||||
@abstractmethod
|
||||
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
||||
def init(
|
||||
self,
|
||||
kv_indices: npt.NDArray[np.int32],
|
||||
aux_index: Optional[int] = None,
|
||||
state_indices: Optional[List[int]] = None,
|
||||
):
|
||||
"""
|
||||
Notify the prefill server about the kv indices and aux index
|
||||
Notify the prefill server about the kv indices, aux index, and state_indices.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@@ -201,6 +201,7 @@ class CommonKVSender(BaseKVSender):
|
||||
def send(
|
||||
self,
|
||||
kv_indices: npt.NDArray[np.int32],
|
||||
state_indices: Optional[List[int]] = None,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@@ -25,11 +25,12 @@ import time
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from sglang.srt.configs.mamba_utils import Mamba2CacheParams
|
||||
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
||||
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll
|
||||
from sglang.srt.disaggregation.utils import (
|
||||
@@ -47,9 +48,19 @@ from sglang.srt.disaggregation.utils import (
|
||||
)
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT, RequestStage, ScheduleBatch
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.allocator import (
|
||||
BaseTokenToKVPoolAllocator,
|
||||
SWATokenToKVPoolAllocator,
|
||||
)
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
HybridLinearKVPool,
|
||||
HybridReqToTokenPool,
|
||||
KVCache,
|
||||
NSATokenToKVPool,
|
||||
ReqToTokenPool,
|
||||
SWAKVPool,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.utils import get_int_env_var, require_mlp_sync
|
||||
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
@@ -124,6 +135,35 @@ class DecodeReqToTokenPool:
|
||||
self.free_slots = list(range(self.size + self.pre_alloc_size))
|
||||
|
||||
|
||||
class HybridMambaDecodeReqToTokenPool(HybridReqToTokenPool):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
max_context_len: int,
|
||||
device: str,
|
||||
enable_memory_saver: bool,
|
||||
cache_params: "Mamba2CacheParams",
|
||||
speculative_num_draft_tokens: int,
|
||||
pre_alloc_size: int,
|
||||
):
|
||||
DecodeReqToTokenPool.__init__(
|
||||
self,
|
||||
size=size,
|
||||
max_context_len=max_context_len,
|
||||
device=device,
|
||||
enable_memory_saver=enable_memory_saver,
|
||||
pre_alloc_size=pre_alloc_size,
|
||||
)
|
||||
self._init_mamba_pool(
|
||||
size + pre_alloc_size, cache_params, device, speculative_num_draft_tokens
|
||||
)
|
||||
|
||||
def clear(self):
|
||||
self.free_slots = list(range(self.size + self.pre_alloc_size))
|
||||
self.mamba_pool.clear()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecodeRequest:
|
||||
req: Req
|
||||
@@ -217,6 +257,28 @@ class DecodePreallocQueue:
|
||||
self.metadata_buffers.get_buf_infos()
|
||||
)
|
||||
|
||||
if hasattr(self.token_to_kv_pool, "get_state_buf_infos"):
|
||||
state_data_ptrs, state_data_lens, state_item_lens = (
|
||||
self.token_to_kv_pool.get_state_buf_infos()
|
||||
)
|
||||
kv_args.state_data_ptrs = state_data_ptrs
|
||||
kv_args.state_data_lens = state_data_lens
|
||||
kv_args.state_item_lens = state_item_lens
|
||||
|
||||
if isinstance(self.token_to_kv_pool, SWAKVPool):
|
||||
kv_args.state_type = "swa"
|
||||
elif isinstance(self.token_to_kv_pool, HybridLinearKVPool):
|
||||
kv_args.state_type = "mamba"
|
||||
elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):
|
||||
kv_args.state_type = "nsa"
|
||||
else:
|
||||
kv_args.state_type = "none"
|
||||
else:
|
||||
kv_args.state_data_ptrs = []
|
||||
kv_args.state_data_lens = []
|
||||
kv_args.state_item_lens = []
|
||||
kv_args.state_type = "none"
|
||||
|
||||
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
||||
kv_args.gpu_id = self.scheduler.gpu_id
|
||||
kv_manager_class: Type[BaseKVManager] = get_kv_class(
|
||||
@@ -414,16 +476,56 @@ class DecodePreallocQueue:
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
page_size = self.token_to_kv_pool_allocator.page_size
|
||||
|
||||
# Prepare extra pool indices for hybrid models
|
||||
if isinstance(self.token_to_kv_pool, HybridLinearKVPool):
|
||||
# Mamba hybrid model: single mamba state index
|
||||
state_indices = [
|
||||
self.req_to_token_pool.req_index_to_mamba_index_mapping[
|
||||
decode_req.req.req_pool_idx
|
||||
]
|
||||
.cpu()
|
||||
.numpy()
|
||||
]
|
||||
elif isinstance(self.token_to_kv_pool, SWAKVPool):
|
||||
# SWA hybrid model: send decode-side SWA window indices
|
||||
seq_len = len(decode_req.req.origin_input_ids)
|
||||
window_size = self.scheduler.sliding_window_size
|
||||
|
||||
window_start = max(0, seq_len - window_size)
|
||||
window_start = (window_start // page_size) * page_size
|
||||
window_kv_indices_full = self.req_to_token_pool.req_to_token[
|
||||
decode_req.req.req_pool_idx, window_start:seq_len
|
||||
]
|
||||
|
||||
# Translate to SWA pool indices
|
||||
window_kv_indices_swa = (
|
||||
self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
|
||||
window_kv_indices_full
|
||||
)
|
||||
)
|
||||
state_indices = window_kv_indices_swa.cpu().numpy()
|
||||
state_indices = kv_to_page_indices(state_indices, page_size)
|
||||
elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):
|
||||
seq_len = len(decode_req.req.origin_input_ids)
|
||||
kv_indices_full = self.req_to_token_pool.req_to_token[
|
||||
decode_req.req.req_pool_idx, :seq_len
|
||||
]
|
||||
state_indices = kv_indices_full.cpu().numpy()
|
||||
state_indices = kv_to_page_indices(state_indices, page_size)
|
||||
else:
|
||||
state_indices = None
|
||||
|
||||
decode_req.metadata_buffer_index = (
|
||||
self.req_to_metadata_buffer_idx_allocator.alloc()
|
||||
)
|
||||
assert decode_req.metadata_buffer_index is not None
|
||||
page_indices = kv_to_page_indices(
|
||||
kv_indices, self.token_to_kv_pool_allocator.page_size
|
||||
page_indices = kv_to_page_indices(kv_indices, page_size)
|
||||
decode_req.kv_receiver.init(
|
||||
page_indices, decode_req.metadata_buffer_index, state_indices
|
||||
)
|
||||
decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index)
|
||||
|
||||
decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
|
||||
preallocated_reqs.append(decode_req)
|
||||
indices_to_remove.add(i)
|
||||
decode_req.req.time_stats.decode_transfer_queue_entry_time = (
|
||||
@@ -503,7 +605,10 @@ class DecodePreallocQueue:
|
||||
|
||||
def _pre_alloc(self, req: Req) -> torch.Tensor:
|
||||
"""Pre-allocate the memory for req_to_token and token_kv_pool"""
|
||||
req_pool_indices = self.req_to_token_pool.alloc(1)
|
||||
if isinstance(self.req_to_token_pool, HybridMambaDecodeReqToTokenPool):
|
||||
req_pool_indices = self.req_to_token_pool.alloc(1, [req])
|
||||
else:
|
||||
req_pool_indices = self.req_to_token_pool.alloc(1)
|
||||
|
||||
assert (
|
||||
req_pool_indices is not None
|
||||
|
||||
@@ -48,9 +48,12 @@ class FakeKVSender(BaseKVSender):
|
||||
def send(
|
||||
self,
|
||||
kv_indices: npt.NDArray[np.int32],
|
||||
state_indices: Optional[List[int]] = None,
|
||||
):
|
||||
self.has_sent = True
|
||||
logger.debug(f"FakeKVSender send with kv_indices: {kv_indices}")
|
||||
logger.debug(
|
||||
f"FakeKVSender send with kv_indices: {kv_indices}, state_indices: {state_indices}"
|
||||
)
|
||||
|
||||
def failure_exception(self):
|
||||
raise Exception("Fake KVSender Exception")
|
||||
@@ -75,10 +78,15 @@ class FakeKVReceiver(BaseKVReceiver):
|
||||
logger.debug("FakeKVReceiver poll success")
|
||||
return KVPoll.Success
|
||||
|
||||
def init(self, kv_indices: list[int], aux_index: Optional[int] = None):
|
||||
def init(
|
||||
self,
|
||||
kv_indices: list[int],
|
||||
aux_index: Optional[int] = None,
|
||||
state_indices: Optional[List[int]] = None,
|
||||
):
|
||||
self.has_init = True
|
||||
logger.debug(
|
||||
f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}"
|
||||
f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}, state_indices: {state_indices}"
|
||||
)
|
||||
|
||||
def failure_exception(self):
|
||||
|
||||
@@ -58,6 +58,7 @@ class TransferKVChunk:
|
||||
index_slice: slice
|
||||
is_last: bool
|
||||
prefill_aux_index: Optional[int]
|
||||
state_indices: Optional[List[int]]
|
||||
|
||||
|
||||
# decode
|
||||
@@ -69,6 +70,7 @@ class TransferInfo:
|
||||
mooncake_session_id: str
|
||||
dst_kv_indices: npt.NDArray[np.int32]
|
||||
dst_aux_index: int
|
||||
dst_state_indices: List[int]
|
||||
required_dst_info_num: int
|
||||
is_dummy: bool
|
||||
|
||||
@@ -78,9 +80,14 @@ class TransferInfo:
|
||||
is_dummy = True
|
||||
dst_kv_indices = np.array([], dtype=np.int32)
|
||||
dst_aux_index = None
|
||||
dst_state_indices = []
|
||||
else:
|
||||
dst_kv_indices = np.frombuffer(msg[4], dtype=np.int32)
|
||||
dst_aux_index = int(msg[5].decode("ascii"))
|
||||
if msg[6] == b"":
|
||||
dst_state_indices = []
|
||||
else:
|
||||
dst_state_indices = list(np.frombuffer(msg[6], dtype=np.int32))
|
||||
is_dummy = False
|
||||
return cls(
|
||||
room=int(msg[0].decode("ascii")),
|
||||
@@ -89,7 +96,8 @@ class TransferInfo:
|
||||
mooncake_session_id=msg[3].decode("ascii"),
|
||||
dst_kv_indices=dst_kv_indices,
|
||||
dst_aux_index=dst_aux_index,
|
||||
required_dst_info_num=int(msg[6].decode("ascii")),
|
||||
dst_state_indices=dst_state_indices,
|
||||
required_dst_info_num=int(msg[7].decode("ascii")),
|
||||
is_dummy=is_dummy,
|
||||
)
|
||||
|
||||
@@ -103,6 +111,7 @@ class KVArgsRegisterInfo:
|
||||
mooncake_session_id: str
|
||||
dst_kv_ptrs: list[int]
|
||||
dst_aux_ptrs: list[int]
|
||||
dst_state_data_ptrs: list[int]
|
||||
dst_tp_rank: int
|
||||
dst_attn_tp_size: int
|
||||
dst_kv_item_len: int
|
||||
@@ -116,9 +125,10 @@ class KVArgsRegisterInfo:
|
||||
mooncake_session_id=msg[3].decode("ascii"),
|
||||
dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
|
||||
dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
|
||||
dst_tp_rank=int(msg[6].decode("ascii")),
|
||||
dst_attn_tp_size=int(msg[7].decode("ascii")),
|
||||
dst_kv_item_len=int(msg[8].decode("ascii")),
|
||||
dst_state_data_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
|
||||
dst_tp_rank=int(msg[7].decode("ascii")),
|
||||
dst_attn_tp_size=int(msg[8].decode("ascii")),
|
||||
dst_kv_item_len=int(msg[9].decode("ascii")),
|
||||
)
|
||||
|
||||
|
||||
@@ -180,6 +190,9 @@ class MooncakeKVManager(CommonKVManager):
|
||||
)
|
||||
for _ in range(transfer_queue_size)
|
||||
]
|
||||
self.state_executors = concurrent.futures.ThreadPoolExecutor(
|
||||
transfer_thread_pool_size // transfer_queue_size
|
||||
)
|
||||
for queue, executor in zip(self.transfer_queues, self.executors):
|
||||
threading.Thread(
|
||||
target=self.transfer_worker, args=(queue, executor), daemon=True
|
||||
@@ -239,6 +252,12 @@ class MooncakeKVManager(CommonKVManager):
|
||||
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
|
||||
)
|
||||
|
||||
# Batch register state/extra pool data buffers
|
||||
if self.kv_args.state_data_ptrs and self.kv_args.state_data_lens:
|
||||
self.engine.batch_register(
|
||||
self.kv_args.state_data_ptrs, self.kv_args.state_data_lens
|
||||
)
|
||||
|
||||
def _transfer_data(self, mooncake_session_id, transfer_blocks):
|
||||
if not transfer_blocks:
|
||||
return 0
|
||||
@@ -248,17 +267,23 @@ class MooncakeKVManager(CommonKVManager):
|
||||
mooncake_session_id, list(src_addrs), list(dst_addrs), list(lengths)
|
||||
)
|
||||
|
||||
def send_kvcache(
|
||||
def _send_kvcache_generic(
|
||||
self,
|
||||
mooncake_session_id: str,
|
||||
prefill_kv_indices: npt.NDArray[np.int32],
|
||||
dst_kv_ptrs: list[int],
|
||||
dst_kv_indices: npt.NDArray[np.int32],
|
||||
src_data_ptrs: list[int],
|
||||
dst_data_ptrs: list[int],
|
||||
item_lens: list[int],
|
||||
prefill_data_indices: npt.NDArray[np.int32],
|
||||
dst_data_indices: npt.NDArray[np.int32],
|
||||
executor: concurrent.futures.ThreadPoolExecutor,
|
||||
):
|
||||
# Group by indices
|
||||
) -> int:
|
||||
"""
|
||||
Generic KV cache transfer supporting both MHA and MLA architectures.
|
||||
This method is used by both send_kvcache (full pool) and maybe_send_extra.
|
||||
"""
|
||||
# Group by indices for optimization
|
||||
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
|
||||
prefill_kv_indices, dst_kv_indices
|
||||
prefill_data_indices, dst_data_indices
|
||||
)
|
||||
|
||||
layers_params = None
|
||||
@@ -266,9 +291,9 @@ class MooncakeKVManager(CommonKVManager):
|
||||
# pp is not supported on the decode side yet
|
||||
if self.is_mla_backend:
|
||||
src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = (
|
||||
self.get_mla_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
|
||||
self.get_mla_kv_ptrs_with_pp(src_data_ptrs, dst_data_ptrs)
|
||||
)
|
||||
kv_item_len = self.kv_args.kv_item_lens[0]
|
||||
kv_item_len = item_lens[0]
|
||||
layers_params = [
|
||||
(
|
||||
src_kv_ptrs[layer_id],
|
||||
@@ -279,9 +304,9 @@ class MooncakeKVManager(CommonKVManager):
|
||||
]
|
||||
else:
|
||||
src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
|
||||
self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
|
||||
self.get_mha_kv_ptrs_with_pp(src_data_ptrs, dst_data_ptrs)
|
||||
)
|
||||
kv_item_len = self.kv_args.kv_item_lens[0]
|
||||
kv_item_len = item_lens[0]
|
||||
layers_params = [
|
||||
(
|
||||
src_k_ptrs[layer_id],
|
||||
@@ -345,6 +370,24 @@ class MooncakeKVManager(CommonKVManager):
|
||||
|
||||
return 0
|
||||
|
||||
def send_kvcache(
|
||||
self,
|
||||
mooncake_session_id: str,
|
||||
prefill_kv_indices: npt.NDArray[np.int32],
|
||||
dst_kv_ptrs: list[int],
|
||||
dst_kv_indices: npt.NDArray[np.int32],
|
||||
executor: concurrent.futures.ThreadPoolExecutor,
|
||||
):
|
||||
return self._send_kvcache_generic(
|
||||
mooncake_session_id=mooncake_session_id,
|
||||
src_data_ptrs=self.kv_args.kv_data_ptrs,
|
||||
dst_data_ptrs=dst_kv_ptrs,
|
||||
item_lens=self.kv_args.kv_item_lens,
|
||||
prefill_data_indices=prefill_kv_indices,
|
||||
dst_data_indices=dst_kv_indices,
|
||||
executor=executor,
|
||||
)
|
||||
|
||||
def send_kvcache_slice(
|
||||
self,
|
||||
mooncake_session_id: str,
|
||||
@@ -593,6 +636,58 @@ class MooncakeKVManager(CommonKVManager):
|
||||
f"Received AUX_DATA for bootstrap_room {room} with length:{len(data)}"
|
||||
)
|
||||
|
||||
def maybe_send_extra(
|
||||
self,
|
||||
req: TransferInfo,
|
||||
prefill_state_indices: list[int],
|
||||
dst_state_data_ptrs: list[int],
|
||||
):
|
||||
"""Send state or extra pool data with type-specific handling."""
|
||||
state_type = getattr(self.kv_args, "state_type", "none")
|
||||
|
||||
if state_type == "mamba":
|
||||
return self._send_mamba_state(
|
||||
req,
|
||||
prefill_state_indices,
|
||||
dst_state_data_ptrs,
|
||||
)
|
||||
elif state_type in ["swa", "nsa"]:
|
||||
# Reuse _send_kvcache_generic interface to send extra pool data
|
||||
prefill_state_indices = np.array(prefill_state_indices, dtype=np.int32)
|
||||
dst_state_indices = np.array(req.dst_state_indices, dtype=np.int32)
|
||||
return self._send_kvcache_generic(
|
||||
mooncake_session_id=req.mooncake_session_id,
|
||||
src_data_ptrs=self.kv_args.state_data_ptrs,
|
||||
dst_data_ptrs=dst_state_data_ptrs,
|
||||
item_lens=self.kv_args.state_item_lens,
|
||||
prefill_data_indices=prefill_state_indices,
|
||||
dst_data_indices=dst_state_indices,
|
||||
executor=self.state_executors,
|
||||
)
|
||||
else:
|
||||
return 0
|
||||
|
||||
def _send_mamba_state(
|
||||
self,
|
||||
req: TransferInfo,
|
||||
prefill_mamba_index: list[int],
|
||||
dst_state_data_ptrs: list[int],
|
||||
):
|
||||
"""Transfer Mamba states."""
|
||||
assert len(prefill_mamba_index) == 1, "Mamba should have single state index"
|
||||
|
||||
transfer_blocks = []
|
||||
prefill_state_data_ptrs = self.kv_args.state_data_ptrs
|
||||
prefill_state_item_lens = self.kv_args.state_item_lens
|
||||
|
||||
for i, dst_state_ptr in enumerate(dst_state_data_ptrs):
|
||||
length = prefill_state_item_lens[i]
|
||||
src_addr = prefill_state_data_ptrs[i] + length * int(prefill_mamba_index[0])
|
||||
dst_addr = dst_state_ptr + length * int(req.dst_state_indices[0])
|
||||
transfer_blocks.append((src_addr, dst_addr, length))
|
||||
|
||||
return self._transfer_data(req.mooncake_session_id, transfer_blocks)
|
||||
|
||||
def sync_status_to_decode_endpoint(
|
||||
self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
|
||||
):
|
||||
@@ -702,6 +797,21 @@ class MooncakeKVManager(CommonKVManager):
|
||||
break
|
||||
|
||||
if kv_chunk.is_last:
|
||||
if kv_chunk.state_indices is not None:
|
||||
if not self.is_mla_backend and (
|
||||
self.attn_tp_size
|
||||
!= target_rank_registration_info.dst_attn_tp_size
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"PD Disaggregation does NOT support PD different TP sizes for non-MLA hybrid models yet."
|
||||
)
|
||||
|
||||
self.maybe_send_extra(
|
||||
req,
|
||||
kv_chunk.state_indices,
|
||||
target_rank_registration_info.dst_state_data_ptrs,
|
||||
)
|
||||
|
||||
if self.pp_group.is_last_rank:
|
||||
# Only the last chunk we need to send the aux data
|
||||
ret = self.send_aux(
|
||||
@@ -765,7 +875,7 @@ class MooncakeKVManager(CommonKVManager):
|
||||
)
|
||||
continue
|
||||
else:
|
||||
required_dst_info_num = int(waiting_req_bytes[6].decode("ascii"))
|
||||
required_dst_info_num = int(waiting_req_bytes[7].decode("ascii"))
|
||||
room = int(room)
|
||||
if room not in self.transfer_infos:
|
||||
self.transfer_infos[room] = {}
|
||||
@@ -876,6 +986,7 @@ class MooncakeKVManager(CommonKVManager):
|
||||
index_slice: slice,
|
||||
is_last: bool,
|
||||
aux_index: Optional[int] = None,
|
||||
state_indices: Optional[List[int]] = None,
|
||||
):
|
||||
assert self.disaggregation_mode == DisaggregationMode.PREFILL
|
||||
assert not is_last or (is_last and aux_index is not None)
|
||||
@@ -909,6 +1020,7 @@ class MooncakeKVManager(CommonKVManager):
|
||||
index_slice=index_slice,
|
||||
is_last=is_last,
|
||||
prefill_aux_index=aux_index,
|
||||
state_indices=state_indices,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -989,6 +1101,7 @@ class MooncakeKVSender(CommonKVSender):
|
||||
def send(
|
||||
self,
|
||||
kv_indices: npt.NDArray[np.int32],
|
||||
state_indices: Optional[List[int]] = None,
|
||||
):
|
||||
index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
|
||||
self.curr_idx += len(kv_indices)
|
||||
@@ -1008,6 +1121,7 @@ class MooncakeKVSender(CommonKVSender):
|
||||
index_slice,
|
||||
True,
|
||||
aux_index=self.aux_index,
|
||||
state_indices=state_indices,
|
||||
)
|
||||
|
||||
def poll(self) -> KVPoll:
|
||||
@@ -1110,6 +1224,9 @@ class MooncakeKVReceiver(CommonKVReceiver):
|
||||
packed_aux_data_ptrs = b"".join(
|
||||
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
||||
)
|
||||
packed_state_data_ptrs = b"".join(
|
||||
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.state_data_ptrs
|
||||
)
|
||||
# Note(shangming): No need to add pp rank here since pp is not supported on the decode side yet
|
||||
tp_rank = self.kv_mgr.kv_args.engine_rank
|
||||
kv_item_len = self.kv_mgr.kv_args.kv_item_lens[0]
|
||||
@@ -1127,13 +1244,19 @@ class MooncakeKVReceiver(CommonKVReceiver):
|
||||
self.session_id.encode("ascii"),
|
||||
packed_kv_data_ptrs,
|
||||
packed_aux_data_ptrs,
|
||||
packed_state_data_ptrs,
|
||||
dst_tp_rank,
|
||||
dst_attn_tp_size,
|
||||
dst_kv_item_len,
|
||||
]
|
||||
)
|
||||
|
||||
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
||||
def init(
|
||||
self,
|
||||
kv_indices: npt.NDArray[np.int32],
|
||||
aux_index: Optional[int] = None,
|
||||
state_indices: Optional[List[int]] = None,
|
||||
):
|
||||
for bootstrap_info in self.bootstrap_infos:
|
||||
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
|
||||
is_dummy = bootstrap_info["is_dummy"]
|
||||
@@ -1147,6 +1270,14 @@ class MooncakeKVReceiver(CommonKVReceiver):
|
||||
self.session_id.encode("ascii"),
|
||||
kv_indices.tobytes() if not is_dummy else b"",
|
||||
str(aux_index).encode("ascii") if not is_dummy else b"",
|
||||
(
|
||||
np.array(
|
||||
state_indices,
|
||||
dtype=np.int32,
|
||||
).tobytes()
|
||||
if not is_dummy and state_indices is not None
|
||||
else b""
|
||||
),
|
||||
str(self.required_dst_info_num).encode("ascii"),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -704,6 +704,7 @@ class NixlKVSender(CommonKVSender):
|
||||
def send(
|
||||
self,
|
||||
kv_indices: npt.NDArray[np.int32],
|
||||
state_indices: Optional[List[int]] = None,
|
||||
):
|
||||
index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
|
||||
self.curr_idx += len(kv_indices)
|
||||
@@ -755,7 +756,12 @@ class NixlKVReceiver(CommonKVReceiver):
|
||||
self.bootstrap_room
|
||||
)
|
||||
|
||||
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
||||
def init(
|
||||
self,
|
||||
kv_indices: npt.NDArray[np.int32],
|
||||
aux_index: Optional[int] = None,
|
||||
state_indices: Optional[List[int]] = None,
|
||||
):
|
||||
for bootstrap_info in self.bootstrap_infos:
|
||||
logger.debug(
|
||||
f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
||||
|
||||
@@ -49,6 +49,11 @@ from sglang.srt.managers.schedule_batch import (
|
||||
RequestStage,
|
||||
ScheduleBatch,
|
||||
)
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
HybridLinearKVPool,
|
||||
NSATokenToKVPool,
|
||||
SWAKVPool,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
||||
from sglang.srt.utils import (
|
||||
DynamicGradMode,
|
||||
@@ -146,6 +151,28 @@ class PrefillBootstrapQueue:
|
||||
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
||||
kv_args.gpu_id = self.scheduler.gpu_id
|
||||
|
||||
if hasattr(self.token_to_kv_pool, "get_state_buf_infos"):
|
||||
state_data_ptrs, state_data_lens, state_item_lens = (
|
||||
self.token_to_kv_pool.get_state_buf_infos()
|
||||
)
|
||||
kv_args.state_data_ptrs = state_data_ptrs
|
||||
kv_args.state_data_lens = state_data_lens
|
||||
kv_args.state_item_lens = state_item_lens
|
||||
|
||||
if isinstance(self.token_to_kv_pool, SWAKVPool):
|
||||
kv_args.state_type = "swa"
|
||||
elif isinstance(self.token_to_kv_pool, HybridLinearKVPool):
|
||||
kv_args.state_type = "mamba"
|
||||
elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):
|
||||
kv_args.state_type = "nsa"
|
||||
else:
|
||||
kv_args.state_type = "none"
|
||||
else:
|
||||
kv_args.state_data_ptrs = []
|
||||
kv_args.state_data_lens = []
|
||||
kv_args.state_item_lens = []
|
||||
kv_args.state_type = "none"
|
||||
|
||||
kv_manager_class: Type[BaseKVManager] = get_kv_class(
|
||||
self.transfer_backend, KVClassType.MANAGER
|
||||
)
|
||||
@@ -618,15 +645,58 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
.numpy()
|
||||
)
|
||||
req.start_send_idx = end_idx
|
||||
state_indices = None
|
||||
if last_chunk:
|
||||
self.disagg_metadata_buffers.set_buf(req)
|
||||
|
||||
# Prepare extra pool indices for hybrid models
|
||||
if isinstance(
|
||||
self.token_to_kv_pool_allocator.get_kvcache(), HybridLinearKVPool
|
||||
):
|
||||
# Mamba hybrid model: send single mamba state index
|
||||
state_indices = [
|
||||
self.req_to_token_pool.req_index_to_mamba_index_mapping[
|
||||
req.req_pool_idx
|
||||
]
|
||||
.cpu()
|
||||
.numpy()
|
||||
]
|
||||
elif isinstance(self.token_to_kv_pool_allocator.get_kvcache(), SWAKVPool):
|
||||
# SWA hybrid model: send last window KV indices
|
||||
seq_len = len(req.fill_ids)
|
||||
window_size = self.sliding_window_size
|
||||
window_start = max(0, seq_len - window_size)
|
||||
window_start = (window_start // page_size) * page_size
|
||||
|
||||
window_kv_indices_full = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, window_start:seq_len
|
||||
]
|
||||
|
||||
# Translate to SWA pool indices
|
||||
window_kv_indices_swa = (
|
||||
self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
|
||||
window_kv_indices_full
|
||||
)
|
||||
)
|
||||
state_indices = window_kv_indices_swa.cpu().numpy()
|
||||
state_indices = kv_to_page_indices(state_indices, page_size)
|
||||
elif isinstance(
|
||||
self.token_to_kv_pool_allocator.get_kvcache(), NSATokenToKVPool
|
||||
):
|
||||
seq_len = len(req.fill_ids)
|
||||
kv_indices_full = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, :seq_len
|
||||
]
|
||||
state_indices = kv_indices_full.cpu().numpy()
|
||||
state_indices = kv_to_page_indices(state_indices, page_size)
|
||||
|
||||
page_indices = kv_to_page_indices(kv_indices, page_size)
|
||||
if len(page_indices) == 0:
|
||||
logger.info(
|
||||
f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty"
|
||||
)
|
||||
return
|
||||
req.disagg_kv_sender.send(page_indices)
|
||||
req.disagg_kv_sender.send(page_indices, state_indices)
|
||||
|
||||
# PP
|
||||
@DynamicGradMode()
|
||||
|
||||
@@ -807,9 +807,6 @@ class Scheduler(
|
||||
self.tree_cache.cache_controller.layer_done_counter
|
||||
)
|
||||
elif self.is_hybrid:
|
||||
assert (
|
||||
self.server_args.disaggregation_mode == "null"
|
||||
), "Hybrid mode does not support disaggregation yet"
|
||||
self.tree_cache = SWARadixCache(
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
@@ -819,9 +816,6 @@ class Scheduler(
|
||||
is_eagle=self.spec_algorithm.is_eagle(),
|
||||
)
|
||||
elif self.is_hybrid_gdn:
|
||||
assert (
|
||||
self.server_args.disaggregation_mode == "null"
|
||||
), "Hybrid GDN mode does not support disaggregation yet"
|
||||
self.tree_cache = MambaRadixCache(
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
|
||||
@@ -142,72 +142,93 @@ class MambaPool:
|
||||
ssm_dtype = cache_params.dtype.temporal
|
||||
num_mamba_layers = len(cache_params.layers)
|
||||
|
||||
# assume conv_state = (dim, state_len)
|
||||
assert conv_state_shape[0] > conv_state_shape[1]
|
||||
conv_state = torch.zeros(
|
||||
size=(num_mamba_layers, size + 1) + conv_state_shape,
|
||||
dtype=conv_dtype,
|
||||
device=device,
|
||||
# for disagg with nvlink
|
||||
self.enable_custom_mem_pool = get_bool_env_var(
|
||||
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
|
||||
)
|
||||
temporal_state = torch.zeros(
|
||||
size=(num_mamba_layers, size + 1) + temporal_state_shape,
|
||||
dtype=ssm_dtype,
|
||||
device=device,
|
||||
)
|
||||
if speculative_num_draft_tokens is not None:
|
||||
# Cache intermediate SSM states per draft token during target verify
|
||||
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V]
|
||||
intermediate_ssm_state_cache = torch.zeros(
|
||||
size=(
|
||||
num_mamba_layers,
|
||||
size + 1,
|
||||
speculative_num_draft_tokens,
|
||||
temporal_state_shape[0],
|
||||
temporal_state_shape[1],
|
||||
temporal_state_shape[2],
|
||||
),
|
||||
dtype=ssm_dtype,
|
||||
device="cuda",
|
||||
)
|
||||
# Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
|
||||
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
|
||||
intermediate_conv_window_cache = torch.zeros(
|
||||
size=(
|
||||
num_mamba_layers,
|
||||
size + 1,
|
||||
speculative_num_draft_tokens,
|
||||
conv_state_shape[0],
|
||||
conv_state_shape[1],
|
||||
),
|
||||
dtype=conv_dtype,
|
||||
device="cuda",
|
||||
)
|
||||
self.mamba_cache = self.SpeculativeState(
|
||||
conv=conv_state,
|
||||
temporal=temporal_state,
|
||||
intermediate_ssm=intermediate_ssm_state_cache,
|
||||
intermediate_conv_window=intermediate_conv_window_cache,
|
||||
)
|
||||
logger.info(
|
||||
f"Mamba Cache is allocated. "
|
||||
f"max_mamba_cache_size: {size}, "
|
||||
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
|
||||
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
|
||||
f"intermediate_ssm_state_cache size: {get_tensor_size_bytes(intermediate_ssm_state_cache) / GB:.2f}GB "
|
||||
f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB "
|
||||
)
|
||||
if self.enable_custom_mem_pool:
|
||||
# TODO(shangming): abstract custom allocator class for more backends
|
||||
from mooncake.allocator import NVLinkAllocator
|
||||
|
||||
allocator = NVLinkAllocator.get_allocator(self.device)
|
||||
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
|
||||
else:
|
||||
self.mamba_cache = self.State(conv=conv_state, temporal=temporal_state)
|
||||
logger.info(
|
||||
f"Mamba Cache is allocated. "
|
||||
f"max_mamba_cache_size: {size}, "
|
||||
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
|
||||
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
|
||||
self.custom_mem_pool = None
|
||||
|
||||
with (
|
||||
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
||||
if self.enable_custom_mem_pool
|
||||
else nullcontext()
|
||||
):
|
||||
# assume conv_state = (dim, state_len)
|
||||
assert conv_state_shape[0] > conv_state_shape[1]
|
||||
conv_state = torch.zeros(
|
||||
size=(num_mamba_layers, size + 1) + conv_state_shape,
|
||||
dtype=conv_dtype,
|
||||
device=device,
|
||||
)
|
||||
self.size = size
|
||||
self.device = device
|
||||
self.free_slots = torch.arange(self.size, dtype=torch.int64, device=self.device)
|
||||
self.mem_usage = self.mamba_cache.mem_usage_bytes() / GB
|
||||
temporal_state = torch.zeros(
|
||||
size=(num_mamba_layers, size + 1) + temporal_state_shape,
|
||||
dtype=ssm_dtype,
|
||||
device=device,
|
||||
)
|
||||
if speculative_num_draft_tokens is not None:
|
||||
# Cache intermediate SSM states per draft token during target verify
|
||||
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V]
|
||||
intermediate_ssm_state_cache = torch.zeros(
|
||||
size=(
|
||||
num_mamba_layers,
|
||||
size + 1,
|
||||
speculative_num_draft_tokens,
|
||||
temporal_state_shape[0],
|
||||
temporal_state_shape[1],
|
||||
temporal_state_shape[2],
|
||||
),
|
||||
dtype=ssm_dtype,
|
||||
device="cuda",
|
||||
)
|
||||
# Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
|
||||
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
|
||||
intermediate_conv_window_cache = torch.zeros(
|
||||
size=(
|
||||
num_mamba_layers,
|
||||
size + 1,
|
||||
speculative_num_draft_tokens,
|
||||
conv_state_shape[0],
|
||||
conv_state_shape[1],
|
||||
),
|
||||
dtype=conv_dtype,
|
||||
device="cuda",
|
||||
)
|
||||
self.mamba_cache = self.SpeculativeState(
|
||||
conv=conv_state,
|
||||
temporal=temporal_state,
|
||||
intermediate_ssm=intermediate_ssm_state_cache,
|
||||
intermediate_conv_window=intermediate_conv_window_cache,
|
||||
)
|
||||
logger.info(
|
||||
f"Mamba Cache is allocated. "
|
||||
f"max_mamba_cache_size: {size}, "
|
||||
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
|
||||
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
|
||||
f"intermediate_ssm_state_cache size: {get_tensor_size_bytes(intermediate_ssm_state_cache) / GB:.2f}GB "
|
||||
f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB "
|
||||
)
|
||||
else:
|
||||
self.mamba_cache = self.State(conv=conv_state, temporal=temporal_state)
|
||||
logger.info(
|
||||
f"Mamba Cache is allocated. "
|
||||
f"max_mamba_cache_size: {size}, "
|
||||
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
|
||||
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
|
||||
)
|
||||
self.size = size
|
||||
self.device = device
|
||||
self.free_slots = torch.arange(
|
||||
self.size, dtype=torch.int64, device=self.device
|
||||
)
|
||||
self.mem_usage = self.mamba_cache.mem_usage_bytes() / GB
|
||||
self.num_mamba_layers = num_mamba_layers
|
||||
|
||||
def get_speculative_mamba2_params_all_layers(self) -> SpeculativeState:
|
||||
assert isinstance(self.mamba_cache, self.SpeculativeState)
|
||||
@@ -253,6 +274,22 @@ class MambaPool:
|
||||
self.copy_from(src_index, dst_index)
|
||||
return dst_index
|
||||
|
||||
def get_contiguous_buf_infos(self):
|
||||
state_tensors = [
|
||||
getattr(self.mamba_cache, field) for field in vars(self.mamba_cache)
|
||||
]
|
||||
data_ptrs, data_lens, item_lens = [], [], []
|
||||
|
||||
for _, state_tensor in enumerate(state_tensors):
|
||||
data_ptrs += [
|
||||
state_tensor[i].data_ptr() for i in range(self.num_mamba_layers)
|
||||
]
|
||||
data_lens += [state_tensor[i].nbytes for i in range(self.num_mamba_layers)]
|
||||
item_lens += [
|
||||
state_tensor[i][0].nbytes for i in range(self.num_mamba_layers)
|
||||
]
|
||||
return data_ptrs, data_lens, item_lens
|
||||
|
||||
|
||||
class HybridReqToTokenPool(ReqToTokenPool):
|
||||
"""A memory pool that maps a request to its token locations."""
|
||||
@@ -274,13 +311,26 @@ class HybridReqToTokenPool(ReqToTokenPool):
|
||||
device=device,
|
||||
enable_memory_saver=enable_memory_saver,
|
||||
)
|
||||
|
||||
self.mamba_pool = MambaPool(
|
||||
self._init_mamba_pool(
|
||||
size=mamba_size,
|
||||
cache_params=cache_params,
|
||||
device=device,
|
||||
speculative_num_draft_tokens=speculative_num_draft_tokens,
|
||||
)
|
||||
|
||||
def _init_mamba_pool(
|
||||
self,
|
||||
size: int,
|
||||
cache_params: "Mamba2CacheParams",
|
||||
device: str,
|
||||
speculative_num_draft_tokens: int = None,
|
||||
):
|
||||
self.mamba_pool = MambaPool(
|
||||
size=size,
|
||||
cache_params=cache_params,
|
||||
device=device,
|
||||
speculative_num_draft_tokens=speculative_num_draft_tokens,
|
||||
)
|
||||
self.mamba_map = {layer_id: i for i, layer_id in enumerate(cache_params.layers)}
|
||||
|
||||
self.device = device
|
||||
@@ -375,6 +425,19 @@ class KVCache(abc.ABC):
|
||||
# default state for optional layer-wise transfer control
|
||||
self.layer_transfer_counter = None
|
||||
|
||||
# for disagg with nvlink
|
||||
self.enable_custom_mem_pool = get_bool_env_var(
|
||||
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
|
||||
)
|
||||
if self.enable_custom_mem_pool:
|
||||
# TODO(shangming): abstract custom allocator class for more backends
|
||||
from mooncake.allocator import NVLinkAllocator
|
||||
|
||||
allocator = NVLinkAllocator.get_allocator(self.device)
|
||||
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
|
||||
else:
|
||||
self.custom_mem_pool = None
|
||||
|
||||
def _finalize_allocation_log(self, num_tokens: int):
|
||||
"""Common logging and mem_usage computation for KV cache allocation.
|
||||
Supports both tuple (K, V) size returns and single KV size returns.
|
||||
@@ -426,6 +489,9 @@ class KVCache(abc.ABC):
|
||||
def load_cpu_copy(self, kv_cache_cpu, indices):
|
||||
raise NotImplementedError()
|
||||
|
||||
def maybe_get_custom_mem_pool(self):
|
||||
return self.custom_mem_pool
|
||||
|
||||
|
||||
class MHATokenToKVPool(KVCache):
|
||||
|
||||
@@ -456,19 +522,6 @@ class MHATokenToKVPool(KVCache):
|
||||
self.head_num = head_num
|
||||
self.head_dim = head_dim
|
||||
|
||||
# for disagg with nvlink
|
||||
self.enable_custom_mem_pool = get_bool_env_var(
|
||||
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
|
||||
)
|
||||
if self.enable_custom_mem_pool:
|
||||
# TODO(shangming): abstract custom allocator class for more backends
|
||||
from mooncake.allocator import NVLinkAllocator
|
||||
|
||||
allocator = NVLinkAllocator.get_allocator(self.device)
|
||||
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
|
||||
else:
|
||||
self.custom_mem_pool = None
|
||||
|
||||
self._create_buffers()
|
||||
|
||||
self.device_module = torch.get_device_module(self.device)
|
||||
@@ -611,9 +664,6 @@ class MHATokenToKVPool(KVCache):
|
||||
]
|
||||
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
||||
|
||||
def maybe_get_custom_mem_pool(self):
|
||||
return self.custom_mem_pool
|
||||
|
||||
def get_cpu_copy(self, indices):
|
||||
torch.cuda.synchronize()
|
||||
kv_cache_cpu = []
|
||||
@@ -756,12 +806,18 @@ class HybridLinearKVPool(KVCache):
|
||||
full_attention_layer_ids: List[int],
|
||||
enable_kvcache_transpose: bool,
|
||||
device: str,
|
||||
mamba_pool: MambaPool,
|
||||
):
|
||||
self.size = size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.full_layer_nums = len(full_attention_layer_ids)
|
||||
self.page_size = page_size
|
||||
# TODO support pp?
|
||||
self.start_layer = 0
|
||||
self.head_num = head_num
|
||||
self.head_dim = head_dim
|
||||
self.mamba_pool = mamba_pool
|
||||
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
|
||||
assert not enable_kvcache_transpose
|
||||
if _is_npu:
|
||||
@@ -790,6 +846,15 @@ class HybridLinearKVPool(KVCache):
|
||||
def get_contiguous_buf_infos(self):
|
||||
return self.full_kv_pool.get_contiguous_buf_infos()
|
||||
|
||||
def get_state_buf_infos(self):
|
||||
mamba_data_ptrs, mamba_data_lens, mamba_item_lens = (
|
||||
self.mamba_pool.get_contiguous_buf_infos()
|
||||
)
|
||||
return mamba_data_ptrs, mamba_data_lens, mamba_item_lens
|
||||
|
||||
def maybe_get_custom_mem_pool(self):
|
||||
return self.full_kv_pool.maybe_get_custom_mem_pool()
|
||||
|
||||
def _transfer_full_attention_id(self, layer_id: int):
|
||||
if layer_id not in self.full_attention_layer_id_mapping:
|
||||
raise ValueError(
|
||||
@@ -841,22 +906,47 @@ class SWAKVPool(KVCache):
|
||||
size: int,
|
||||
size_swa: int,
|
||||
dtype: torch.dtype,
|
||||
head_num: int,
|
||||
head_dim: int,
|
||||
swa_attention_layer_ids: List[int],
|
||||
full_attention_layer_ids: List[int],
|
||||
enable_kvcache_transpose: bool,
|
||||
device: str,
|
||||
token_to_kv_pool_class: KVCache = MHATokenToKVPool,
|
||||
**kwargs,
|
||||
):
|
||||
self.size = size
|
||||
self.size_swa = size_swa
|
||||
self.dtype = dtype
|
||||
self.head_num = head_num
|
||||
self.head_dim = head_dim
|
||||
self.device = device
|
||||
self.swa_layer_nums = len(swa_attention_layer_ids)
|
||||
self.full_layer_nums = len(full_attention_layer_ids)
|
||||
self.start_layer = 0
|
||||
self.page_size = 1
|
||||
|
||||
kwargs["page_size"] = 1
|
||||
kwargs["enable_memory_saver"] = False
|
||||
kwargs["head_num"] = head_num
|
||||
kwargs["head_dim"] = head_dim
|
||||
kwargs["device"] = device
|
||||
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
|
||||
assert not enable_kvcache_transpose
|
||||
|
||||
# for disagg with nvlink
|
||||
self.enable_custom_mem_pool = get_bool_env_var(
|
||||
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
|
||||
)
|
||||
if self.enable_custom_mem_pool:
|
||||
# TODO(shangming): abstract custom allocator class for more backends
|
||||
from mooncake.allocator import NVLinkAllocator
|
||||
|
||||
allocator = NVLinkAllocator.get_allocator(self.device)
|
||||
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
|
||||
else:
|
||||
self.custom_mem_pool = None
|
||||
|
||||
self.swa_kv_pool = token_to_kv_pool_class(
|
||||
size=size_swa,
|
||||
dtype=dtype,
|
||||
@@ -878,6 +968,9 @@ class SWAKVPool(KVCache):
|
||||
|
||||
k_size, v_size = self.get_kv_size_bytes()
|
||||
self.mem_usage = (k_size + v_size) / GB
|
||||
logger.info(
|
||||
f"SWAKVPool mem usage: {self.mem_usage} GB, swa size: {self.size_swa}, full size: {self.size}"
|
||||
)
|
||||
|
||||
def get_kv_size_bytes(self):
|
||||
k_size, v_size = self.full_kv_pool.get_kv_size_bytes()
|
||||
@@ -888,15 +981,19 @@ class SWAKVPool(KVCache):
|
||||
full_kv_data_ptrs, full_kv_data_lens, full_kv_item_lens = (
|
||||
self.full_kv_pool.get_contiguous_buf_infos()
|
||||
)
|
||||
|
||||
kv_data_ptrs = full_kv_data_ptrs
|
||||
kv_data_lens = full_kv_data_lens
|
||||
kv_item_lens = full_kv_item_lens
|
||||
|
||||
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
||||
|
||||
def get_state_buf_infos(self):
|
||||
swa_kv_data_ptrs, swa_kv_data_lens, swa_kv_item_lens = (
|
||||
self.swa_kv_pool.get_contiguous_buf_infos()
|
||||
)
|
||||
|
||||
kv_data_ptrs = full_kv_data_ptrs + swa_kv_data_ptrs
|
||||
kv_data_lens = full_kv_data_lens + swa_kv_data_lens
|
||||
kv_item_lens = full_kv_item_lens + swa_kv_item_lens
|
||||
|
||||
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
||||
return swa_kv_data_ptrs, swa_kv_data_lens, swa_kv_item_lens
|
||||
|
||||
def get_key_buffer(self, layer_id: int):
|
||||
layer_id_pool, is_swa = self.layers_mapping[layer_id]
|
||||
@@ -1152,19 +1249,6 @@ class MLATokenToKVPool(KVCache):
|
||||
else (kv_lora_rank + qk_rope_head_dim)
|
||||
)
|
||||
|
||||
# for disagg with nvlink
|
||||
self.enable_custom_mem_pool = get_bool_env_var(
|
||||
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
|
||||
)
|
||||
if self.enable_custom_mem_pool:
|
||||
# TODO(shangming): abstract custom allocator class for more backends
|
||||
from mooncake.allocator import NVLinkAllocator
|
||||
|
||||
allocator = NVLinkAllocator.get_allocator(self.device)
|
||||
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
|
||||
else:
|
||||
self.custom_mem_pool = None
|
||||
|
||||
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
||||
with (
|
||||
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
||||
@@ -1207,9 +1291,6 @@ class MLATokenToKVPool(KVCache):
|
||||
]
|
||||
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
||||
|
||||
def maybe_get_custom_mem_pool(self):
|
||||
return self.custom_mem_pool
|
||||
|
||||
def get_key_buffer(self, layer_id: int):
|
||||
if self.layer_transfer_counter is not None:
|
||||
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
||||
@@ -1346,24 +1427,31 @@ class NSATokenToKVPool(MLATokenToKVPool):
|
||||
assert index_head_dim == 128
|
||||
|
||||
assert self.page_size == 64
|
||||
self.index_k_with_scale_buffer = [
|
||||
torch.zeros(
|
||||
# Layout:
|
||||
# ref: test_attention.py :: kv_cache_cast_to_fp8
|
||||
# shape: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4)
|
||||
# data: for page i,
|
||||
# * buf[i, :page_size * head_dim] for fp8 data
|
||||
# * buf[i, page_size * head_dim:].view(float32) for scale
|
||||
(
|
||||
(size + page_size + 1) // self.page_size,
|
||||
self.page_size
|
||||
* (index_head_dim + index_head_dim // self.quant_block_size * 4),
|
||||
),
|
||||
dtype=self.index_k_with_scale_buffer_dtype,
|
||||
device=device,
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
with (
|
||||
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
||||
if self.custom_mem_pool
|
||||
else nullcontext()
|
||||
):
|
||||
self.index_k_with_scale_buffer = [
|
||||
torch.zeros(
|
||||
# Layout:
|
||||
# ref: test_attention.py :: kv_cache_cast_to_fp8
|
||||
# shape: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4)
|
||||
# data: for page i,
|
||||
# * buf[i, :page_size * head_dim] for fp8 data
|
||||
# * buf[i, page_size * head_dim:].view(float32) for scale
|
||||
(
|
||||
(size + page_size + 1) // self.page_size,
|
||||
self.page_size
|
||||
* (
|
||||
index_head_dim + index_head_dim // self.quant_block_size * 4
|
||||
),
|
||||
),
|
||||
dtype=self.index_k_with_scale_buffer_dtype,
|
||||
device=device,
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
self._finalize_allocation_log(size)
|
||||
|
||||
def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor:
|
||||
@@ -1406,6 +1494,18 @@ class NSATokenToKVPool(MLATokenToKVPool):
|
||||
pool=self, buf=buf, loc=loc, index_k=index_k, index_k_scale=index_k_scale
|
||||
)
|
||||
|
||||
def get_state_buf_infos(self):
|
||||
data_ptrs = [
|
||||
self.index_k_with_scale_buffer[i].data_ptr() for i in range(self.layer_num)
|
||||
]
|
||||
data_lens = [
|
||||
self.index_k_with_scale_buffer[i].nbytes for i in range(self.layer_num)
|
||||
]
|
||||
item_lens = [
|
||||
self.index_k_with_scale_buffer[i][0].nbytes for i in range(self.layer_num)
|
||||
]
|
||||
return data_ptrs, data_lens, item_lens
|
||||
|
||||
def get_kv_size_bytes(self):
|
||||
kv_size_bytes = super().get_kv_size_bytes()
|
||||
for index_k_cache in self.index_k_with_scale_buffer:
|
||||
@@ -1636,27 +1736,38 @@ class DoubleSparseTokenToKVPool(KVCache):
|
||||
)
|
||||
|
||||
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
||||
# [size, head_num, head_dim] for each layer
|
||||
self.k_buffer = [
|
||||
torch.zeros(
|
||||
(size + page_size, head_num, head_dim), dtype=dtype, device=device
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
self.v_buffer = [
|
||||
torch.zeros(
|
||||
(size + page_size, head_num, head_dim), dtype=dtype, device=device
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
with (
|
||||
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
||||
if self.enable_custom_mem_pool
|
||||
else nullcontext()
|
||||
):
|
||||
# [size, head_num, head_dim] for each layer
|
||||
self.k_buffer = [
|
||||
torch.zeros(
|
||||
(size + page_size, head_num, head_dim),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
self.v_buffer = [
|
||||
torch.zeros(
|
||||
(size + page_size, head_num, head_dim),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
# [size, head_num, heavy_channel_num] for each layer
|
||||
self.label_buffer = [
|
||||
torch.zeros(
|
||||
(size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
# [size, head_num, heavy_channel_num] for each layer
|
||||
self.label_buffer = [
|
||||
torch.zeros(
|
||||
(size + 1, head_num, heavy_channel_num),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
def get_key_buffer(self, layer_id: int):
|
||||
return self.k_buffer[layer_id - self.start_layer]
|
||||
|
||||
@@ -1669,19 +1669,34 @@ class ModelRunner:
|
||||
extra_max_context_len += self.server_args.speculative_num_draft_tokens
|
||||
|
||||
if self.server_args.disaggregation_mode == "decode":
|
||||
from sglang.srt.disaggregation.decode import DecodeReqToTokenPool
|
||||
from sglang.srt.disaggregation.decode import (
|
||||
DecodeReqToTokenPool,
|
||||
HybridMambaDecodeReqToTokenPool,
|
||||
)
|
||||
|
||||
# subscribe memory for pre-allocated requests
|
||||
# if max_num_reqs <= 32, we pre-allocate 2x requests
|
||||
pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0
|
||||
self.req_to_token_pool = DecodeReqToTokenPool(
|
||||
size=max_num_reqs,
|
||||
max_context_len=self.model_config.context_len
|
||||
+ extra_max_context_len,
|
||||
device=self.device,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
pre_alloc_size=pre_alloc_size,
|
||||
)
|
||||
if config := self.mambaish_config:
|
||||
self.req_to_token_pool = HybridMambaDecodeReqToTokenPool(
|
||||
size=max_num_reqs,
|
||||
max_context_len=self.model_config.context_len
|
||||
+ extra_max_context_len,
|
||||
device=self.device,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
cache_params=config.mamba2_cache_params,
|
||||
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
|
||||
pre_alloc_size=pre_alloc_size,
|
||||
)
|
||||
else:
|
||||
self.req_to_token_pool = DecodeReqToTokenPool(
|
||||
size=max_num_reqs,
|
||||
max_context_len=self.model_config.context_len
|
||||
+ extra_max_context_len,
|
||||
device=self.device,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
pre_alloc_size=pre_alloc_size,
|
||||
)
|
||||
elif config := self.mambaish_config:
|
||||
self.req_to_token_pool = HybridReqToTokenPool(
|
||||
size=max_num_reqs,
|
||||
@@ -1807,6 +1822,7 @@ class ModelRunner:
|
||||
),
|
||||
enable_kvcache_transpose=False,
|
||||
device=self.device,
|
||||
mamba_pool=self.req_to_token_pool.mamba_pool,
|
||||
)
|
||||
else:
|
||||
self.token_to_kv_pool = MHATokenToKVPool(
|
||||
|
||||
Reference in New Issue
Block a user