diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index 6dcfc0d3b..213ef2715 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -51,6 +51,7 @@ import logging import multiprocessing import os import time +from types import SimpleNamespace from typing import Tuple import numpy as np @@ -257,11 +258,18 @@ def prepare_synthetic_inputs_for_latency_test( @torch.no_grad def extend(reqs, model_runner): + # Create dummy tree_cache for benchmarks (no prefix caching, just allocation) + dummy_tree_cache = SimpleNamespace( + page_size=1, + device=model_runner.device, + token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator, + ) + batch = ScheduleBatch.init_new( reqs=reqs, req_to_token_pool=model_runner.req_to_token_pool, token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator, - tree_cache=None, + tree_cache=dummy_tree_cache, model_config=model_runner.model_config, enable_overlap=False, spec_algorithm=SpeculativeAlgorithm.NONE, diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 075d90477..3b18b9452 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -45,8 +45,6 @@ from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union import numpy as np import torch -import triton -import triton.language as tl from sglang.global_config import global_config from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject @@ -62,6 +60,7 @@ from sglang.srt.mem_cache.allocator import ( ) from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache +from sglang.srt.mem_cache.common import alloc_for_decode, alloc_for_extend from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool from sglang.srt.mem_cache.radix_cache import RadixKey from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache @@ -70,7 +69,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import flatten_nested_list, support_triton +from sglang.srt.utils import flatten_nested_list if TYPE_CHECKING: from sglang.srt.configs.model_config import ModelConfig @@ -1001,158 +1000,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): def is_empty(self): return len(self.reqs) == 0 - def alloc_req_slots(self, num_reqs: int, reqs: Optional[List[Req]] = None): - if isinstance(self.req_to_token_pool, HybridReqToTokenPool): - req_pool_indices = self.req_to_token_pool.alloc(num_reqs, reqs) - else: - req_pool_indices = self.req_to_token_pool.alloc(num_reqs) - if req_pool_indices is None: - raise RuntimeError( - "alloc_req_slots runs out of memory. " - "Please set a smaller number for `--max-running-requests`. " - f"{self.req_to_token_pool.available_size()=}, " - f"{num_reqs=}, " - ) - return req_pool_indices - - def alloc_token_slots(self, num_tokens: int, backup_state: bool = False): - self._evict_tree_cache_if_needed(num_tokens) - - if backup_state: - state = self.token_to_kv_pool_allocator.backup_state() - - out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens) - if out_cache_loc is None: - phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode" - error_msg = ( - f"{phase_str} out of memory. Try to lower your batch size.\n" - f"Try to allocate {num_tokens} tokens.\n" - f"{self._available_and_evictable_str()}" - ) - logger.error(error_msg) - if self.tree_cache is not None: - self.tree_cache.pretty_print() - raise RuntimeError(error_msg) - - if backup_state: - return out_cache_loc, state - else: - return out_cache_loc - - def alloc_paged_token_slots_extend( - self, - prefix_lens: torch.Tensor, - prefix_lens_cpu: torch.Tensor, - seq_lens: torch.Tensor, - seq_lens_cpu: torch.Tensor, - last_loc: torch.Tensor, - extend_num_tokens: int, - backup_state: bool = False, - ): - # Over estimate the number of tokens: assume each request needs a new page. - num_tokens = ( - extend_num_tokens - + len(seq_lens_cpu) * self.token_to_kv_pool_allocator.page_size - ) - self._evict_tree_cache_if_needed(num_tokens) - - if backup_state: - state = self.token_to_kv_pool_allocator.backup_state() - - out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend( - prefix_lens, - prefix_lens_cpu, - seq_lens, - seq_lens_cpu, - last_loc, - extend_num_tokens, - ) - if out_cache_loc is None: - error_msg = ( - f"Prefill out of memory. Try to lower your batch size.\n" - f"Try to allocate {extend_num_tokens} tokens.\n" - f"{self._available_and_evictable_str()}" - ) - logger.error(error_msg) - raise RuntimeError(error_msg) - - if backup_state: - return out_cache_loc, state - else: - return out_cache_loc - - def alloc_paged_token_slots_decode( - self, - seq_lens: torch.Tensor, - seq_lens_cpu: torch.Tensor, - last_loc: torch.Tensor, - backup_state: bool = False, - ): - # Over estimate the number of tokens: assume each request needs a new page. - num_tokens = len(seq_lens) * self.token_to_kv_pool_allocator.page_size - self._evict_tree_cache_if_needed(num_tokens) - - if backup_state: - state = self.token_to_kv_pool_allocator.backup_state() - - out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode( - seq_lens, seq_lens_cpu, last_loc - ) - if out_cache_loc is None: - error_msg = ( - f"Decode out of memory. Try to lower your batch size.\n" - f"Try to allocate {len(seq_lens)} tokens.\n" - f"{self._available_and_evictable_str()}" - ) - logger.error(error_msg) - raise RuntimeError(error_msg) - - if backup_state: - return out_cache_loc, state - else: - return out_cache_loc - - def write_cache_indices( - self, - req_pool_indices: List[int], - prefix_lens: List[int], - seq_lens: List[int], - extend_lens: List[int], - out_cache_loc: torch.Tensor, - req_pool_indices_tensor: torch.Tensor, - prefix_lens_tensor: torch.Tensor, - seq_lens_tensor: torch.Tensor, - extend_lens_tensor: torch.Tensor, - prefix_tensors: list[torch.Tensor], - ): - if support_triton(global_server_args_dict.get("attention_backend")): - prefix_pointers = torch.tensor( - [t.data_ptr() for t in prefix_tensors], device=self.device - ) - # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start) - write_req_to_token_pool_triton[(len(req_pool_indices),)]( - self.req_to_token_pool.req_to_token, - req_pool_indices_tensor, - prefix_pointers, - prefix_lens_tensor, - seq_lens_tensor, - extend_lens_tensor, - out_cache_loc, - self.req_to_token_pool.req_to_token.shape[1], - ) - else: - pt = 0 - for i in range(len(req_pool_indices)): - self.req_to_token_pool.write( - (req_pool_indices[i], slice(0, prefix_lens[i])), - prefix_tensors[i], - ) - self.req_to_token_pool.write( - (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])), - out_cache_loc[pt : pt + extend_lens[i]], - ) - pt += extend_lens[i] - def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]): self.encoder_lens_cpu = [] self.encoder_cached = [] @@ -1253,10 +1100,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to( self.device, non_blocking=True ) - prefix_lens_tensor = torch.tensor( - prefix_lens, dtype=torch.int64, device=self.device - ) - prefix_lens_cpu_tensor = torch.tensor(prefix_lens, dtype=torch.int64) token_type_ids_tensor = None if len(token_type_ids) > 0: @@ -1264,48 +1107,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): sum(token_type_ids, []), dtype=torch.int64 ).to(self.device, non_blocking=True) - extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor - - # Allocate req slots - bs = len(self.reqs) - req_pool_indices = self.alloc_req_slots(bs, self.reqs) - req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to( - self.device, non_blocking=True - ) + # Set batch fields needed by alloc_for_extend + self.prefix_lens = prefix_lens + self.extend_lens = extend_lens + self.seq_lens = seq_lens_tensor + self.seq_lens_cpu = seq_lens_cpu + self.extend_num_tokens = extend_num_tokens # Allocate memory - if self.token_to_kv_pool_allocator.page_size == 1: - out_cache_loc = self.alloc_token_slots(extend_num_tokens) - else: - last_loc = [ - ( - r.prefix_indices[-1:] - if len(r.prefix_indices) > 0 - else torch.tensor([-1], device=self.device) - ) - for r in self.reqs - ] - out_cache_loc = self.alloc_paged_token_slots_extend( - prefix_lens_tensor, - prefix_lens_cpu_tensor, - seq_lens_tensor, - seq_lens_cpu, - torch.cat(last_loc), - extend_num_tokens, - ) - - # Write allocated tokens to req_to_token_pool - self.write_cache_indices( - req_pool_indices, - prefix_lens, - seq_lens, - extend_lens, - out_cache_loc, - req_pool_indices_tensor, - prefix_lens_tensor, - seq_lens_tensor, - extend_lens_tensor, - [r.prefix_indices for r in reqs], + out_cache_loc, req_pool_indices_tensor, req_pool_indices = alloc_for_extend( + self ) # Set fields @@ -1317,12 +1128,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): req.req_pool_idx = req_pool_indices[i] assert seq_len - pre_len == req.extend_input_len - if pre_len > 0: - if isinstance(self.tree_cache, SWAChunkCache): - self.tree_cache.evict_swa( - req, pre_len, self.model_config.attention_chunk_size - ) - # If input_embeds are available, store them if req.input_embeds is not None: # If req.input_embeds is already a list, append its content directly @@ -1414,8 +1219,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.input_ids = input_ids_tensor self.req_pool_indices = req_pool_indices_tensor - self.seq_lens = seq_lens_tensor - self.seq_lens_cpu = seq_lens_cpu self.orig_seq_lens = orig_seq_lens_tensor self.out_cache_loc = out_cache_loc self.input_embeds = ( @@ -1439,9 +1242,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.token_ids_logprobs = [r.token_ids_logprob for r in reqs] self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs] - self.extend_num_tokens = extend_num_tokens - self.prefix_lens = prefix_lens - self.extend_lens = extend_lens self.extend_input_logprob_token_ids = extend_input_logprob_token_ids if self.model_config.is_encoder_decoder: @@ -1681,11 +1481,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.output_ids = None if self.model_config.is_encoder_decoder: - locs = self.encoder_lens + self.seq_lens self.prepare_encoder_info_decode() - else: - locs = self.seq_lens.clone() + # Allocate memory + self.out_cache_loc = alloc_for_decode(self, token_per_req=1) + + # Update seq_lens after allocation if self.enable_overlap: # Do not use in-place operations in the overlap mode self.seq_lens = self.seq_lens + 1 @@ -1698,28 +1499,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.orig_seq_lens.add_(1) self.seq_lens_sum += bs - # free memory - if isinstance(self.tree_cache, SWAChunkCache): - for req in self.reqs: - self.tree_cache.evict_swa( - req, req.seqlen - 1, self.model_config.attention_chunk_size - ) - - # Allocate memory - if self.token_to_kv_pool_allocator.page_size == 1: - self.out_cache_loc = self.alloc_token_slots(bs) - else: - last_loc = self.req_to_token_pool.req_to_token[ - self.req_pool_indices, self.seq_lens - 2 - ] - self.out_cache_loc = self.alloc_paged_token_slots_decode( - self.seq_lens, self.seq_lens_cpu, last_loc - ) - - self.req_to_token_pool.write( - (self.req_pool_indices, locs), self.out_cache_loc.to(torch.int32) - ) - def filter_batch( self, chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None, @@ -1940,23 +1719,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): else: return self.token_to_kv_pool_allocator.available_size() >= num_tokens - def _available_and_evictable_str(self) -> str: - if self.is_hybrid: - full_available_size = self.token_to_kv_pool_allocator.full_available_size() - swa_available_size = self.token_to_kv_pool_allocator.swa_available_size() - full_evictable_size = self.tree_cache.full_evictable_size() - swa_evictable_size = self.tree_cache.swa_evictable_size() - return ( - f"Available full tokens: {full_available_size + full_evictable_size} ({full_available_size=} + {full_evictable_size=})\n" - f"Available swa tokens: {swa_available_size + swa_evictable_size} ({swa_available_size=} + {swa_evictable_size=})\n" - f"Full LRU list evictable size: {self.tree_cache.full_lru_list_evictable_size()}\n" - f"SWA LRU list evictable size: {self.tree_cache.swa_lru_list_evictable_size()}\n" - ) - else: - available_size = self.token_to_kv_pool_allocator.available_size() - evictable_size = self.tree_cache.evictable_size() - return f"Available tokens: {available_size + evictable_size} ({available_size=} + {evictable_size=})\n" - def __str__(self): return ( f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, " @@ -2038,128 +1800,3 @@ class ModelWorkerBatch: # Whether this batch is prefill-only (no token generation needed) is_prefill_only: bool = False - - -@triton.jit -def write_req_to_token_pool_triton( - req_to_token_ptr, # [max_batch, max_context_len] - req_pool_indices, - prefix_tensors, - pre_lens, - seq_lens, - extend_lens, - out_cache_loc, - req_to_token_ptr_stride: tl.constexpr, -): - BLOCK_SIZE: tl.constexpr = 512 - pid = tl.program_id(0) - - req_pool_index = tl.load(req_pool_indices + pid) - pre_len = tl.load(pre_lens + pid) - seq_len = tl.load(seq_lens + pid) - prefix_tensor = tl.load(prefix_tensors + pid).to(tl.pointer_type(tl.int64)) - - # write prefix - num_loop = tl.cdiv(pre_len, BLOCK_SIZE) - for i in range(num_loop): - offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE - mask = offset < pre_len - value = tl.load(prefix_tensor + offset, mask=mask) - tl.store( - req_to_token_ptr + req_pool_index * req_to_token_ptr_stride + offset, - value, - mask=mask, - ) - - # NOTE: This can be slow for large bs - cumsum_start = tl.cast(0, tl.int64) - for i in range(pid): - cumsum_start += tl.load(extend_lens + i) - - num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE) - for i in range(num_loop): - offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE - mask = offset < (seq_len - pre_len) - value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask) - tl.store( - req_to_token_ptr - + req_pool_index * req_to_token_ptr_stride - + offset - + pre_len, - value, - mask=mask, - ) - - -def get_last_loc( - req_to_token: torch.Tensor, - req_pool_indices_tensor: torch.Tensor, - prefix_lens_tensor: torch.Tensor, -) -> torch.Tensor: - if ( - global_server_args_dict["attention_backend"] != "ascend" - and global_server_args_dict["attention_backend"] != "torch_native" - ): - impl = get_last_loc_triton - else: - impl = get_last_loc_torch - - return impl(req_to_token, req_pool_indices_tensor, prefix_lens_tensor) - - -def get_last_loc_torch( - req_to_token: torch.Tensor, - req_pool_indices_tensor: torch.Tensor, - prefix_lens_tensor: torch.Tensor, -) -> torch.Tensor: - return torch.where( - prefix_lens_tensor > 0, - req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1], - torch.full_like(prefix_lens_tensor, -1), - ) - - -@triton.jit -def get_last_loc_kernel( - req_to_token, - req_pool_indices_tensor, - prefix_lens_tensor, - result, - num_tokens, - req_to_token_stride, - BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(0) - offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE - mask = offset < num_tokens - - prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0) - req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0) - - token_mask = prefix_lens > 0 - token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1) - tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1) - - tl.store(result + offset, tokens, mask=mask) - - -def get_last_loc_triton( - req_to_token: torch.Tensor, - req_pool_indices_tensor: torch.Tensor, - prefix_lens_tensor: torch.Tensor, -) -> torch.Tensor: - BLOCK_SIZE = 256 - num_tokens = prefix_lens_tensor.shape[0] - result = torch.empty_like(prefix_lens_tensor) - grid = (triton.cdiv(num_tokens, BLOCK_SIZE),) - - get_last_loc_kernel[grid]( - req_to_token, - req_pool_indices_tensor, - prefix_lens_tensor, - result, - num_tokens, - req_to_token.stride(0), - BLOCK_SIZE, - ) - return result diff --git a/python/sglang/srt/mem_cache/common.py b/python/sglang/srt/mem_cache/common.py new file mode 100644 index 000000000..040bc45bf --- /dev/null +++ b/python/sglang/srt/mem_cache/common.py @@ -0,0 +1,479 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import torch +import triton +import triton.language as tl + +from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator +from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache +from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache +from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import support_triton + +if TYPE_CHECKING: + from sglang.srt.managers.schedule_batch import Req, ScheduleBatch + +logger = logging.getLogger(__name__) + +GLOBAL_SERVER_ARGS_KEYS = ["attention_backend"] + +global_server_args_dict = {k: getattr(ServerArgs, k) for k in GLOBAL_SERVER_ARGS_KEYS} + + +@triton.jit +def write_req_to_token_pool_triton( + req_to_token_ptr, # [max_batch, max_context_len] + req_pool_indices, + prefix_tensors, + pre_lens, + seq_lens, + extend_lens, + out_cache_loc, + req_to_token_ptr_stride: tl.constexpr, +): + BLOCK_SIZE: tl.constexpr = 512 + pid = tl.program_id(0) + + req_pool_index = tl.load(req_pool_indices + pid) + pre_len = tl.load(pre_lens + pid) + seq_len = tl.load(seq_lens + pid) + prefix_tensor = tl.load(prefix_tensors + pid).to(tl.pointer_type(tl.int64)) + + # write prefix + num_loop = tl.cdiv(pre_len, BLOCK_SIZE) + for i in range(num_loop): + offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + mask = offset < pre_len + value = tl.load(prefix_tensor + offset, mask=mask) + tl.store( + req_to_token_ptr + req_pool_index * req_to_token_ptr_stride + offset, + value, + mask=mask, + ) + + # NOTE: This can be slow for large bs + cumsum_start = tl.cast(0, tl.int64) + for i in range(pid): + cumsum_start += tl.load(extend_lens + i) + + num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE) + for i in range(num_loop): + offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + mask = offset < (seq_len - pre_len) + value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask) + tl.store( + req_to_token_ptr + + req_pool_index * req_to_token_ptr_stride + + offset + + pre_len, + value, + mask=mask, + ) + + +def write_cache_indices( + out_cache_loc: torch.Tensor, + req_pool_indices_tensor: torch.Tensor, + req_pool_indices_cpu: torch.Tensor, + prefix_lens_tensor: torch.Tensor, + prefix_lens_cpu: torch.Tensor, + seq_lens_tensor: torch.Tensor, + seq_lens_cpu: torch.Tensor, + extend_lens_tensor: torch.Tensor, + extend_lens_cpu: torch.Tensor, + prefix_tensors: list[torch.Tensor], + req_to_token_pool: ReqToTokenPool, +): + if support_triton(global_server_args_dict.get("attention_backend")): + prefix_pointers = torch.tensor( + [t.data_ptr() for t in prefix_tensors], + device=req_to_token_pool.device, + ) + # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start) + write_req_to_token_pool_triton[(req_pool_indices_tensor.shape[0],)]( + req_to_token_pool.req_to_token, + req_pool_indices_tensor, + prefix_pointers, + prefix_lens_tensor, + seq_lens_tensor, + extend_lens_tensor, + out_cache_loc, + req_to_token_pool.req_to_token.shape[1], + ) + else: + pt = 0 + for i in range(req_pool_indices_cpu.shape[0]): + req_idx = req_pool_indices_cpu[i].item() + prefix_len = prefix_lens_cpu[i].item() + seq_len = seq_lens_cpu[i].item() + extend_len = extend_lens_cpu[i].item() + + req_to_token_pool.write( + (req_idx, slice(0, prefix_len)), + prefix_tensors[i], + ) + req_to_token_pool.write( + (req_idx, slice(prefix_len, seq_len)), + out_cache_loc[pt : pt + extend_len], + ) + pt += extend_len + + +def get_last_loc( + req_to_token: torch.Tensor, + req_pool_indices_tensor: torch.Tensor, + prefix_lens_tensor: torch.Tensor, +) -> torch.Tensor: + if ( + global_server_args_dict["attention_backend"] != "ascend" + and global_server_args_dict["attention_backend"] != "torch_native" + ): + impl = get_last_loc_triton + else: + impl = get_last_loc_torch + + return impl(req_to_token, req_pool_indices_tensor, prefix_lens_tensor) + + +def get_last_loc_torch( + req_to_token: torch.Tensor, + req_pool_indices_tensor: torch.Tensor, + prefix_lens_tensor: torch.Tensor, +) -> torch.Tensor: + return torch.where( + prefix_lens_tensor > 0, + req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1], + torch.full_like(prefix_lens_tensor, -1), + ) + + +@triton.jit +def get_last_loc_kernel( + req_to_token, + req_pool_indices_tensor, + prefix_lens_tensor, + result, + num_tokens, + req_to_token_stride, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE + mask = offset < num_tokens + + prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0) + req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0) + + token_mask = prefix_lens > 0 + token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1) + tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1) + + tl.store(result + offset, tokens, mask=mask) + + +def get_last_loc_triton( + req_to_token: torch.Tensor, + req_pool_indices_tensor: torch.Tensor, + prefix_lens_tensor: torch.Tensor, +) -> torch.Tensor: + BLOCK_SIZE = 256 + num_tokens = prefix_lens_tensor.shape[0] + result = torch.empty_like(prefix_lens_tensor) + grid = (triton.cdiv(num_tokens, BLOCK_SIZE),) + + get_last_loc_kernel[grid]( + req_to_token, + req_pool_indices_tensor, + prefix_lens_tensor, + result, + num_tokens, + req_to_token.stride(0), + BLOCK_SIZE, + ) + return result + + +def alloc_token_slots( + tree_cache: BasePrefixCache, + num_tokens: int, + backup_state: bool = False, +): + allocator = tree_cache.token_to_kv_pool_allocator + evict_from_tree_cache(tree_cache, num_tokens) + + state = None + if backup_state: + state = allocator.backup_state() + + out_cache_loc = allocator.alloc(num_tokens) + + if out_cache_loc is None: + error_msg = ( + f"Out of memory. Try to lower your batch size.\n" + f"Try to allocate {num_tokens} tokens.\n" + f"{available_and_evictable_str(tree_cache)}" + ) + logger.error(error_msg) + if tree_cache is not None: + tree_cache.pretty_print() + raise RuntimeError(error_msg) + + return (out_cache_loc, state) if backup_state else out_cache_loc + + +def evict_from_tree_cache(tree_cache: BasePrefixCache | None, num_tokens: int): + if tree_cache is None: + return + + if isinstance(tree_cache, (SWAChunkCache, ChunkCache)): + return + + allocator = tree_cache.token_to_kv_pool_allocator + + # Check if this is a hybrid allocator + if hasattr(allocator, "full_available_size"): + # Hybrid allocator + full_available_size = allocator.full_available_size() + swa_available_size = allocator.swa_available_size() + + if full_available_size < num_tokens or swa_available_size < num_tokens: + full_num_tokens = max(0, num_tokens - full_available_size) + swa_num_tokens = max(0, num_tokens - swa_available_size) + tree_cache.evict(full_num_tokens, swa_num_tokens) + else: + # Standard allocator + if allocator.available_size() < num_tokens: + tree_cache.evict(num_tokens) + + +def alloc_paged_token_slots_extend( + tree_cache: BasePrefixCache, + prefix_lens: torch.Tensor, + prefix_lens_cpu: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_cpu: torch.Tensor, + last_loc: torch.Tensor, + extend_num_tokens: int, + backup_state: bool = False, +): + # Over estimate the number of tokens: assume each request needs a new page. + allocator = tree_cache.token_to_kv_pool_allocator + num_tokens = extend_num_tokens + len(seq_lens_cpu) * allocator.page_size + evict_from_tree_cache(tree_cache, num_tokens) + + state = None + if backup_state: + state = allocator.backup_state() + + out_cache_loc = allocator.alloc_extend( + prefix_lens, + prefix_lens_cpu, + seq_lens, + seq_lens_cpu, + last_loc, + extend_num_tokens, + ) + + if out_cache_loc is None: + error_msg = ( + f"Prefill out of memory. Try to lower your batch size.\n" + f"Try to allocate {extend_num_tokens} tokens.\n" + f"{available_and_evictable_str(tree_cache)}" + ) + logger.error(error_msg) + if tree_cache is not None: + tree_cache.pretty_print() + raise RuntimeError(error_msg) + + return (out_cache_loc, state) if backup_state else out_cache_loc + + +def alloc_req_slots( + req_to_token_pool: ReqToTokenPool, + num_reqs: int, + reqs: list[Req] | None, +) -> list[int]: + """Allocate request slots from the pool.""" + if isinstance(req_to_token_pool, HybridReqToTokenPool): + req_pool_indices = req_to_token_pool.alloc(num_reqs, reqs) + else: + req_pool_indices = req_to_token_pool.alloc(num_reqs) + + if req_pool_indices is None: + raise RuntimeError( + "alloc_req_slots runs out of memory. " + "Please set a smaller number for `--max-running-requests`. " + f"{req_to_token_pool.available_size()=}, " + f"{num_reqs=}, " + ) + return req_pool_indices + + +def alloc_for_extend( + batch: ScheduleBatch, +) -> tuple[torch.Tensor, torch.Tensor, list[int]]: + """ + Allocate KV cache for extend batch and write to req_to_token_pool. + + Returns: + out_cache_loc: allocated cache locations + req_pool_indices_device: request pool indices at a device tensor + req_pool_indices: request pool indices as list + """ + # free out-of-window swa tokens + if isinstance(batch.tree_cache, SWAChunkCache): + for req, pre_len in zip(batch.reqs, batch.prefix_lens): + batch.tree_cache.evict_swa( + req, pre_len, batch.model_config.attention_chunk_size + ) + + bs = len(batch.reqs) + prefix_tensors = [r.prefix_indices for r in batch.reqs] + + # Create tensors for allocation + prefix_lens_cpu = torch.tensor(batch.prefix_lens, dtype=torch.int64) + extend_lens_cpu = torch.tensor(batch.extend_lens, dtype=torch.int64) + prefix_lens_device = prefix_lens_cpu.to(batch.device, non_blocking=True) + extend_lens_device = extend_lens_cpu.to(batch.device, non_blocking=True) + + # Allocate req slots + req_pool_indices = alloc_req_slots(batch.req_to_token_pool, bs, batch.reqs) + req_pool_indices_cpu = torch.tensor(req_pool_indices, dtype=torch.int64) + req_pool_indices_device = req_pool_indices_cpu.to(batch.device, non_blocking=True) + + # Allocate KV cache (throws exception on failure) + if batch.tree_cache.page_size == 1: + out_cache_loc = alloc_token_slots(batch.tree_cache, batch.extend_num_tokens) + else: + # Paged allocation - build last_loc + last_loc = [ + ( + t[-1:] + if len(t) > 0 + else torch.tensor([-1], device=batch.tree_cache.device) + ) + for t in prefix_tensors + ] + out_cache_loc = alloc_paged_token_slots_extend( + tree_cache=batch.tree_cache, + prefix_lens=prefix_lens_device, + prefix_lens_cpu=prefix_lens_cpu, + seq_lens=batch.seq_lens, + seq_lens_cpu=batch.seq_lens_cpu, + last_loc=torch.cat(last_loc), + extend_num_tokens=batch.extend_num_tokens, + ) + + # Write to req_to_token_pool + write_cache_indices( + out_cache_loc, + req_pool_indices_device, + req_pool_indices_cpu, + prefix_lens_device, + prefix_lens_cpu, + batch.seq_lens, + batch.seq_lens_cpu, + extend_lens_device, + extend_lens_cpu, + prefix_tensors, + batch.req_to_token_pool, + ) + + return out_cache_loc, req_pool_indices_device, req_pool_indices + + +def alloc_paged_token_slots_decode( + tree_cache: BasePrefixCache, + seq_lens: torch.Tensor, + seq_lens_cpu: torch.Tensor, + last_loc: torch.Tensor, + token_per_req: int = 1, +) -> torch.Tensor: + """Allocate paged KV cache for decode batch.""" + allocator = tree_cache.token_to_kv_pool_allocator + # Over estimate the number of tokens: assume each request needs a new page. + num_tokens = len(seq_lens) * allocator.page_size + evict_from_tree_cache(tree_cache, num_tokens) + + out_cache_loc = allocator.alloc_decode(seq_lens, seq_lens_cpu, last_loc) + + if out_cache_loc is None: + error_msg = ( + f"Decode out of memory. Try to lower your batch size.\n" + f"Try to allocate {len(seq_lens) * token_per_req} tokens.\n" + f"{available_and_evictable_str(tree_cache)}" + ) + logger.error(error_msg) + if tree_cache is not None: + tree_cache.pretty_print() + raise RuntimeError(error_msg) + + return out_cache_loc + + +def alloc_for_decode(batch: ScheduleBatch, token_per_req: int) -> torch.Tensor: + """ + Allocate KV cache for decode batch and write to req_to_token_pool. + + Returns: + out_cache_loc: allocated cache locations + """ + if isinstance(batch.tree_cache, SWAChunkCache): + for req in batch.reqs: + batch.tree_cache.evict_swa( + req, req.seqlen - 1, batch.model_config.attention_chunk_size + ) + + bs = batch.seq_lens.shape[0] + + if batch.tree_cache.page_size == 1: + # Non-paged allocation + out_cache_loc = alloc_token_slots(batch.tree_cache, bs * token_per_req) + else: + # Paged allocation + last_loc = batch.req_to_token_pool.req_to_token[ + batch.req_pool_indices, batch.seq_lens - 1 + ] + seq_lens_next = batch.seq_lens + token_per_req + out_cache_loc = alloc_paged_token_slots_decode( + tree_cache=batch.tree_cache, + seq_lens=seq_lens_next, + seq_lens_cpu=batch.seq_lens_cpu + token_per_req, + last_loc=last_loc, + token_per_req=token_per_req, + ) + + # Write to req_to_token_pool + if batch.model_config.is_encoder_decoder: + locs = batch.encoder_lens + batch.seq_lens + else: + locs = batch.seq_lens.clone() + + batch.req_to_token_pool.write( + (batch.req_pool_indices, locs), out_cache_loc.to(torch.int32) + ) + + return out_cache_loc + + +def available_and_evictable_str(tree_cache) -> str: + token_to_kv_pool_allocator = tree_cache.token_to_kv_pool_allocator + if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator): + full_available_size = token_to_kv_pool_allocator.full_available_size() + swa_available_size = token_to_kv_pool_allocator.swa_available_size() + full_evictable_size = tree_cache.full_evictable_size() + swa_evictable_size = tree_cache.swa_evictable_size() + return ( + f"Available full tokens: {full_available_size + full_evictable_size} ({full_available_size=} + {full_evictable_size=})\n" + f"Available swa tokens: {swa_available_size + swa_evictable_size} ({swa_available_size=} + {swa_evictable_size=})\n" + f"Full LRU list evictable size: {tree_cache.full_lru_list_evictable_size()}\n" + f"SWA LRU list evictable size: {tree_cache.swa_lru_list_evictable_size()}\n" + ) + else: + available_size = token_to_kv_pool_allocator.available_size() + evictable_size = tree_cache.evictable_size() + return f"Available tokens: {available_size + evictable_size} ({available_size=} + {evictable_size=})\n" diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py index 5d8c920c4..46ecc1b32 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -10,12 +10,13 @@ from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import apply_custom_logit_processor -from sglang.srt.managers.schedule_batch import ( - ScheduleBatch, - get_last_loc, - global_server_args_dict, -) +from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator +from sglang.srt.mem_cache.common import ( + alloc_paged_token_slots_extend, + alloc_token_slots, + get_last_loc, +) from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode from sglang.srt.speculative.spec_info import SpecInput, SpecInputType from sglang.srt.speculative.spec_utils import ( @@ -100,7 +101,10 @@ class EagleVerifyInput(SpecInput): batch.input_ids = self.draft_token if page_size == 1: - batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids)) + batch.out_cache_loc = alloc_token_slots( + batch.tree_cache, + len(batch.input_ids), + ) end_offset = batch.seq_lens + self.draft_token_num else: prefix_lens = batch.seq_lens @@ -112,7 +116,8 @@ class EagleVerifyInput(SpecInput): batch.req_pool_indices, prefix_lens, ) - batch.out_cache_loc = batch.alloc_paged_token_slots_extend( + batch.out_cache_loc = alloc_paged_token_slots_extend( + batch.tree_cache, prefix_lens, prefix_lens_cpu, end_offset, diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 430e38eb4..162ce53ec 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -14,13 +14,14 @@ from sglang.srt.distributed import ( ) from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs -from sglang.srt.managers.schedule_batch import ( - ScheduleBatch, - get_last_loc, - global_server_args_dict, -) +from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.managers.scheduler import GenerationBatchResult from sglang.srt.managers.tp_worker import TpModelWorker +from sglang.srt.mem_cache.common import ( + alloc_paged_token_slots_extend, + alloc_token_slots, + get_last_loc, +) from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, ForwardBatch, @@ -541,8 +542,10 @@ class EAGLEWorker(TpModelWorker): # [ topk 0 ] [ topk 1 ] # [iter=0, iter=1, iter=2] [iter=0, iter=1, iter=2] if self.page_size == 1: - out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots( - num_seqs * self.speculative_num_steps * self.topk, backup_state=True + out_cache_loc, token_to_kv_pool_state_backup = alloc_token_slots( + batch.tree_cache, + num_seqs * self.speculative_num_steps * self.topk, + backup_state=True, ) else: if self.topk == 1: @@ -601,7 +604,8 @@ class EAGLEWorker(TpModelWorker): extend_num_tokens = torch.sum((seq_lens_cpu - prefix_lens_cpu)).item() out_cache_loc, token_to_kv_pool_state_backup = ( - batch.alloc_paged_token_slots_extend( + alloc_paged_token_slots_extend( + batch.tree_cache, prefix_lens, prefix_lens_cpu, seq_lens, diff --git a/python/sglang/srt/speculative/ngram_info.py b/python/sglang/srt/speculative/ngram_info.py index 345fcbd66..ce4557b89 100644 --- a/python/sglang/srt/speculative/ngram_info.py +++ b/python/sglang/srt/speculative/ngram_info.py @@ -16,10 +16,11 @@ import torch.nn.functional as F from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import apply_custom_logit_processor -from sglang.srt.managers.schedule_batch import ( - ScheduleBatch, +from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict +from sglang.srt.mem_cache.common import ( + alloc_paged_token_slots_extend, + alloc_token_slots, get_last_loc, - global_server_args_dict, ) from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.speculative.spec_info import SpecInput, SpecInputType @@ -74,7 +75,10 @@ class NgramVerifyInput(SpecInput): batch.input_ids = self.draft_token if page_size == 1: - batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids)) + batch.out_cache_loc = alloc_token_slots( + batch.tree_cache, + len(batch.input_ids), + ) end_offset = batch.seq_lens + self.draft_token_num else: # TODO(lsyin): add prefix lens cpu here to support page size > 1 @@ -87,7 +91,8 @@ class NgramVerifyInput(SpecInput): batch.req_pool_indices, prefix_lens, ) - batch.out_cache_loc = batch.alloc_paged_token_slots_extend( + batch.out_cache_loc = alloc_paged_token_slots_extend( + batch.tree_cache, prefix_lens, prefix_lens_cpu, end_offset, diff --git a/test/srt/test_forward_split_prefill.py b/test/srt/test_forward_split_prefill.py index 314e35ec9..4ca3c12fe 100644 --- a/test/srt/test_forward_split_prefill.py +++ b/test/srt/test_forward_split_prefill.py @@ -8,6 +8,7 @@ python3 test_forward_split_prefill.py """ import unittest +from types import SimpleNamespace import numpy as np import torch @@ -95,11 +96,18 @@ class TestForwardSplitPrefill(CustomTestCase): req.logprob_start_len = len(req.origin_input_ids) - 1 reqs.append(req) + # Create dummy tree_cache for tests (no prefix caching, just allocation) + dummy_tree_cache = SimpleNamespace( + page_size=1, + device=self.model_runner.device, + token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator, + ) + batch = ScheduleBatch.init_new( reqs=reqs, req_to_token_pool=self.model_runner.req_to_token_pool, token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator, - tree_cache=None, + tree_cache=dummy_tree_cache, model_config=self.model_config, enable_overlap=False, spec_algorithm=SpeculativeAlgorithm.NONE,