Support page size > 1 (#4356)

This commit is contained in:
Lianmin Zheng
2025-03-12 22:22:39 -07:00
committed by GitHub
parent 2f6bacee03
commit c76040e31b
23 changed files with 877 additions and 284 deletions

View File

@@ -264,11 +264,15 @@ class CudaGraphRunner:
def model_capture_mode(self):
if hasattr(self.model_runner.model, "capture_mode"):
self.model_runner.model.capture_mode = True
if hasattr(self.model_runner.token_to_kv_pool, "capture_mode"):
self.model_runner.token_to_kv_pool.capture_mode = True
yield
if hasattr(self.model_runner.model, "capture_mode"):
self.model_runner.model.capture_mode = False
if hasattr(self.model_runner.token_to_kv_pool, "capture_mode"):
self.model_runner.token_to_kv_pool.capture_mode = False
def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention:

View File

@@ -38,12 +38,12 @@ import triton
import triton.language as tl
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.utils import get_compiler_backend
from sglang.srt.utils import get_compiler_backend, next_power_of_2
if TYPE_CHECKING:
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
@@ -51,9 +51,8 @@ if TYPE_CHECKING:
class ForwardMode(IntEnum):
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
PREFILL = auto()
# Extend a sequence. The KV cache of the beginning part of the sequence is already computed (e.g., system prompt).
# It is also called "prefill" in common terminology.
EXTEND = auto()
# Decode one token.
DECODE = auto()
@@ -153,6 +152,12 @@ class ForwardBatch:
top_logprobs_nums: Optional[List[int]] = None
token_ids_logprobs: Optional[List[List[int]]] = None
# For logits and logprobs post processing
temp_scaled_logprobs: bool = False
temperature: torch.Tensor = None
top_p_normalized_logprobs: bool = False
top_p: torch.Tensor = None
# Position information
positions: torch.Tensor = None
@@ -189,7 +194,7 @@ class ForwardBatch:
# Attention backend
req_to_token_pool: ReqToTokenPool = None
token_to_kv_pool: BaseTokenToKVPool = None
token_to_kv_pool: KVCache = None
attn_backend: AttentionBackend = None
# For DP attention
@@ -229,7 +234,6 @@ class ForwardBatch:
extend_input_logprob_token_ids_gpu = (
batch.extend_input_logprob_token_ids.to(device, non_blocking=True)
)
ret = cls(
forward_mode=batch.forward_mode,
batch_size=len(batch.seq_lens),
@@ -417,8 +421,8 @@ def compute_position_kernel(
prefix_len = tl.load(extend_prefix_lens + pid) if has_prefix else 0
seq_len = tl.load(extend_seq_lens + pid)
# TODO: optimize this?
cumsum_start = 0
# NOTE: This can be slow for large bs
cumsum_start = tl.cast(0, tl.int64)
for i in range(pid):
cumsum_start += tl.load(extend_seq_lens + i)

View File

@@ -53,6 +53,7 @@ from sglang.srt.mem_cache.memory_pool import (
ReqToTokenPool,
TokenToKVPoolAllocator,
)
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader import get_model
@@ -430,7 +431,7 @@ class ModelRunner:
self.model_config.model_path = model_path
load_config = LoadConfig(load_format=load_format)
# Only support the DefaultModelLoader for now
# Only support DefaultModelLoader for now
loader = get_model_loader(load_config)
if not isinstance(loader, DefaultModelLoader):
message = f"Failed to get model loader: {loader}."
@@ -732,6 +733,7 @@ class ModelRunner:
):
self.token_to_kv_pool = MLATokenToKVPool(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
kv_lora_rank=self.model_config.kv_lora_rank,
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
@@ -742,6 +744,7 @@ class ModelRunner:
elif self.server_args.enable_double_sparsity:
self.token_to_kv_pool = DoubleSparseTokenToKVPool(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
head_dim=self.model_config.head_dim,
@@ -753,6 +756,7 @@ class ModelRunner:
else:
self.token_to_kv_pool = MHATokenToKVPool(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
head_dim=self.model_config.head_dim,
@@ -762,12 +766,21 @@ class ModelRunner:
)
if self.token_to_kv_pool_allocator is None:
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
self.max_total_num_tokens,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
)
if self.page_size == 1:
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
self.max_total_num_tokens,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
)
else:
self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
)
else:
assert self.is_draft_worker