diff --git a/python/sglang/srt/disaggregation/base/conn.py b/python/sglang/srt/disaggregation/base/conn.py index 3f5877ea3..341d33bcf 100644 --- a/python/sglang/srt/disaggregation/base/conn.py +++ b/python/sglang/srt/disaggregation/base/conn.py @@ -20,6 +20,10 @@ class KVArgs: aux_data_ptrs: List[int] aux_data_lens: List[int] aux_item_lens: List[int] + state_data_ptrs: List[int] + state_data_lens: List[int] + state_item_lens: List[int] + state_type: str # "none", "mamba", "swa" ib_device: str ib_traffic_class: str gpu_id: int @@ -76,9 +80,13 @@ class BaseKVSender(ABC): ... @abstractmethod - def send(self, kv_indices: npt.NDArray[np.int32]): + def send( + self, + kv_indices: npt.NDArray[np.int32], + state_indices: Optional[List[int]] = None, + ): """ - Send the kv cache at the given kv indices to the decoder server + Send the kv cache at the given kv indices and the extra cache/state at the given indices to the decoder server """ ... @@ -108,9 +116,14 @@ class BaseKVReceiver(ABC): ): ... @abstractmethod - def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None): + def init( + self, + kv_indices: npt.NDArray[np.int32], + aux_index: Optional[int] = None, + state_indices: Optional[List[int]] = None, + ): """ - Notify the prefill server about the kv indices and aux index + Notify the prefill server about the kv indices, aux index, and state_indices. """ ... diff --git a/python/sglang/srt/disaggregation/common/conn.py b/python/sglang/srt/disaggregation/common/conn.py index 82876066f..e34778a38 100644 --- a/python/sglang/srt/disaggregation/common/conn.py +++ b/python/sglang/srt/disaggregation/common/conn.py @@ -201,6 +201,7 @@ class CommonKVSender(BaseKVSender): def send( self, kv_indices: npt.NDArray[np.int32], + state_indices: Optional[List[int]] = None, ): pass diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 8f0e1d6b5..45589ec51 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -25,11 +25,12 @@ import time from collections import deque from dataclasses import dataclass from http import HTTPStatus -from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union import torch from torch.distributed import ProcessGroup +from sglang.srt.configs.mamba_utils import Mamba2CacheParams from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll from sglang.srt.disaggregation.utils import ( @@ -47,9 +48,19 @@ from sglang.srt.disaggregation.utils import ( ) from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.managers.schedule_batch import FINISH_ABORT, RequestStage, ScheduleBatch -from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator +from sglang.srt.mem_cache.allocator import ( + BaseTokenToKVPoolAllocator, + SWATokenToKVPoolAllocator, +) from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache -from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool +from sglang.srt.mem_cache.memory_pool import ( + HybridLinearKVPool, + HybridReqToTokenPool, + KVCache, + NSATokenToKVPool, + ReqToTokenPool, + SWAKVPool, +) from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.utils import get_int_env_var, require_mlp_sync from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter @@ -124,6 +135,35 @@ class DecodeReqToTokenPool: self.free_slots = list(range(self.size + self.pre_alloc_size)) +class HybridMambaDecodeReqToTokenPool(HybridReqToTokenPool): + + def __init__( + self, + size: int, + max_context_len: int, + device: str, + enable_memory_saver: bool, + cache_params: "Mamba2CacheParams", + speculative_num_draft_tokens: int, + pre_alloc_size: int, + ): + DecodeReqToTokenPool.__init__( + self, + size=size, + max_context_len=max_context_len, + device=device, + enable_memory_saver=enable_memory_saver, + pre_alloc_size=pre_alloc_size, + ) + self._init_mamba_pool( + size + pre_alloc_size, cache_params, device, speculative_num_draft_tokens + ) + + def clear(self): + self.free_slots = list(range(self.size + self.pre_alloc_size)) + self.mamba_pool.clear() + + @dataclass class DecodeRequest: req: Req @@ -217,6 +257,28 @@ class DecodePreallocQueue: self.metadata_buffers.get_buf_infos() ) + if hasattr(self.token_to_kv_pool, "get_state_buf_infos"): + state_data_ptrs, state_data_lens, state_item_lens = ( + self.token_to_kv_pool.get_state_buf_infos() + ) + kv_args.state_data_ptrs = state_data_ptrs + kv_args.state_data_lens = state_data_lens + kv_args.state_item_lens = state_item_lens + + if isinstance(self.token_to_kv_pool, SWAKVPool): + kv_args.state_type = "swa" + elif isinstance(self.token_to_kv_pool, HybridLinearKVPool): + kv_args.state_type = "mamba" + elif isinstance(self.token_to_kv_pool, NSATokenToKVPool): + kv_args.state_type = "nsa" + else: + kv_args.state_type = "none" + else: + kv_args.state_data_ptrs = [] + kv_args.state_data_lens = [] + kv_args.state_item_lens = [] + kv_args.state_type = "none" + kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device kv_args.gpu_id = self.scheduler.gpu_id kv_manager_class: Type[BaseKVManager] = get_kv_class( @@ -414,16 +476,56 @@ class DecodePreallocQueue: .cpu() .numpy() ) + page_size = self.token_to_kv_pool_allocator.page_size + + # Prepare extra pool indices for hybrid models + if isinstance(self.token_to_kv_pool, HybridLinearKVPool): + # Mamba hybrid model: single mamba state index + state_indices = [ + self.req_to_token_pool.req_index_to_mamba_index_mapping[ + decode_req.req.req_pool_idx + ] + .cpu() + .numpy() + ] + elif isinstance(self.token_to_kv_pool, SWAKVPool): + # SWA hybrid model: send decode-side SWA window indices + seq_len = len(decode_req.req.origin_input_ids) + window_size = self.scheduler.sliding_window_size + + window_start = max(0, seq_len - window_size) + window_start = (window_start // page_size) * page_size + window_kv_indices_full = self.req_to_token_pool.req_to_token[ + decode_req.req.req_pool_idx, window_start:seq_len + ] + + # Translate to SWA pool indices + window_kv_indices_swa = ( + self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa( + window_kv_indices_full + ) + ) + state_indices = window_kv_indices_swa.cpu().numpy() + state_indices = kv_to_page_indices(state_indices, page_size) + elif isinstance(self.token_to_kv_pool, NSATokenToKVPool): + seq_len = len(decode_req.req.origin_input_ids) + kv_indices_full = self.req_to_token_pool.req_to_token[ + decode_req.req.req_pool_idx, :seq_len + ] + state_indices = kv_indices_full.cpu().numpy() + state_indices = kv_to_page_indices(state_indices, page_size) + else: + state_indices = None decode_req.metadata_buffer_index = ( self.req_to_metadata_buffer_idx_allocator.alloc() ) assert decode_req.metadata_buffer_index is not None - page_indices = kv_to_page_indices( - kv_indices, self.token_to_kv_pool_allocator.page_size + page_indices = kv_to_page_indices(kv_indices, page_size) + decode_req.kv_receiver.init( + page_indices, decode_req.metadata_buffer_index, state_indices ) - decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index) - + decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP) preallocated_reqs.append(decode_req) indices_to_remove.add(i) decode_req.req.time_stats.decode_transfer_queue_entry_time = ( @@ -503,7 +605,10 @@ class DecodePreallocQueue: def _pre_alloc(self, req: Req) -> torch.Tensor: """Pre-allocate the memory for req_to_token and token_kv_pool""" - req_pool_indices = self.req_to_token_pool.alloc(1) + if isinstance(self.req_to_token_pool, HybridMambaDecodeReqToTokenPool): + req_pool_indices = self.req_to_token_pool.alloc(1, [req]) + else: + req_pool_indices = self.req_to_token_pool.alloc(1) assert ( req_pool_indices is not None diff --git a/python/sglang/srt/disaggregation/fake/conn.py b/python/sglang/srt/disaggregation/fake/conn.py index 120633824..e759465e4 100644 --- a/python/sglang/srt/disaggregation/fake/conn.py +++ b/python/sglang/srt/disaggregation/fake/conn.py @@ -48,9 +48,12 @@ class FakeKVSender(BaseKVSender): def send( self, kv_indices: npt.NDArray[np.int32], + state_indices: Optional[List[int]] = None, ): self.has_sent = True - logger.debug(f"FakeKVSender send with kv_indices: {kv_indices}") + logger.debug( + f"FakeKVSender send with kv_indices: {kv_indices}, state_indices: {state_indices}" + ) def failure_exception(self): raise Exception("Fake KVSender Exception") @@ -75,10 +78,15 @@ class FakeKVReceiver(BaseKVReceiver): logger.debug("FakeKVReceiver poll success") return KVPoll.Success - def init(self, kv_indices: list[int], aux_index: Optional[int] = None): + def init( + self, + kv_indices: list[int], + aux_index: Optional[int] = None, + state_indices: Optional[List[int]] = None, + ): self.has_init = True logger.debug( - f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}" + f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}, state_indices: {state_indices}" ) def failure_exception(self): diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index b6f12e46e..8013f0f0b 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -58,6 +58,7 @@ class TransferKVChunk: index_slice: slice is_last: bool prefill_aux_index: Optional[int] + state_indices: Optional[List[int]] # decode @@ -69,6 +70,7 @@ class TransferInfo: mooncake_session_id: str dst_kv_indices: npt.NDArray[np.int32] dst_aux_index: int + dst_state_indices: List[int] required_dst_info_num: int is_dummy: bool @@ -78,9 +80,14 @@ class TransferInfo: is_dummy = True dst_kv_indices = np.array([], dtype=np.int32) dst_aux_index = None + dst_state_indices = [] else: dst_kv_indices = np.frombuffer(msg[4], dtype=np.int32) dst_aux_index = int(msg[5].decode("ascii")) + if msg[6] == b"": + dst_state_indices = [] + else: + dst_state_indices = list(np.frombuffer(msg[6], dtype=np.int32)) is_dummy = False return cls( room=int(msg[0].decode("ascii")), @@ -89,7 +96,8 @@ class TransferInfo: mooncake_session_id=msg[3].decode("ascii"), dst_kv_indices=dst_kv_indices, dst_aux_index=dst_aux_index, - required_dst_info_num=int(msg[6].decode("ascii")), + dst_state_indices=dst_state_indices, + required_dst_info_num=int(msg[7].decode("ascii")), is_dummy=is_dummy, ) @@ -103,6 +111,7 @@ class KVArgsRegisterInfo: mooncake_session_id: str dst_kv_ptrs: list[int] dst_aux_ptrs: list[int] + dst_state_data_ptrs: list[int] dst_tp_rank: int dst_attn_tp_size: int dst_kv_item_len: int @@ -116,9 +125,10 @@ class KVArgsRegisterInfo: mooncake_session_id=msg[3].decode("ascii"), dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])), dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])), - dst_tp_rank=int(msg[6].decode("ascii")), - dst_attn_tp_size=int(msg[7].decode("ascii")), - dst_kv_item_len=int(msg[8].decode("ascii")), + dst_state_data_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])), + dst_tp_rank=int(msg[7].decode("ascii")), + dst_attn_tp_size=int(msg[8].decode("ascii")), + dst_kv_item_len=int(msg[9].decode("ascii")), ) @@ -180,6 +190,9 @@ class MooncakeKVManager(CommonKVManager): ) for _ in range(transfer_queue_size) ] + self.state_executors = concurrent.futures.ThreadPoolExecutor( + transfer_thread_pool_size // transfer_queue_size + ) for queue, executor in zip(self.transfer_queues, self.executors): threading.Thread( target=self.transfer_worker, args=(queue, executor), daemon=True @@ -239,6 +252,12 @@ class MooncakeKVManager(CommonKVManager): self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens ) + # Batch register state/extra pool data buffers + if self.kv_args.state_data_ptrs and self.kv_args.state_data_lens: + self.engine.batch_register( + self.kv_args.state_data_ptrs, self.kv_args.state_data_lens + ) + def _transfer_data(self, mooncake_session_id, transfer_blocks): if not transfer_blocks: return 0 @@ -248,17 +267,23 @@ class MooncakeKVManager(CommonKVManager): mooncake_session_id, list(src_addrs), list(dst_addrs), list(lengths) ) - def send_kvcache( + def _send_kvcache_generic( self, mooncake_session_id: str, - prefill_kv_indices: npt.NDArray[np.int32], - dst_kv_ptrs: list[int], - dst_kv_indices: npt.NDArray[np.int32], + src_data_ptrs: list[int], + dst_data_ptrs: list[int], + item_lens: list[int], + prefill_data_indices: npt.NDArray[np.int32], + dst_data_indices: npt.NDArray[np.int32], executor: concurrent.futures.ThreadPoolExecutor, - ): - # Group by indices + ) -> int: + """ + Generic KV cache transfer supporting both MHA and MLA architectures. + This method is used by both send_kvcache (full pool) and maybe_send_extra. + """ + # Group by indices for optimization prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous( - prefill_kv_indices, dst_kv_indices + prefill_data_indices, dst_data_indices ) layers_params = None @@ -266,9 +291,9 @@ class MooncakeKVManager(CommonKVManager): # pp is not supported on the decode side yet if self.is_mla_backend: src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = ( - self.get_mla_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs) + self.get_mla_kv_ptrs_with_pp(src_data_ptrs, dst_data_ptrs) ) - kv_item_len = self.kv_args.kv_item_lens[0] + kv_item_len = item_lens[0] layers_params = [ ( src_kv_ptrs[layer_id], @@ -279,9 +304,9 @@ class MooncakeKVManager(CommonKVManager): ] else: src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = ( - self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs) + self.get_mha_kv_ptrs_with_pp(src_data_ptrs, dst_data_ptrs) ) - kv_item_len = self.kv_args.kv_item_lens[0] + kv_item_len = item_lens[0] layers_params = [ ( src_k_ptrs[layer_id], @@ -345,6 +370,24 @@ class MooncakeKVManager(CommonKVManager): return 0 + def send_kvcache( + self, + mooncake_session_id: str, + prefill_kv_indices: npt.NDArray[np.int32], + dst_kv_ptrs: list[int], + dst_kv_indices: npt.NDArray[np.int32], + executor: concurrent.futures.ThreadPoolExecutor, + ): + return self._send_kvcache_generic( + mooncake_session_id=mooncake_session_id, + src_data_ptrs=self.kv_args.kv_data_ptrs, + dst_data_ptrs=dst_kv_ptrs, + item_lens=self.kv_args.kv_item_lens, + prefill_data_indices=prefill_kv_indices, + dst_data_indices=dst_kv_indices, + executor=executor, + ) + def send_kvcache_slice( self, mooncake_session_id: str, @@ -593,6 +636,58 @@ class MooncakeKVManager(CommonKVManager): f"Received AUX_DATA for bootstrap_room {room} with length:{len(data)}" ) + def maybe_send_extra( + self, + req: TransferInfo, + prefill_state_indices: list[int], + dst_state_data_ptrs: list[int], + ): + """Send state or extra pool data with type-specific handling.""" + state_type = getattr(self.kv_args, "state_type", "none") + + if state_type == "mamba": + return self._send_mamba_state( + req, + prefill_state_indices, + dst_state_data_ptrs, + ) + elif state_type in ["swa", "nsa"]: + # Reuse _send_kvcache_generic interface to send extra pool data + prefill_state_indices = np.array(prefill_state_indices, dtype=np.int32) + dst_state_indices = np.array(req.dst_state_indices, dtype=np.int32) + return self._send_kvcache_generic( + mooncake_session_id=req.mooncake_session_id, + src_data_ptrs=self.kv_args.state_data_ptrs, + dst_data_ptrs=dst_state_data_ptrs, + item_lens=self.kv_args.state_item_lens, + prefill_data_indices=prefill_state_indices, + dst_data_indices=dst_state_indices, + executor=self.state_executors, + ) + else: + return 0 + + def _send_mamba_state( + self, + req: TransferInfo, + prefill_mamba_index: list[int], + dst_state_data_ptrs: list[int], + ): + """Transfer Mamba states.""" + assert len(prefill_mamba_index) == 1, "Mamba should have single state index" + + transfer_blocks = [] + prefill_state_data_ptrs = self.kv_args.state_data_ptrs + prefill_state_item_lens = self.kv_args.state_item_lens + + for i, dst_state_ptr in enumerate(dst_state_data_ptrs): + length = prefill_state_item_lens[i] + src_addr = prefill_state_data_ptrs[i] + length * int(prefill_mamba_index[0]) + dst_addr = dst_state_ptr + length * int(req.dst_state_indices[0]) + transfer_blocks.append((src_addr, dst_addr, length)) + + return self._transfer_data(req.mooncake_session_id, transfer_blocks) + def sync_status_to_decode_endpoint( self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int ): @@ -702,6 +797,21 @@ class MooncakeKVManager(CommonKVManager): break if kv_chunk.is_last: + if kv_chunk.state_indices is not None: + if not self.is_mla_backend and ( + self.attn_tp_size + != target_rank_registration_info.dst_attn_tp_size + ): + raise RuntimeError( + f"PD Disaggregation does NOT support PD different TP sizes for non-MLA hybrid models yet." + ) + + self.maybe_send_extra( + req, + kv_chunk.state_indices, + target_rank_registration_info.dst_state_data_ptrs, + ) + if self.pp_group.is_last_rank: # Only the last chunk we need to send the aux data ret = self.send_aux( @@ -765,7 +875,7 @@ class MooncakeKVManager(CommonKVManager): ) continue else: - required_dst_info_num = int(waiting_req_bytes[6].decode("ascii")) + required_dst_info_num = int(waiting_req_bytes[7].decode("ascii")) room = int(room) if room not in self.transfer_infos: self.transfer_infos[room] = {} @@ -876,6 +986,7 @@ class MooncakeKVManager(CommonKVManager): index_slice: slice, is_last: bool, aux_index: Optional[int] = None, + state_indices: Optional[List[int]] = None, ): assert self.disaggregation_mode == DisaggregationMode.PREFILL assert not is_last or (is_last and aux_index is not None) @@ -909,6 +1020,7 @@ class MooncakeKVManager(CommonKVManager): index_slice=index_slice, is_last=is_last, prefill_aux_index=aux_index, + state_indices=state_indices, ) ) @@ -989,6 +1101,7 @@ class MooncakeKVSender(CommonKVSender): def send( self, kv_indices: npt.NDArray[np.int32], + state_indices: Optional[List[int]] = None, ): index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices)) self.curr_idx += len(kv_indices) @@ -1008,6 +1121,7 @@ class MooncakeKVSender(CommonKVSender): index_slice, True, aux_index=self.aux_index, + state_indices=state_indices, ) def poll(self) -> KVPoll: @@ -1110,6 +1224,9 @@ class MooncakeKVReceiver(CommonKVReceiver): packed_aux_data_ptrs = b"".join( struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs ) + packed_state_data_ptrs = b"".join( + struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.state_data_ptrs + ) # Note(shangming): No need to add pp rank here since pp is not supported on the decode side yet tp_rank = self.kv_mgr.kv_args.engine_rank kv_item_len = self.kv_mgr.kv_args.kv_item_lens[0] @@ -1127,13 +1244,19 @@ class MooncakeKVReceiver(CommonKVReceiver): self.session_id.encode("ascii"), packed_kv_data_ptrs, packed_aux_data_ptrs, + packed_state_data_ptrs, dst_tp_rank, dst_attn_tp_size, dst_kv_item_len, ] ) - def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None): + def init( + self, + kv_indices: npt.NDArray[np.int32], + aux_index: Optional[int] = None, + state_indices: Optional[List[int]] = None, + ): for bootstrap_info in self.bootstrap_infos: sock, lock = self._connect_to_bootstrap_server(bootstrap_info) is_dummy = bootstrap_info["is_dummy"] @@ -1147,6 +1270,14 @@ class MooncakeKVReceiver(CommonKVReceiver): self.session_id.encode("ascii"), kv_indices.tobytes() if not is_dummy else b"", str(aux_index).encode("ascii") if not is_dummy else b"", + ( + np.array( + state_indices, + dtype=np.int32, + ).tobytes() + if not is_dummy and state_indices is not None + else b"" + ), str(self.required_dst_info_num).encode("ascii"), ] ) diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index df5f9e49c..8d9bdffc6 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -704,6 +704,7 @@ class NixlKVSender(CommonKVSender): def send( self, kv_indices: npt.NDArray[np.int32], + state_indices: Optional[List[int]] = None, ): index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices)) self.curr_idx += len(kv_indices) @@ -755,7 +756,12 @@ class NixlKVReceiver(CommonKVReceiver): self.bootstrap_room ) - def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None): + def init( + self, + kv_indices: npt.NDArray[np.int32], + aux_index: Optional[int] = None, + state_indices: Optional[List[int]] = None, + ): for bootstrap_info in self.bootstrap_infos: logger.debug( f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}" diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index b9884414c..23cd0dd17 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -49,6 +49,11 @@ from sglang.srt.managers.schedule_batch import ( RequestStage, ScheduleBatch, ) +from sglang.srt.mem_cache.memory_pool import ( + HybridLinearKVPool, + NSATokenToKVPool, + SWAKVPool, +) from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors from sglang.srt.utils import ( DynamicGradMode, @@ -146,6 +151,28 @@ class PrefillBootstrapQueue: kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device kv_args.gpu_id = self.scheduler.gpu_id + if hasattr(self.token_to_kv_pool, "get_state_buf_infos"): + state_data_ptrs, state_data_lens, state_item_lens = ( + self.token_to_kv_pool.get_state_buf_infos() + ) + kv_args.state_data_ptrs = state_data_ptrs + kv_args.state_data_lens = state_data_lens + kv_args.state_item_lens = state_item_lens + + if isinstance(self.token_to_kv_pool, SWAKVPool): + kv_args.state_type = "swa" + elif isinstance(self.token_to_kv_pool, HybridLinearKVPool): + kv_args.state_type = "mamba" + elif isinstance(self.token_to_kv_pool, NSATokenToKVPool): + kv_args.state_type = "nsa" + else: + kv_args.state_type = "none" + else: + kv_args.state_data_ptrs = [] + kv_args.state_data_lens = [] + kv_args.state_item_lens = [] + kv_args.state_type = "none" + kv_manager_class: Type[BaseKVManager] = get_kv_class( self.transfer_backend, KVClassType.MANAGER ) @@ -618,15 +645,58 @@ class SchedulerDisaggregationPrefillMixin: .numpy() ) req.start_send_idx = end_idx + state_indices = None if last_chunk: self.disagg_metadata_buffers.set_buf(req) + + # Prepare extra pool indices for hybrid models + if isinstance( + self.token_to_kv_pool_allocator.get_kvcache(), HybridLinearKVPool + ): + # Mamba hybrid model: send single mamba state index + state_indices = [ + self.req_to_token_pool.req_index_to_mamba_index_mapping[ + req.req_pool_idx + ] + .cpu() + .numpy() + ] + elif isinstance(self.token_to_kv_pool_allocator.get_kvcache(), SWAKVPool): + # SWA hybrid model: send last window KV indices + seq_len = len(req.fill_ids) + window_size = self.sliding_window_size + window_start = max(0, seq_len - window_size) + window_start = (window_start // page_size) * page_size + + window_kv_indices_full = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, window_start:seq_len + ] + + # Translate to SWA pool indices + window_kv_indices_swa = ( + self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa( + window_kv_indices_full + ) + ) + state_indices = window_kv_indices_swa.cpu().numpy() + state_indices = kv_to_page_indices(state_indices, page_size) + elif isinstance( + self.token_to_kv_pool_allocator.get_kvcache(), NSATokenToKVPool + ): + seq_len = len(req.fill_ids) + kv_indices_full = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, :seq_len + ] + state_indices = kv_indices_full.cpu().numpy() + state_indices = kv_to_page_indices(state_indices, page_size) + page_indices = kv_to_page_indices(kv_indices, page_size) if len(page_indices) == 0: logger.info( f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty" ) return - req.disagg_kv_sender.send(page_indices) + req.disagg_kv_sender.send(page_indices, state_indices) # PP @DynamicGradMode() diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index f9601c9ac..ce0148e98 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -807,9 +807,6 @@ class Scheduler( self.tree_cache.cache_controller.layer_done_counter ) elif self.is_hybrid: - assert ( - self.server_args.disaggregation_mode == "null" - ), "Hybrid mode does not support disaggregation yet" self.tree_cache = SWARadixCache( req_to_token_pool=self.req_to_token_pool, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, @@ -819,9 +816,6 @@ class Scheduler( is_eagle=self.spec_algorithm.is_eagle(), ) elif self.is_hybrid_gdn: - assert ( - self.server_args.disaggregation_mode == "null" - ), "Hybrid GDN mode does not support disaggregation yet" self.tree_cache = MambaRadixCache( req_to_token_pool=self.req_to_token_pool, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 15d48142c..c468269f3 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -142,72 +142,93 @@ class MambaPool: ssm_dtype = cache_params.dtype.temporal num_mamba_layers = len(cache_params.layers) - # assume conv_state = (dim, state_len) - assert conv_state_shape[0] > conv_state_shape[1] - conv_state = torch.zeros( - size=(num_mamba_layers, size + 1) + conv_state_shape, - dtype=conv_dtype, - device=device, + # for disagg with nvlink + self.enable_custom_mem_pool = get_bool_env_var( + "SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false" ) - temporal_state = torch.zeros( - size=(num_mamba_layers, size + 1) + temporal_state_shape, - dtype=ssm_dtype, - device=device, - ) - if speculative_num_draft_tokens is not None: - # Cache intermediate SSM states per draft token during target verify - # Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V] - intermediate_ssm_state_cache = torch.zeros( - size=( - num_mamba_layers, - size + 1, - speculative_num_draft_tokens, - temporal_state_shape[0], - temporal_state_shape[1], - temporal_state_shape[2], - ), - dtype=ssm_dtype, - device="cuda", - ) - # Cache intermediate conv windows (last K-1 inputs) per draft token during target verify - # Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1] - intermediate_conv_window_cache = torch.zeros( - size=( - num_mamba_layers, - size + 1, - speculative_num_draft_tokens, - conv_state_shape[0], - conv_state_shape[1], - ), - dtype=conv_dtype, - device="cuda", - ) - self.mamba_cache = self.SpeculativeState( - conv=conv_state, - temporal=temporal_state, - intermediate_ssm=intermediate_ssm_state_cache, - intermediate_conv_window=intermediate_conv_window_cache, - ) - logger.info( - f"Mamba Cache is allocated. " - f"max_mamba_cache_size: {size}, " - f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, " - f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB " - f"intermediate_ssm_state_cache size: {get_tensor_size_bytes(intermediate_ssm_state_cache) / GB:.2f}GB " - f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB " - ) + if self.enable_custom_mem_pool: + # TODO(shangming): abstract custom allocator class for more backends + from mooncake.allocator import NVLinkAllocator + + allocator = NVLinkAllocator.get_allocator(self.device) + self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator()) else: - self.mamba_cache = self.State(conv=conv_state, temporal=temporal_state) - logger.info( - f"Mamba Cache is allocated. " - f"max_mamba_cache_size: {size}, " - f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, " - f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB " + self.custom_mem_pool = None + + with ( + torch.cuda.use_mem_pool(self.custom_mem_pool) + if self.enable_custom_mem_pool + else nullcontext() + ): + # assume conv_state = (dim, state_len) + assert conv_state_shape[0] > conv_state_shape[1] + conv_state = torch.zeros( + size=(num_mamba_layers, size + 1) + conv_state_shape, + dtype=conv_dtype, + device=device, ) - self.size = size - self.device = device - self.free_slots = torch.arange(self.size, dtype=torch.int64, device=self.device) - self.mem_usage = self.mamba_cache.mem_usage_bytes() / GB + temporal_state = torch.zeros( + size=(num_mamba_layers, size + 1) + temporal_state_shape, + dtype=ssm_dtype, + device=device, + ) + if speculative_num_draft_tokens is not None: + # Cache intermediate SSM states per draft token during target verify + # Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V] + intermediate_ssm_state_cache = torch.zeros( + size=( + num_mamba_layers, + size + 1, + speculative_num_draft_tokens, + temporal_state_shape[0], + temporal_state_shape[1], + temporal_state_shape[2], + ), + dtype=ssm_dtype, + device="cuda", + ) + # Cache intermediate conv windows (last K-1 inputs) per draft token during target verify + # Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1] + intermediate_conv_window_cache = torch.zeros( + size=( + num_mamba_layers, + size + 1, + speculative_num_draft_tokens, + conv_state_shape[0], + conv_state_shape[1], + ), + dtype=conv_dtype, + device="cuda", + ) + self.mamba_cache = self.SpeculativeState( + conv=conv_state, + temporal=temporal_state, + intermediate_ssm=intermediate_ssm_state_cache, + intermediate_conv_window=intermediate_conv_window_cache, + ) + logger.info( + f"Mamba Cache is allocated. " + f"max_mamba_cache_size: {size}, " + f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, " + f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB " + f"intermediate_ssm_state_cache size: {get_tensor_size_bytes(intermediate_ssm_state_cache) / GB:.2f}GB " + f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB " + ) + else: + self.mamba_cache = self.State(conv=conv_state, temporal=temporal_state) + logger.info( + f"Mamba Cache is allocated. " + f"max_mamba_cache_size: {size}, " + f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, " + f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB " + ) + self.size = size + self.device = device + self.free_slots = torch.arange( + self.size, dtype=torch.int64, device=self.device + ) + self.mem_usage = self.mamba_cache.mem_usage_bytes() / GB + self.num_mamba_layers = num_mamba_layers def get_speculative_mamba2_params_all_layers(self) -> SpeculativeState: assert isinstance(self.mamba_cache, self.SpeculativeState) @@ -253,6 +274,22 @@ class MambaPool: self.copy_from(src_index, dst_index) return dst_index + def get_contiguous_buf_infos(self): + state_tensors = [ + getattr(self.mamba_cache, field) for field in vars(self.mamba_cache) + ] + data_ptrs, data_lens, item_lens = [], [], [] + + for _, state_tensor in enumerate(state_tensors): + data_ptrs += [ + state_tensor[i].data_ptr() for i in range(self.num_mamba_layers) + ] + data_lens += [state_tensor[i].nbytes for i in range(self.num_mamba_layers)] + item_lens += [ + state_tensor[i][0].nbytes for i in range(self.num_mamba_layers) + ] + return data_ptrs, data_lens, item_lens + class HybridReqToTokenPool(ReqToTokenPool): """A memory pool that maps a request to its token locations.""" @@ -274,13 +311,26 @@ class HybridReqToTokenPool(ReqToTokenPool): device=device, enable_memory_saver=enable_memory_saver, ) - - self.mamba_pool = MambaPool( + self._init_mamba_pool( size=mamba_size, cache_params=cache_params, device=device, speculative_num_draft_tokens=speculative_num_draft_tokens, ) + + def _init_mamba_pool( + self, + size: int, + cache_params: "Mamba2CacheParams", + device: str, + speculative_num_draft_tokens: int = None, + ): + self.mamba_pool = MambaPool( + size=size, + cache_params=cache_params, + device=device, + speculative_num_draft_tokens=speculative_num_draft_tokens, + ) self.mamba_map = {layer_id: i for i, layer_id in enumerate(cache_params.layers)} self.device = device @@ -375,6 +425,19 @@ class KVCache(abc.ABC): # default state for optional layer-wise transfer control self.layer_transfer_counter = None + # for disagg with nvlink + self.enable_custom_mem_pool = get_bool_env_var( + "SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false" + ) + if self.enable_custom_mem_pool: + # TODO(shangming): abstract custom allocator class for more backends + from mooncake.allocator import NVLinkAllocator + + allocator = NVLinkAllocator.get_allocator(self.device) + self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator()) + else: + self.custom_mem_pool = None + def _finalize_allocation_log(self, num_tokens: int): """Common logging and mem_usage computation for KV cache allocation. Supports both tuple (K, V) size returns and single KV size returns. @@ -426,6 +489,9 @@ class KVCache(abc.ABC): def load_cpu_copy(self, kv_cache_cpu, indices): raise NotImplementedError() + def maybe_get_custom_mem_pool(self): + return self.custom_mem_pool + class MHATokenToKVPool(KVCache): @@ -456,19 +522,6 @@ class MHATokenToKVPool(KVCache): self.head_num = head_num self.head_dim = head_dim - # for disagg with nvlink - self.enable_custom_mem_pool = get_bool_env_var( - "SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false" - ) - if self.enable_custom_mem_pool: - # TODO(shangming): abstract custom allocator class for more backends - from mooncake.allocator import NVLinkAllocator - - allocator = NVLinkAllocator.get_allocator(self.device) - self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator()) - else: - self.custom_mem_pool = None - self._create_buffers() self.device_module = torch.get_device_module(self.device) @@ -611,9 +664,6 @@ class MHATokenToKVPool(KVCache): ] return kv_data_ptrs, kv_data_lens, kv_item_lens - def maybe_get_custom_mem_pool(self): - return self.custom_mem_pool - def get_cpu_copy(self, indices): torch.cuda.synchronize() kv_cache_cpu = [] @@ -756,12 +806,18 @@ class HybridLinearKVPool(KVCache): full_attention_layer_ids: List[int], enable_kvcache_transpose: bool, device: str, + mamba_pool: MambaPool, ): self.size = size self.dtype = dtype self.device = device self.full_layer_nums = len(full_attention_layer_ids) self.page_size = page_size + # TODO support pp? + self.start_layer = 0 + self.head_num = head_num + self.head_dim = head_dim + self.mamba_pool = mamba_pool # TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True assert not enable_kvcache_transpose if _is_npu: @@ -790,6 +846,15 @@ class HybridLinearKVPool(KVCache): def get_contiguous_buf_infos(self): return self.full_kv_pool.get_contiguous_buf_infos() + def get_state_buf_infos(self): + mamba_data_ptrs, mamba_data_lens, mamba_item_lens = ( + self.mamba_pool.get_contiguous_buf_infos() + ) + return mamba_data_ptrs, mamba_data_lens, mamba_item_lens + + def maybe_get_custom_mem_pool(self): + return self.full_kv_pool.maybe_get_custom_mem_pool() + def _transfer_full_attention_id(self, layer_id: int): if layer_id not in self.full_attention_layer_id_mapping: raise ValueError( @@ -841,22 +906,47 @@ class SWAKVPool(KVCache): size: int, size_swa: int, dtype: torch.dtype, + head_num: int, + head_dim: int, swa_attention_layer_ids: List[int], full_attention_layer_ids: List[int], enable_kvcache_transpose: bool, + device: str, token_to_kv_pool_class: KVCache = MHATokenToKVPool, **kwargs, ): self.size = size self.size_swa = size_swa self.dtype = dtype + self.head_num = head_num + self.head_dim = head_dim + self.device = device self.swa_layer_nums = len(swa_attention_layer_ids) self.full_layer_nums = len(full_attention_layer_ids) + self.start_layer = 0 + self.page_size = 1 + kwargs["page_size"] = 1 kwargs["enable_memory_saver"] = False + kwargs["head_num"] = head_num + kwargs["head_dim"] = head_dim + kwargs["device"] = device # TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True assert not enable_kvcache_transpose + # for disagg with nvlink + self.enable_custom_mem_pool = get_bool_env_var( + "SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false" + ) + if self.enable_custom_mem_pool: + # TODO(shangming): abstract custom allocator class for more backends + from mooncake.allocator import NVLinkAllocator + + allocator = NVLinkAllocator.get_allocator(self.device) + self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator()) + else: + self.custom_mem_pool = None + self.swa_kv_pool = token_to_kv_pool_class( size=size_swa, dtype=dtype, @@ -878,6 +968,9 @@ class SWAKVPool(KVCache): k_size, v_size = self.get_kv_size_bytes() self.mem_usage = (k_size + v_size) / GB + logger.info( + f"SWAKVPool mem usage: {self.mem_usage} GB, swa size: {self.size_swa}, full size: {self.size}" + ) def get_kv_size_bytes(self): k_size, v_size = self.full_kv_pool.get_kv_size_bytes() @@ -888,15 +981,19 @@ class SWAKVPool(KVCache): full_kv_data_ptrs, full_kv_data_lens, full_kv_item_lens = ( self.full_kv_pool.get_contiguous_buf_infos() ) + + kv_data_ptrs = full_kv_data_ptrs + kv_data_lens = full_kv_data_lens + kv_item_lens = full_kv_item_lens + + return kv_data_ptrs, kv_data_lens, kv_item_lens + + def get_state_buf_infos(self): swa_kv_data_ptrs, swa_kv_data_lens, swa_kv_item_lens = ( self.swa_kv_pool.get_contiguous_buf_infos() ) - kv_data_ptrs = full_kv_data_ptrs + swa_kv_data_ptrs - kv_data_lens = full_kv_data_lens + swa_kv_data_lens - kv_item_lens = full_kv_item_lens + swa_kv_item_lens - - return kv_data_ptrs, kv_data_lens, kv_item_lens + return swa_kv_data_ptrs, swa_kv_data_lens, swa_kv_item_lens def get_key_buffer(self, layer_id: int): layer_id_pool, is_swa = self.layers_mapping[layer_id] @@ -1152,19 +1249,6 @@ class MLATokenToKVPool(KVCache): else (kv_lora_rank + qk_rope_head_dim) ) - # for disagg with nvlink - self.enable_custom_mem_pool = get_bool_env_var( - "SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false" - ) - if self.enable_custom_mem_pool: - # TODO(shangming): abstract custom allocator class for more backends - from mooncake.allocator import NVLinkAllocator - - allocator = NVLinkAllocator.get_allocator(self.device) - self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator()) - else: - self.custom_mem_pool = None - with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): with ( torch.cuda.use_mem_pool(self.custom_mem_pool) @@ -1207,9 +1291,6 @@ class MLATokenToKVPool(KVCache): ] return kv_data_ptrs, kv_data_lens, kv_item_lens - def maybe_get_custom_mem_pool(self): - return self.custom_mem_pool - def get_key_buffer(self, layer_id: int): if self.layer_transfer_counter is not None: self.layer_transfer_counter.wait_until(layer_id - self.start_layer) @@ -1346,24 +1427,31 @@ class NSATokenToKVPool(MLATokenToKVPool): assert index_head_dim == 128 assert self.page_size == 64 - self.index_k_with_scale_buffer = [ - torch.zeros( - # Layout: - # ref: test_attention.py :: kv_cache_cast_to_fp8 - # shape: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4) - # data: for page i, - # * buf[i, :page_size * head_dim] for fp8 data - # * buf[i, page_size * head_dim:].view(float32) for scale - ( - (size + page_size + 1) // self.page_size, - self.page_size - * (index_head_dim + index_head_dim // self.quant_block_size * 4), - ), - dtype=self.index_k_with_scale_buffer_dtype, - device=device, - ) - for _ in range(layer_num) - ] + with ( + torch.cuda.use_mem_pool(self.custom_mem_pool) + if self.custom_mem_pool + else nullcontext() + ): + self.index_k_with_scale_buffer = [ + torch.zeros( + # Layout: + # ref: test_attention.py :: kv_cache_cast_to_fp8 + # shape: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4) + # data: for page i, + # * buf[i, :page_size * head_dim] for fp8 data + # * buf[i, page_size * head_dim:].view(float32) for scale + ( + (size + page_size + 1) // self.page_size, + self.page_size + * ( + index_head_dim + index_head_dim // self.quant_block_size * 4 + ), + ), + dtype=self.index_k_with_scale_buffer_dtype, + device=device, + ) + for _ in range(layer_num) + ] self._finalize_allocation_log(size) def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor: @@ -1406,6 +1494,18 @@ class NSATokenToKVPool(MLATokenToKVPool): pool=self, buf=buf, loc=loc, index_k=index_k, index_k_scale=index_k_scale ) + def get_state_buf_infos(self): + data_ptrs = [ + self.index_k_with_scale_buffer[i].data_ptr() for i in range(self.layer_num) + ] + data_lens = [ + self.index_k_with_scale_buffer[i].nbytes for i in range(self.layer_num) + ] + item_lens = [ + self.index_k_with_scale_buffer[i][0].nbytes for i in range(self.layer_num) + ] + return data_ptrs, data_lens, item_lens + def get_kv_size_bytes(self): kv_size_bytes = super().get_kv_size_bytes() for index_k_cache in self.index_k_with_scale_buffer: @@ -1636,27 +1736,38 @@ class DoubleSparseTokenToKVPool(KVCache): ) with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): - # [size, head_num, head_dim] for each layer - self.k_buffer = [ - torch.zeros( - (size + page_size, head_num, head_dim), dtype=dtype, device=device - ) - for _ in range(layer_num) - ] - self.v_buffer = [ - torch.zeros( - (size + page_size, head_num, head_dim), dtype=dtype, device=device - ) - for _ in range(layer_num) - ] + with ( + torch.cuda.use_mem_pool(self.custom_mem_pool) + if self.enable_custom_mem_pool + else nullcontext() + ): + # [size, head_num, head_dim] for each layer + self.k_buffer = [ + torch.zeros( + (size + page_size, head_num, head_dim), + dtype=dtype, + device=device, + ) + for _ in range(layer_num) + ] + self.v_buffer = [ + torch.zeros( + (size + page_size, head_num, head_dim), + dtype=dtype, + device=device, + ) + for _ in range(layer_num) + ] - # [size, head_num, heavy_channel_num] for each layer - self.label_buffer = [ - torch.zeros( - (size + 1, head_num, heavy_channel_num), dtype=dtype, device=device - ) - for _ in range(layer_num) - ] + # [size, head_num, heavy_channel_num] for each layer + self.label_buffer = [ + torch.zeros( + (size + 1, head_num, heavy_channel_num), + dtype=dtype, + device=device, + ) + for _ in range(layer_num) + ] def get_key_buffer(self, layer_id: int): return self.k_buffer[layer_id - self.start_layer] diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 4575d2091..4ef8bc993 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1669,19 +1669,34 @@ class ModelRunner: extra_max_context_len += self.server_args.speculative_num_draft_tokens if self.server_args.disaggregation_mode == "decode": - from sglang.srt.disaggregation.decode import DecodeReqToTokenPool + from sglang.srt.disaggregation.decode import ( + DecodeReqToTokenPool, + HybridMambaDecodeReqToTokenPool, + ) # subscribe memory for pre-allocated requests # if max_num_reqs <= 32, we pre-allocate 2x requests pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0 - self.req_to_token_pool = DecodeReqToTokenPool( - size=max_num_reqs, - max_context_len=self.model_config.context_len - + extra_max_context_len, - device=self.device, - enable_memory_saver=self.server_args.enable_memory_saver, - pre_alloc_size=pre_alloc_size, - ) + if config := self.mambaish_config: + self.req_to_token_pool = HybridMambaDecodeReqToTokenPool( + size=max_num_reqs, + max_context_len=self.model_config.context_len + + extra_max_context_len, + device=self.device, + enable_memory_saver=self.server_args.enable_memory_saver, + cache_params=config.mamba2_cache_params, + speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens, + pre_alloc_size=pre_alloc_size, + ) + else: + self.req_to_token_pool = DecodeReqToTokenPool( + size=max_num_reqs, + max_context_len=self.model_config.context_len + + extra_max_context_len, + device=self.device, + enable_memory_saver=self.server_args.enable_memory_saver, + pre_alloc_size=pre_alloc_size, + ) elif config := self.mambaish_config: self.req_to_token_pool = HybridReqToTokenPool( size=max_num_reqs, @@ -1807,6 +1822,7 @@ class ModelRunner: ), enable_kvcache_transpose=False, device=self.device, + mamba_pool=self.req_to_token_pool.mamba_pool, ) else: self.token_to_kv_pool = MHATokenToKVPool( diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 87903ef47..781803c1b 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -163,6 +163,7 @@ suites = { TestFile("test_deepseek_v3_basic.py", 275), TestFile("test_deepseek_v3_mtp.py", 275), TestFile("test_disaggregation_different_tp.py", 600), + TestFile("test_disaggregation_hybrid_attention.py", 200), TestFile("test_disaggregation_pp.py", 140), ], "per-commit-4-gpu-b200": [ diff --git a/test/srt/test_disaggregation_hybrid_attention.py b/test/srt/test_disaggregation_hybrid_attention.py new file mode 100644 index 000000000..34ed29c72 --- /dev/null +++ b/test/srt/test_disaggregation_hybrid_attention.py @@ -0,0 +1,83 @@ +import os +import unittest +from types import SimpleNamespace + +from sglang.srt.environ import envs +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_disaggregation_utils import TestDisaggregationBase +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + popen_launch_pd_server, +) + + +class TestDisaggregationHybridAttentionMamba(TestDisaggregationBase): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.model = "Qwen/Qwen3-Next-80B-A3B-Instruct" + + # Non blocking start servers + cls.start_prefill() + cls.start_decode() + + # Block until both + cls.wait_server_ready(cls.prefill_url + "/health") + cls.wait_server_ready(cls.decode_url + "/health") + + cls.launch_lb() + + @classmethod + def start_prefill(cls): + prefill_args = [ + "--trust-remote-code", + "--disaggregation-mode", + "prefill", + "--tp", + "4", + ] + prefill_args += cls.transfer_backend + cls.rdma_devices + cls.process_prefill = popen_launch_pd_server( + cls.model, + cls.prefill_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=prefill_args, + ) + + @classmethod + def start_decode(cls): + decode_args = [ + "--trust-remote-code", + "--disaggregation-mode", + "decode", + "--tp", + "4", + "--base-gpu-id", + "4", + ] + decode_args += cls.transfer_backend + cls.rdma_devices + cls.process_decode = popen_launch_pd_server( + cls.model, + cls.decode_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=decode_args, + ) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host=f"http://{self.base_host}", + port=int(self.lb_port), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(f"Evaluation metrics: {metrics}") + + self.assertGreater(metrics["accuracy"], 0.93) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_mamba_unittest.py b/test/srt/test_mamba_unittest.py index 401eb584f..7bbca75e1 100644 --- a/test/srt/test_mamba_unittest.py +++ b/test/srt/test_mamba_unittest.py @@ -42,6 +42,7 @@ class TestMamba(unittest.TestCase): full_attention_layer_ids=full_attention_layer_ids, enable_kvcache_transpose=False, device=device, + mamba_pool=None, ) assert pool._transfer_full_attention_id(global_interval - 1) == 0 assert pool._transfer_full_attention_id(2 * global_interval - 1) == 1 @@ -173,6 +174,7 @@ class TestMamba(unittest.TestCase): full_attention_layer_ids=full_attention_layer_ids, enable_kvcache_transpose=False, device=device, + mamba_pool=req_to_token_pool.mamba_pool, ) # setup token to kv pool allocator