diff --git a/docs/references/llama4.md b/docs/references/llama4.md index 1380510b8..b09a6e240 100644 --- a/docs/references/llama4.md +++ b/docs/references/llama4.md @@ -16,10 +16,11 @@ python3 -m sglang.launch_server --model-path meta-llama/Llama-4-Scout-17B-16E-In ### Configuration Tips -- **OOM Mitigation**: Adjust `--context-length` to avoid a GPU out-of-memory issue. For the Scout model, we recommend setting this value up to 1M on 8\*H100 and up to 2.5M on 8\*H200. For the Maverick model, we don't need to set context length on 8\*H200. +- **OOM Mitigation**: Adjust `--context-length` to avoid a GPU out-of-memory issue. For the Scout model, we recommend setting this value up to 1M on 8\*H100 and up to 2.5M on 8\*H200. For the Maverick model, we don't need to set context length on 8\*H200. When hybrid kv cache is enabled, `--context-length` can be set up to 5M on 8\*H100 and up to 10M on 8\*H200 for the Scout model. - **Chat Template**: Add `--chat-template llama-4` for chat completion tasks. - **Enable Multi-Modal**: Add `--enable-multimodal` for multi-modal capabilities. +- **Enable Hybrid-KVCache**: Add `--hybrid-kvcache-ratio` for hybrid kv cache. Details can be seen in [this PR](https://github.com/sgl-project/sglang/pull/6563) ## Benchmarking Results diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 63412791d..6f202db6f 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -59,6 +59,7 @@ class ModelConfig: quantization: Optional[str] = None, override_config_file: Optional[str] = None, is_draft_model: bool = False, + hybrid_kvcache_ratio: Optional[float] = None, impl: Union[str, ModelImpl] = ModelImpl.AUTO, ) -> None: @@ -86,6 +87,18 @@ class ModelConfig: self.attention_chunk_size = getattr( self.hf_text_config, "attention_chunk_size", None ) + self.is_hybrid = is_hybrid_model( + self.hf_config.architectures, + hybrid_kvcache_ratio=hybrid_kvcache_ratio, + context_length=context_length, + attention_chunk_size=self.attention_chunk_size, + ) + if self.is_hybrid is not None: + self.swa_attention_layer_ids, self.full_attention_layer_ids = ( + get_hybrid_layer_ids( + self.hf_config.architectures, self.hf_text_config.num_hidden_layers + ) + ) if enable_multimodal is None: mm_disabled_models = [ @@ -264,6 +277,7 @@ class ModelConfig: enable_multimodal=server_args.enable_multimodal, dtype=server_args.dtype, quantization=server_args.quantization, + hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio, impl=server_args.impl, **kwargs, ) @@ -633,3 +647,36 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: if scale <= 1: return 1.0 return 0.1 * mscale * math.log(scale) + 1.0 + + +def is_hybrid_model( + model_architectures: List[str], + hybrid_kvcache_ratio: Optional[float], + context_length: Optional[int], + attention_chunk_size: Optional[int], +): + if hybrid_kvcache_ratio is None: + return None + elif ( + hybrid_kvcache_ratio > 0 + and model_architectures[0] == "Llama4ForConditionalGeneration" + and context_length > attention_chunk_size + ): + return hybrid_kvcache_ratio + else: + return None + + +def get_hybrid_layer_ids(model_architectures: List[str], num_hidden_layers: int): + if "Llama4ForConditionalGeneration" in model_architectures: + swa_attention_layer_ids = [ + i for i in range(num_hidden_layers) if (i + 1) % 4 != 0 + ] + full_attention_layer_ids = [ + i for i in range(num_hidden_layers) if (i + 1) % 4 == 0 + ] + else: + raise ValueError( + "get_hybrid_layer_ids is only implemented for Llama4ForConditionalGeneration" + ) + return swa_attention_layer_ids, full_attention_layer_ids diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index a71631596..576780ebf 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -433,9 +433,7 @@ class DecodePreallocQueue: else 0 ) - available_size = self.token_to_kv_pool_allocator.available_size() - - allocatable_tokens = available_size - max( + allocatable_tokens = self.token_to_kv_pool_allocator.available_size() - max( # preserve some space for future decode self.num_reserved_decode_tokens * ( diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 85899636e..b0615be3c 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -9,6 +9,7 @@ import torch from sglang.srt.configs.model_config import AttentionArch from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.mem_cache.memory_pool import SWAKVPool from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput @@ -320,6 +321,11 @@ class FlashAttentionBackend(AttentionBackend): self.page_size = model_runner.page_size self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA self.skip_prefill = skip_prefill + self.is_hybrid = model_runner.is_hybrid + if self.is_hybrid: + self.full_to_swa_index_mapping = ( + model_runner.token_to_kv_pool.full_to_swa_index_mapping + ) self.topk = model_runner.server_args.speculative_eagle_topk or 0 self.speculative_num_steps = speculative_num_steps self.speculative_num_draft_tokens = ( @@ -428,7 +434,7 @@ class FlashAttentionBackend(AttentionBackend): forward_batch.req_pool_indices, : metadata.max_seq_len_k ] # TODO: we need to test this part for llama 4 eagle case - self._init_local_attn_metadata(metadata, device) + self._init_local_attn_metadata(forward_batch, metadata, device) elif forward_batch.forward_mode.is_target_verify(): if self.topk <= 1: metadata.cache_seqlens_int32 = ( @@ -456,7 +462,7 @@ class FlashAttentionBackend(AttentionBackend): forward_batch.req_pool_indices, : metadata.max_seq_len_k ] - self._init_local_attn_metadata(metadata, device) + self._init_local_attn_metadata(forward_batch, metadata, device) else: metadata.cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32) metadata.max_seq_len_q = self.speculative_num_draft_tokens @@ -575,7 +581,7 @@ class FlashAttentionBackend(AttentionBackend): # Setup local attention if enabled if forward_batch.forward_mode == ForwardMode.EXTEND: - self._init_local_attn_metadata(metadata, device) + self._init_local_attn_metadata(forward_batch, metadata, device) # Encoder metadata for cross attention if forward_batch.encoder_lens is not None: @@ -1588,7 +1594,7 @@ class FlashAttentionBackend(AttentionBackend): forward_mode: ForwardMode, spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], seq_lens_cpu: Optional[torch.Tensor], - out_cache_loc: torch.Tensor = None, + out_cache_loc: Optional[torch.Tensor] = None, ): """Initialize forward metadata for replaying CUDA graph.""" seq_lens = seq_lens[:bs] @@ -1673,7 +1679,10 @@ class FlashAttentionBackend(AttentionBackend): self.page_size, ) - self._update_local_attn_metadata_for_replay(metadata, bs) + self._update_local_attn_metadata_for_replay( + metadata, + bs, + ) elif forward_mode.is_target_verify(): if self.topk <= 1: metadata = self.target_verify_metadata[bs] @@ -1829,7 +1838,9 @@ class FlashAttentionBackend(AttentionBackend): """Get the fill value for sequence length in CUDA graph.""" return 1 - def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device): + def _init_local_attn_metadata( + self, forwardbatch: ForwardBatch, metadata: FlashAttentionMetadata, device + ): """Centralized utility to initialize local_attn_metadata if chunked attention is enabled.""" if self.attention_chunk_size is None: metadata.local_attn_metadata = None @@ -1837,7 +1848,12 @@ class FlashAttentionBackend(AttentionBackend): cu_seqlens_q = metadata.cu_seqlens_q cache_seqlens_int32 = metadata.cache_seqlens_int32 - page_table = metadata.page_table + if self.is_hybrid: + page_table = self.full_to_swa_index_mapping[metadata.page_table].to( + torch.int32 + ) + else: + page_table = metadata.page_table if cu_seqlens_q is None or cache_seqlens_int32 is None or page_table is None: metadata.local_attn_metadata = None return @@ -1923,7 +1939,9 @@ class FlashAttentionBackend(AttentionBackend): ) def _update_local_attn_metadata_for_replay( - self, metadata: FlashAttentionMetadata, bs: int + self, + metadata: FlashAttentionMetadata, + bs: int, ): """Update preallocated local attention metadata in-place before CUDA graph replay.""" if self.attention_chunk_size is None: @@ -1954,7 +1972,12 @@ class FlashAttentionBackend(AttentionBackend): # Without this slicing, the pre-allocated page_table may contain zeros or invalid indices # beyond the actual sequence length, leading to incorrect attention calculations max_seq_len = int(seqlens.max().item()) - sliced_page_table = metadata.page_table[:bs, :max_seq_len] + if self.is_hybrid: + sliced_page_table = self.full_to_swa_index_mapping[ + metadata.page_table[:bs, :max_seq_len] + ].to(torch.int32) + else: + sliced_page_table = metadata.page_table[:bs, :max_seq_len] cu_seqlens_q_np = cu_seqlens_q.cpu().numpy() seqlens_np = seqlens.cpu().numpy() diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 85e8500dc..6728f8852 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -56,7 +56,7 @@ from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank from sglang.srt.layers.multimodal import gpu_tensor_hash from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache -from sglang.srt.mem_cache.chunk_cache import ChunkCache +from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.metrics.collector import TimeStats from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode @@ -485,6 +485,9 @@ class Req: # for corss-endoder model self.token_type_ids = token_type_ids + # The length of KV that have been removed in local attention chunked prefill + self.evicted_seqlen_local = 0 + # Sampling info if isinstance(sampling_params.custom_params, dict): sampling_params = copy.copy(sampling_params) @@ -1191,6 +1194,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.req_to_token_pool.write( (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices ) + if isinstance(self.tree_cache, SWAChunkCache): + self.tree_cache.evict( + req, pre_len, self.model_config.attention_chunk_size + ) # If input_embeds are available, store them if req.input_embeds is not None: @@ -1383,7 +1390,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): * buf_multiplier * self.token_to_kv_pool_allocator.page_size ) - if self.token_to_kv_pool_allocator.available_size() >= tokens_required: return True @@ -1564,6 +1570,13 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.seq_lens.add_(1) self.seq_lens_sum += bs + # free memory + if isinstance(self.tree_cache, SWAChunkCache): + for req in self.reqs: + self.tree_cache.evict( + req, req.seqlen - 1, self.model_config.attention_chunk_size + ) + # Allocate memory if self.token_to_kv_pool_allocator.page_size == 1: self.out_cache_loc = self.alloc_token_slots(bs) @@ -1798,7 +1811,6 @@ class ModelWorkerBatch: seq_lens: torch.Tensor # The indices of output tokens in the token_to_kv_pool_allocator out_cache_loc: torch.Tensor - # The sequence length tensor on CPU seq_lens_cpu: Optional[torch.Tensor] seq_lens_sum: int diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 3377df6b1..692d4673d 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -126,7 +126,8 @@ from sglang.srt.managers.session_controller import Session from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient from sglang.srt.managers.utils import validate_input_length -from sglang.srt.mem_cache.chunk_cache import ChunkCache +from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator +from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache from sglang.srt.mem_cache.hiradix_cache import HiRadixCache from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats @@ -570,7 +571,11 @@ class Scheduler( server_args.chunked_prefill_size is not None and server_args.disable_radix_cache ): - self.tree_cache = ChunkCache( + if self.model_config.is_hybrid: + ChunkCacheClass = SWAChunkCache + else: + ChunkCacheClass = ChunkCache + self.tree_cache = ChunkCacheClass( req_to_token_pool=self.req_to_token_pool, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, page_size=self.page_size, @@ -1283,9 +1288,8 @@ class Scheduler( self.last_input_throughput = self.last_prefill_tokens / gap_latency self.last_prefill_tokens = adder.log_input_tokens - num_used = self.max_total_num_tokens - ( - self.token_to_kv_pool_allocator.available_size() - + self.tree_cache.evictable_size() + usage_msg, num_used = self.token_to_kv_pool_allocator.log_usage( + self.tree_cache.evictable_size() ) num_new_seq = len(can_run_list) @@ -1294,7 +1298,7 @@ class Scheduler( f"#new-seq: {num_new_seq}, " f"#new-token: {adder.log_input_tokens}, " f"#cached-token: {adder.log_hit_tokens}, " - f"token usage: {num_used / self.max_total_num_tokens:.2f}, " + f"{usage_msg}" ) if self.disaggregation_mode == DisaggregationMode.PREFILL: @@ -1337,9 +1341,8 @@ class Scheduler( self.last_gen_throughput = self.num_generated_tokens / gap_latency self.num_generated_tokens = 0 num_running_reqs = len(batch.reqs) - num_used = self.max_total_num_tokens - ( - self.token_to_kv_pool_allocator.available_size() - + self.tree_cache.evictable_size() + usage_msg, num_used = self.token_to_kv_pool_allocator.log_usage( + self.tree_cache.evictable_size() ) if RECORD_STEP_TIME: @@ -1347,12 +1350,7 @@ class Scheduler( gap_latency / self.server_args.decode_log_interval ) - msg = ( - f"Decode batch. " - f"#running-req: {num_running_reqs}, " - f"#token: {num_used}, " - f"token usage: {num_used / self.max_total_num_tokens:.2f}, " - ) + msg = f"Decode batch. " f"#running-req: {num_running_reqs}, " f"{usage_msg}" if self.spec_algorithm.is_none(): spec_accept_length = 0 @@ -1390,10 +1388,11 @@ class Scheduler( self._publish_kv_events() def check_memory(self): - available_size = ( - self.token_to_kv_pool_allocator.available_size() - + self.tree_cache.evictable_size() - ) + if isinstance(self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator): + available_token_size = self.token_to_kv_pool_allocator.full_available_size() + else: + available_token_size = self.token_to_kv_pool_allocator.available_size() + available_size = available_token_size + self.tree_cache.evictable_size() protected_size = self.tree_cache.protected_size() memory_leak = available_size != ( self.max_total_num_tokens @@ -1404,7 +1403,7 @@ class Scheduler( msg = ( "token_to_kv_pool_allocator memory leak detected! " f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n" - f"{self.token_to_kv_pool_allocator.available_size()=}\n" + f"{available_token_size=}\n" f"{self.tree_cache.evictable_size()=}\n" ) raise ValueError(msg) diff --git a/python/sglang/srt/mem_cache/allocator.py b/python/sglang/srt/mem_cache/allocator.py index 61f1c842b..6bcabf648 100644 --- a/python/sglang/srt/mem_cache/allocator.py +++ b/python/sglang/srt/mem_cache/allocator.py @@ -20,12 +20,14 @@ Page-aligned memory pool. """ import abc +import weakref from typing import TYPE_CHECKING import torch import triton import triton.language as tl +from sglang.srt.mem_cache.memory_pool import SWAKVPool from sglang.srt.utils import get_bool_env_var, next_power_of_2 if TYPE_CHECKING: @@ -55,6 +57,11 @@ class BaseTokenToKVPoolAllocator(abc.ABC): def debug_print(self) -> str: return "" + def log_usage(self, evictable_size: int = 0): + num_used = self.size - (self.available_size() + evictable_size) + msg = f"#token: {num_used}, token usage: {num_used / self.size:.2f}, " + return msg, num_used + def available_size(self): return len(self.free_pages) * self.page_size @@ -146,6 +153,128 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): return self._kvcache.load_cpu_copy(kv_cache_cpu, indices) +class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): + """Allocator for SWA hybrid KV cache.""" + + def __init__( + self, + size: int, + size_swa: int, + dtype: torch.dtype, + device: str, + kvcache: SWAKVPool, + ): + super().__init__(size, 1, dtype, device, kvcache) + assert isinstance(kvcache, SWAKVPool) + self._size_full = size + self._size_swa = size_swa + self.full_attn_allocator = TokenToKVPoolAllocator( + size, + dtype, + device, + kvcache.full_kv_pool, + ) + self.swa_attn_allocator = TokenToKVPoolAllocator( + size_swa, + dtype, + device, + kvcache.swa_kv_pool, + ) + self.full_to_swa_index_mapping = torch.empty( + size + size_swa + 1, + dtype=torch.int64, + device=device, + ) + self.clear() + + self._kvcache.full_to_swa_index_mapping = self.full_to_swa_index_mapping + + def available_size(self): + return min(self.full_available_size(), self.swa_available_size()) + + def full_available_size(self): + return self.full_attn_allocator.available_size() + + def swa_available_size(self): + return self.swa_attn_allocator.available_size() + + @property + def size_full(self): + return self._size_full + + @property + def size_swa(self): + return self._size_swa + + def debug_print(self) -> str: + msg = "" + msg += f"#swa-available-size: {self.swa_attn_allocator.available_size()}, " + msg += ( + f"#full-attn-available-size: {self.full_attn_allocator.available_size()}, " + ) + return msg + + def log_usage(self, swa_evictable_size: int = 0, full_evictable_size: int = 0): + used_full = self.size_full - (self.full_available_size() + full_evictable_size) + used_swa = self.size_swa - (self.swa_available_size() + swa_evictable_size) + msg = ( + f"#token: full={used_full}, swa={used_swa}, " + f"token usage: full={used_full / self.size_full:.2f}, " + f"swa={used_swa / self.size_swa:.2f}, " + ) + return msg, used_full + + def get_kvcache(self): + return self._kvcache + + def translate_loc_from_full_to_swa(self, kv_indices: torch.Tensor): + assert self.full_to_swa_index_mapping is not None + return self.full_to_swa_index_mapping[kv_indices].to(torch.int32) + + def alloc(self, need_size: int): + if need_size > self.full_attn_allocator.available_size(): + return None + if need_size > self.swa_attn_allocator.available_size(): + return None + + alloc_full_indices = self.full_attn_allocator.alloc(need_size) + alloc_swa_indices = self.swa_attn_allocator.alloc(need_size) + self.full_to_swa_index_mapping[alloc_full_indices] = alloc_swa_indices + return alloc_full_indices + + def free(self, free_index: torch.Tensor): + if free_index.numel() == 0: + return + if self.is_not_in_free_group: + self.full_attn_allocator.free(free_index) + self.free_swa(free_index) + else: + self.free_group.append(free_index) + assert ( + self.full_attn_allocator.available_size() <= self.full_attn_allocator.size + ) + assert self.swa_attn_allocator.available_size() <= self.swa_attn_allocator.size + + def free_swa(self, free_index: torch.Tensor): + swa_indices = self.full_to_swa_index_mapping[free_index] + swa_indices = swa_indices[swa_indices > 0] + self.swa_attn_allocator.free(swa_indices) + self.full_to_swa_index_mapping[free_index] = 0 + + def backup_state(self): + raise NotImplementedError + + def restore_state(self, state): + raise NotImplementedError + + def clear(self): + self.swa_attn_allocator.clear() + self.full_attn_allocator.clear() + self.full_to_swa_index_mapping.fill_(0) + self.is_in_free_group = False + self.free_group = [] + + @triton.jit def alloc_extend_kernel( pre_lens_ptr, diff --git a/python/sglang/srt/mem_cache/chunk_cache.py b/python/sglang/srt/mem_cache/chunk_cache.py index 68a993b51..a1e58aa3a 100644 --- a/python/sglang/srt/mem_cache/chunk_cache.py +++ b/python/sglang/srt/mem_cache/chunk_cache.py @@ -2,11 +2,14 @@ from __future__ import annotations """Cache for chunked prefill, used when RadixCache is disabled.""" -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple import torch -from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator +from sglang.srt.mem_cache.allocator import ( + BaseTokenToKVPoolAllocator, + SWATokenToKVPoolAllocator, +) from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult from sglang.srt.mem_cache.memory_pool import ReqToTokenPool @@ -63,3 +66,32 @@ class ChunkCache(BasePrefixCache): def pretty_print(self): return "" + + +class SWAChunkCache(ChunkCache): + """ChunkCache with support for hybrid KV cache operations.""" + + def __init__( + self, + req_to_token_pool: ReqToTokenPool, + token_to_kv_pool_allocator: SWATokenToKVPoolAllocator, + page_size: int, + ): + super().__init__(req_to_token_pool, token_to_kv_pool_allocator, page_size) + assert isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator) + + def evict( + self, + req: Req, + prelen: int, + attention_chunk_size: int, + ): + if prelen >= req.evicted_seqlen_local + attention_chunk_size: + new_evicted_seqlen_local = attention_chunk_size * ( + prelen // attention_chunk_size + ) + free_slots = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, req.evicted_seqlen_local : new_evicted_seqlen_local + ] + self.token_to_kv_pool_allocator.free_swa(free_slots) + req.evicted_seqlen_local = new_evicted_seqlen_local diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index c7580c622..14eef2043 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -27,10 +27,11 @@ KVCache actually holds the physical kv cache. import abc import logging from contextlib import nullcontext -from typing import List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import numpy as np import torch +import torch.distributed as dist import triton import triton.language as tl @@ -66,6 +67,7 @@ class ReqToTokenPool: self.req_to_token = torch.zeros( (size, max_context_len), dtype=torch.int32, device=device ) + self.free_slots = list(range(size)) def write(self, indices, values): @@ -191,7 +193,6 @@ class MHATokenToKVPool(KVCache): start_layer, end_layer, ) - self.head_num = head_num self.head_dim = head_dim @@ -392,10 +393,14 @@ class MHATokenToKVPool(KVCache): cache_v: torch.Tensor, k_scale: Optional[float] = None, v_scale: Optional[float] = None, + layer_id_override: Optional[int] = None, ): from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode - layer_id = layer.layer_id + if layer_id_override is not None: + layer_id = layer_id_override + else: + layer_id = layer.layer_id if cache_k.dtype != self.dtype: if k_scale is not None: cache_k.div_(k_scale) @@ -431,6 +436,136 @@ class MHATokenToKVPool(KVCache): ) +class SWAKVPool(KVCache): + """KV cache with separate pools for full and SWA attention layers.""" + + def __init__( + self, + size: int, + size_swa: int, + dtype: torch.dtype, + head_num: int, + head_dim: int, + swa_attention_layer_ids: List[int], + full_attention_layer_ids: List[int], + enable_kvcache_transpose: bool, + device: str, + ): + self.size = size + self.size_swa = size_swa + self.dtype = dtype + self.device = device + self.swa_layer_nums = len(swa_attention_layer_ids) + self.full_layer_nums = len(full_attention_layer_ids) + self.page_size = 1 + # TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True + assert not enable_kvcache_transpose + TokenToKVPoolClass = MHATokenToKVPool + self.swa_kv_pool = TokenToKVPoolClass( + size=size_swa, + page_size=self.page_size, + dtype=dtype, + head_num=head_num, + head_dim=head_dim, + layer_num=self.swa_layer_nums, + device=device, + enable_memory_saver=False, + ) + self.full_kv_pool = TokenToKVPoolClass( + size=size, + page_size=self.page_size, + dtype=dtype, + head_num=head_num, + head_dim=head_dim, + layer_num=self.full_layer_nums, + device=device, + enable_memory_saver=False, + ) + self.layers_mapping: Dict[int, Tuple[int, bool]] = {} + for full_attn_layer_id, global_layer_id in enumerate(full_attention_layer_ids): + self.layers_mapping[global_layer_id] = (full_attn_layer_id, False) + for swa_layer_id, global_layer_id in enumerate(swa_attention_layer_ids): + self.layers_mapping[global_layer_id] = (swa_layer_id, True) + self.full_to_swa_index_mapping: Optional[torch.Tensor] = None + + def get_kv_size_bytes(self): + raise NotImplementedError + + def get_contiguous_buf_infos(self): + full_kv_data_ptrs, full_kv_data_lens, full_kv_item_lens = ( + self.full_kv_pool.get_contiguous_buf_infos() + ) + swa_kv_data_ptrs, swa_kv_data_lens, swa_kv_item_lens = ( + self.swa_kv_pool.get_contiguous_buf_infos() + ) + + kv_data_ptrs = full_kv_data_ptrs + swa_kv_data_ptrs + kv_data_lens = full_kv_data_lens + swa_kv_data_lens + kv_item_lens = full_kv_item_lens + swa_kv_item_lens + + return kv_data_ptrs, kv_data_lens, kv_item_lens + + def get_key_buffer(self, layer_id: int): + layer_id_pool, is_swa = self.layers_mapping[layer_id] + if is_swa: + return self.swa_kv_pool.get_key_buffer(layer_id_pool) + else: + return self.full_kv_pool.get_key_buffer(layer_id_pool) + + def get_value_buffer(self, layer_id: int): + layer_id_pool, is_swa = self.layers_mapping[layer_id] + if is_swa: + return self.swa_kv_pool.get_value_buffer(layer_id_pool) + else: + return self.full_kv_pool.get_value_buffer(layer_id_pool) + + def get_kv_buffer(self, layer_id: int): + layer_id_pool, is_swa = self.layers_mapping[layer_id] + if is_swa: + return self.swa_kv_pool.get_kv_buffer(layer_id_pool) + else: + return self.full_kv_pool.get_kv_buffer(layer_id_pool) + + def translate_loc_from_full_to_swa(self, kv_indices: torch.Tensor): + assert self.full_to_swa_index_mapping is not None + return self.full_to_swa_index_mapping[kv_indices].to(torch.int32) + + def set_kv_buffer( + self, + layer: RadixAttention, + loc: torch.Tensor, + cache_k: torch.Tensor, + cache_v: torch.Tensor, + k_scale: float = 1.0, + v_scale: float = 1.0, + ): + + layer_id = layer.layer_id + layer_id_pool, is_swa = self.layers_mapping[layer_id] + if is_swa: + if self.full_to_swa_index_mapping is not None: + loc = self.translate_loc_from_full_to_swa(loc) + self.swa_kv_pool.set_kv_buffer( + None, + loc, + cache_k, + cache_v, + k_scale, + v_scale, + layer_id_override=layer_id_pool, + ) + else: + self.full_kv_pool.set_kv_buffer( + None, + loc, + cache_k, + cache_v, + k_scale, + v_scale, + layer_id_override=layer_id_pool, + ) + + @triton.jit def set_mla_kv_buffer_kernel( kv_buffer_ptr, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 0c9185bb9..9ac26810e 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -74,6 +74,7 @@ from sglang.srt.managers.schedule_batch import ( from sglang.srt.mem_cache.allocator import ( BaseTokenToKVPoolAllocator, PagedTokenToKVPoolAllocator, + SWATokenToKVPoolAllocator, TokenToKVPoolAllocator, ) from sglang.srt.mem_cache.memory_pool import ( @@ -81,6 +82,7 @@ from sglang.srt.mem_cache.memory_pool import ( MHATokenToKVPool, MLATokenToKVPool, ReqToTokenPool, + SWAKVPool, ) from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner from sglang.srt.model_executor.expert_location_updater import ExpertLocationUpdater @@ -185,6 +187,7 @@ class ModelRunner: self.page_size = server_args.page_size self.req_to_token_pool = req_to_token_pool self.token_to_kv_pool_allocator = token_to_kv_pool_allocator + self.is_hybrid = model_config.is_hybrid self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA self.attention_chunk_size = model_config.attention_chunk_size @@ -437,6 +440,10 @@ class ModelRunner: if self.model_config.context_len > 8192: self.mem_fraction_static *= 0.85 + if self.is_hybrid and not server_args.disable_radix_cache: + logger.info("Automatically disable radix cache for hybrid cache.") + server_args.disable_radix_cache = True + def init_torch_distributed(self): logger.info("Init torch distributed begin.") @@ -852,6 +859,40 @@ class ModelRunner: max_num_token = int(rest_memory * (1 << 30) // cell_size) return max_num_token + def set_num_token_hybrid(self): + if ( + "Llama4ForConditionalGeneration" + in self.model_config.hf_config.architectures + ): + temp_ratio = ( + (1 - self.is_hybrid) + + self.is_hybrid + * self.attention_chunk_size + / self.model_config.context_len + ) + self.swa_max_total_num_tokens = ( + 4 * self.max_total_num_tokens * temp_ratio // (3 * temp_ratio + 1) + ) + self.full_max_total_num_tokens = ( + 4 * self.max_total_num_tokens + - 12 * self.max_total_num_tokens * temp_ratio // (3 * temp_ratio + 1) + ) + self.swa_max_total_num_tokens = int( + self.swa_max_total_num_tokens + // self.server_args.page_size + * self.server_args.page_size + ) + self.full_max_total_num_tokens = int( + self.full_max_total_num_tokens + // self.server_args.page_size + * self.server_args.page_size + ) + self.max_total_num_tokens = self.full_max_total_num_tokens + else: + raise ValueError( + f"Unsupported model for hybrid cache: {self.model_config.hf_config.architectures}." + ) + def init_memory_pool( self, total_gpu_memory: int, @@ -929,6 +970,10 @@ class ModelRunner: * self.server_args.page_size ) + # create token size for hybrid cache + if self.is_hybrid: + self.set_num_token_hybrid() + if self.max_total_num_tokens <= 0: raise RuntimeError( "Not enough memory. Please try to increase --mem-fraction-static." @@ -991,27 +1036,53 @@ class ModelRunner: end_layer=self.end_layer, ) else: - self.token_to_kv_pool = MHATokenToKVPool( - self.max_total_num_tokens, - page_size=self.page_size, - dtype=self.kv_cache_dtype, - head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()), - head_dim=self.model_config.head_dim, - layer_num=self.num_effective_layers, - device=self.device, - enable_memory_saver=self.server_args.enable_memory_saver, - start_layer=self.start_layer, - end_layer=self.end_layer, - ) + if self.is_hybrid: + self.token_to_kv_pool = SWAKVPool( + size=self.full_max_total_num_tokens, + size_swa=self.swa_max_total_num_tokens, + dtype=self.kv_cache_dtype, + head_num=self.model_config.get_num_kv_heads( + get_attention_tp_size() + ), + head_dim=self.model_config.head_dim, + swa_attention_layer_ids=self.model_config.swa_attention_layer_ids, + full_attention_layer_ids=self.model_config.full_attention_layer_ids, + enable_kvcache_transpose=False, + device=self.device, + ) + else: + self.token_to_kv_pool = MHATokenToKVPool( + self.max_total_num_tokens, + page_size=self.page_size, + dtype=self.kv_cache_dtype, + head_num=self.model_config.get_num_kv_heads( + get_attention_tp_size() + ), + head_dim=self.model_config.head_dim, + layer_num=self.num_effective_layers, + device=self.device, + enable_memory_saver=self.server_args.enable_memory_saver, + start_layer=self.start_layer, + end_layer=self.end_layer, + ) if self.token_to_kv_pool_allocator is None: if self.page_size == 1: - self.token_to_kv_pool_allocator = TokenToKVPoolAllocator( - self.max_total_num_tokens, - dtype=self.kv_cache_dtype, - device=self.device, - kvcache=self.token_to_kv_pool, - ) + if self.is_hybrid: + self.token_to_kv_pool_allocator = SWATokenToKVPoolAllocator( + self.full_max_total_num_tokens, + self.swa_max_total_num_tokens, + dtype=self.kv_cache_dtype, + device=self.device, + kvcache=self.token_to_kv_pool, + ) + else: + self.token_to_kv_pool_allocator = TokenToKVPoolAllocator( + self.max_total_num_tokens, + dtype=self.kv_cache_dtype, + device=self.device, + kvcache=self.token_to_kv_pool, + ) else: self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator( self.max_total_num_tokens, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 978044daa..835e6e888 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -61,6 +61,7 @@ class ServerArgs: is_embedding: bool = False enable_multimodal: Optional[bool] = None revision: Optional[str] = None + hybrid_kvcache_ratio: Optional[float] = None impl: str = "auto" # Port for the HTTP server @@ -817,6 +818,18 @@ class ServerArgs: default=ServerArgs.page_size, help="The number of tokens in a page.", ) + parser.add_argument( + "--hybrid-kvcache-ratio", + nargs="?", + const=0.5, + type=float, + default=ServerArgs.hybrid_kvcache_ratio, + help=( + "Mix ratio in [0,1] between uniform and hybrid kv buffers " + "(0.0 = pure uniform: swa_size / full_size = 1)" + "(1.0 = pure hybrid: swa_size / full_size = local_attention_size / context_length)" + ), + ) # Other runtime options parser.add_argument(