Support page size > 1 (#4356)
This commit is contained in:
@@ -264,11 +264,15 @@ class CudaGraphRunner:
|
||||
def model_capture_mode(self):
|
||||
if hasattr(self.model_runner.model, "capture_mode"):
|
||||
self.model_runner.model.capture_mode = True
|
||||
if hasattr(self.model_runner.token_to_kv_pool, "capture_mode"):
|
||||
self.model_runner.token_to_kv_pool.capture_mode = True
|
||||
|
||||
yield
|
||||
|
||||
if hasattr(self.model_runner.model, "capture_mode"):
|
||||
self.model_runner.model.capture_mode = False
|
||||
if hasattr(self.model_runner.token_to_kv_pool, "capture_mode"):
|
||||
self.model_runner.token_to_kv_pool.capture_mode = False
|
||||
|
||||
def can_run(self, forward_batch: ForwardBatch):
|
||||
if self.enable_dp_attention:
|
||||
|
||||
@@ -38,12 +38,12 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
||||
from sglang.srt.utils import get_compiler_backend
|
||||
from sglang.srt.utils import get_compiler_backend, next_power_of_2
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
@@ -51,9 +51,8 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class ForwardMode(IntEnum):
|
||||
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
||||
PREFILL = auto()
|
||||
# Extend a sequence. The KV cache of the beginning part of the sequence is already computed (e.g., system prompt).
|
||||
# It is also called "prefill" in common terminology.
|
||||
EXTEND = auto()
|
||||
# Decode one token.
|
||||
DECODE = auto()
|
||||
@@ -153,6 +152,12 @@ class ForwardBatch:
|
||||
top_logprobs_nums: Optional[List[int]] = None
|
||||
token_ids_logprobs: Optional[List[List[int]]] = None
|
||||
|
||||
# For logits and logprobs post processing
|
||||
temp_scaled_logprobs: bool = False
|
||||
temperature: torch.Tensor = None
|
||||
top_p_normalized_logprobs: bool = False
|
||||
top_p: torch.Tensor = None
|
||||
|
||||
# Position information
|
||||
positions: torch.Tensor = None
|
||||
|
||||
@@ -189,7 +194,7 @@ class ForwardBatch:
|
||||
|
||||
# Attention backend
|
||||
req_to_token_pool: ReqToTokenPool = None
|
||||
token_to_kv_pool: BaseTokenToKVPool = None
|
||||
token_to_kv_pool: KVCache = None
|
||||
attn_backend: AttentionBackend = None
|
||||
|
||||
# For DP attention
|
||||
@@ -229,7 +234,6 @@ class ForwardBatch:
|
||||
extend_input_logprob_token_ids_gpu = (
|
||||
batch.extend_input_logprob_token_ids.to(device, non_blocking=True)
|
||||
)
|
||||
|
||||
ret = cls(
|
||||
forward_mode=batch.forward_mode,
|
||||
batch_size=len(batch.seq_lens),
|
||||
@@ -417,8 +421,8 @@ def compute_position_kernel(
|
||||
prefix_len = tl.load(extend_prefix_lens + pid) if has_prefix else 0
|
||||
seq_len = tl.load(extend_seq_lens + pid)
|
||||
|
||||
# TODO: optimize this?
|
||||
cumsum_start = 0
|
||||
# NOTE: This can be slow for large bs
|
||||
cumsum_start = tl.cast(0, tl.int64)
|
||||
for i in range(pid):
|
||||
cumsum_start += tl.load(extend_seq_lens + i)
|
||||
|
||||
|
||||
@@ -53,6 +53,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
||||
ReqToTokenPool,
|
||||
TokenToKVPoolAllocator,
|
||||
)
|
||||
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
|
||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader import get_model
|
||||
@@ -430,7 +431,7 @@ class ModelRunner:
|
||||
self.model_config.model_path = model_path
|
||||
load_config = LoadConfig(load_format=load_format)
|
||||
|
||||
# Only support the DefaultModelLoader for now
|
||||
# Only support DefaultModelLoader for now
|
||||
loader = get_model_loader(load_config)
|
||||
if not isinstance(loader, DefaultModelLoader):
|
||||
message = f"Failed to get model loader: {loader}."
|
||||
@@ -732,6 +733,7 @@ class ModelRunner:
|
||||
):
|
||||
self.token_to_kv_pool = MLATokenToKVPool(
|
||||
self.max_total_num_tokens,
|
||||
page_size=self.page_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
kv_lora_rank=self.model_config.kv_lora_rank,
|
||||
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
||||
@@ -742,6 +744,7 @@ class ModelRunner:
|
||||
elif self.server_args.enable_double_sparsity:
|
||||
self.token_to_kv_pool = DoubleSparseTokenToKVPool(
|
||||
self.max_total_num_tokens,
|
||||
page_size=self.page_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
|
||||
head_dim=self.model_config.head_dim,
|
||||
@@ -753,6 +756,7 @@ class ModelRunner:
|
||||
else:
|
||||
self.token_to_kv_pool = MHATokenToKVPool(
|
||||
self.max_total_num_tokens,
|
||||
page_size=self.page_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
|
||||
head_dim=self.model_config.head_dim,
|
||||
@@ -762,12 +766,21 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
if self.token_to_kv_pool_allocator is None:
|
||||
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
|
||||
self.max_total_num_tokens,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
kvcache=self.token_to_kv_pool,
|
||||
)
|
||||
if self.page_size == 1:
|
||||
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
|
||||
self.max_total_num_tokens,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
kvcache=self.token_to_kv_pool,
|
||||
)
|
||||
else:
|
||||
self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
|
||||
self.max_total_num_tokens,
|
||||
page_size=self.page_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
kvcache=self.token_to_kv_pool,
|
||||
)
|
||||
else:
|
||||
assert self.is_draft_worker
|
||||
|
||||
|
||||
Reference in New Issue
Block a user