[PD] Add PD support for hybrid model (Qwen3-Next, DeepSeek V3.2 Exp) (#10912)
Signed-off-by: Shangming Cai <csmthu@gmail.com> Co-authored-by: hzh0425 <hzh0425@apache.org> Co-authored-by: ZeldaHuang <hzm414167@alibaba-inc.com>
This commit is contained in:
@@ -20,6 +20,10 @@ class KVArgs:
|
|||||||
aux_data_ptrs: List[int]
|
aux_data_ptrs: List[int]
|
||||||
aux_data_lens: List[int]
|
aux_data_lens: List[int]
|
||||||
aux_item_lens: List[int]
|
aux_item_lens: List[int]
|
||||||
|
state_data_ptrs: List[int]
|
||||||
|
state_data_lens: List[int]
|
||||||
|
state_item_lens: List[int]
|
||||||
|
state_type: str # "none", "mamba", "swa"
|
||||||
ib_device: str
|
ib_device: str
|
||||||
ib_traffic_class: str
|
ib_traffic_class: str
|
||||||
gpu_id: int
|
gpu_id: int
|
||||||
@@ -76,9 +80,13 @@ class BaseKVSender(ABC):
|
|||||||
...
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def send(self, kv_indices: npt.NDArray[np.int32]):
|
def send(
|
||||||
|
self,
|
||||||
|
kv_indices: npt.NDArray[np.int32],
|
||||||
|
state_indices: Optional[List[int]] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Send the kv cache at the given kv indices to the decoder server
|
Send the kv cache at the given kv indices and the extra cache/state at the given indices to the decoder server
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@@ -108,9 +116,14 @@ class BaseKVReceiver(ABC):
|
|||||||
): ...
|
): ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
def init(
|
||||||
|
self,
|
||||||
|
kv_indices: npt.NDArray[np.int32],
|
||||||
|
aux_index: Optional[int] = None,
|
||||||
|
state_indices: Optional[List[int]] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Notify the prefill server about the kv indices and aux index
|
Notify the prefill server about the kv indices, aux index, and state_indices.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|||||||
@@ -201,6 +201,7 @@ class CommonKVSender(BaseKVSender):
|
|||||||
def send(
|
def send(
|
||||||
self,
|
self,
|
||||||
kv_indices: npt.NDArray[np.int32],
|
kv_indices: npt.NDArray[np.int32],
|
||||||
|
state_indices: Optional[List[int]] = None,
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -25,11 +25,12 @@ import time
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
from sglang.srt.configs.mamba_utils import Mamba2CacheParams
|
||||||
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
||||||
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll
|
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll
|
||||||
from sglang.srt.disaggregation.utils import (
|
from sglang.srt.disaggregation.utils import (
|
||||||
@@ -47,9 +48,19 @@ from sglang.srt.disaggregation.utils import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT, RequestStage, ScheduleBatch
|
from sglang.srt.managers.schedule_batch import FINISH_ABORT, RequestStage, ScheduleBatch
|
||||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
from sglang.srt.mem_cache.allocator import (
|
||||||
|
BaseTokenToKVPoolAllocator,
|
||||||
|
SWATokenToKVPoolAllocator,
|
||||||
|
)
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
|
from sglang.srt.mem_cache.memory_pool import (
|
||||||
|
HybridLinearKVPool,
|
||||||
|
HybridReqToTokenPool,
|
||||||
|
KVCache,
|
||||||
|
NSATokenToKVPool,
|
||||||
|
ReqToTokenPool,
|
||||||
|
SWAKVPool,
|
||||||
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||||
from sglang.srt.utils import get_int_env_var, require_mlp_sync
|
from sglang.srt.utils import get_int_env_var, require_mlp_sync
|
||||||
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||||
@@ -124,6 +135,35 @@ class DecodeReqToTokenPool:
|
|||||||
self.free_slots = list(range(self.size + self.pre_alloc_size))
|
self.free_slots = list(range(self.size + self.pre_alloc_size))
|
||||||
|
|
||||||
|
|
||||||
|
class HybridMambaDecodeReqToTokenPool(HybridReqToTokenPool):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
size: int,
|
||||||
|
max_context_len: int,
|
||||||
|
device: str,
|
||||||
|
enable_memory_saver: bool,
|
||||||
|
cache_params: "Mamba2CacheParams",
|
||||||
|
speculative_num_draft_tokens: int,
|
||||||
|
pre_alloc_size: int,
|
||||||
|
):
|
||||||
|
DecodeReqToTokenPool.__init__(
|
||||||
|
self,
|
||||||
|
size=size,
|
||||||
|
max_context_len=max_context_len,
|
||||||
|
device=device,
|
||||||
|
enable_memory_saver=enable_memory_saver,
|
||||||
|
pre_alloc_size=pre_alloc_size,
|
||||||
|
)
|
||||||
|
self._init_mamba_pool(
|
||||||
|
size + pre_alloc_size, cache_params, device, speculative_num_draft_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self.free_slots = list(range(self.size + self.pre_alloc_size))
|
||||||
|
self.mamba_pool.clear()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DecodeRequest:
|
class DecodeRequest:
|
||||||
req: Req
|
req: Req
|
||||||
@@ -217,6 +257,28 @@ class DecodePreallocQueue:
|
|||||||
self.metadata_buffers.get_buf_infos()
|
self.metadata_buffers.get_buf_infos()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if hasattr(self.token_to_kv_pool, "get_state_buf_infos"):
|
||||||
|
state_data_ptrs, state_data_lens, state_item_lens = (
|
||||||
|
self.token_to_kv_pool.get_state_buf_infos()
|
||||||
|
)
|
||||||
|
kv_args.state_data_ptrs = state_data_ptrs
|
||||||
|
kv_args.state_data_lens = state_data_lens
|
||||||
|
kv_args.state_item_lens = state_item_lens
|
||||||
|
|
||||||
|
if isinstance(self.token_to_kv_pool, SWAKVPool):
|
||||||
|
kv_args.state_type = "swa"
|
||||||
|
elif isinstance(self.token_to_kv_pool, HybridLinearKVPool):
|
||||||
|
kv_args.state_type = "mamba"
|
||||||
|
elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):
|
||||||
|
kv_args.state_type = "nsa"
|
||||||
|
else:
|
||||||
|
kv_args.state_type = "none"
|
||||||
|
else:
|
||||||
|
kv_args.state_data_ptrs = []
|
||||||
|
kv_args.state_data_lens = []
|
||||||
|
kv_args.state_item_lens = []
|
||||||
|
kv_args.state_type = "none"
|
||||||
|
|
||||||
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
||||||
kv_args.gpu_id = self.scheduler.gpu_id
|
kv_args.gpu_id = self.scheduler.gpu_id
|
||||||
kv_manager_class: Type[BaseKVManager] = get_kv_class(
|
kv_manager_class: Type[BaseKVManager] = get_kv_class(
|
||||||
@@ -414,16 +476,56 @@ class DecodePreallocQueue:
|
|||||||
.cpu()
|
.cpu()
|
||||||
.numpy()
|
.numpy()
|
||||||
)
|
)
|
||||||
|
page_size = self.token_to_kv_pool_allocator.page_size
|
||||||
|
|
||||||
|
# Prepare extra pool indices for hybrid models
|
||||||
|
if isinstance(self.token_to_kv_pool, HybridLinearKVPool):
|
||||||
|
# Mamba hybrid model: single mamba state index
|
||||||
|
state_indices = [
|
||||||
|
self.req_to_token_pool.req_index_to_mamba_index_mapping[
|
||||||
|
decode_req.req.req_pool_idx
|
||||||
|
]
|
||||||
|
.cpu()
|
||||||
|
.numpy()
|
||||||
|
]
|
||||||
|
elif isinstance(self.token_to_kv_pool, SWAKVPool):
|
||||||
|
# SWA hybrid model: send decode-side SWA window indices
|
||||||
|
seq_len = len(decode_req.req.origin_input_ids)
|
||||||
|
window_size = self.scheduler.sliding_window_size
|
||||||
|
|
||||||
|
window_start = max(0, seq_len - window_size)
|
||||||
|
window_start = (window_start // page_size) * page_size
|
||||||
|
window_kv_indices_full = self.req_to_token_pool.req_to_token[
|
||||||
|
decode_req.req.req_pool_idx, window_start:seq_len
|
||||||
|
]
|
||||||
|
|
||||||
|
# Translate to SWA pool indices
|
||||||
|
window_kv_indices_swa = (
|
||||||
|
self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
|
||||||
|
window_kv_indices_full
|
||||||
|
)
|
||||||
|
)
|
||||||
|
state_indices = window_kv_indices_swa.cpu().numpy()
|
||||||
|
state_indices = kv_to_page_indices(state_indices, page_size)
|
||||||
|
elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):
|
||||||
|
seq_len = len(decode_req.req.origin_input_ids)
|
||||||
|
kv_indices_full = self.req_to_token_pool.req_to_token[
|
||||||
|
decode_req.req.req_pool_idx, :seq_len
|
||||||
|
]
|
||||||
|
state_indices = kv_indices_full.cpu().numpy()
|
||||||
|
state_indices = kv_to_page_indices(state_indices, page_size)
|
||||||
|
else:
|
||||||
|
state_indices = None
|
||||||
|
|
||||||
decode_req.metadata_buffer_index = (
|
decode_req.metadata_buffer_index = (
|
||||||
self.req_to_metadata_buffer_idx_allocator.alloc()
|
self.req_to_metadata_buffer_idx_allocator.alloc()
|
||||||
)
|
)
|
||||||
assert decode_req.metadata_buffer_index is not None
|
assert decode_req.metadata_buffer_index is not None
|
||||||
page_indices = kv_to_page_indices(
|
page_indices = kv_to_page_indices(kv_indices, page_size)
|
||||||
kv_indices, self.token_to_kv_pool_allocator.page_size
|
decode_req.kv_receiver.init(
|
||||||
|
page_indices, decode_req.metadata_buffer_index, state_indices
|
||||||
)
|
)
|
||||||
decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index)
|
decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
|
||||||
|
|
||||||
preallocated_reqs.append(decode_req)
|
preallocated_reqs.append(decode_req)
|
||||||
indices_to_remove.add(i)
|
indices_to_remove.add(i)
|
||||||
decode_req.req.time_stats.decode_transfer_queue_entry_time = (
|
decode_req.req.time_stats.decode_transfer_queue_entry_time = (
|
||||||
@@ -503,7 +605,10 @@ class DecodePreallocQueue:
|
|||||||
|
|
||||||
def _pre_alloc(self, req: Req) -> torch.Tensor:
|
def _pre_alloc(self, req: Req) -> torch.Tensor:
|
||||||
"""Pre-allocate the memory for req_to_token and token_kv_pool"""
|
"""Pre-allocate the memory for req_to_token and token_kv_pool"""
|
||||||
req_pool_indices = self.req_to_token_pool.alloc(1)
|
if isinstance(self.req_to_token_pool, HybridMambaDecodeReqToTokenPool):
|
||||||
|
req_pool_indices = self.req_to_token_pool.alloc(1, [req])
|
||||||
|
else:
|
||||||
|
req_pool_indices = self.req_to_token_pool.alloc(1)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
req_pool_indices is not None
|
req_pool_indices is not None
|
||||||
|
|||||||
@@ -48,9 +48,12 @@ class FakeKVSender(BaseKVSender):
|
|||||||
def send(
|
def send(
|
||||||
self,
|
self,
|
||||||
kv_indices: npt.NDArray[np.int32],
|
kv_indices: npt.NDArray[np.int32],
|
||||||
|
state_indices: Optional[List[int]] = None,
|
||||||
):
|
):
|
||||||
self.has_sent = True
|
self.has_sent = True
|
||||||
logger.debug(f"FakeKVSender send with kv_indices: {kv_indices}")
|
logger.debug(
|
||||||
|
f"FakeKVSender send with kv_indices: {kv_indices}, state_indices: {state_indices}"
|
||||||
|
)
|
||||||
|
|
||||||
def failure_exception(self):
|
def failure_exception(self):
|
||||||
raise Exception("Fake KVSender Exception")
|
raise Exception("Fake KVSender Exception")
|
||||||
@@ -75,10 +78,15 @@ class FakeKVReceiver(BaseKVReceiver):
|
|||||||
logger.debug("FakeKVReceiver poll success")
|
logger.debug("FakeKVReceiver poll success")
|
||||||
return KVPoll.Success
|
return KVPoll.Success
|
||||||
|
|
||||||
def init(self, kv_indices: list[int], aux_index: Optional[int] = None):
|
def init(
|
||||||
|
self,
|
||||||
|
kv_indices: list[int],
|
||||||
|
aux_index: Optional[int] = None,
|
||||||
|
state_indices: Optional[List[int]] = None,
|
||||||
|
):
|
||||||
self.has_init = True
|
self.has_init = True
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}"
|
f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}, state_indices: {state_indices}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def failure_exception(self):
|
def failure_exception(self):
|
||||||
|
|||||||
@@ -58,6 +58,7 @@ class TransferKVChunk:
|
|||||||
index_slice: slice
|
index_slice: slice
|
||||||
is_last: bool
|
is_last: bool
|
||||||
prefill_aux_index: Optional[int]
|
prefill_aux_index: Optional[int]
|
||||||
|
state_indices: Optional[List[int]]
|
||||||
|
|
||||||
|
|
||||||
# decode
|
# decode
|
||||||
@@ -69,6 +70,7 @@ class TransferInfo:
|
|||||||
mooncake_session_id: str
|
mooncake_session_id: str
|
||||||
dst_kv_indices: npt.NDArray[np.int32]
|
dst_kv_indices: npt.NDArray[np.int32]
|
||||||
dst_aux_index: int
|
dst_aux_index: int
|
||||||
|
dst_state_indices: List[int]
|
||||||
required_dst_info_num: int
|
required_dst_info_num: int
|
||||||
is_dummy: bool
|
is_dummy: bool
|
||||||
|
|
||||||
@@ -78,9 +80,14 @@ class TransferInfo:
|
|||||||
is_dummy = True
|
is_dummy = True
|
||||||
dst_kv_indices = np.array([], dtype=np.int32)
|
dst_kv_indices = np.array([], dtype=np.int32)
|
||||||
dst_aux_index = None
|
dst_aux_index = None
|
||||||
|
dst_state_indices = []
|
||||||
else:
|
else:
|
||||||
dst_kv_indices = np.frombuffer(msg[4], dtype=np.int32)
|
dst_kv_indices = np.frombuffer(msg[4], dtype=np.int32)
|
||||||
dst_aux_index = int(msg[5].decode("ascii"))
|
dst_aux_index = int(msg[5].decode("ascii"))
|
||||||
|
if msg[6] == b"":
|
||||||
|
dst_state_indices = []
|
||||||
|
else:
|
||||||
|
dst_state_indices = list(np.frombuffer(msg[6], dtype=np.int32))
|
||||||
is_dummy = False
|
is_dummy = False
|
||||||
return cls(
|
return cls(
|
||||||
room=int(msg[0].decode("ascii")),
|
room=int(msg[0].decode("ascii")),
|
||||||
@@ -89,7 +96,8 @@ class TransferInfo:
|
|||||||
mooncake_session_id=msg[3].decode("ascii"),
|
mooncake_session_id=msg[3].decode("ascii"),
|
||||||
dst_kv_indices=dst_kv_indices,
|
dst_kv_indices=dst_kv_indices,
|
||||||
dst_aux_index=dst_aux_index,
|
dst_aux_index=dst_aux_index,
|
||||||
required_dst_info_num=int(msg[6].decode("ascii")),
|
dst_state_indices=dst_state_indices,
|
||||||
|
required_dst_info_num=int(msg[7].decode("ascii")),
|
||||||
is_dummy=is_dummy,
|
is_dummy=is_dummy,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -103,6 +111,7 @@ class KVArgsRegisterInfo:
|
|||||||
mooncake_session_id: str
|
mooncake_session_id: str
|
||||||
dst_kv_ptrs: list[int]
|
dst_kv_ptrs: list[int]
|
||||||
dst_aux_ptrs: list[int]
|
dst_aux_ptrs: list[int]
|
||||||
|
dst_state_data_ptrs: list[int]
|
||||||
dst_tp_rank: int
|
dst_tp_rank: int
|
||||||
dst_attn_tp_size: int
|
dst_attn_tp_size: int
|
||||||
dst_kv_item_len: int
|
dst_kv_item_len: int
|
||||||
@@ -116,9 +125,10 @@ class KVArgsRegisterInfo:
|
|||||||
mooncake_session_id=msg[3].decode("ascii"),
|
mooncake_session_id=msg[3].decode("ascii"),
|
||||||
dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
|
dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
|
||||||
dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
|
dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
|
||||||
dst_tp_rank=int(msg[6].decode("ascii")),
|
dst_state_data_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
|
||||||
dst_attn_tp_size=int(msg[7].decode("ascii")),
|
dst_tp_rank=int(msg[7].decode("ascii")),
|
||||||
dst_kv_item_len=int(msg[8].decode("ascii")),
|
dst_attn_tp_size=int(msg[8].decode("ascii")),
|
||||||
|
dst_kv_item_len=int(msg[9].decode("ascii")),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -180,6 +190,9 @@ class MooncakeKVManager(CommonKVManager):
|
|||||||
)
|
)
|
||||||
for _ in range(transfer_queue_size)
|
for _ in range(transfer_queue_size)
|
||||||
]
|
]
|
||||||
|
self.state_executors = concurrent.futures.ThreadPoolExecutor(
|
||||||
|
transfer_thread_pool_size // transfer_queue_size
|
||||||
|
)
|
||||||
for queue, executor in zip(self.transfer_queues, self.executors):
|
for queue, executor in zip(self.transfer_queues, self.executors):
|
||||||
threading.Thread(
|
threading.Thread(
|
||||||
target=self.transfer_worker, args=(queue, executor), daemon=True
|
target=self.transfer_worker, args=(queue, executor), daemon=True
|
||||||
@@ -239,6 +252,12 @@ class MooncakeKVManager(CommonKVManager):
|
|||||||
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
|
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Batch register state/extra pool data buffers
|
||||||
|
if self.kv_args.state_data_ptrs and self.kv_args.state_data_lens:
|
||||||
|
self.engine.batch_register(
|
||||||
|
self.kv_args.state_data_ptrs, self.kv_args.state_data_lens
|
||||||
|
)
|
||||||
|
|
||||||
def _transfer_data(self, mooncake_session_id, transfer_blocks):
|
def _transfer_data(self, mooncake_session_id, transfer_blocks):
|
||||||
if not transfer_blocks:
|
if not transfer_blocks:
|
||||||
return 0
|
return 0
|
||||||
@@ -248,17 +267,23 @@ class MooncakeKVManager(CommonKVManager):
|
|||||||
mooncake_session_id, list(src_addrs), list(dst_addrs), list(lengths)
|
mooncake_session_id, list(src_addrs), list(dst_addrs), list(lengths)
|
||||||
)
|
)
|
||||||
|
|
||||||
def send_kvcache(
|
def _send_kvcache_generic(
|
||||||
self,
|
self,
|
||||||
mooncake_session_id: str,
|
mooncake_session_id: str,
|
||||||
prefill_kv_indices: npt.NDArray[np.int32],
|
src_data_ptrs: list[int],
|
||||||
dst_kv_ptrs: list[int],
|
dst_data_ptrs: list[int],
|
||||||
dst_kv_indices: npt.NDArray[np.int32],
|
item_lens: list[int],
|
||||||
|
prefill_data_indices: npt.NDArray[np.int32],
|
||||||
|
dst_data_indices: npt.NDArray[np.int32],
|
||||||
executor: concurrent.futures.ThreadPoolExecutor,
|
executor: concurrent.futures.ThreadPoolExecutor,
|
||||||
):
|
) -> int:
|
||||||
# Group by indices
|
"""
|
||||||
|
Generic KV cache transfer supporting both MHA and MLA architectures.
|
||||||
|
This method is used by both send_kvcache (full pool) and maybe_send_extra.
|
||||||
|
"""
|
||||||
|
# Group by indices for optimization
|
||||||
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
|
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
|
||||||
prefill_kv_indices, dst_kv_indices
|
prefill_data_indices, dst_data_indices
|
||||||
)
|
)
|
||||||
|
|
||||||
layers_params = None
|
layers_params = None
|
||||||
@@ -266,9 +291,9 @@ class MooncakeKVManager(CommonKVManager):
|
|||||||
# pp is not supported on the decode side yet
|
# pp is not supported on the decode side yet
|
||||||
if self.is_mla_backend:
|
if self.is_mla_backend:
|
||||||
src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = (
|
src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = (
|
||||||
self.get_mla_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
|
self.get_mla_kv_ptrs_with_pp(src_data_ptrs, dst_data_ptrs)
|
||||||
)
|
)
|
||||||
kv_item_len = self.kv_args.kv_item_lens[0]
|
kv_item_len = item_lens[0]
|
||||||
layers_params = [
|
layers_params = [
|
||||||
(
|
(
|
||||||
src_kv_ptrs[layer_id],
|
src_kv_ptrs[layer_id],
|
||||||
@@ -279,9 +304,9 @@ class MooncakeKVManager(CommonKVManager):
|
|||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
|
src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
|
||||||
self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
|
self.get_mha_kv_ptrs_with_pp(src_data_ptrs, dst_data_ptrs)
|
||||||
)
|
)
|
||||||
kv_item_len = self.kv_args.kv_item_lens[0]
|
kv_item_len = item_lens[0]
|
||||||
layers_params = [
|
layers_params = [
|
||||||
(
|
(
|
||||||
src_k_ptrs[layer_id],
|
src_k_ptrs[layer_id],
|
||||||
@@ -345,6 +370,24 @@ class MooncakeKVManager(CommonKVManager):
|
|||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
def send_kvcache(
|
||||||
|
self,
|
||||||
|
mooncake_session_id: str,
|
||||||
|
prefill_kv_indices: npt.NDArray[np.int32],
|
||||||
|
dst_kv_ptrs: list[int],
|
||||||
|
dst_kv_indices: npt.NDArray[np.int32],
|
||||||
|
executor: concurrent.futures.ThreadPoolExecutor,
|
||||||
|
):
|
||||||
|
return self._send_kvcache_generic(
|
||||||
|
mooncake_session_id=mooncake_session_id,
|
||||||
|
src_data_ptrs=self.kv_args.kv_data_ptrs,
|
||||||
|
dst_data_ptrs=dst_kv_ptrs,
|
||||||
|
item_lens=self.kv_args.kv_item_lens,
|
||||||
|
prefill_data_indices=prefill_kv_indices,
|
||||||
|
dst_data_indices=dst_kv_indices,
|
||||||
|
executor=executor,
|
||||||
|
)
|
||||||
|
|
||||||
def send_kvcache_slice(
|
def send_kvcache_slice(
|
||||||
self,
|
self,
|
||||||
mooncake_session_id: str,
|
mooncake_session_id: str,
|
||||||
@@ -593,6 +636,58 @@ class MooncakeKVManager(CommonKVManager):
|
|||||||
f"Received AUX_DATA for bootstrap_room {room} with length:{len(data)}"
|
f"Received AUX_DATA for bootstrap_room {room} with length:{len(data)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def maybe_send_extra(
|
||||||
|
self,
|
||||||
|
req: TransferInfo,
|
||||||
|
prefill_state_indices: list[int],
|
||||||
|
dst_state_data_ptrs: list[int],
|
||||||
|
):
|
||||||
|
"""Send state or extra pool data with type-specific handling."""
|
||||||
|
state_type = getattr(self.kv_args, "state_type", "none")
|
||||||
|
|
||||||
|
if state_type == "mamba":
|
||||||
|
return self._send_mamba_state(
|
||||||
|
req,
|
||||||
|
prefill_state_indices,
|
||||||
|
dst_state_data_ptrs,
|
||||||
|
)
|
||||||
|
elif state_type in ["swa", "nsa"]:
|
||||||
|
# Reuse _send_kvcache_generic interface to send extra pool data
|
||||||
|
prefill_state_indices = np.array(prefill_state_indices, dtype=np.int32)
|
||||||
|
dst_state_indices = np.array(req.dst_state_indices, dtype=np.int32)
|
||||||
|
return self._send_kvcache_generic(
|
||||||
|
mooncake_session_id=req.mooncake_session_id,
|
||||||
|
src_data_ptrs=self.kv_args.state_data_ptrs,
|
||||||
|
dst_data_ptrs=dst_state_data_ptrs,
|
||||||
|
item_lens=self.kv_args.state_item_lens,
|
||||||
|
prefill_data_indices=prefill_state_indices,
|
||||||
|
dst_data_indices=dst_state_indices,
|
||||||
|
executor=self.state_executors,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def _send_mamba_state(
|
||||||
|
self,
|
||||||
|
req: TransferInfo,
|
||||||
|
prefill_mamba_index: list[int],
|
||||||
|
dst_state_data_ptrs: list[int],
|
||||||
|
):
|
||||||
|
"""Transfer Mamba states."""
|
||||||
|
assert len(prefill_mamba_index) == 1, "Mamba should have single state index"
|
||||||
|
|
||||||
|
transfer_blocks = []
|
||||||
|
prefill_state_data_ptrs = self.kv_args.state_data_ptrs
|
||||||
|
prefill_state_item_lens = self.kv_args.state_item_lens
|
||||||
|
|
||||||
|
for i, dst_state_ptr in enumerate(dst_state_data_ptrs):
|
||||||
|
length = prefill_state_item_lens[i]
|
||||||
|
src_addr = prefill_state_data_ptrs[i] + length * int(prefill_mamba_index[0])
|
||||||
|
dst_addr = dst_state_ptr + length * int(req.dst_state_indices[0])
|
||||||
|
transfer_blocks.append((src_addr, dst_addr, length))
|
||||||
|
|
||||||
|
return self._transfer_data(req.mooncake_session_id, transfer_blocks)
|
||||||
|
|
||||||
def sync_status_to_decode_endpoint(
|
def sync_status_to_decode_endpoint(
|
||||||
self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
|
self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
|
||||||
):
|
):
|
||||||
@@ -702,6 +797,21 @@ class MooncakeKVManager(CommonKVManager):
|
|||||||
break
|
break
|
||||||
|
|
||||||
if kv_chunk.is_last:
|
if kv_chunk.is_last:
|
||||||
|
if kv_chunk.state_indices is not None:
|
||||||
|
if not self.is_mla_backend and (
|
||||||
|
self.attn_tp_size
|
||||||
|
!= target_rank_registration_info.dst_attn_tp_size
|
||||||
|
):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"PD Disaggregation does NOT support PD different TP sizes for non-MLA hybrid models yet."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.maybe_send_extra(
|
||||||
|
req,
|
||||||
|
kv_chunk.state_indices,
|
||||||
|
target_rank_registration_info.dst_state_data_ptrs,
|
||||||
|
)
|
||||||
|
|
||||||
if self.pp_group.is_last_rank:
|
if self.pp_group.is_last_rank:
|
||||||
# Only the last chunk we need to send the aux data
|
# Only the last chunk we need to send the aux data
|
||||||
ret = self.send_aux(
|
ret = self.send_aux(
|
||||||
@@ -765,7 +875,7 @@ class MooncakeKVManager(CommonKVManager):
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
required_dst_info_num = int(waiting_req_bytes[6].decode("ascii"))
|
required_dst_info_num = int(waiting_req_bytes[7].decode("ascii"))
|
||||||
room = int(room)
|
room = int(room)
|
||||||
if room not in self.transfer_infos:
|
if room not in self.transfer_infos:
|
||||||
self.transfer_infos[room] = {}
|
self.transfer_infos[room] = {}
|
||||||
@@ -876,6 +986,7 @@ class MooncakeKVManager(CommonKVManager):
|
|||||||
index_slice: slice,
|
index_slice: slice,
|
||||||
is_last: bool,
|
is_last: bool,
|
||||||
aux_index: Optional[int] = None,
|
aux_index: Optional[int] = None,
|
||||||
|
state_indices: Optional[List[int]] = None,
|
||||||
):
|
):
|
||||||
assert self.disaggregation_mode == DisaggregationMode.PREFILL
|
assert self.disaggregation_mode == DisaggregationMode.PREFILL
|
||||||
assert not is_last or (is_last and aux_index is not None)
|
assert not is_last or (is_last and aux_index is not None)
|
||||||
@@ -909,6 +1020,7 @@ class MooncakeKVManager(CommonKVManager):
|
|||||||
index_slice=index_slice,
|
index_slice=index_slice,
|
||||||
is_last=is_last,
|
is_last=is_last,
|
||||||
prefill_aux_index=aux_index,
|
prefill_aux_index=aux_index,
|
||||||
|
state_indices=state_indices,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -989,6 +1101,7 @@ class MooncakeKVSender(CommonKVSender):
|
|||||||
def send(
|
def send(
|
||||||
self,
|
self,
|
||||||
kv_indices: npt.NDArray[np.int32],
|
kv_indices: npt.NDArray[np.int32],
|
||||||
|
state_indices: Optional[List[int]] = None,
|
||||||
):
|
):
|
||||||
index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
|
index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
|
||||||
self.curr_idx += len(kv_indices)
|
self.curr_idx += len(kv_indices)
|
||||||
@@ -1008,6 +1121,7 @@ class MooncakeKVSender(CommonKVSender):
|
|||||||
index_slice,
|
index_slice,
|
||||||
True,
|
True,
|
||||||
aux_index=self.aux_index,
|
aux_index=self.aux_index,
|
||||||
|
state_indices=state_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
def poll(self) -> KVPoll:
|
def poll(self) -> KVPoll:
|
||||||
@@ -1110,6 +1224,9 @@ class MooncakeKVReceiver(CommonKVReceiver):
|
|||||||
packed_aux_data_ptrs = b"".join(
|
packed_aux_data_ptrs = b"".join(
|
||||||
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
||||||
)
|
)
|
||||||
|
packed_state_data_ptrs = b"".join(
|
||||||
|
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.state_data_ptrs
|
||||||
|
)
|
||||||
# Note(shangming): No need to add pp rank here since pp is not supported on the decode side yet
|
# Note(shangming): No need to add pp rank here since pp is not supported on the decode side yet
|
||||||
tp_rank = self.kv_mgr.kv_args.engine_rank
|
tp_rank = self.kv_mgr.kv_args.engine_rank
|
||||||
kv_item_len = self.kv_mgr.kv_args.kv_item_lens[0]
|
kv_item_len = self.kv_mgr.kv_args.kv_item_lens[0]
|
||||||
@@ -1127,13 +1244,19 @@ class MooncakeKVReceiver(CommonKVReceiver):
|
|||||||
self.session_id.encode("ascii"),
|
self.session_id.encode("ascii"),
|
||||||
packed_kv_data_ptrs,
|
packed_kv_data_ptrs,
|
||||||
packed_aux_data_ptrs,
|
packed_aux_data_ptrs,
|
||||||
|
packed_state_data_ptrs,
|
||||||
dst_tp_rank,
|
dst_tp_rank,
|
||||||
dst_attn_tp_size,
|
dst_attn_tp_size,
|
||||||
dst_kv_item_len,
|
dst_kv_item_len,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
def init(
|
||||||
|
self,
|
||||||
|
kv_indices: npt.NDArray[np.int32],
|
||||||
|
aux_index: Optional[int] = None,
|
||||||
|
state_indices: Optional[List[int]] = None,
|
||||||
|
):
|
||||||
for bootstrap_info in self.bootstrap_infos:
|
for bootstrap_info in self.bootstrap_infos:
|
||||||
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
|
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
|
||||||
is_dummy = bootstrap_info["is_dummy"]
|
is_dummy = bootstrap_info["is_dummy"]
|
||||||
@@ -1147,6 +1270,14 @@ class MooncakeKVReceiver(CommonKVReceiver):
|
|||||||
self.session_id.encode("ascii"),
|
self.session_id.encode("ascii"),
|
||||||
kv_indices.tobytes() if not is_dummy else b"",
|
kv_indices.tobytes() if not is_dummy else b"",
|
||||||
str(aux_index).encode("ascii") if not is_dummy else b"",
|
str(aux_index).encode("ascii") if not is_dummy else b"",
|
||||||
|
(
|
||||||
|
np.array(
|
||||||
|
state_indices,
|
||||||
|
dtype=np.int32,
|
||||||
|
).tobytes()
|
||||||
|
if not is_dummy and state_indices is not None
|
||||||
|
else b""
|
||||||
|
),
|
||||||
str(self.required_dst_info_num).encode("ascii"),
|
str(self.required_dst_info_num).encode("ascii"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -704,6 +704,7 @@ class NixlKVSender(CommonKVSender):
|
|||||||
def send(
|
def send(
|
||||||
self,
|
self,
|
||||||
kv_indices: npt.NDArray[np.int32],
|
kv_indices: npt.NDArray[np.int32],
|
||||||
|
state_indices: Optional[List[int]] = None,
|
||||||
):
|
):
|
||||||
index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
|
index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
|
||||||
self.curr_idx += len(kv_indices)
|
self.curr_idx += len(kv_indices)
|
||||||
@@ -755,7 +756,12 @@ class NixlKVReceiver(CommonKVReceiver):
|
|||||||
self.bootstrap_room
|
self.bootstrap_room
|
||||||
)
|
)
|
||||||
|
|
||||||
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
def init(
|
||||||
|
self,
|
||||||
|
kv_indices: npt.NDArray[np.int32],
|
||||||
|
aux_index: Optional[int] = None,
|
||||||
|
state_indices: Optional[List[int]] = None,
|
||||||
|
):
|
||||||
for bootstrap_info in self.bootstrap_infos:
|
for bootstrap_info in self.bootstrap_infos:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
||||||
|
|||||||
@@ -49,6 +49,11 @@ from sglang.srt.managers.schedule_batch import (
|
|||||||
RequestStage,
|
RequestStage,
|
||||||
ScheduleBatch,
|
ScheduleBatch,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.mem_cache.memory_pool import (
|
||||||
|
HybridLinearKVPool,
|
||||||
|
NSATokenToKVPool,
|
||||||
|
SWAKVPool,
|
||||||
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
DynamicGradMode,
|
DynamicGradMode,
|
||||||
@@ -146,6 +151,28 @@ class PrefillBootstrapQueue:
|
|||||||
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
||||||
kv_args.gpu_id = self.scheduler.gpu_id
|
kv_args.gpu_id = self.scheduler.gpu_id
|
||||||
|
|
||||||
|
if hasattr(self.token_to_kv_pool, "get_state_buf_infos"):
|
||||||
|
state_data_ptrs, state_data_lens, state_item_lens = (
|
||||||
|
self.token_to_kv_pool.get_state_buf_infos()
|
||||||
|
)
|
||||||
|
kv_args.state_data_ptrs = state_data_ptrs
|
||||||
|
kv_args.state_data_lens = state_data_lens
|
||||||
|
kv_args.state_item_lens = state_item_lens
|
||||||
|
|
||||||
|
if isinstance(self.token_to_kv_pool, SWAKVPool):
|
||||||
|
kv_args.state_type = "swa"
|
||||||
|
elif isinstance(self.token_to_kv_pool, HybridLinearKVPool):
|
||||||
|
kv_args.state_type = "mamba"
|
||||||
|
elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):
|
||||||
|
kv_args.state_type = "nsa"
|
||||||
|
else:
|
||||||
|
kv_args.state_type = "none"
|
||||||
|
else:
|
||||||
|
kv_args.state_data_ptrs = []
|
||||||
|
kv_args.state_data_lens = []
|
||||||
|
kv_args.state_item_lens = []
|
||||||
|
kv_args.state_type = "none"
|
||||||
|
|
||||||
kv_manager_class: Type[BaseKVManager] = get_kv_class(
|
kv_manager_class: Type[BaseKVManager] = get_kv_class(
|
||||||
self.transfer_backend, KVClassType.MANAGER
|
self.transfer_backend, KVClassType.MANAGER
|
||||||
)
|
)
|
||||||
@@ -618,15 +645,58 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
.numpy()
|
.numpy()
|
||||||
)
|
)
|
||||||
req.start_send_idx = end_idx
|
req.start_send_idx = end_idx
|
||||||
|
state_indices = None
|
||||||
if last_chunk:
|
if last_chunk:
|
||||||
self.disagg_metadata_buffers.set_buf(req)
|
self.disagg_metadata_buffers.set_buf(req)
|
||||||
|
|
||||||
|
# Prepare extra pool indices for hybrid models
|
||||||
|
if isinstance(
|
||||||
|
self.token_to_kv_pool_allocator.get_kvcache(), HybridLinearKVPool
|
||||||
|
):
|
||||||
|
# Mamba hybrid model: send single mamba state index
|
||||||
|
state_indices = [
|
||||||
|
self.req_to_token_pool.req_index_to_mamba_index_mapping[
|
||||||
|
req.req_pool_idx
|
||||||
|
]
|
||||||
|
.cpu()
|
||||||
|
.numpy()
|
||||||
|
]
|
||||||
|
elif isinstance(self.token_to_kv_pool_allocator.get_kvcache(), SWAKVPool):
|
||||||
|
# SWA hybrid model: send last window KV indices
|
||||||
|
seq_len = len(req.fill_ids)
|
||||||
|
window_size = self.sliding_window_size
|
||||||
|
window_start = max(0, seq_len - window_size)
|
||||||
|
window_start = (window_start // page_size) * page_size
|
||||||
|
|
||||||
|
window_kv_indices_full = self.req_to_token_pool.req_to_token[
|
||||||
|
req.req_pool_idx, window_start:seq_len
|
||||||
|
]
|
||||||
|
|
||||||
|
# Translate to SWA pool indices
|
||||||
|
window_kv_indices_swa = (
|
||||||
|
self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
|
||||||
|
window_kv_indices_full
|
||||||
|
)
|
||||||
|
)
|
||||||
|
state_indices = window_kv_indices_swa.cpu().numpy()
|
||||||
|
state_indices = kv_to_page_indices(state_indices, page_size)
|
||||||
|
elif isinstance(
|
||||||
|
self.token_to_kv_pool_allocator.get_kvcache(), NSATokenToKVPool
|
||||||
|
):
|
||||||
|
seq_len = len(req.fill_ids)
|
||||||
|
kv_indices_full = self.req_to_token_pool.req_to_token[
|
||||||
|
req.req_pool_idx, :seq_len
|
||||||
|
]
|
||||||
|
state_indices = kv_indices_full.cpu().numpy()
|
||||||
|
state_indices = kv_to_page_indices(state_indices, page_size)
|
||||||
|
|
||||||
page_indices = kv_to_page_indices(kv_indices, page_size)
|
page_indices = kv_to_page_indices(kv_indices, page_size)
|
||||||
if len(page_indices) == 0:
|
if len(page_indices) == 0:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty"
|
f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
req.disagg_kv_sender.send(page_indices)
|
req.disagg_kv_sender.send(page_indices, state_indices)
|
||||||
|
|
||||||
# PP
|
# PP
|
||||||
@DynamicGradMode()
|
@DynamicGradMode()
|
||||||
|
|||||||
@@ -807,9 +807,6 @@ class Scheduler(
|
|||||||
self.tree_cache.cache_controller.layer_done_counter
|
self.tree_cache.cache_controller.layer_done_counter
|
||||||
)
|
)
|
||||||
elif self.is_hybrid:
|
elif self.is_hybrid:
|
||||||
assert (
|
|
||||||
self.server_args.disaggregation_mode == "null"
|
|
||||||
), "Hybrid mode does not support disaggregation yet"
|
|
||||||
self.tree_cache = SWARadixCache(
|
self.tree_cache = SWARadixCache(
|
||||||
req_to_token_pool=self.req_to_token_pool,
|
req_to_token_pool=self.req_to_token_pool,
|
||||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||||
@@ -819,9 +816,6 @@ class Scheduler(
|
|||||||
is_eagle=self.spec_algorithm.is_eagle(),
|
is_eagle=self.spec_algorithm.is_eagle(),
|
||||||
)
|
)
|
||||||
elif self.is_hybrid_gdn:
|
elif self.is_hybrid_gdn:
|
||||||
assert (
|
|
||||||
self.server_args.disaggregation_mode == "null"
|
|
||||||
), "Hybrid GDN mode does not support disaggregation yet"
|
|
||||||
self.tree_cache = MambaRadixCache(
|
self.tree_cache = MambaRadixCache(
|
||||||
req_to_token_pool=self.req_to_token_pool,
|
req_to_token_pool=self.req_to_token_pool,
|
||||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||||
|
|||||||
@@ -142,72 +142,93 @@ class MambaPool:
|
|||||||
ssm_dtype = cache_params.dtype.temporal
|
ssm_dtype = cache_params.dtype.temporal
|
||||||
num_mamba_layers = len(cache_params.layers)
|
num_mamba_layers = len(cache_params.layers)
|
||||||
|
|
||||||
# assume conv_state = (dim, state_len)
|
# for disagg with nvlink
|
||||||
assert conv_state_shape[0] > conv_state_shape[1]
|
self.enable_custom_mem_pool = get_bool_env_var(
|
||||||
conv_state = torch.zeros(
|
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
|
||||||
size=(num_mamba_layers, size + 1) + conv_state_shape,
|
|
||||||
dtype=conv_dtype,
|
|
||||||
device=device,
|
|
||||||
)
|
)
|
||||||
temporal_state = torch.zeros(
|
if self.enable_custom_mem_pool:
|
||||||
size=(num_mamba_layers, size + 1) + temporal_state_shape,
|
# TODO(shangming): abstract custom allocator class for more backends
|
||||||
dtype=ssm_dtype,
|
from mooncake.allocator import NVLinkAllocator
|
||||||
device=device,
|
|
||||||
)
|
allocator = NVLinkAllocator.get_allocator(self.device)
|
||||||
if speculative_num_draft_tokens is not None:
|
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
|
||||||
# Cache intermediate SSM states per draft token during target verify
|
|
||||||
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V]
|
|
||||||
intermediate_ssm_state_cache = torch.zeros(
|
|
||||||
size=(
|
|
||||||
num_mamba_layers,
|
|
||||||
size + 1,
|
|
||||||
speculative_num_draft_tokens,
|
|
||||||
temporal_state_shape[0],
|
|
||||||
temporal_state_shape[1],
|
|
||||||
temporal_state_shape[2],
|
|
||||||
),
|
|
||||||
dtype=ssm_dtype,
|
|
||||||
device="cuda",
|
|
||||||
)
|
|
||||||
# Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
|
|
||||||
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
|
|
||||||
intermediate_conv_window_cache = torch.zeros(
|
|
||||||
size=(
|
|
||||||
num_mamba_layers,
|
|
||||||
size + 1,
|
|
||||||
speculative_num_draft_tokens,
|
|
||||||
conv_state_shape[0],
|
|
||||||
conv_state_shape[1],
|
|
||||||
),
|
|
||||||
dtype=conv_dtype,
|
|
||||||
device="cuda",
|
|
||||||
)
|
|
||||||
self.mamba_cache = self.SpeculativeState(
|
|
||||||
conv=conv_state,
|
|
||||||
temporal=temporal_state,
|
|
||||||
intermediate_ssm=intermediate_ssm_state_cache,
|
|
||||||
intermediate_conv_window=intermediate_conv_window_cache,
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Mamba Cache is allocated. "
|
|
||||||
f"max_mamba_cache_size: {size}, "
|
|
||||||
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
|
|
||||||
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
|
|
||||||
f"intermediate_ssm_state_cache size: {get_tensor_size_bytes(intermediate_ssm_state_cache) / GB:.2f}GB "
|
|
||||||
f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB "
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.mamba_cache = self.State(conv=conv_state, temporal=temporal_state)
|
self.custom_mem_pool = None
|
||||||
logger.info(
|
|
||||||
f"Mamba Cache is allocated. "
|
with (
|
||||||
f"max_mamba_cache_size: {size}, "
|
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
||||||
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
|
if self.enable_custom_mem_pool
|
||||||
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
|
else nullcontext()
|
||||||
|
):
|
||||||
|
# assume conv_state = (dim, state_len)
|
||||||
|
assert conv_state_shape[0] > conv_state_shape[1]
|
||||||
|
conv_state = torch.zeros(
|
||||||
|
size=(num_mamba_layers, size + 1) + conv_state_shape,
|
||||||
|
dtype=conv_dtype,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
self.size = size
|
temporal_state = torch.zeros(
|
||||||
self.device = device
|
size=(num_mamba_layers, size + 1) + temporal_state_shape,
|
||||||
self.free_slots = torch.arange(self.size, dtype=torch.int64, device=self.device)
|
dtype=ssm_dtype,
|
||||||
self.mem_usage = self.mamba_cache.mem_usage_bytes() / GB
|
device=device,
|
||||||
|
)
|
||||||
|
if speculative_num_draft_tokens is not None:
|
||||||
|
# Cache intermediate SSM states per draft token during target verify
|
||||||
|
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V]
|
||||||
|
intermediate_ssm_state_cache = torch.zeros(
|
||||||
|
size=(
|
||||||
|
num_mamba_layers,
|
||||||
|
size + 1,
|
||||||
|
speculative_num_draft_tokens,
|
||||||
|
temporal_state_shape[0],
|
||||||
|
temporal_state_shape[1],
|
||||||
|
temporal_state_shape[2],
|
||||||
|
),
|
||||||
|
dtype=ssm_dtype,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
# Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
|
||||||
|
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
|
||||||
|
intermediate_conv_window_cache = torch.zeros(
|
||||||
|
size=(
|
||||||
|
num_mamba_layers,
|
||||||
|
size + 1,
|
||||||
|
speculative_num_draft_tokens,
|
||||||
|
conv_state_shape[0],
|
||||||
|
conv_state_shape[1],
|
||||||
|
),
|
||||||
|
dtype=conv_dtype,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
self.mamba_cache = self.SpeculativeState(
|
||||||
|
conv=conv_state,
|
||||||
|
temporal=temporal_state,
|
||||||
|
intermediate_ssm=intermediate_ssm_state_cache,
|
||||||
|
intermediate_conv_window=intermediate_conv_window_cache,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Mamba Cache is allocated. "
|
||||||
|
f"max_mamba_cache_size: {size}, "
|
||||||
|
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
|
||||||
|
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
|
||||||
|
f"intermediate_ssm_state_cache size: {get_tensor_size_bytes(intermediate_ssm_state_cache) / GB:.2f}GB "
|
||||||
|
f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB "
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.mamba_cache = self.State(conv=conv_state, temporal=temporal_state)
|
||||||
|
logger.info(
|
||||||
|
f"Mamba Cache is allocated. "
|
||||||
|
f"max_mamba_cache_size: {size}, "
|
||||||
|
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
|
||||||
|
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
|
||||||
|
)
|
||||||
|
self.size = size
|
||||||
|
self.device = device
|
||||||
|
self.free_slots = torch.arange(
|
||||||
|
self.size, dtype=torch.int64, device=self.device
|
||||||
|
)
|
||||||
|
self.mem_usage = self.mamba_cache.mem_usage_bytes() / GB
|
||||||
|
self.num_mamba_layers = num_mamba_layers
|
||||||
|
|
||||||
def get_speculative_mamba2_params_all_layers(self) -> SpeculativeState:
|
def get_speculative_mamba2_params_all_layers(self) -> SpeculativeState:
|
||||||
assert isinstance(self.mamba_cache, self.SpeculativeState)
|
assert isinstance(self.mamba_cache, self.SpeculativeState)
|
||||||
@@ -253,6 +274,22 @@ class MambaPool:
|
|||||||
self.copy_from(src_index, dst_index)
|
self.copy_from(src_index, dst_index)
|
||||||
return dst_index
|
return dst_index
|
||||||
|
|
||||||
|
def get_contiguous_buf_infos(self):
|
||||||
|
state_tensors = [
|
||||||
|
getattr(self.mamba_cache, field) for field in vars(self.mamba_cache)
|
||||||
|
]
|
||||||
|
data_ptrs, data_lens, item_lens = [], [], []
|
||||||
|
|
||||||
|
for _, state_tensor in enumerate(state_tensors):
|
||||||
|
data_ptrs += [
|
||||||
|
state_tensor[i].data_ptr() for i in range(self.num_mamba_layers)
|
||||||
|
]
|
||||||
|
data_lens += [state_tensor[i].nbytes for i in range(self.num_mamba_layers)]
|
||||||
|
item_lens += [
|
||||||
|
state_tensor[i][0].nbytes for i in range(self.num_mamba_layers)
|
||||||
|
]
|
||||||
|
return data_ptrs, data_lens, item_lens
|
||||||
|
|
||||||
|
|
||||||
class HybridReqToTokenPool(ReqToTokenPool):
|
class HybridReqToTokenPool(ReqToTokenPool):
|
||||||
"""A memory pool that maps a request to its token locations."""
|
"""A memory pool that maps a request to its token locations."""
|
||||||
@@ -274,13 +311,26 @@ class HybridReqToTokenPool(ReqToTokenPool):
|
|||||||
device=device,
|
device=device,
|
||||||
enable_memory_saver=enable_memory_saver,
|
enable_memory_saver=enable_memory_saver,
|
||||||
)
|
)
|
||||||
|
self._init_mamba_pool(
|
||||||
self.mamba_pool = MambaPool(
|
|
||||||
size=mamba_size,
|
size=mamba_size,
|
||||||
cache_params=cache_params,
|
cache_params=cache_params,
|
||||||
device=device,
|
device=device,
|
||||||
speculative_num_draft_tokens=speculative_num_draft_tokens,
|
speculative_num_draft_tokens=speculative_num_draft_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _init_mamba_pool(
|
||||||
|
self,
|
||||||
|
size: int,
|
||||||
|
cache_params: "Mamba2CacheParams",
|
||||||
|
device: str,
|
||||||
|
speculative_num_draft_tokens: int = None,
|
||||||
|
):
|
||||||
|
self.mamba_pool = MambaPool(
|
||||||
|
size=size,
|
||||||
|
cache_params=cache_params,
|
||||||
|
device=device,
|
||||||
|
speculative_num_draft_tokens=speculative_num_draft_tokens,
|
||||||
|
)
|
||||||
self.mamba_map = {layer_id: i for i, layer_id in enumerate(cache_params.layers)}
|
self.mamba_map = {layer_id: i for i, layer_id in enumerate(cache_params.layers)}
|
||||||
|
|
||||||
self.device = device
|
self.device = device
|
||||||
@@ -375,6 +425,19 @@ class KVCache(abc.ABC):
|
|||||||
# default state for optional layer-wise transfer control
|
# default state for optional layer-wise transfer control
|
||||||
self.layer_transfer_counter = None
|
self.layer_transfer_counter = None
|
||||||
|
|
||||||
|
# for disagg with nvlink
|
||||||
|
self.enable_custom_mem_pool = get_bool_env_var(
|
||||||
|
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
|
||||||
|
)
|
||||||
|
if self.enable_custom_mem_pool:
|
||||||
|
# TODO(shangming): abstract custom allocator class for more backends
|
||||||
|
from mooncake.allocator import NVLinkAllocator
|
||||||
|
|
||||||
|
allocator = NVLinkAllocator.get_allocator(self.device)
|
||||||
|
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
|
||||||
|
else:
|
||||||
|
self.custom_mem_pool = None
|
||||||
|
|
||||||
def _finalize_allocation_log(self, num_tokens: int):
|
def _finalize_allocation_log(self, num_tokens: int):
|
||||||
"""Common logging and mem_usage computation for KV cache allocation.
|
"""Common logging and mem_usage computation for KV cache allocation.
|
||||||
Supports both tuple (K, V) size returns and single KV size returns.
|
Supports both tuple (K, V) size returns and single KV size returns.
|
||||||
@@ -426,6 +489,9 @@ class KVCache(abc.ABC):
|
|||||||
def load_cpu_copy(self, kv_cache_cpu, indices):
|
def load_cpu_copy(self, kv_cache_cpu, indices):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def maybe_get_custom_mem_pool(self):
|
||||||
|
return self.custom_mem_pool
|
||||||
|
|
||||||
|
|
||||||
class MHATokenToKVPool(KVCache):
|
class MHATokenToKVPool(KVCache):
|
||||||
|
|
||||||
@@ -456,19 +522,6 @@ class MHATokenToKVPool(KVCache):
|
|||||||
self.head_num = head_num
|
self.head_num = head_num
|
||||||
self.head_dim = head_dim
|
self.head_dim = head_dim
|
||||||
|
|
||||||
# for disagg with nvlink
|
|
||||||
self.enable_custom_mem_pool = get_bool_env_var(
|
|
||||||
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
|
|
||||||
)
|
|
||||||
if self.enable_custom_mem_pool:
|
|
||||||
# TODO(shangming): abstract custom allocator class for more backends
|
|
||||||
from mooncake.allocator import NVLinkAllocator
|
|
||||||
|
|
||||||
allocator = NVLinkAllocator.get_allocator(self.device)
|
|
||||||
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
|
|
||||||
else:
|
|
||||||
self.custom_mem_pool = None
|
|
||||||
|
|
||||||
self._create_buffers()
|
self._create_buffers()
|
||||||
|
|
||||||
self.device_module = torch.get_device_module(self.device)
|
self.device_module = torch.get_device_module(self.device)
|
||||||
@@ -611,9 +664,6 @@ class MHATokenToKVPool(KVCache):
|
|||||||
]
|
]
|
||||||
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
||||||
|
|
||||||
def maybe_get_custom_mem_pool(self):
|
|
||||||
return self.custom_mem_pool
|
|
||||||
|
|
||||||
def get_cpu_copy(self, indices):
|
def get_cpu_copy(self, indices):
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
kv_cache_cpu = []
|
kv_cache_cpu = []
|
||||||
@@ -756,12 +806,18 @@ class HybridLinearKVPool(KVCache):
|
|||||||
full_attention_layer_ids: List[int],
|
full_attention_layer_ids: List[int],
|
||||||
enable_kvcache_transpose: bool,
|
enable_kvcache_transpose: bool,
|
||||||
device: str,
|
device: str,
|
||||||
|
mamba_pool: MambaPool,
|
||||||
):
|
):
|
||||||
self.size = size
|
self.size = size
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.device = device
|
self.device = device
|
||||||
self.full_layer_nums = len(full_attention_layer_ids)
|
self.full_layer_nums = len(full_attention_layer_ids)
|
||||||
self.page_size = page_size
|
self.page_size = page_size
|
||||||
|
# TODO support pp?
|
||||||
|
self.start_layer = 0
|
||||||
|
self.head_num = head_num
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.mamba_pool = mamba_pool
|
||||||
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
|
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
|
||||||
assert not enable_kvcache_transpose
|
assert not enable_kvcache_transpose
|
||||||
if _is_npu:
|
if _is_npu:
|
||||||
@@ -790,6 +846,15 @@ class HybridLinearKVPool(KVCache):
|
|||||||
def get_contiguous_buf_infos(self):
|
def get_contiguous_buf_infos(self):
|
||||||
return self.full_kv_pool.get_contiguous_buf_infos()
|
return self.full_kv_pool.get_contiguous_buf_infos()
|
||||||
|
|
||||||
|
def get_state_buf_infos(self):
|
||||||
|
mamba_data_ptrs, mamba_data_lens, mamba_item_lens = (
|
||||||
|
self.mamba_pool.get_contiguous_buf_infos()
|
||||||
|
)
|
||||||
|
return mamba_data_ptrs, mamba_data_lens, mamba_item_lens
|
||||||
|
|
||||||
|
def maybe_get_custom_mem_pool(self):
|
||||||
|
return self.full_kv_pool.maybe_get_custom_mem_pool()
|
||||||
|
|
||||||
def _transfer_full_attention_id(self, layer_id: int):
|
def _transfer_full_attention_id(self, layer_id: int):
|
||||||
if layer_id not in self.full_attention_layer_id_mapping:
|
if layer_id not in self.full_attention_layer_id_mapping:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -841,22 +906,47 @@ class SWAKVPool(KVCache):
|
|||||||
size: int,
|
size: int,
|
||||||
size_swa: int,
|
size_swa: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
|
head_num: int,
|
||||||
|
head_dim: int,
|
||||||
swa_attention_layer_ids: List[int],
|
swa_attention_layer_ids: List[int],
|
||||||
full_attention_layer_ids: List[int],
|
full_attention_layer_ids: List[int],
|
||||||
enable_kvcache_transpose: bool,
|
enable_kvcache_transpose: bool,
|
||||||
|
device: str,
|
||||||
token_to_kv_pool_class: KVCache = MHATokenToKVPool,
|
token_to_kv_pool_class: KVCache = MHATokenToKVPool,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.size = size
|
self.size = size
|
||||||
self.size_swa = size_swa
|
self.size_swa = size_swa
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
self.head_num = head_num
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.device = device
|
||||||
self.swa_layer_nums = len(swa_attention_layer_ids)
|
self.swa_layer_nums = len(swa_attention_layer_ids)
|
||||||
self.full_layer_nums = len(full_attention_layer_ids)
|
self.full_layer_nums = len(full_attention_layer_ids)
|
||||||
|
self.start_layer = 0
|
||||||
|
self.page_size = 1
|
||||||
|
|
||||||
kwargs["page_size"] = 1
|
kwargs["page_size"] = 1
|
||||||
kwargs["enable_memory_saver"] = False
|
kwargs["enable_memory_saver"] = False
|
||||||
|
kwargs["head_num"] = head_num
|
||||||
|
kwargs["head_dim"] = head_dim
|
||||||
|
kwargs["device"] = device
|
||||||
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
|
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
|
||||||
assert not enable_kvcache_transpose
|
assert not enable_kvcache_transpose
|
||||||
|
|
||||||
|
# for disagg with nvlink
|
||||||
|
self.enable_custom_mem_pool = get_bool_env_var(
|
||||||
|
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
|
||||||
|
)
|
||||||
|
if self.enable_custom_mem_pool:
|
||||||
|
# TODO(shangming): abstract custom allocator class for more backends
|
||||||
|
from mooncake.allocator import NVLinkAllocator
|
||||||
|
|
||||||
|
allocator = NVLinkAllocator.get_allocator(self.device)
|
||||||
|
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
|
||||||
|
else:
|
||||||
|
self.custom_mem_pool = None
|
||||||
|
|
||||||
self.swa_kv_pool = token_to_kv_pool_class(
|
self.swa_kv_pool = token_to_kv_pool_class(
|
||||||
size=size_swa,
|
size=size_swa,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
@@ -878,6 +968,9 @@ class SWAKVPool(KVCache):
|
|||||||
|
|
||||||
k_size, v_size = self.get_kv_size_bytes()
|
k_size, v_size = self.get_kv_size_bytes()
|
||||||
self.mem_usage = (k_size + v_size) / GB
|
self.mem_usage = (k_size + v_size) / GB
|
||||||
|
logger.info(
|
||||||
|
f"SWAKVPool mem usage: {self.mem_usage} GB, swa size: {self.size_swa}, full size: {self.size}"
|
||||||
|
)
|
||||||
|
|
||||||
def get_kv_size_bytes(self):
|
def get_kv_size_bytes(self):
|
||||||
k_size, v_size = self.full_kv_pool.get_kv_size_bytes()
|
k_size, v_size = self.full_kv_pool.get_kv_size_bytes()
|
||||||
@@ -888,15 +981,19 @@ class SWAKVPool(KVCache):
|
|||||||
full_kv_data_ptrs, full_kv_data_lens, full_kv_item_lens = (
|
full_kv_data_ptrs, full_kv_data_lens, full_kv_item_lens = (
|
||||||
self.full_kv_pool.get_contiguous_buf_infos()
|
self.full_kv_pool.get_contiguous_buf_infos()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
kv_data_ptrs = full_kv_data_ptrs
|
||||||
|
kv_data_lens = full_kv_data_lens
|
||||||
|
kv_item_lens = full_kv_item_lens
|
||||||
|
|
||||||
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
||||||
|
|
||||||
|
def get_state_buf_infos(self):
|
||||||
swa_kv_data_ptrs, swa_kv_data_lens, swa_kv_item_lens = (
|
swa_kv_data_ptrs, swa_kv_data_lens, swa_kv_item_lens = (
|
||||||
self.swa_kv_pool.get_contiguous_buf_infos()
|
self.swa_kv_pool.get_contiguous_buf_infos()
|
||||||
)
|
)
|
||||||
|
|
||||||
kv_data_ptrs = full_kv_data_ptrs + swa_kv_data_ptrs
|
return swa_kv_data_ptrs, swa_kv_data_lens, swa_kv_item_lens
|
||||||
kv_data_lens = full_kv_data_lens + swa_kv_data_lens
|
|
||||||
kv_item_lens = full_kv_item_lens + swa_kv_item_lens
|
|
||||||
|
|
||||||
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
|
||||||
|
|
||||||
def get_key_buffer(self, layer_id: int):
|
def get_key_buffer(self, layer_id: int):
|
||||||
layer_id_pool, is_swa = self.layers_mapping[layer_id]
|
layer_id_pool, is_swa = self.layers_mapping[layer_id]
|
||||||
@@ -1152,19 +1249,6 @@ class MLATokenToKVPool(KVCache):
|
|||||||
else (kv_lora_rank + qk_rope_head_dim)
|
else (kv_lora_rank + qk_rope_head_dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
# for disagg with nvlink
|
|
||||||
self.enable_custom_mem_pool = get_bool_env_var(
|
|
||||||
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
|
|
||||||
)
|
|
||||||
if self.enable_custom_mem_pool:
|
|
||||||
# TODO(shangming): abstract custom allocator class for more backends
|
|
||||||
from mooncake.allocator import NVLinkAllocator
|
|
||||||
|
|
||||||
allocator = NVLinkAllocator.get_allocator(self.device)
|
|
||||||
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
|
|
||||||
else:
|
|
||||||
self.custom_mem_pool = None
|
|
||||||
|
|
||||||
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
||||||
with (
|
with (
|
||||||
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
||||||
@@ -1207,9 +1291,6 @@ class MLATokenToKVPool(KVCache):
|
|||||||
]
|
]
|
||||||
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
||||||
|
|
||||||
def maybe_get_custom_mem_pool(self):
|
|
||||||
return self.custom_mem_pool
|
|
||||||
|
|
||||||
def get_key_buffer(self, layer_id: int):
|
def get_key_buffer(self, layer_id: int):
|
||||||
if self.layer_transfer_counter is not None:
|
if self.layer_transfer_counter is not None:
|
||||||
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
||||||
@@ -1346,24 +1427,31 @@ class NSATokenToKVPool(MLATokenToKVPool):
|
|||||||
assert index_head_dim == 128
|
assert index_head_dim == 128
|
||||||
|
|
||||||
assert self.page_size == 64
|
assert self.page_size == 64
|
||||||
self.index_k_with_scale_buffer = [
|
with (
|
||||||
torch.zeros(
|
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
||||||
# Layout:
|
if self.custom_mem_pool
|
||||||
# ref: test_attention.py :: kv_cache_cast_to_fp8
|
else nullcontext()
|
||||||
# shape: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4)
|
):
|
||||||
# data: for page i,
|
self.index_k_with_scale_buffer = [
|
||||||
# * buf[i, :page_size * head_dim] for fp8 data
|
torch.zeros(
|
||||||
# * buf[i, page_size * head_dim:].view(float32) for scale
|
# Layout:
|
||||||
(
|
# ref: test_attention.py :: kv_cache_cast_to_fp8
|
||||||
(size + page_size + 1) // self.page_size,
|
# shape: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4)
|
||||||
self.page_size
|
# data: for page i,
|
||||||
* (index_head_dim + index_head_dim // self.quant_block_size * 4),
|
# * buf[i, :page_size * head_dim] for fp8 data
|
||||||
),
|
# * buf[i, page_size * head_dim:].view(float32) for scale
|
||||||
dtype=self.index_k_with_scale_buffer_dtype,
|
(
|
||||||
device=device,
|
(size + page_size + 1) // self.page_size,
|
||||||
)
|
self.page_size
|
||||||
for _ in range(layer_num)
|
* (
|
||||||
]
|
index_head_dim + index_head_dim // self.quant_block_size * 4
|
||||||
|
),
|
||||||
|
),
|
||||||
|
dtype=self.index_k_with_scale_buffer_dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
for _ in range(layer_num)
|
||||||
|
]
|
||||||
self._finalize_allocation_log(size)
|
self._finalize_allocation_log(size)
|
||||||
|
|
||||||
def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor:
|
def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor:
|
||||||
@@ -1406,6 +1494,18 @@ class NSATokenToKVPool(MLATokenToKVPool):
|
|||||||
pool=self, buf=buf, loc=loc, index_k=index_k, index_k_scale=index_k_scale
|
pool=self, buf=buf, loc=loc, index_k=index_k, index_k_scale=index_k_scale
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_state_buf_infos(self):
|
||||||
|
data_ptrs = [
|
||||||
|
self.index_k_with_scale_buffer[i].data_ptr() for i in range(self.layer_num)
|
||||||
|
]
|
||||||
|
data_lens = [
|
||||||
|
self.index_k_with_scale_buffer[i].nbytes for i in range(self.layer_num)
|
||||||
|
]
|
||||||
|
item_lens = [
|
||||||
|
self.index_k_with_scale_buffer[i][0].nbytes for i in range(self.layer_num)
|
||||||
|
]
|
||||||
|
return data_ptrs, data_lens, item_lens
|
||||||
|
|
||||||
def get_kv_size_bytes(self):
|
def get_kv_size_bytes(self):
|
||||||
kv_size_bytes = super().get_kv_size_bytes()
|
kv_size_bytes = super().get_kv_size_bytes()
|
||||||
for index_k_cache in self.index_k_with_scale_buffer:
|
for index_k_cache in self.index_k_with_scale_buffer:
|
||||||
@@ -1636,27 +1736,38 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|||||||
)
|
)
|
||||||
|
|
||||||
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
||||||
# [size, head_num, head_dim] for each layer
|
with (
|
||||||
self.k_buffer = [
|
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
||||||
torch.zeros(
|
if self.enable_custom_mem_pool
|
||||||
(size + page_size, head_num, head_dim), dtype=dtype, device=device
|
else nullcontext()
|
||||||
)
|
):
|
||||||
for _ in range(layer_num)
|
# [size, head_num, head_dim] for each layer
|
||||||
]
|
self.k_buffer = [
|
||||||
self.v_buffer = [
|
torch.zeros(
|
||||||
torch.zeros(
|
(size + page_size, head_num, head_dim),
|
||||||
(size + page_size, head_num, head_dim), dtype=dtype, device=device
|
dtype=dtype,
|
||||||
)
|
device=device,
|
||||||
for _ in range(layer_num)
|
)
|
||||||
]
|
for _ in range(layer_num)
|
||||||
|
]
|
||||||
|
self.v_buffer = [
|
||||||
|
torch.zeros(
|
||||||
|
(size + page_size, head_num, head_dim),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
for _ in range(layer_num)
|
||||||
|
]
|
||||||
|
|
||||||
# [size, head_num, heavy_channel_num] for each layer
|
# [size, head_num, heavy_channel_num] for each layer
|
||||||
self.label_buffer = [
|
self.label_buffer = [
|
||||||
torch.zeros(
|
torch.zeros(
|
||||||
(size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
|
(size + 1, head_num, heavy_channel_num),
|
||||||
)
|
dtype=dtype,
|
||||||
for _ in range(layer_num)
|
device=device,
|
||||||
]
|
)
|
||||||
|
for _ in range(layer_num)
|
||||||
|
]
|
||||||
|
|
||||||
def get_key_buffer(self, layer_id: int):
|
def get_key_buffer(self, layer_id: int):
|
||||||
return self.k_buffer[layer_id - self.start_layer]
|
return self.k_buffer[layer_id - self.start_layer]
|
||||||
|
|||||||
@@ -1669,19 +1669,34 @@ class ModelRunner:
|
|||||||
extra_max_context_len += self.server_args.speculative_num_draft_tokens
|
extra_max_context_len += self.server_args.speculative_num_draft_tokens
|
||||||
|
|
||||||
if self.server_args.disaggregation_mode == "decode":
|
if self.server_args.disaggregation_mode == "decode":
|
||||||
from sglang.srt.disaggregation.decode import DecodeReqToTokenPool
|
from sglang.srt.disaggregation.decode import (
|
||||||
|
DecodeReqToTokenPool,
|
||||||
|
HybridMambaDecodeReqToTokenPool,
|
||||||
|
)
|
||||||
|
|
||||||
# subscribe memory for pre-allocated requests
|
# subscribe memory for pre-allocated requests
|
||||||
# if max_num_reqs <= 32, we pre-allocate 2x requests
|
# if max_num_reqs <= 32, we pre-allocate 2x requests
|
||||||
pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0
|
pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0
|
||||||
self.req_to_token_pool = DecodeReqToTokenPool(
|
if config := self.mambaish_config:
|
||||||
size=max_num_reqs,
|
self.req_to_token_pool = HybridMambaDecodeReqToTokenPool(
|
||||||
max_context_len=self.model_config.context_len
|
size=max_num_reqs,
|
||||||
+ extra_max_context_len,
|
max_context_len=self.model_config.context_len
|
||||||
device=self.device,
|
+ extra_max_context_len,
|
||||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
device=self.device,
|
||||||
pre_alloc_size=pre_alloc_size,
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||||
)
|
cache_params=config.mamba2_cache_params,
|
||||||
|
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
|
||||||
|
pre_alloc_size=pre_alloc_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.req_to_token_pool = DecodeReqToTokenPool(
|
||||||
|
size=max_num_reqs,
|
||||||
|
max_context_len=self.model_config.context_len
|
||||||
|
+ extra_max_context_len,
|
||||||
|
device=self.device,
|
||||||
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||||
|
pre_alloc_size=pre_alloc_size,
|
||||||
|
)
|
||||||
elif config := self.mambaish_config:
|
elif config := self.mambaish_config:
|
||||||
self.req_to_token_pool = HybridReqToTokenPool(
|
self.req_to_token_pool = HybridReqToTokenPool(
|
||||||
size=max_num_reqs,
|
size=max_num_reqs,
|
||||||
@@ -1807,6 +1822,7 @@ class ModelRunner:
|
|||||||
),
|
),
|
||||||
enable_kvcache_transpose=False,
|
enable_kvcache_transpose=False,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
mamba_pool=self.req_to_token_pool.mamba_pool,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.token_to_kv_pool = MHATokenToKVPool(
|
self.token_to_kv_pool = MHATokenToKVPool(
|
||||||
|
|||||||
@@ -163,6 +163,7 @@ suites = {
|
|||||||
TestFile("test_deepseek_v3_basic.py", 275),
|
TestFile("test_deepseek_v3_basic.py", 275),
|
||||||
TestFile("test_deepseek_v3_mtp.py", 275),
|
TestFile("test_deepseek_v3_mtp.py", 275),
|
||||||
TestFile("test_disaggregation_different_tp.py", 600),
|
TestFile("test_disaggregation_different_tp.py", 600),
|
||||||
|
TestFile("test_disaggregation_hybrid_attention.py", 200),
|
||||||
TestFile("test_disaggregation_pp.py", 140),
|
TestFile("test_disaggregation_pp.py", 140),
|
||||||
],
|
],
|
||||||
"per-commit-4-gpu-b200": [
|
"per-commit-4-gpu-b200": [
|
||||||
|
|||||||
83
test/srt/test_disaggregation_hybrid_attention.py
Normal file
83
test/srt/test_disaggregation_hybrid_attention.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from sglang.srt.environ import envs
|
||||||
|
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
||||||
|
from sglang.test.test_disaggregation_utils import TestDisaggregationBase
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
popen_launch_pd_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDisaggregationHybridAttentionMamba(TestDisaggregationBase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
super().setUpClass()
|
||||||
|
cls.model = "Qwen/Qwen3-Next-80B-A3B-Instruct"
|
||||||
|
|
||||||
|
# Non blocking start servers
|
||||||
|
cls.start_prefill()
|
||||||
|
cls.start_decode()
|
||||||
|
|
||||||
|
# Block until both
|
||||||
|
cls.wait_server_ready(cls.prefill_url + "/health")
|
||||||
|
cls.wait_server_ready(cls.decode_url + "/health")
|
||||||
|
|
||||||
|
cls.launch_lb()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def start_prefill(cls):
|
||||||
|
prefill_args = [
|
||||||
|
"--trust-remote-code",
|
||||||
|
"--disaggregation-mode",
|
||||||
|
"prefill",
|
||||||
|
"--tp",
|
||||||
|
"4",
|
||||||
|
]
|
||||||
|
prefill_args += cls.transfer_backend + cls.rdma_devices
|
||||||
|
cls.process_prefill = popen_launch_pd_server(
|
||||||
|
cls.model,
|
||||||
|
cls.prefill_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=prefill_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def start_decode(cls):
|
||||||
|
decode_args = [
|
||||||
|
"--trust-remote-code",
|
||||||
|
"--disaggregation-mode",
|
||||||
|
"decode",
|
||||||
|
"--tp",
|
||||||
|
"4",
|
||||||
|
"--base-gpu-id",
|
||||||
|
"4",
|
||||||
|
]
|
||||||
|
decode_args += cls.transfer_backend + cls.rdma_devices
|
||||||
|
cls.process_decode = popen_launch_pd_server(
|
||||||
|
cls.model,
|
||||||
|
cls.decode_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=decode_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_gsm8k(self):
|
||||||
|
args = SimpleNamespace(
|
||||||
|
num_shots=5,
|
||||||
|
data_path=None,
|
||||||
|
num_questions=200,
|
||||||
|
max_new_tokens=512,
|
||||||
|
parallel=128,
|
||||||
|
host=f"http://{self.base_host}",
|
||||||
|
port=int(self.lb_port),
|
||||||
|
)
|
||||||
|
metrics = run_eval_few_shot_gsm8k(args)
|
||||||
|
print(f"Evaluation metrics: {metrics}")
|
||||||
|
|
||||||
|
self.assertGreater(metrics["accuracy"], 0.93)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -42,6 +42,7 @@ class TestMamba(unittest.TestCase):
|
|||||||
full_attention_layer_ids=full_attention_layer_ids,
|
full_attention_layer_ids=full_attention_layer_ids,
|
||||||
enable_kvcache_transpose=False,
|
enable_kvcache_transpose=False,
|
||||||
device=device,
|
device=device,
|
||||||
|
mamba_pool=None,
|
||||||
)
|
)
|
||||||
assert pool._transfer_full_attention_id(global_interval - 1) == 0
|
assert pool._transfer_full_attention_id(global_interval - 1) == 0
|
||||||
assert pool._transfer_full_attention_id(2 * global_interval - 1) == 1
|
assert pool._transfer_full_attention_id(2 * global_interval - 1) == 1
|
||||||
@@ -173,6 +174,7 @@ class TestMamba(unittest.TestCase):
|
|||||||
full_attention_layer_ids=full_attention_layer_ids,
|
full_attention_layer_ids=full_attention_layer_ids,
|
||||||
enable_kvcache_transpose=False,
|
enable_kvcache_transpose=False,
|
||||||
device=device,
|
device=device,
|
||||||
|
mamba_pool=req_to_token_pool.mamba_pool,
|
||||||
)
|
)
|
||||||
|
|
||||||
# setup token to kv pool allocator
|
# setup token to kv pool allocator
|
||||||
|
|||||||
Reference in New Issue
Block a user