SWA Prefix Cache (#7367)
Co-authored-by: Ying Sheng <sqy1415@gmail.com>
This commit is contained in:
@@ -711,7 +711,6 @@ def get_hybrid_layer_ids(model_architectures: List[str], num_hidden_layers: int)
|
||||
i for i in range(num_hidden_layers) if (i + 1) % 4 == 0
|
||||
]
|
||||
else:
|
||||
raise ValueError(
|
||||
"get_hybrid_layer_ids is only implemented for Llama4ForConditionalGeneration"
|
||||
)
|
||||
swa_attention_layer_ids = None
|
||||
full_attention_layer_ids = None
|
||||
return swa_attention_layer_ids, full_attention_layer_ids
|
||||
|
||||
@@ -439,7 +439,15 @@ class DecodePreallocQueue:
|
||||
else 0
|
||||
)
|
||||
|
||||
allocatable_tokens = self.token_to_kv_pool_allocator.available_size() - max(
|
||||
if self.scheduler.model_config.is_hybrid:
|
||||
available_size = min(
|
||||
self.token_to_kv_pool_allocator.full_available_size(),
|
||||
self.token_to_kv_pool_allocator.swa_available_size(),
|
||||
)
|
||||
else:
|
||||
available_size = self.token_to_kv_pool_allocator.available_size()
|
||||
|
||||
allocatable_tokens = available_size - max(
|
||||
# preserve some space for future decode
|
||||
self.num_reserved_decode_tokens
|
||||
* (
|
||||
|
||||
@@ -26,6 +26,7 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.layers.utils import is_sm100_supported
|
||||
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.utils import is_flashinfer_available, next_power_of_2
|
||||
@@ -589,6 +590,7 @@ class FlashInferIndicesUpdaterDecode:
|
||||
self.kv_indptr = attn_backend.kv_indptr
|
||||
self.kv_last_page_len = attn_backend.kv_last_page_len
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
|
||||
|
||||
# Dispatch the update function
|
||||
if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
||||
@@ -655,6 +657,10 @@ class FlashInferIndicesUpdaterDecode:
|
||||
paged_kernel_lens_sum_tmp = seq_lens_sum
|
||||
kv_start_idx_tmp = None
|
||||
|
||||
use_sliding_window_kv_pool = wrapper_id == 0 and isinstance(
|
||||
self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
|
||||
)
|
||||
|
||||
self.call_begin_forward(
|
||||
decode_wrappers[wrapper_id],
|
||||
req_pool_indices,
|
||||
@@ -663,6 +669,7 @@ class FlashInferIndicesUpdaterDecode:
|
||||
self.kv_indptr[wrapper_id],
|
||||
kv_start_idx_tmp,
|
||||
spec_info,
|
||||
use_sliding_window_kv_pool=use_sliding_window_kv_pool,
|
||||
)
|
||||
|
||||
def update_cross_attention(
|
||||
@@ -704,6 +711,7 @@ class FlashInferIndicesUpdaterDecode:
|
||||
kv_indptr: torch.Tensor,
|
||||
kv_start_idx: torch.Tensor,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
use_sliding_window_kv_pool: bool = False,
|
||||
):
|
||||
if spec_info is None:
|
||||
bs = len(req_pool_indices)
|
||||
@@ -731,6 +739,14 @@ class FlashInferIndicesUpdaterDecode:
|
||||
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||
bs = kv_indptr.shape[0] - 1
|
||||
|
||||
if use_sliding_window_kv_pool:
|
||||
kv_last_index = kv_indptr[-1]
|
||||
kv_indices[:kv_last_index] = (
|
||||
self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
|
||||
kv_indices[:kv_last_index]
|
||||
)
|
||||
)
|
||||
|
||||
wrapper.begin_forward(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
@@ -765,6 +781,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
self.kv_last_page_len = attn_backend.kv_last_page_len
|
||||
self.qo_indptr = attn_backend.qo_indptr
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
|
||||
self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
|
||||
|
||||
# Dispatch the update function
|
||||
@@ -848,6 +865,9 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
paged_kernel_lens_sum = seq_lens_sum
|
||||
|
||||
kv_start_idx = seq_lens - paged_kernel_lens
|
||||
use_sliding_window_kv_pool = wrapper_id == 0 and isinstance(
|
||||
self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
|
||||
)
|
||||
|
||||
self.call_begin_forward(
|
||||
self.prefill_wrapper_ragged,
|
||||
@@ -862,6 +882,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
self.qo_indptr[wrapper_id],
|
||||
use_ragged,
|
||||
spec_info,
|
||||
use_sliding_window_kv_pool=use_sliding_window_kv_pool,
|
||||
)
|
||||
|
||||
def update_cross_attention(
|
||||
@@ -916,6 +937,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
qo_indptr: torch.Tensor,
|
||||
use_ragged: bool,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
use_sliding_window_kv_pool: bool = False,
|
||||
):
|
||||
bs = len(seq_lens)
|
||||
if spec_info is None:
|
||||
@@ -964,6 +986,14 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
q_data_type=self.q_data_type,
|
||||
)
|
||||
|
||||
if use_sliding_window_kv_pool:
|
||||
kv_last_index = kv_indptr[-1]
|
||||
kv_indices[:kv_last_index] = (
|
||||
self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
|
||||
kv_indices[:kv_last_index]
|
||||
)
|
||||
)
|
||||
|
||||
# cached part
|
||||
wrapper_paged.begin_forward(
|
||||
qo_indptr,
|
||||
|
||||
@@ -52,10 +52,14 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
||||
ScheduleBatchDisaggregationDecodeMixin,
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.allocator import (
|
||||
BaseTokenToKVPoolAllocator,
|
||||
SWATokenToKVPoolAllocator,
|
||||
)
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||
from sglang.srt.metrics.collector import TimeStats
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
@@ -527,6 +531,8 @@ class Req:
|
||||
self.last_node: Any = None
|
||||
self.last_host_node: Any = None
|
||||
self.host_hit_length = 0
|
||||
# The node to lock until for swa radix tree lock ref
|
||||
self.swa_uuid_for_lock: Optional[int] = None
|
||||
|
||||
# Whether or not if it is chunked. It increments whenever
|
||||
# it is chunked, and decrement whenever chunked request is
|
||||
@@ -745,6 +751,7 @@ class Req:
|
||||
def reset_for_retract(self):
|
||||
self.prefix_indices = []
|
||||
self.last_node = None
|
||||
self.swa_uuid_for_lock = None
|
||||
self.extend_input_len = 0
|
||||
self.is_retracted = True
|
||||
self.input_token_logprobs = None
|
||||
@@ -813,6 +820,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
req_to_token_pool: ReqToTokenPool = None
|
||||
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator = None
|
||||
tree_cache: BasePrefixCache = None
|
||||
is_hybrid: bool = False
|
||||
|
||||
# Batch configs
|
||||
model_config: ModelConfig = None
|
||||
@@ -918,11 +926,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
):
|
||||
return_logprob = any(req.return_logprob for req in reqs)
|
||||
|
||||
is_hybrid = False
|
||||
if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
|
||||
assert isinstance(tree_cache, SWARadixCache) or isinstance(
|
||||
tree_cache, SWAChunkCache
|
||||
), "SWARadixCache or SWAChunkCache is required for SWATokenToKVPoolAllocator"
|
||||
is_hybrid = True
|
||||
|
||||
return cls(
|
||||
reqs=reqs,
|
||||
req_to_token_pool=req_to_token_pool,
|
||||
token_to_kv_pool_allocator=token_to_kv_pool_allocator,
|
||||
tree_cache=tree_cache,
|
||||
is_hybrid=is_hybrid,
|
||||
model_config=model_config,
|
||||
enable_overlap=enable_overlap,
|
||||
return_logprob=return_logprob,
|
||||
@@ -953,9 +969,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
return req_pool_indices
|
||||
|
||||
def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
|
||||
if self.token_to_kv_pool_allocator.available_size() < num_tokens:
|
||||
if self.tree_cache is not None:
|
||||
self.tree_cache.evict(num_tokens)
|
||||
self._evict_tree_cache_if_needed(num_tokens)
|
||||
|
||||
if backup_state:
|
||||
state = self.token_to_kv_pool_allocator.backup_state()
|
||||
@@ -966,7 +980,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
error_msg = (
|
||||
f"{phase_str} out of memory. Try to lower your batch size.\n"
|
||||
f"Try to allocate {num_tokens} tokens.\n"
|
||||
f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
|
||||
f"{self._available_and_evictable_str()}"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
if self.tree_cache is not None:
|
||||
@@ -986,16 +1000,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
extend_num_tokens: int,
|
||||
backup_state: bool = False,
|
||||
):
|
||||
if (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
< extend_num_tokens
|
||||
num_tokens = (
|
||||
extend_num_tokens
|
||||
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size
|
||||
):
|
||||
if self.tree_cache is not None:
|
||||
self.tree_cache.evict(
|
||||
extend_num_tokens
|
||||
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
|
||||
)
|
||||
)
|
||||
self._evict_tree_cache_if_needed(num_tokens)
|
||||
|
||||
if backup_state:
|
||||
state = self.token_to_kv_pool_allocator.backup_state()
|
||||
@@ -1007,9 +1016,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
error_msg = (
|
||||
f"Prefill out of memory. Try to lower your batch size.\n"
|
||||
f"Try to allocate {extend_num_tokens} tokens.\n"
|
||||
f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
|
||||
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
|
||||
f"{self.tree_cache.evictable_size()=}\n"
|
||||
f"{self._available_and_evictable_str()}"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
@@ -1025,14 +1032,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
last_loc: torch.Tensor,
|
||||
backup_state: bool = False,
|
||||
):
|
||||
if self.tree_cache is not None:
|
||||
if (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
< len(seq_lens) * self.token_to_kv_pool_allocator.page_size
|
||||
):
|
||||
self.tree_cache.evict(
|
||||
len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
|
||||
)
|
||||
num_tokens = len(seq_lens) * self.token_to_kv_pool_allocator.page_size
|
||||
|
||||
self._evict_tree_cache_if_needed(num_tokens)
|
||||
|
||||
if backup_state:
|
||||
state = self.token_to_kv_pool_allocator.backup_state()
|
||||
@@ -1042,9 +1044,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
error_msg = (
|
||||
f"Decode out of memory. Try to lower your batch size.\n"
|
||||
f"Try to allocate {len(seq_lens)} tokens.\n"
|
||||
f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
|
||||
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
|
||||
f"{self.tree_cache.evictable_size()=}\n"
|
||||
f"{self._available_and_evictable_str()}"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
@@ -1181,7 +1181,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
|
||||
)
|
||||
if isinstance(self.tree_cache, SWAChunkCache):
|
||||
self.tree_cache.evict(
|
||||
self.tree_cache.evict_swa(
|
||||
req, pre_len, self.model_config.attention_chunk_size
|
||||
)
|
||||
|
||||
@@ -1371,17 +1371,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
)
|
||||
|
||||
def check_decode_mem(self, buf_multiplier=1):
|
||||
tokens_required = (
|
||||
num_tokens = (
|
||||
self.new_page_count_next_decode()
|
||||
* buf_multiplier
|
||||
* self.token_to_kv_pool_allocator.page_size
|
||||
)
|
||||
if self.token_to_kv_pool_allocator.available_size() >= tokens_required:
|
||||
return True
|
||||
|
||||
self.tree_cache.evict(tokens_required)
|
||||
|
||||
return self.token_to_kv_pool_allocator.available_size() >= tokens_required
|
||||
self._evict_tree_cache_if_needed(num_tokens)
|
||||
return self._is_available_size_sufficient(num_tokens)
|
||||
|
||||
def retract_decode(self, server_args: ServerArgs):
|
||||
"""Retract the decoding requests when there is not enough memory."""
|
||||
@@ -1414,19 +1411,38 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
|
||||
)
|
||||
|
||||
def _get_available_size():
|
||||
if self.is_hybrid:
|
||||
return min(
|
||||
self.token_to_kv_pool_allocator.full_available_size(),
|
||||
self.token_to_kv_pool_allocator.swa_available_size(),
|
||||
)
|
||||
else:
|
||||
return self.token_to_kv_pool_allocator.available_size()
|
||||
|
||||
retracted_reqs = []
|
||||
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
||||
first_iter = True
|
||||
while (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
< get_required_tokens(len(sorted_indices))
|
||||
_get_available_size() < get_required_tokens(len(sorted_indices))
|
||||
or first_iter
|
||||
):
|
||||
if len(sorted_indices) == 1:
|
||||
# Corner case: only one request left
|
||||
assert (
|
||||
self.token_to_kv_pool_allocator.available_size() > 0
|
||||
), "No space left for only one request"
|
||||
if self.is_hybrid:
|
||||
full_available_size = (
|
||||
self.token_to_kv_pool_allocator.full_available_size()
|
||||
)
|
||||
swa_available_size = (
|
||||
self.token_to_kv_pool_allocator.swa_available_size()
|
||||
)
|
||||
assert (
|
||||
full_available_size > 0 and swa_available_size > 0
|
||||
), f"No space left for only one request in SWA mode {full_available_size=}, {swa_available_size=}"
|
||||
else:
|
||||
assert (
|
||||
self.token_to_kv_pool_allocator.available_size() > 0
|
||||
), f"No space left for only one request, {self.token_to_kv_pool_allocator.available_size()=}"
|
||||
break
|
||||
|
||||
first_iter = False
|
||||
@@ -1458,15 +1474,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
|
||||
# release the last node
|
||||
self.tree_cache.dec_lock_ref(req.last_node)
|
||||
if self.is_hybrid:
|
||||
self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
|
||||
else:
|
||||
self.tree_cache.dec_lock_ref(req.last_node)
|
||||
|
||||
# NOTE(lsyin): we should use the newly evictable memory instantly.
|
||||
residual_size = (
|
||||
len(sorted_indices) * global_config.retract_decode_steps
|
||||
- self.token_to_kv_pool_allocator.available_size()
|
||||
)
|
||||
residual_size = max(0, residual_size)
|
||||
self.tree_cache.evict(residual_size)
|
||||
num_tokens = len(sorted_indices) * global_config.retract_decode_steps
|
||||
self._evict_tree_cache_if_needed(num_tokens)
|
||||
|
||||
req.reset_for_retract()
|
||||
|
||||
@@ -1559,7 +1574,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
# free memory
|
||||
if isinstance(self.tree_cache, SWAChunkCache):
|
||||
for req in self.reqs:
|
||||
self.tree_cache.evict(
|
||||
self.tree_cache.evict_swa(
|
||||
req, req.seqlen - 1, self.model_config.attention_chunk_size
|
||||
)
|
||||
|
||||
@@ -1778,6 +1793,53 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
is_extend_in_batch=self.is_extend_in_batch,
|
||||
)
|
||||
|
||||
def _evict_tree_cache_if_needed(
|
||||
self,
|
||||
num_tokens: int,
|
||||
) -> None:
|
||||
if isinstance(self.tree_cache, SWAChunkCache):
|
||||
return
|
||||
|
||||
if self.is_hybrid:
|
||||
full_available_size = self.token_to_kv_pool_allocator.full_available_size()
|
||||
swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
|
||||
|
||||
if full_available_size < num_tokens or swa_available_size < num_tokens:
|
||||
if self.tree_cache is not None:
|
||||
full_num_tokens = max(0, num_tokens - full_available_size)
|
||||
swa_num_tokens = max(0, num_tokens - swa_available_size)
|
||||
self.tree_cache.evict(full_num_tokens, swa_num_tokens)
|
||||
else:
|
||||
if self.token_to_kv_pool_allocator.available_size() < num_tokens:
|
||||
if self.tree_cache is not None:
|
||||
self.tree_cache.evict(num_tokens)
|
||||
|
||||
def _is_available_size_sufficient(self, num_tokens: int) -> bool:
|
||||
if self.is_hybrid:
|
||||
return (
|
||||
self.token_to_kv_pool_allocator.full_available_size() >= num_tokens
|
||||
and self.token_to_kv_pool_allocator.swa_available_size() >= num_tokens
|
||||
)
|
||||
else:
|
||||
return self.token_to_kv_pool_allocator.available_size() >= num_tokens
|
||||
|
||||
def _available_and_evictable_str(self) -> str:
|
||||
if self.is_hybrid:
|
||||
full_available_size = self.token_to_kv_pool_allocator.full_available_size()
|
||||
swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
|
||||
full_evictable_size = self.tree_cache.full_evictable_size()
|
||||
swa_evictable_size = self.tree_cache.swa_evictable_size()
|
||||
return (
|
||||
f"Available full tokens: {full_available_size + full_evictable_size} ({full_available_size=} + {full_evictable_size=})\n"
|
||||
f"Available swa tokens: {swa_available_size + swa_evictable_size} ({swa_available_size=} + {swa_evictable_size=})\n"
|
||||
f"Full LRU list evictable size: {self.tree_cache.full_lru_list_evictable_size()}\n"
|
||||
f"SWA LRU list evictable size: {self.tree_cache.swa_lru_list_evictable_size()}\n"
|
||||
)
|
||||
else:
|
||||
available_size = self.token_to_kv_pool_allocator.available_size()
|
||||
evictable_size = self.tree_cache.evictable_size()
|
||||
return f"Available tokens: {available_size + evictable_size} ({available_size=} + {evictable_size=})\n"
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
|
||||
|
||||
@@ -25,6 +25,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
||||
|
||||
@@ -311,21 +312,43 @@ class PrefillAdder:
|
||||
]
|
||||
)
|
||||
|
||||
@property
|
||||
def rem_total_tokens(self):
|
||||
return (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
- self.rem_total_token_offset
|
||||
self.is_hybrid = isinstance(
|
||||
self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
|
||||
)
|
||||
|
||||
@property
|
||||
def rem_total_tokens(self):
|
||||
if self.is_hybrid:
|
||||
available_and_evictable = min(
|
||||
self.token_to_kv_pool_allocator.full_available_size()
|
||||
+ self.tree_cache.full_evictable_size(),
|
||||
self.token_to_kv_pool_allocator.swa_available_size()
|
||||
+ self.tree_cache.swa_evictable_size(),
|
||||
)
|
||||
else:
|
||||
available_and_evictable = (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
)
|
||||
|
||||
return available_and_evictable - self.rem_total_token_offset
|
||||
|
||||
@property
|
||||
def cur_rem_tokens(self):
|
||||
return (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
- self.cur_rem_token_offset
|
||||
)
|
||||
if self.is_hybrid:
|
||||
available_and_evictable = min(
|
||||
self.token_to_kv_pool_allocator.full_available_size()
|
||||
+ self.tree_cache.full_evictable_size(),
|
||||
self.token_to_kv_pool_allocator.swa_available_size()
|
||||
+ self.tree_cache.swa_evictable_size(),
|
||||
)
|
||||
else:
|
||||
available_and_evictable = (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
)
|
||||
|
||||
return available_and_evictable - self.cur_rem_token_offset
|
||||
|
||||
def ceil_paged_tokens(self, tokens: int) -> int:
|
||||
return -(-tokens // self.page_size) * self.page_size
|
||||
@@ -376,11 +399,18 @@ class PrefillAdder:
|
||||
|
||||
@contextmanager
|
||||
def _lock_node(self, last_node: TreeNode):
|
||||
try:
|
||||
self.tree_cache.inc_lock_ref(last_node)
|
||||
yield None
|
||||
finally:
|
||||
self.tree_cache.dec_lock_ref(last_node)
|
||||
if self.is_hybrid:
|
||||
try:
|
||||
swa_uuid_for_lock = self.tree_cache.inc_lock_ref(last_node)
|
||||
yield None
|
||||
finally:
|
||||
self.tree_cache.dec_lock_ref(last_node, swa_uuid_for_lock)
|
||||
else:
|
||||
try:
|
||||
self.tree_cache.inc_lock_ref(last_node)
|
||||
yield None
|
||||
finally:
|
||||
self.tree_cache.dec_lock_ref(last_node)
|
||||
|
||||
def add_one_req_ignore_eos(self, req: Req, has_chunked_req: bool):
|
||||
# Early exit if no enough tokens for the input tokens
|
||||
@@ -422,16 +452,19 @@ class PrefillAdder:
|
||||
else:
|
||||
add_req_state(req, insert_sort=True)
|
||||
|
||||
cur_rem_tokens = self.cur_rem_tokens - len(req.origin_input_ids)
|
||||
tokens_freed = 0
|
||||
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
|
||||
# tokens_left gives a reservative calculation as the last token is not stored
|
||||
bs = len(self.req_states) - i
|
||||
min_free_tokens = cur_rem_tokens + tokens_freed - tokens_left * bs
|
||||
# reserve tokens for corner cases
|
||||
if min_free_tokens <= IGNORE_EOS_RESERVE_TOKENS * bs:
|
||||
return AddReqResult.NO_TOKEN
|
||||
tokens_freed += tokens_occupied
|
||||
if not self.is_hybrid:
|
||||
# Skip this logic for swa. The SWA has different memory management, and
|
||||
# this mechanism is underestimating the memory usage.
|
||||
cur_rem_tokens = self.cur_rem_tokens - len(req.origin_input_ids)
|
||||
tokens_freed = 0
|
||||
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
|
||||
# tokens_left gives a reservative calculation as the last token is not stored
|
||||
bs = len(self.req_states) - i
|
||||
min_free_tokens = cur_rem_tokens + tokens_freed - tokens_left * bs
|
||||
# reserve tokens for corner cases
|
||||
if min_free_tokens <= IGNORE_EOS_RESERVE_TOKENS * bs:
|
||||
return AddReqResult.NO_TOKEN
|
||||
tokens_freed += tokens_occupied
|
||||
|
||||
if (
|
||||
self.rem_chunk_tokens is None # chunked prefill is disabled
|
||||
@@ -499,7 +532,11 @@ class PrefillAdder:
|
||||
if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
|
||||
# Non-chunked prefill
|
||||
self.can_run_list.append(req)
|
||||
self.tree_cache.inc_lock_ref(req.last_node)
|
||||
if self.is_hybrid:
|
||||
swa_uuid_for_lock = self.tree_cache.inc_lock_ref(req.last_node)
|
||||
req.swa_uuid_for_lock = swa_uuid_for_lock
|
||||
else:
|
||||
self.tree_cache.inc_lock_ref(req.last_node)
|
||||
self._update_prefill_budget(
|
||||
prefix_len,
|
||||
input_tokens,
|
||||
@@ -520,7 +557,11 @@ class PrefillAdder:
|
||||
|
||||
self.can_run_list.append(req)
|
||||
self.new_chunked_req = req
|
||||
self.tree_cache.inc_lock_ref(req.last_node)
|
||||
if self.is_hybrid:
|
||||
swa_uuid_for_lock = self.tree_cache.inc_lock_ref(req.last_node)
|
||||
req.swa_uuid_for_lock = swa_uuid_for_lock
|
||||
else:
|
||||
self.tree_cache.inc_lock_ref(req.last_node)
|
||||
self._update_prefill_budget(prefix_len, trunc_len, 0)
|
||||
|
||||
return self.budget_state()
|
||||
|
||||
@@ -129,10 +129,10 @@ from sglang.srt.managers.session_controller import Session
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
||||
from sglang.srt.managers.utils import validate_input_length
|
||||
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
||||
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
||||
from sglang.srt.reasoning_parser import ReasoningParser
|
||||
@@ -390,6 +390,14 @@ class Scheduler(
|
||||
global_server_args_dict.update(worker_global_server_args_dict)
|
||||
set_random_seed(self.random_seed)
|
||||
|
||||
# Hybrid
|
||||
self.is_hybrid = self.tp_worker.is_hybrid
|
||||
if self.is_hybrid:
|
||||
self.sliding_window_size = self.tp_worker.sliding_window_size
|
||||
self.full_tokens_per_layer, self.swa_tokens_per_layer = (
|
||||
self.tp_worker.get_tokens_per_layer_info()
|
||||
)
|
||||
|
||||
# Print debug info
|
||||
if tp_rank == 0:
|
||||
avail_mem = get_available_gpu_memory(
|
||||
@@ -570,7 +578,7 @@ class Scheduler(
|
||||
server_args.chunked_prefill_size is not None
|
||||
and server_args.disable_radix_cache
|
||||
):
|
||||
if self.model_config.is_hybrid:
|
||||
if self.is_hybrid:
|
||||
ChunkCacheClass = SWAChunkCache
|
||||
else:
|
||||
ChunkCacheClass = ChunkCache
|
||||
@@ -603,6 +611,17 @@ class Scheduler(
|
||||
self.tp_worker.register_hicache_layer_transfer_counter(
|
||||
self.tree_cache.cache_controller.layer_done_counter
|
||||
)
|
||||
elif self.is_hybrid:
|
||||
assert (
|
||||
self.server_args.disaggregation_mode == "null"
|
||||
), "Hybrid mode does not support disaggregation yet"
|
||||
self.tree_cache = SWARadixCache(
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
sliding_window_size=self.sliding_window_size,
|
||||
page_size=self.page_size,
|
||||
disable=server_args.disable_radix_cache,
|
||||
)
|
||||
|
||||
else:
|
||||
self.tree_cache = RadixCache(
|
||||
@@ -774,6 +793,7 @@ class Scheduler(
|
||||
else:
|
||||
# When the server is idle, do self-check and re-init some states
|
||||
self.check_memory()
|
||||
self.check_tree_cache()
|
||||
self.new_token_ratio = self.init_new_token_ratio
|
||||
self.maybe_sleep_on_idle()
|
||||
|
||||
@@ -819,6 +839,7 @@ class Scheduler(
|
||||
elif batch is None:
|
||||
# When the server is idle, do self-check and re-init some states
|
||||
self.check_memory()
|
||||
self.check_tree_cache()
|
||||
self.new_token_ratio = self.init_new_token_ratio
|
||||
self.maybe_sleep_on_idle()
|
||||
|
||||
@@ -955,6 +976,7 @@ class Scheduler(
|
||||
# When the server is idle, self-check and re-init some states
|
||||
if server_is_idle:
|
||||
self.check_memory()
|
||||
self.check_tree_cache()
|
||||
self.new_token_ratio = self.init_new_token_ratio
|
||||
self.maybe_sleep_on_idle()
|
||||
|
||||
@@ -1306,9 +1328,26 @@ class Scheduler(
|
||||
self.last_input_throughput = self.last_prefill_tokens / gap_latency
|
||||
self.last_prefill_tokens = adder.log_input_tokens
|
||||
|
||||
usage_msg, num_used = self.token_to_kv_pool_allocator.log_usage(
|
||||
self.tree_cache.evictable_size()
|
||||
)
|
||||
if self.is_hybrid:
|
||||
(
|
||||
full_num_used,
|
||||
swa_num_used,
|
||||
full_token_usage,
|
||||
swa_token_usage,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
) = self._get_swa_token_info()
|
||||
num_used = max(full_num_used, swa_num_used)
|
||||
token_usage = max(full_token_usage, swa_token_usage)
|
||||
token_msg = (
|
||||
f"full token usage: {full_token_usage:.2f}, "
|
||||
f"swa token usage: {swa_token_usage:.2f}, "
|
||||
)
|
||||
else:
|
||||
num_used, token_usage, _, _ = self._get_token_info()
|
||||
token_msg = f"token usage: {token_usage:.2f}, "
|
||||
|
||||
num_new_seq = len(can_run_list)
|
||||
f = (
|
||||
@@ -1316,7 +1355,7 @@ class Scheduler(
|
||||
f"#new-seq: {num_new_seq}, "
|
||||
f"#new-token: {adder.log_input_tokens}, "
|
||||
f"#cached-token: {adder.log_hit_tokens}, "
|
||||
f"{usage_msg}"
|
||||
f"{token_msg}"
|
||||
)
|
||||
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
@@ -1338,7 +1377,7 @@ class Scheduler(
|
||||
)
|
||||
self.stats.num_running_reqs = running_bs
|
||||
self.stats.num_used_tokens = num_used
|
||||
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
|
||||
self.stats.token_usage = round(token_usage, 2)
|
||||
self.stats.num_queue_reqs = len(self.waiting_queue)
|
||||
self.stats.cache_hit_rate = cache_hit_rate
|
||||
|
||||
@@ -1361,16 +1400,35 @@ class Scheduler(
|
||||
self.last_gen_throughput = self.num_generated_tokens / gap_latency
|
||||
self.num_generated_tokens = 0
|
||||
num_running_reqs = len(batch.reqs)
|
||||
usage_msg, num_used = self.token_to_kv_pool_allocator.log_usage(
|
||||
self.tree_cache.evictable_size()
|
||||
)
|
||||
if self.is_hybrid:
|
||||
(
|
||||
full_num_used,
|
||||
swa_num_used,
|
||||
full_token_usage,
|
||||
swa_token_usage,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
) = self._get_swa_token_info()
|
||||
num_used = max(full_num_used, swa_num_used)
|
||||
token_usage = max(full_token_usage, swa_token_usage)
|
||||
token_msg = (
|
||||
f"#full token: {full_num_used}, "
|
||||
f"full token usage: {full_token_usage:.2f}, "
|
||||
f"#swa token: {swa_num_used}, "
|
||||
f"swa token usage: {swa_token_usage:.2f}, "
|
||||
)
|
||||
else:
|
||||
num_used, token_usage, _, _ = self._get_token_info()
|
||||
token_msg = f"#token: {num_used}, " f"token usage: {token_usage:.2f}, "
|
||||
|
||||
if RECORD_STEP_TIME:
|
||||
self.step_time_dict[num_running_reqs].append(
|
||||
gap_latency / self.server_args.decode_log_interval
|
||||
)
|
||||
|
||||
msg = f"Decode batch. " f"#running-req: {num_running_reqs}, " f"{usage_msg}"
|
||||
msg = f"Decode batch. #running-req: {num_running_reqs}, {token_msg}"
|
||||
|
||||
if self.spec_algorithm.is_none():
|
||||
spec_accept_length = 0
|
||||
@@ -1398,7 +1456,7 @@ class Scheduler(
|
||||
if self.enable_metrics:
|
||||
self.stats.num_running_reqs = num_running_reqs
|
||||
self.stats.num_used_tokens = num_used
|
||||
self.stats.token_usage = num_used / self.max_total_num_tokens
|
||||
self.stats.token_usage = round(token_usage, 2)
|
||||
self.stats.cache_hit_rate = 0.0
|
||||
self.stats.gen_throughput = self.last_gen_throughput
|
||||
self.stats.num_queue_reqs = len(self.waiting_queue)
|
||||
@@ -1409,24 +1467,34 @@ class Scheduler(
|
||||
self._publish_kv_events()
|
||||
|
||||
def check_memory(self):
|
||||
if isinstance(self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
|
||||
available_token_size = self.token_to_kv_pool_allocator.full_available_size()
|
||||
else:
|
||||
available_token_size = self.token_to_kv_pool_allocator.available_size()
|
||||
available_size = available_token_size + self.tree_cache.evictable_size()
|
||||
protected_size = self.tree_cache.protected_size()
|
||||
memory_leak = available_size != (
|
||||
self.max_total_num_tokens
|
||||
if not self.enable_hierarchical_cache
|
||||
else self.max_total_num_tokens - protected_size
|
||||
)
|
||||
if memory_leak:
|
||||
msg = (
|
||||
"token_to_kv_pool_allocator memory leak detected! "
|
||||
f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
|
||||
f"{available_token_size=}\n"
|
||||
f"{self.tree_cache.evictable_size()=}\n"
|
||||
if self.is_hybrid:
|
||||
(
|
||||
full_num_used,
|
||||
swa_num_used,
|
||||
_,
|
||||
_,
|
||||
full_available_size,
|
||||
full_evictable_size,
|
||||
swa_available_size,
|
||||
swa_evictable_size,
|
||||
) = self._get_swa_token_info()
|
||||
memory_leak = full_num_used != 0 or swa_num_used != 0
|
||||
token_msg = (
|
||||
f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n"
|
||||
f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n"
|
||||
)
|
||||
else:
|
||||
_, _, available_size, evictable_size = self._get_token_info()
|
||||
protected_size = self.tree_cache.protected_size()
|
||||
memory_leak = (available_size + evictable_size) != (
|
||||
self.max_total_num_tokens
|
||||
if not self.enable_hierarchical_cache
|
||||
else self.max_total_num_tokens - protected_size
|
||||
)
|
||||
token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
|
||||
|
||||
if memory_leak:
|
||||
msg = "token_to_kv_pool_allocator memory leak detected! " f"{token_msg}"
|
||||
raise ValueError(msg)
|
||||
|
||||
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
@@ -1450,20 +1518,66 @@ class Scheduler(
|
||||
and time.perf_counter() > self.metrics_collector.last_log_time + 30
|
||||
):
|
||||
# During idle time, also collect metrics every 30 seconds.
|
||||
num_used = self.max_total_num_tokens - (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
)
|
||||
if self.is_hybrid:
|
||||
(
|
||||
full_num_used,
|
||||
swa_num_used,
|
||||
full_token_usage,
|
||||
swa_token_usage,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
) = self._get_swa_token_info()
|
||||
num_used = max(full_num_used, swa_num_used)
|
||||
token_usage = max(full_token_usage, swa_token_usage)
|
||||
else:
|
||||
num_used, token_usage, _, _ = self._get_token_info()
|
||||
num_running_reqs = len(self.running_batch.reqs)
|
||||
self.stats.num_running_reqs = num_running_reqs
|
||||
self.stats.num_used_tokens = num_used
|
||||
self.stats.token_usage = num_used / self.max_total_num_tokens
|
||||
self.stats.token_usage = round(token_usage, 2)
|
||||
self.stats.gen_throughput = 0
|
||||
self.stats.num_queue_reqs = len(self.waiting_queue)
|
||||
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
||||
self.metrics_collector.log_stats(self.stats)
|
||||
self._publish_kv_events()
|
||||
|
||||
def check_tree_cache(self):
|
||||
if self.is_hybrid and isinstance(self.tree_cache, SWARadixCache):
|
||||
self.tree_cache.sanity_check()
|
||||
|
||||
def _get_token_info(self):
|
||||
available_size = self.token_to_kv_pool_allocator.available_size()
|
||||
evictable_size = self.tree_cache.evictable_size()
|
||||
num_used = self.max_total_num_tokens - (available_size + evictable_size)
|
||||
token_usage = num_used / self.max_total_num_tokens
|
||||
return num_used, token_usage, available_size, evictable_size
|
||||
|
||||
def _get_swa_token_info(self):
|
||||
full_available_size = self.token_to_kv_pool_allocator.full_available_size()
|
||||
full_evictable_size = self.tree_cache.full_evictable_size()
|
||||
swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
|
||||
swa_evictable_size = self.tree_cache.swa_evictable_size()
|
||||
full_num_used = self.full_tokens_per_layer - (
|
||||
full_available_size + full_evictable_size
|
||||
)
|
||||
swa_num_used = self.swa_tokens_per_layer - (
|
||||
swa_available_size + swa_evictable_size
|
||||
)
|
||||
full_token_usage = full_num_used / self.full_tokens_per_layer
|
||||
swa_token_usage = swa_num_used / self.swa_tokens_per_layer
|
||||
return (
|
||||
full_num_used,
|
||||
swa_num_used,
|
||||
full_token_usage,
|
||||
swa_token_usage,
|
||||
full_available_size,
|
||||
full_evictable_size,
|
||||
swa_available_size,
|
||||
swa_evictable_size,
|
||||
)
|
||||
|
||||
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
|
||||
# Merge the prefill batch into the running batch
|
||||
chunked_req_to_exclude = set()
|
||||
@@ -2042,11 +2156,30 @@ class Scheduler(
|
||||
|
||||
if not disable_request_logging():
|
||||
# Print batch size and memory pool info to check whether there are de-sync issues.
|
||||
if self.is_hybrid:
|
||||
(
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
full_available_size,
|
||||
full_evictable_size,
|
||||
swa_available_size,
|
||||
swa_evictable_size,
|
||||
) = self._get_swa_token_info()
|
||||
info_msg = (
|
||||
f"{full_available_size=}, "
|
||||
f"{full_evictable_size=}, "
|
||||
f"{swa_available_size=}, "
|
||||
f"{swa_evictable_size=}, "
|
||||
)
|
||||
else:
|
||||
_, _, available_size, evictable_size = self._get_token_info()
|
||||
info_msg = f"{available_size=}, " f"{evictable_size=}, "
|
||||
logger.error(
|
||||
f"{self.cur_batch.batch_size()=}, "
|
||||
f"{self.cur_batch.reqs=}, "
|
||||
f"{self.token_to_kv_pool_allocator.available_size()=}, "
|
||||
f"{self.tree_cache.evictable_size()=}, "
|
||||
f"{info_msg}"
|
||||
)
|
||||
|
||||
pyspy_dump_schedulers()
|
||||
@@ -2101,11 +2234,24 @@ class Scheduler(
|
||||
|
||||
def get_load(self):
|
||||
# TODO(lsyin): use dynamically maintained num_waiting_tokens
|
||||
load = (
|
||||
self.max_total_num_tokens
|
||||
- self.token_to_kv_pool_allocator.available_size()
|
||||
- self.tree_cache.evictable_size()
|
||||
)
|
||||
if self.is_hybrid:
|
||||
load_full = (
|
||||
self.full_tokens_per_layer
|
||||
- self.token_to_kv_pool_allocator.full_available_size()
|
||||
- self.tree_cache.full_evictable_size()
|
||||
)
|
||||
load_swa = (
|
||||
self.swa_tokens_per_layer
|
||||
- self.token_to_kv_pool_allocator.swa_available_size()
|
||||
- self.tree_cache.swa_evictable_size()
|
||||
)
|
||||
load = max(load_full, load_swa)
|
||||
else:
|
||||
load = (
|
||||
self.max_total_num_tokens
|
||||
- self.token_to_kv_pool_allocator.available_size()
|
||||
- self.tree_cache.evictable_size()
|
||||
)
|
||||
load += sum(len(req.origin_input_ids) for req in self.waiting_queue)
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
load += sum(
|
||||
|
||||
@@ -174,6 +174,20 @@ class TpModelWorker:
|
||||
self.model_runner.token_to_kv_pool.size,
|
||||
)
|
||||
|
||||
@property
|
||||
def sliding_window_size(self) -> Optional[int]:
|
||||
return self.model_runner.sliding_window_size
|
||||
|
||||
@property
|
||||
def is_hybrid(self) -> bool:
|
||||
return self.model_runner.is_hybrid is not None
|
||||
|
||||
def get_tokens_per_layer_info(self):
|
||||
return (
|
||||
self.model_runner.full_max_total_num_tokens,
|
||||
self.model_runner.swa_max_total_num_tokens,
|
||||
)
|
||||
|
||||
def get_pad_input_ids_func(self):
|
||||
return getattr(self.model_runner.model, "pad_input_ids", None)
|
||||
|
||||
|
||||
@@ -102,6 +102,17 @@ class TpModelWorkerClient:
|
||||
def get_worker_info(self):
|
||||
return self.worker.get_worker_info()
|
||||
|
||||
def get_tokens_per_layer_info(self):
|
||||
return self.worker.get_tokens_per_layer_info()
|
||||
|
||||
@property
|
||||
def sliding_window_size(self) -> Optional[int]:
|
||||
return self.worker.sliding_window_size
|
||||
|
||||
@property
|
||||
def is_hybrid(self) -> bool:
|
||||
return self.worker.is_hybrid
|
||||
|
||||
def get_pad_input_ids_func(self):
|
||||
return self.worker.get_pad_input_ids_func()
|
||||
|
||||
|
||||
@@ -57,11 +57,6 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
|
||||
def debug_print(self) -> str:
|
||||
return ""
|
||||
|
||||
def log_usage(self, evictable_size: int = 0):
|
||||
num_used = self.size - (self.available_size() + evictable_size)
|
||||
msg = f"#token: {num_used}, token usage: {num_used / self.size:.2f}, "
|
||||
return msg, num_used
|
||||
|
||||
def available_size(self):
|
||||
return len(self.free_pages) * self.page_size
|
||||
|
||||
@@ -190,7 +185,7 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
self._kvcache.full_to_swa_index_mapping = self.full_to_swa_index_mapping
|
||||
|
||||
def available_size(self):
|
||||
return min(self.full_available_size(), self.swa_available_size())
|
||||
raise NotImplementedError()
|
||||
|
||||
def full_available_size(self):
|
||||
return self.full_attn_allocator.available_size()
|
||||
@@ -214,16 +209,6 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
)
|
||||
return msg
|
||||
|
||||
def log_usage(self, swa_evictable_size: int = 0, full_evictable_size: int = 0):
|
||||
used_full = self.size_full - (self.full_available_size() + full_evictable_size)
|
||||
used_swa = self.size_swa - (self.swa_available_size() + swa_evictable_size)
|
||||
msg = (
|
||||
f"#token: full={used_full}, swa={used_swa}, "
|
||||
f"token usage: full={used_full / self.size_full:.2f}, "
|
||||
f"swa={used_swa / self.size_swa:.2f}, "
|
||||
)
|
||||
return msg, used_full
|
||||
|
||||
def get_kvcache(self):
|
||||
return self._kvcache
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, List, NamedTuple, Tuple
|
||||
from typing import TYPE_CHECKING, Any, List, NamedTuple, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -56,15 +56,27 @@ class BasePrefixCache(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def dec_lock_ref(self, node: Any):
|
||||
def dec_lock_ref(self, node: Any, swa_uuid_for_lock: Optional[str] = None):
|
||||
pass
|
||||
|
||||
def evictable_size(self):
|
||||
return 0
|
||||
|
||||
def full_evictable_size(self):
|
||||
return 0
|
||||
|
||||
def swa_evictable_size(self):
|
||||
return 0
|
||||
|
||||
def protected_size(self):
|
||||
return 0
|
||||
|
||||
def full_protected_size(self):
|
||||
return 0
|
||||
|
||||
def swa_protected_size(self):
|
||||
return 0
|
||||
|
||||
def total_size(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ class ChunkCache(BasePrefixCache):
|
||||
def inc_lock_ref(self, node: Any):
|
||||
return 0
|
||||
|
||||
def dec_lock_ref(self, node: Any):
|
||||
def dec_lock_ref(self, node: Any, swa_uuid_for_lock: Optional[str] = None):
|
||||
return 0
|
||||
|
||||
def pretty_print(self):
|
||||
@@ -80,7 +80,7 @@ class SWAChunkCache(ChunkCache):
|
||||
super().__init__(req_to_token_pool, token_to_kv_pool_allocator, page_size)
|
||||
assert isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator)
|
||||
|
||||
def evict(
|
||||
def evict_swa(
|
||||
self,
|
||||
req: Req,
|
||||
prelen: int,
|
||||
@@ -95,3 +95,6 @@ class SWAChunkCache(ChunkCache):
|
||||
]
|
||||
self.token_to_kv_pool_allocator.free_swa(free_slots)
|
||||
req.evicted_seqlen_local = new_evicted_seqlen_local
|
||||
|
||||
def evict(self, num_tokens: int):
|
||||
pass
|
||||
|
||||
1025
python/sglang/srt/mem_cache/swa_radix_cache.py
Normal file
1025
python/sglang/srt/mem_cache/swa_radix_cache.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -275,6 +275,15 @@ class ModelRunner:
|
||||
self.sampler = Sampler()
|
||||
self.load_model()
|
||||
|
||||
if (
|
||||
not self.server_args.disable_hybrid_swa_memory
|
||||
and self.sliding_window_size is not None
|
||||
and self.sliding_window_size > 0
|
||||
):
|
||||
architectures = self.model_config.hf_config.architectures
|
||||
if architectures and not any("Llama4" in arch for arch in architectures):
|
||||
self.is_hybrid = self.model_config.is_hybrid = True
|
||||
|
||||
self.start_layer = getattr(self.model, "start_layer", 0)
|
||||
self.end_layer = getattr(
|
||||
self.model, "end_layer", self.model_config.num_hidden_layers
|
||||
@@ -471,10 +480,6 @@ class ModelRunner:
|
||||
if self.model_config.context_len > 8192:
|
||||
self.mem_fraction_static *= 0.85
|
||||
|
||||
if self.is_hybrid and not server_args.disable_radix_cache:
|
||||
logger.info("Automatically disable radix cache for hybrid cache.")
|
||||
server_args.disable_radix_cache = True
|
||||
|
||||
def init_torch_distributed(self):
|
||||
logger.info("Init torch distributed begin.")
|
||||
|
||||
@@ -645,11 +650,15 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
# Parse other args
|
||||
self.sliding_window_size = (
|
||||
self.model.get_attention_sliding_window_size()
|
||||
if hasattr(self.model, "get_attention_sliding_window_size")
|
||||
else None
|
||||
)
|
||||
self.sliding_window_size = None
|
||||
if hasattr(self.model, "get_attention_sliding_window_size"):
|
||||
self.sliding_window_size = self.model.get_attention_sliding_window_size()
|
||||
elif self.model_config.attention_chunk_size is not None:
|
||||
self.sliding_window_size = self.model_config.attention_chunk_size
|
||||
print(
|
||||
f"Setting sliding_window_size to be attention_chunk_size: {self.sliding_window_size}"
|
||||
)
|
||||
|
||||
self.dtype = self.model_config.dtype
|
||||
|
||||
after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
||||
@@ -992,8 +1001,53 @@ class ModelRunner:
|
||||
)
|
||||
self.max_total_num_tokens = self.full_max_total_num_tokens
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported model for hybrid cache: {self.model_config.hf_config.architectures}."
|
||||
assert self.sliding_window_size is not None and self.sliding_window_size > 0
|
||||
full_attention_layer_ids = []
|
||||
swa_attention_layer_ids = []
|
||||
|
||||
try:
|
||||
layers = self.model.model.layers
|
||||
except:
|
||||
try:
|
||||
layers = self.model.language_model.model.layers
|
||||
except:
|
||||
self.is_hybrid = False
|
||||
return
|
||||
|
||||
for layer in layers:
|
||||
if (
|
||||
layer.self_attn.attn.sliding_window_size is None
|
||||
or layer.self_attn.attn.sliding_window_size == -1
|
||||
):
|
||||
full_attention_layer_ids.append(layer.layer_id)
|
||||
else:
|
||||
swa_attention_layer_ids.append(layer.layer_id)
|
||||
self.model_config.swa_attention_layer_ids = swa_attention_layer_ids
|
||||
self.model_config.full_attention_layer_ids = full_attention_layer_ids
|
||||
|
||||
# Algorithm:
|
||||
# Existing max_total_num_tokens is per layer and assume all layers have the same number of tokens.
|
||||
# - Find total # of tokens available across layers.
|
||||
# - Calculate full_max_total_num_tokens and swa_max_total_num_tokens based on the given swa_full_tokens_ratio.
|
||||
total_tokens = (
|
||||
self.max_total_num_tokens * self.model_config.num_hidden_layers
|
||||
)
|
||||
full_layers_num = len(full_attention_layer_ids)
|
||||
swa_layers_num = len(swa_attention_layer_ids)
|
||||
swa_full_tokens_ratio = self.server_args.swa_full_tokens_ratio
|
||||
|
||||
# Solve the equations:
|
||||
# 1. swa_max_total_num_tokens * swa_layers_num + full_max_total_num_tokens * full_layers_num == total_tokens
|
||||
# 2. full_max_total_num_tokens * swa_full_tokens_ratio == swa_max_total_num_tokens
|
||||
denominator = swa_full_tokens_ratio * swa_layers_num + full_layers_num
|
||||
self.full_max_total_num_tokens = int(total_tokens / denominator)
|
||||
self.swa_max_total_num_tokens = int(
|
||||
self.full_max_total_num_tokens * swa_full_tokens_ratio
|
||||
)
|
||||
self.max_total_num_tokens = self.full_max_total_num_tokens
|
||||
|
||||
logger.info(
|
||||
f"Use Sliding window memory pool. full_layer_tokens={self.full_max_total_num_tokens}, swa_layer_tokens={self.swa_max_total_num_tokens}"
|
||||
)
|
||||
|
||||
def init_memory_pool(
|
||||
@@ -1072,7 +1126,6 @@ class ModelRunner:
|
||||
// self.server_args.page_size
|
||||
* self.server_args.page_size
|
||||
)
|
||||
|
||||
# create token size for hybrid cache
|
||||
if self.is_hybrid:
|
||||
self.set_num_token_hybrid()
|
||||
|
||||
@@ -190,6 +190,7 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = Gemma2Attention(
|
||||
layer_id=layer_id,
|
||||
|
||||
@@ -63,6 +63,7 @@ class ServerArgs:
|
||||
enable_multimodal: Optional[bool] = None
|
||||
revision: Optional[str] = None
|
||||
hybrid_kvcache_ratio: Optional[float] = None
|
||||
swa_full_tokens_ratio: float = 0.8
|
||||
impl: str = "auto"
|
||||
|
||||
# Port for the HTTP server
|
||||
@@ -225,6 +226,7 @@ class ServerArgs:
|
||||
enable_return_hidden_states: bool = False
|
||||
enable_triton_kernel_moe: bool = False
|
||||
warmups: Optional[str] = None
|
||||
disable_hybrid_swa_memory: bool = False
|
||||
|
||||
# Debug tensor dumps
|
||||
debug_tensor_dump_output_folder: Optional[str] = None
|
||||
@@ -481,14 +483,22 @@ class ServerArgs:
|
||||
|
||||
model_arch = get_model_arch(self)
|
||||
|
||||
# Auto set draft_model_path DeepSeek-V3/R1
|
||||
if model_arch == "DeepseekV3ForCausalLM":
|
||||
# Auto set draft_model_path DeepSeek-V3/R1
|
||||
if self.speculative_draft_model_path is None:
|
||||
self.speculative_draft_model_path = self.model_path
|
||||
else:
|
||||
logger.warning(
|
||||
"DeepSeek MTP does not require setting speculative_draft_model_path."
|
||||
)
|
||||
elif "Llama4" in model_arch:
|
||||
# TODO: remove this after Llama4 supports in other backends
|
||||
if self.attention_backend != "fa3":
|
||||
self.attention_backend = "fa3"
|
||||
logger.warning(
|
||||
"Llama4 requires using fa3 attention backend. "
|
||||
"Attention backend is automatically set to fa3."
|
||||
)
|
||||
|
||||
# Auto choose parameters
|
||||
if self.speculative_num_steps is None:
|
||||
@@ -852,6 +862,18 @@ class ServerArgs:
|
||||
"(1.0 = pure hybrid: swa_size / full_size = local_attention_size / context_length)"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--swa-full-tokens-ratio",
|
||||
type=float,
|
||||
default=ServerArgs.swa_full_tokens_ratio,
|
||||
help="The ratio of SWA layer KV tokens / full layer KV tokens, regardless of the number of swa:full layers. It should be between 0 and 1. "
|
||||
"E.g. 0.5 means if each swa layer has 50 tokens, then each full layer has 100 tokens.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-hybrid-swa-memory",
|
||||
action="store_true",
|
||||
help="Disable the hybrid SWA memory.",
|
||||
)
|
||||
|
||||
# Other runtime options
|
||||
parser.add_argument(
|
||||
@@ -1730,10 +1752,6 @@ class ServerArgs:
|
||||
else:
|
||||
self.lora_paths[lora_path] = lora_path
|
||||
|
||||
model_arch = get_model_arch(self)
|
||||
if "Llama4" in model_arch and self.hybrid_kvcache_ratio is not None:
|
||||
assert self.attention_backend == "fa3"
|
||||
|
||||
|
||||
def prepare_server_args(argv: List[str]) -> ServerArgs:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user