Support page size > 1 (#4356)

This commit is contained in:
Lianmin Zheng
2025-03-12 22:22:39 -07:00
committed by GitHub
parent 2f6bacee03
commit c76040e31b
23 changed files with 877 additions and 284 deletions

View File

@@ -53,6 +53,7 @@ from sglang.srt.mem_cache.memory_pool import (
ReqToTokenPool,
TokenToKVPoolAllocator,
)
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader import get_model
@@ -430,7 +431,7 @@ class ModelRunner:
self.model_config.model_path = model_path
load_config = LoadConfig(load_format=load_format)
# Only support the DefaultModelLoader for now
# Only support DefaultModelLoader for now
loader = get_model_loader(load_config)
if not isinstance(loader, DefaultModelLoader):
message = f"Failed to get model loader: {loader}."
@@ -732,6 +733,7 @@ class ModelRunner:
):
self.token_to_kv_pool = MLATokenToKVPool(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
kv_lora_rank=self.model_config.kv_lora_rank,
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
@@ -742,6 +744,7 @@ class ModelRunner:
elif self.server_args.enable_double_sparsity:
self.token_to_kv_pool = DoubleSparseTokenToKVPool(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
head_dim=self.model_config.head_dim,
@@ -753,6 +756,7 @@ class ModelRunner:
else:
self.token_to_kv_pool = MHATokenToKVPool(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
head_dim=self.model_config.head_dim,
@@ -762,12 +766,21 @@ class ModelRunner:
)
if self.token_to_kv_pool_allocator is None:
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
self.max_total_num_tokens,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
)
if self.page_size == 1:
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
self.max_total_num_tokens,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
)
else:
self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
)
else:
assert self.is_draft_worker