Hybrid kv cache for LLaMA4 (#6563)
Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Co-authored-by: tarinkk <rt572@physics.rutger.edu> Co-authored-by: tarinkk <rt572@rutgers.physics.edu> Co-authored-by: Hanming Lu <69857889+hanming-lu@users.noreply.github.com>
This commit is contained in:
@@ -16,10 +16,11 @@ python3 -m sglang.launch_server --model-path meta-llama/Llama-4-Scout-17B-16E-In
|
||||
|
||||
### Configuration Tips
|
||||
|
||||
- **OOM Mitigation**: Adjust `--context-length` to avoid a GPU out-of-memory issue. For the Scout model, we recommend setting this value up to 1M on 8\*H100 and up to 2.5M on 8\*H200. For the Maverick model, we don't need to set context length on 8\*H200.
|
||||
- **OOM Mitigation**: Adjust `--context-length` to avoid a GPU out-of-memory issue. For the Scout model, we recommend setting this value up to 1M on 8\*H100 and up to 2.5M on 8\*H200. For the Maverick model, we don't need to set context length on 8\*H200. When hybrid kv cache is enabled, `--context-length` can be set up to 5M on 8\*H100 and up to 10M on 8\*H200 for the Scout model.
|
||||
|
||||
- **Chat Template**: Add `--chat-template llama-4` for chat completion tasks.
|
||||
- **Enable Multi-Modal**: Add `--enable-multimodal` for multi-modal capabilities.
|
||||
- **Enable Hybrid-KVCache**: Add `--hybrid-kvcache-ratio` for hybrid kv cache. Details can be seen in [this PR](https://github.com/sgl-project/sglang/pull/6563)
|
||||
|
||||
## Benchmarking Results
|
||||
|
||||
|
||||
@@ -59,6 +59,7 @@ class ModelConfig:
|
||||
quantization: Optional[str] = None,
|
||||
override_config_file: Optional[str] = None,
|
||||
is_draft_model: bool = False,
|
||||
hybrid_kvcache_ratio: Optional[float] = None,
|
||||
impl: Union[str, ModelImpl] = ModelImpl.AUTO,
|
||||
) -> None:
|
||||
|
||||
@@ -86,6 +87,18 @@ class ModelConfig:
|
||||
self.attention_chunk_size = getattr(
|
||||
self.hf_text_config, "attention_chunk_size", None
|
||||
)
|
||||
self.is_hybrid = is_hybrid_model(
|
||||
self.hf_config.architectures,
|
||||
hybrid_kvcache_ratio=hybrid_kvcache_ratio,
|
||||
context_length=context_length,
|
||||
attention_chunk_size=self.attention_chunk_size,
|
||||
)
|
||||
if self.is_hybrid is not None:
|
||||
self.swa_attention_layer_ids, self.full_attention_layer_ids = (
|
||||
get_hybrid_layer_ids(
|
||||
self.hf_config.architectures, self.hf_text_config.num_hidden_layers
|
||||
)
|
||||
)
|
||||
|
||||
if enable_multimodal is None:
|
||||
mm_disabled_models = [
|
||||
@@ -264,6 +277,7 @@ class ModelConfig:
|
||||
enable_multimodal=server_args.enable_multimodal,
|
||||
dtype=server_args.dtype,
|
||||
quantization=server_args.quantization,
|
||||
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
|
||||
impl=server_args.impl,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -633,3 +647,36 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
|
||||
def is_hybrid_model(
|
||||
model_architectures: List[str],
|
||||
hybrid_kvcache_ratio: Optional[float],
|
||||
context_length: Optional[int],
|
||||
attention_chunk_size: Optional[int],
|
||||
):
|
||||
if hybrid_kvcache_ratio is None:
|
||||
return None
|
||||
elif (
|
||||
hybrid_kvcache_ratio > 0
|
||||
and model_architectures[0] == "Llama4ForConditionalGeneration"
|
||||
and context_length > attention_chunk_size
|
||||
):
|
||||
return hybrid_kvcache_ratio
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_hybrid_layer_ids(model_architectures: List[str], num_hidden_layers: int):
|
||||
if "Llama4ForConditionalGeneration" in model_architectures:
|
||||
swa_attention_layer_ids = [
|
||||
i for i in range(num_hidden_layers) if (i + 1) % 4 != 0
|
||||
]
|
||||
full_attention_layer_ids = [
|
||||
i for i in range(num_hidden_layers) if (i + 1) % 4 == 0
|
||||
]
|
||||
else:
|
||||
raise ValueError(
|
||||
"get_hybrid_layer_ids is only implemented for Llama4ForConditionalGeneration"
|
||||
)
|
||||
return swa_attention_layer_ids, full_attention_layer_ids
|
||||
|
||||
@@ -433,9 +433,7 @@ class DecodePreallocQueue:
|
||||
else 0
|
||||
)
|
||||
|
||||
available_size = self.token_to_kv_pool_allocator.available_size()
|
||||
|
||||
allocatable_tokens = available_size - max(
|
||||
allocatable_tokens = self.token_to_kv_pool_allocator.available_size() - max(
|
||||
# preserve some space for future decode
|
||||
self.num_reserved_decode_tokens
|
||||
* (
|
||||
|
||||
@@ -9,6 +9,7 @@ import torch
|
||||
from sglang.srt.configs.model_config import AttentionArch
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.mem_cache.memory_pool import SWAKVPool
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
|
||||
@@ -320,6 +321,11 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
self.page_size = model_runner.page_size
|
||||
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
|
||||
self.skip_prefill = skip_prefill
|
||||
self.is_hybrid = model_runner.is_hybrid
|
||||
if self.is_hybrid:
|
||||
self.full_to_swa_index_mapping = (
|
||||
model_runner.token_to_kv_pool.full_to_swa_index_mapping
|
||||
)
|
||||
self.topk = model_runner.server_args.speculative_eagle_topk or 0
|
||||
self.speculative_num_steps = speculative_num_steps
|
||||
self.speculative_num_draft_tokens = (
|
||||
@@ -428,7 +434,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||
]
|
||||
# TODO: we need to test this part for llama 4 eagle case
|
||||
self._init_local_attn_metadata(metadata, device)
|
||||
self._init_local_attn_metadata(forward_batch, metadata, device)
|
||||
elif forward_batch.forward_mode.is_target_verify():
|
||||
if self.topk <= 1:
|
||||
metadata.cache_seqlens_int32 = (
|
||||
@@ -456,7 +462,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||
]
|
||||
|
||||
self._init_local_attn_metadata(metadata, device)
|
||||
self._init_local_attn_metadata(forward_batch, metadata, device)
|
||||
else:
|
||||
metadata.cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32)
|
||||
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
||||
@@ -575,7 +581,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
# Setup local attention if enabled
|
||||
if forward_batch.forward_mode == ForwardMode.EXTEND:
|
||||
self._init_local_attn_metadata(metadata, device)
|
||||
self._init_local_attn_metadata(forward_batch, metadata, device)
|
||||
|
||||
# Encoder metadata for cross attention
|
||||
if forward_batch.encoder_lens is not None:
|
||||
@@ -1588,7 +1594,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
out_cache_loc: torch.Tensor = None,
|
||||
out_cache_loc: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""Initialize forward metadata for replaying CUDA graph."""
|
||||
seq_lens = seq_lens[:bs]
|
||||
@@ -1673,7 +1679,10 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
self.page_size,
|
||||
)
|
||||
|
||||
self._update_local_attn_metadata_for_replay(metadata, bs)
|
||||
self._update_local_attn_metadata_for_replay(
|
||||
metadata,
|
||||
bs,
|
||||
)
|
||||
elif forward_mode.is_target_verify():
|
||||
if self.topk <= 1:
|
||||
metadata = self.target_verify_metadata[bs]
|
||||
@@ -1829,7 +1838,9 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
"""Get the fill value for sequence length in CUDA graph."""
|
||||
return 1
|
||||
|
||||
def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device):
|
||||
def _init_local_attn_metadata(
|
||||
self, forwardbatch: ForwardBatch, metadata: FlashAttentionMetadata, device
|
||||
):
|
||||
"""Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""
|
||||
if self.attention_chunk_size is None:
|
||||
metadata.local_attn_metadata = None
|
||||
@@ -1837,7 +1848,12 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
cu_seqlens_q = metadata.cu_seqlens_q
|
||||
cache_seqlens_int32 = metadata.cache_seqlens_int32
|
||||
page_table = metadata.page_table
|
||||
if self.is_hybrid:
|
||||
page_table = self.full_to_swa_index_mapping[metadata.page_table].to(
|
||||
torch.int32
|
||||
)
|
||||
else:
|
||||
page_table = metadata.page_table
|
||||
if cu_seqlens_q is None or cache_seqlens_int32 is None or page_table is None:
|
||||
metadata.local_attn_metadata = None
|
||||
return
|
||||
@@ -1923,7 +1939,9 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
)
|
||||
|
||||
def _update_local_attn_metadata_for_replay(
|
||||
self, metadata: FlashAttentionMetadata, bs: int
|
||||
self,
|
||||
metadata: FlashAttentionMetadata,
|
||||
bs: int,
|
||||
):
|
||||
"""Update preallocated local attention metadata in-place before CUDA graph replay."""
|
||||
if self.attention_chunk_size is None:
|
||||
@@ -1954,7 +1972,12 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
# Without this slicing, the pre-allocated page_table may contain zeros or invalid indices
|
||||
# beyond the actual sequence length, leading to incorrect attention calculations
|
||||
max_seq_len = int(seqlens.max().item())
|
||||
sliced_page_table = metadata.page_table[:bs, :max_seq_len]
|
||||
if self.is_hybrid:
|
||||
sliced_page_table = self.full_to_swa_index_mapping[
|
||||
metadata.page_table[:bs, :max_seq_len]
|
||||
].to(torch.int32)
|
||||
else:
|
||||
sliced_page_table = metadata.page_table[:bs, :max_seq_len]
|
||||
|
||||
cu_seqlens_q_np = cu_seqlens_q.cpu().numpy()
|
||||
seqlens_np = seqlens.cpu().numpy()
|
||||
|
||||
@@ -56,7 +56,7 @@ from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
||||
from sglang.srt.layers.multimodal import gpu_tensor_hash
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||
from sglang.srt.metrics.collector import TimeStats
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
||||
@@ -485,6 +485,9 @@ class Req:
|
||||
# for corss-endoder model
|
||||
self.token_type_ids = token_type_ids
|
||||
|
||||
# The length of KV that have been removed in local attention chunked prefill
|
||||
self.evicted_seqlen_local = 0
|
||||
|
||||
# Sampling info
|
||||
if isinstance(sampling_params.custom_params, dict):
|
||||
sampling_params = copy.copy(sampling_params)
|
||||
@@ -1191,6 +1194,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.req_to_token_pool.write(
|
||||
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
|
||||
)
|
||||
if isinstance(self.tree_cache, SWAChunkCache):
|
||||
self.tree_cache.evict(
|
||||
req, pre_len, self.model_config.attention_chunk_size
|
||||
)
|
||||
|
||||
# If input_embeds are available, store them
|
||||
if req.input_embeds is not None:
|
||||
@@ -1383,7 +1390,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
* buf_multiplier
|
||||
* self.token_to_kv_pool_allocator.page_size
|
||||
)
|
||||
|
||||
if self.token_to_kv_pool_allocator.available_size() >= tokens_required:
|
||||
return True
|
||||
|
||||
@@ -1564,6 +1570,13 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.seq_lens.add_(1)
|
||||
self.seq_lens_sum += bs
|
||||
|
||||
# free memory
|
||||
if isinstance(self.tree_cache, SWAChunkCache):
|
||||
for req in self.reqs:
|
||||
self.tree_cache.evict(
|
||||
req, req.seqlen - 1, self.model_config.attention_chunk_size
|
||||
)
|
||||
|
||||
# Allocate memory
|
||||
if self.token_to_kv_pool_allocator.page_size == 1:
|
||||
self.out_cache_loc = self.alloc_token_slots(bs)
|
||||
@@ -1798,7 +1811,6 @@ class ModelWorkerBatch:
|
||||
seq_lens: torch.Tensor
|
||||
# The indices of output tokens in the token_to_kv_pool_allocator
|
||||
out_cache_loc: torch.Tensor
|
||||
|
||||
# The sequence length tensor on CPU
|
||||
seq_lens_cpu: Optional[torch.Tensor]
|
||||
seq_lens_sum: int
|
||||
|
||||
@@ -126,7 +126,8 @@ from sglang.srt.managers.session_controller import Session
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
||||
from sglang.srt.managers.utils import validate_input_length
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
||||
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
||||
@@ -570,7 +571,11 @@ class Scheduler(
|
||||
server_args.chunked_prefill_size is not None
|
||||
and server_args.disable_radix_cache
|
||||
):
|
||||
self.tree_cache = ChunkCache(
|
||||
if self.model_config.is_hybrid:
|
||||
ChunkCacheClass = SWAChunkCache
|
||||
else:
|
||||
ChunkCacheClass = ChunkCache
|
||||
self.tree_cache = ChunkCacheClass(
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
page_size=self.page_size,
|
||||
@@ -1283,9 +1288,8 @@ class Scheduler(
|
||||
self.last_input_throughput = self.last_prefill_tokens / gap_latency
|
||||
self.last_prefill_tokens = adder.log_input_tokens
|
||||
|
||||
num_used = self.max_total_num_tokens - (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
usage_msg, num_used = self.token_to_kv_pool_allocator.log_usage(
|
||||
self.tree_cache.evictable_size()
|
||||
)
|
||||
|
||||
num_new_seq = len(can_run_list)
|
||||
@@ -1294,7 +1298,7 @@ class Scheduler(
|
||||
f"#new-seq: {num_new_seq}, "
|
||||
f"#new-token: {adder.log_input_tokens}, "
|
||||
f"#cached-token: {adder.log_hit_tokens}, "
|
||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||
f"{usage_msg}"
|
||||
)
|
||||
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
@@ -1337,9 +1341,8 @@ class Scheduler(
|
||||
self.last_gen_throughput = self.num_generated_tokens / gap_latency
|
||||
self.num_generated_tokens = 0
|
||||
num_running_reqs = len(batch.reqs)
|
||||
num_used = self.max_total_num_tokens - (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
usage_msg, num_used = self.token_to_kv_pool_allocator.log_usage(
|
||||
self.tree_cache.evictable_size()
|
||||
)
|
||||
|
||||
if RECORD_STEP_TIME:
|
||||
@@ -1347,12 +1350,7 @@ class Scheduler(
|
||||
gap_latency / self.server_args.decode_log_interval
|
||||
)
|
||||
|
||||
msg = (
|
||||
f"Decode batch. "
|
||||
f"#running-req: {num_running_reqs}, "
|
||||
f"#token: {num_used}, "
|
||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||
)
|
||||
msg = f"Decode batch. " f"#running-req: {num_running_reqs}, " f"{usage_msg}"
|
||||
|
||||
if self.spec_algorithm.is_none():
|
||||
spec_accept_length = 0
|
||||
@@ -1390,10 +1388,11 @@ class Scheduler(
|
||||
self._publish_kv_events()
|
||||
|
||||
def check_memory(self):
|
||||
available_size = (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
)
|
||||
if isinstance(self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
|
||||
available_token_size = self.token_to_kv_pool_allocator.full_available_size()
|
||||
else:
|
||||
available_token_size = self.token_to_kv_pool_allocator.available_size()
|
||||
available_size = available_token_size + self.tree_cache.evictable_size()
|
||||
protected_size = self.tree_cache.protected_size()
|
||||
memory_leak = available_size != (
|
||||
self.max_total_num_tokens
|
||||
@@ -1404,7 +1403,7 @@ class Scheduler(
|
||||
msg = (
|
||||
"token_to_kv_pool_allocator memory leak detected! "
|
||||
f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
|
||||
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
|
||||
f"{available_token_size=}\n"
|
||||
f"{self.tree_cache.evictable_size()=}\n"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
@@ -20,12 +20,14 @@ Page-aligned memory pool.
|
||||
"""
|
||||
|
||||
import abc
|
||||
import weakref
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.mem_cache.memory_pool import SWAKVPool
|
||||
from sglang.srt.utils import get_bool_env_var, next_power_of_2
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -55,6 +57,11 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
|
||||
def debug_print(self) -> str:
|
||||
return ""
|
||||
|
||||
def log_usage(self, evictable_size: int = 0):
|
||||
num_used = self.size - (self.available_size() + evictable_size)
|
||||
msg = f"#token: {num_used}, token usage: {num_used / self.size:.2f}, "
|
||||
return msg, num_used
|
||||
|
||||
def available_size(self):
|
||||
return len(self.free_pages) * self.page_size
|
||||
|
||||
@@ -146,6 +153,128 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
|
||||
|
||||
|
||||
class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
"""Allocator for SWA hybrid KV cache."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
size_swa: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
kvcache: SWAKVPool,
|
||||
):
|
||||
super().__init__(size, 1, dtype, device, kvcache)
|
||||
assert isinstance(kvcache, SWAKVPool)
|
||||
self._size_full = size
|
||||
self._size_swa = size_swa
|
||||
self.full_attn_allocator = TokenToKVPoolAllocator(
|
||||
size,
|
||||
dtype,
|
||||
device,
|
||||
kvcache.full_kv_pool,
|
||||
)
|
||||
self.swa_attn_allocator = TokenToKVPoolAllocator(
|
||||
size_swa,
|
||||
dtype,
|
||||
device,
|
||||
kvcache.swa_kv_pool,
|
||||
)
|
||||
self.full_to_swa_index_mapping = torch.empty(
|
||||
size + size_swa + 1,
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
self.clear()
|
||||
|
||||
self._kvcache.full_to_swa_index_mapping = self.full_to_swa_index_mapping
|
||||
|
||||
def available_size(self):
|
||||
return min(self.full_available_size(), self.swa_available_size())
|
||||
|
||||
def full_available_size(self):
|
||||
return self.full_attn_allocator.available_size()
|
||||
|
||||
def swa_available_size(self):
|
||||
return self.swa_attn_allocator.available_size()
|
||||
|
||||
@property
|
||||
def size_full(self):
|
||||
return self._size_full
|
||||
|
||||
@property
|
||||
def size_swa(self):
|
||||
return self._size_swa
|
||||
|
||||
def debug_print(self) -> str:
|
||||
msg = ""
|
||||
msg += f"#swa-available-size: {self.swa_attn_allocator.available_size()}, "
|
||||
msg += (
|
||||
f"#full-attn-available-size: {self.full_attn_allocator.available_size()}, "
|
||||
)
|
||||
return msg
|
||||
|
||||
def log_usage(self, swa_evictable_size: int = 0, full_evictable_size: int = 0):
|
||||
used_full = self.size_full - (self.full_available_size() + full_evictable_size)
|
||||
used_swa = self.size_swa - (self.swa_available_size() + swa_evictable_size)
|
||||
msg = (
|
||||
f"#token: full={used_full}, swa={used_swa}, "
|
||||
f"token usage: full={used_full / self.size_full:.2f}, "
|
||||
f"swa={used_swa / self.size_swa:.2f}, "
|
||||
)
|
||||
return msg, used_full
|
||||
|
||||
def get_kvcache(self):
|
||||
return self._kvcache
|
||||
|
||||
def translate_loc_from_full_to_swa(self, kv_indices: torch.Tensor):
|
||||
assert self.full_to_swa_index_mapping is not None
|
||||
return self.full_to_swa_index_mapping[kv_indices].to(torch.int32)
|
||||
|
||||
def alloc(self, need_size: int):
|
||||
if need_size > self.full_attn_allocator.available_size():
|
||||
return None
|
||||
if need_size > self.swa_attn_allocator.available_size():
|
||||
return None
|
||||
|
||||
alloc_full_indices = self.full_attn_allocator.alloc(need_size)
|
||||
alloc_swa_indices = self.swa_attn_allocator.alloc(need_size)
|
||||
self.full_to_swa_index_mapping[alloc_full_indices] = alloc_swa_indices
|
||||
return alloc_full_indices
|
||||
|
||||
def free(self, free_index: torch.Tensor):
|
||||
if free_index.numel() == 0:
|
||||
return
|
||||
if self.is_not_in_free_group:
|
||||
self.full_attn_allocator.free(free_index)
|
||||
self.free_swa(free_index)
|
||||
else:
|
||||
self.free_group.append(free_index)
|
||||
assert (
|
||||
self.full_attn_allocator.available_size() <= self.full_attn_allocator.size
|
||||
)
|
||||
assert self.swa_attn_allocator.available_size() <= self.swa_attn_allocator.size
|
||||
|
||||
def free_swa(self, free_index: torch.Tensor):
|
||||
swa_indices = self.full_to_swa_index_mapping[free_index]
|
||||
swa_indices = swa_indices[swa_indices > 0]
|
||||
self.swa_attn_allocator.free(swa_indices)
|
||||
self.full_to_swa_index_mapping[free_index] = 0
|
||||
|
||||
def backup_state(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def restore_state(self, state):
|
||||
raise NotImplementedError
|
||||
|
||||
def clear(self):
|
||||
self.swa_attn_allocator.clear()
|
||||
self.full_attn_allocator.clear()
|
||||
self.full_to_swa_index_mapping.fill_(0)
|
||||
self.is_in_free_group = False
|
||||
self.free_group = []
|
||||
|
||||
|
||||
@triton.jit
|
||||
def alloc_extend_kernel(
|
||||
pre_lens_ptr,
|
||||
|
||||
@@ -2,11 +2,14 @@ from __future__ import annotations
|
||||
|
||||
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.allocator import (
|
||||
BaseTokenToKVPoolAllocator,
|
||||
SWATokenToKVPoolAllocator,
|
||||
)
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||
|
||||
@@ -63,3 +66,32 @@ class ChunkCache(BasePrefixCache):
|
||||
|
||||
def pretty_print(self):
|
||||
return ""
|
||||
|
||||
|
||||
class SWAChunkCache(ChunkCache):
|
||||
"""ChunkCache with support for hybrid KV cache operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool_allocator: SWATokenToKVPoolAllocator,
|
||||
page_size: int,
|
||||
):
|
||||
super().__init__(req_to_token_pool, token_to_kv_pool_allocator, page_size)
|
||||
assert isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator)
|
||||
|
||||
def evict(
|
||||
self,
|
||||
req: Req,
|
||||
prelen: int,
|
||||
attention_chunk_size: int,
|
||||
):
|
||||
if prelen >= req.evicted_seqlen_local + attention_chunk_size:
|
||||
new_evicted_seqlen_local = attention_chunk_size * (
|
||||
prelen // attention_chunk_size
|
||||
)
|
||||
free_slots = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, req.evicted_seqlen_local : new_evicted_seqlen_local
|
||||
]
|
||||
self.token_to_kv_pool_allocator.free_swa(free_slots)
|
||||
req.evicted_seqlen_local = new_evicted_seqlen_local
|
||||
|
||||
@@ -27,10 +27,11 @@ KVCache actually holds the physical kv cache.
|
||||
import abc
|
||||
import logging
|
||||
from contextlib import nullcontext
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
@@ -66,6 +67,7 @@ class ReqToTokenPool:
|
||||
self.req_to_token = torch.zeros(
|
||||
(size, max_context_len), dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
self.free_slots = list(range(size))
|
||||
|
||||
def write(self, indices, values):
|
||||
@@ -191,7 +193,6 @@ class MHATokenToKVPool(KVCache):
|
||||
start_layer,
|
||||
end_layer,
|
||||
)
|
||||
|
||||
self.head_num = head_num
|
||||
self.head_dim = head_dim
|
||||
|
||||
@@ -392,10 +393,14 @@ class MHATokenToKVPool(KVCache):
|
||||
cache_v: torch.Tensor,
|
||||
k_scale: Optional[float] = None,
|
||||
v_scale: Optional[float] = None,
|
||||
layer_id_override: Optional[int] = None,
|
||||
):
|
||||
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||
|
||||
layer_id = layer.layer_id
|
||||
if layer_id_override is not None:
|
||||
layer_id = layer_id_override
|
||||
else:
|
||||
layer_id = layer.layer_id
|
||||
if cache_k.dtype != self.dtype:
|
||||
if k_scale is not None:
|
||||
cache_k.div_(k_scale)
|
||||
@@ -431,6 +436,136 @@ class MHATokenToKVPool(KVCache):
|
||||
)
|
||||
|
||||
|
||||
class SWAKVPool(KVCache):
|
||||
"""KV cache with separate pools for full and SWA attention layers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
size_swa: int,
|
||||
dtype: torch.dtype,
|
||||
head_num: int,
|
||||
head_dim: int,
|
||||
swa_attention_layer_ids: List[int],
|
||||
full_attention_layer_ids: List[int],
|
||||
enable_kvcache_transpose: bool,
|
||||
device: str,
|
||||
):
|
||||
self.size = size
|
||||
self.size_swa = size_swa
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.swa_layer_nums = len(swa_attention_layer_ids)
|
||||
self.full_layer_nums = len(full_attention_layer_ids)
|
||||
self.page_size = 1
|
||||
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
|
||||
assert not enable_kvcache_transpose
|
||||
TokenToKVPoolClass = MHATokenToKVPool
|
||||
self.swa_kv_pool = TokenToKVPoolClass(
|
||||
size=size_swa,
|
||||
page_size=self.page_size,
|
||||
dtype=dtype,
|
||||
head_num=head_num,
|
||||
head_dim=head_dim,
|
||||
layer_num=self.swa_layer_nums,
|
||||
device=device,
|
||||
enable_memory_saver=False,
|
||||
)
|
||||
self.full_kv_pool = TokenToKVPoolClass(
|
||||
size=size,
|
||||
page_size=self.page_size,
|
||||
dtype=dtype,
|
||||
head_num=head_num,
|
||||
head_dim=head_dim,
|
||||
layer_num=self.full_layer_nums,
|
||||
device=device,
|
||||
enable_memory_saver=False,
|
||||
)
|
||||
self.layers_mapping: Dict[int, Tuple[int, bool]] = {}
|
||||
for full_attn_layer_id, global_layer_id in enumerate(full_attention_layer_ids):
|
||||
self.layers_mapping[global_layer_id] = (full_attn_layer_id, False)
|
||||
for swa_layer_id, global_layer_id in enumerate(swa_attention_layer_ids):
|
||||
self.layers_mapping[global_layer_id] = (swa_layer_id, True)
|
||||
self.full_to_swa_index_mapping: Optional[torch.Tensor] = None
|
||||
|
||||
def get_kv_size_bytes(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_contiguous_buf_infos(self):
|
||||
full_kv_data_ptrs, full_kv_data_lens, full_kv_item_lens = (
|
||||
self.full_kv_pool.get_contiguous_buf_infos()
|
||||
)
|
||||
swa_kv_data_ptrs, swa_kv_data_lens, swa_kv_item_lens = (
|
||||
self.swa_kv_pool.get_contiguous_buf_infos()
|
||||
)
|
||||
|
||||
kv_data_ptrs = full_kv_data_ptrs + swa_kv_data_ptrs
|
||||
kv_data_lens = full_kv_data_lens + swa_kv_data_lens
|
||||
kv_item_lens = full_kv_item_lens + swa_kv_item_lens
|
||||
|
||||
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
||||
|
||||
def get_key_buffer(self, layer_id: int):
|
||||
layer_id_pool, is_swa = self.layers_mapping[layer_id]
|
||||
if is_swa:
|
||||
return self.swa_kv_pool.get_key_buffer(layer_id_pool)
|
||||
else:
|
||||
return self.full_kv_pool.get_key_buffer(layer_id_pool)
|
||||
|
||||
def get_value_buffer(self, layer_id: int):
|
||||
layer_id_pool, is_swa = self.layers_mapping[layer_id]
|
||||
if is_swa:
|
||||
return self.swa_kv_pool.get_value_buffer(layer_id_pool)
|
||||
else:
|
||||
return self.full_kv_pool.get_value_buffer(layer_id_pool)
|
||||
|
||||
def get_kv_buffer(self, layer_id: int):
|
||||
layer_id_pool, is_swa = self.layers_mapping[layer_id]
|
||||
if is_swa:
|
||||
return self.swa_kv_pool.get_kv_buffer(layer_id_pool)
|
||||
else:
|
||||
return self.full_kv_pool.get_kv_buffer(layer_id_pool)
|
||||
|
||||
def translate_loc_from_full_to_swa(self, kv_indices: torch.Tensor):
|
||||
assert self.full_to_swa_index_mapping is not None
|
||||
return self.full_to_swa_index_mapping[kv_indices].to(torch.int32)
|
||||
|
||||
def set_kv_buffer(
|
||||
self,
|
||||
layer: RadixAttention,
|
||||
loc: torch.Tensor,
|
||||
cache_k: torch.Tensor,
|
||||
cache_v: torch.Tensor,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
):
|
||||
|
||||
layer_id = layer.layer_id
|
||||
layer_id_pool, is_swa = self.layers_mapping[layer_id]
|
||||
if is_swa:
|
||||
if self.full_to_swa_index_mapping is not None:
|
||||
loc = self.translate_loc_from_full_to_swa(loc)
|
||||
self.swa_kv_pool.set_kv_buffer(
|
||||
None,
|
||||
loc,
|
||||
cache_k,
|
||||
cache_v,
|
||||
k_scale,
|
||||
v_scale,
|
||||
layer_id_override=layer_id_pool,
|
||||
)
|
||||
else:
|
||||
self.full_kv_pool.set_kv_buffer(
|
||||
None,
|
||||
loc,
|
||||
cache_k,
|
||||
cache_v,
|
||||
k_scale,
|
||||
v_scale,
|
||||
layer_id_override=layer_id_pool,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def set_mla_kv_buffer_kernel(
|
||||
kv_buffer_ptr,
|
||||
|
||||
@@ -74,6 +74,7 @@ from sglang.srt.managers.schedule_batch import (
|
||||
from sglang.srt.mem_cache.allocator import (
|
||||
BaseTokenToKVPoolAllocator,
|
||||
PagedTokenToKVPoolAllocator,
|
||||
SWATokenToKVPoolAllocator,
|
||||
TokenToKVPoolAllocator,
|
||||
)
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
@@ -81,6 +82,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
||||
MHATokenToKVPool,
|
||||
MLATokenToKVPool,
|
||||
ReqToTokenPool,
|
||||
SWAKVPool,
|
||||
)
|
||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||
from sglang.srt.model_executor.expert_location_updater import ExpertLocationUpdater
|
||||
@@ -185,6 +187,7 @@ class ModelRunner:
|
||||
self.page_size = server_args.page_size
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||
self.is_hybrid = model_config.is_hybrid
|
||||
self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
|
||||
self.attention_chunk_size = model_config.attention_chunk_size
|
||||
|
||||
@@ -437,6 +440,10 @@ class ModelRunner:
|
||||
if self.model_config.context_len > 8192:
|
||||
self.mem_fraction_static *= 0.85
|
||||
|
||||
if self.is_hybrid and not server_args.disable_radix_cache:
|
||||
logger.info("Automatically disable radix cache for hybrid cache.")
|
||||
server_args.disable_radix_cache = True
|
||||
|
||||
def init_torch_distributed(self):
|
||||
logger.info("Init torch distributed begin.")
|
||||
|
||||
@@ -852,6 +859,40 @@ class ModelRunner:
|
||||
max_num_token = int(rest_memory * (1 << 30) // cell_size)
|
||||
return max_num_token
|
||||
|
||||
def set_num_token_hybrid(self):
|
||||
if (
|
||||
"Llama4ForConditionalGeneration"
|
||||
in self.model_config.hf_config.architectures
|
||||
):
|
||||
temp_ratio = (
|
||||
(1 - self.is_hybrid)
|
||||
+ self.is_hybrid
|
||||
* self.attention_chunk_size
|
||||
/ self.model_config.context_len
|
||||
)
|
||||
self.swa_max_total_num_tokens = (
|
||||
4 * self.max_total_num_tokens * temp_ratio // (3 * temp_ratio + 1)
|
||||
)
|
||||
self.full_max_total_num_tokens = (
|
||||
4 * self.max_total_num_tokens
|
||||
- 12 * self.max_total_num_tokens * temp_ratio // (3 * temp_ratio + 1)
|
||||
)
|
||||
self.swa_max_total_num_tokens = int(
|
||||
self.swa_max_total_num_tokens
|
||||
// self.server_args.page_size
|
||||
* self.server_args.page_size
|
||||
)
|
||||
self.full_max_total_num_tokens = int(
|
||||
self.full_max_total_num_tokens
|
||||
// self.server_args.page_size
|
||||
* self.server_args.page_size
|
||||
)
|
||||
self.max_total_num_tokens = self.full_max_total_num_tokens
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported model for hybrid cache: {self.model_config.hf_config.architectures}."
|
||||
)
|
||||
|
||||
def init_memory_pool(
|
||||
self,
|
||||
total_gpu_memory: int,
|
||||
@@ -929,6 +970,10 @@ class ModelRunner:
|
||||
* self.server_args.page_size
|
||||
)
|
||||
|
||||
# create token size for hybrid cache
|
||||
if self.is_hybrid:
|
||||
self.set_num_token_hybrid()
|
||||
|
||||
if self.max_total_num_tokens <= 0:
|
||||
raise RuntimeError(
|
||||
"Not enough memory. Please try to increase --mem-fraction-static."
|
||||
@@ -991,27 +1036,53 @@ class ModelRunner:
|
||||
end_layer=self.end_layer,
|
||||
)
|
||||
else:
|
||||
self.token_to_kv_pool = MHATokenToKVPool(
|
||||
self.max_total_num_tokens,
|
||||
page_size=self.page_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
|
||||
head_dim=self.model_config.head_dim,
|
||||
layer_num=self.num_effective_layers,
|
||||
device=self.device,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
start_layer=self.start_layer,
|
||||
end_layer=self.end_layer,
|
||||
)
|
||||
if self.is_hybrid:
|
||||
self.token_to_kv_pool = SWAKVPool(
|
||||
size=self.full_max_total_num_tokens,
|
||||
size_swa=self.swa_max_total_num_tokens,
|
||||
dtype=self.kv_cache_dtype,
|
||||
head_num=self.model_config.get_num_kv_heads(
|
||||
get_attention_tp_size()
|
||||
),
|
||||
head_dim=self.model_config.head_dim,
|
||||
swa_attention_layer_ids=self.model_config.swa_attention_layer_ids,
|
||||
full_attention_layer_ids=self.model_config.full_attention_layer_ids,
|
||||
enable_kvcache_transpose=False,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
self.token_to_kv_pool = MHATokenToKVPool(
|
||||
self.max_total_num_tokens,
|
||||
page_size=self.page_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
head_num=self.model_config.get_num_kv_heads(
|
||||
get_attention_tp_size()
|
||||
),
|
||||
head_dim=self.model_config.head_dim,
|
||||
layer_num=self.num_effective_layers,
|
||||
device=self.device,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
start_layer=self.start_layer,
|
||||
end_layer=self.end_layer,
|
||||
)
|
||||
|
||||
if self.token_to_kv_pool_allocator is None:
|
||||
if self.page_size == 1:
|
||||
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
|
||||
self.max_total_num_tokens,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
kvcache=self.token_to_kv_pool,
|
||||
)
|
||||
if self.is_hybrid:
|
||||
self.token_to_kv_pool_allocator = SWATokenToKVPoolAllocator(
|
||||
self.full_max_total_num_tokens,
|
||||
self.swa_max_total_num_tokens,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
kvcache=self.token_to_kv_pool,
|
||||
)
|
||||
else:
|
||||
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
|
||||
self.max_total_num_tokens,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
kvcache=self.token_to_kv_pool,
|
||||
)
|
||||
else:
|
||||
self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
|
||||
self.max_total_num_tokens,
|
||||
|
||||
@@ -61,6 +61,7 @@ class ServerArgs:
|
||||
is_embedding: bool = False
|
||||
enable_multimodal: Optional[bool] = None
|
||||
revision: Optional[str] = None
|
||||
hybrid_kvcache_ratio: Optional[float] = None
|
||||
impl: str = "auto"
|
||||
|
||||
# Port for the HTTP server
|
||||
@@ -817,6 +818,18 @@ class ServerArgs:
|
||||
default=ServerArgs.page_size,
|
||||
help="The number of tokens in a page.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hybrid-kvcache-ratio",
|
||||
nargs="?",
|
||||
const=0.5,
|
||||
type=float,
|
||||
default=ServerArgs.hybrid_kvcache_ratio,
|
||||
help=(
|
||||
"Mix ratio in [0,1] between uniform and hybrid kv buffers "
|
||||
"(0.0 = pure uniform: swa_size / full_size = 1)"
|
||||
"(1.0 = pure hybrid: swa_size / full_size = local_attention_size / context_length)"
|
||||
),
|
||||
)
|
||||
|
||||
# Other runtime options
|
||||
parser.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user