Separate allocation logic from scheduler (#11313)

This commit is contained in:
cctry
2025-10-10 17:38:54 -07:00
committed by GitHub
parent 9aa4502d11
commit b36afed4a7
7 changed files with 545 additions and 399 deletions

View File

@@ -45,8 +45,6 @@ from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
import numpy as np
import torch
import triton
import triton.language as tl
from sglang.global_config import global_config
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
@@ -62,6 +60,7 @@ from sglang.srt.mem_cache.allocator import (
)
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.common import alloc_for_decode, alloc_for_extend
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
from sglang.srt.mem_cache.radix_cache import RadixKey
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
@@ -70,7 +69,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import flatten_nested_list, support_triton
from sglang.srt.utils import flatten_nested_list
if TYPE_CHECKING:
from sglang.srt.configs.model_config import ModelConfig
@@ -1001,158 +1000,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def is_empty(self):
return len(self.reqs) == 0
def alloc_req_slots(self, num_reqs: int, reqs: Optional[List[Req]] = None):
if isinstance(self.req_to_token_pool, HybridReqToTokenPool):
req_pool_indices = self.req_to_token_pool.alloc(num_reqs, reqs)
else:
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
if req_pool_indices is None:
raise RuntimeError(
"alloc_req_slots runs out of memory. "
"Please set a smaller number for `--max-running-requests`. "
f"{self.req_to_token_pool.available_size()=}, "
f"{num_reqs=}, "
)
return req_pool_indices
def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
self._evict_tree_cache_if_needed(num_tokens)
if backup_state:
state = self.token_to_kv_pool_allocator.backup_state()
out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
if out_cache_loc is None:
phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
error_msg = (
f"{phase_str} out of memory. Try to lower your batch size.\n"
f"Try to allocate {num_tokens} tokens.\n"
f"{self._available_and_evictable_str()}"
)
logger.error(error_msg)
if self.tree_cache is not None:
self.tree_cache.pretty_print()
raise RuntimeError(error_msg)
if backup_state:
return out_cache_loc, state
else:
return out_cache_loc
def alloc_paged_token_slots_extend(
self,
prefix_lens: torch.Tensor,
prefix_lens_cpu: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_cpu: torch.Tensor,
last_loc: torch.Tensor,
extend_num_tokens: int,
backup_state: bool = False,
):
# Over estimate the number of tokens: assume each request needs a new page.
num_tokens = (
extend_num_tokens
+ len(seq_lens_cpu) * self.token_to_kv_pool_allocator.page_size
)
self._evict_tree_cache_if_needed(num_tokens)
if backup_state:
state = self.token_to_kv_pool_allocator.backup_state()
out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
prefix_lens,
prefix_lens_cpu,
seq_lens,
seq_lens_cpu,
last_loc,
extend_num_tokens,
)
if out_cache_loc is None:
error_msg = (
f"Prefill out of memory. Try to lower your batch size.\n"
f"Try to allocate {extend_num_tokens} tokens.\n"
f"{self._available_and_evictable_str()}"
)
logger.error(error_msg)
raise RuntimeError(error_msg)
if backup_state:
return out_cache_loc, state
else:
return out_cache_loc
def alloc_paged_token_slots_decode(
self,
seq_lens: torch.Tensor,
seq_lens_cpu: torch.Tensor,
last_loc: torch.Tensor,
backup_state: bool = False,
):
# Over estimate the number of tokens: assume each request needs a new page.
num_tokens = len(seq_lens) * self.token_to_kv_pool_allocator.page_size
self._evict_tree_cache_if_needed(num_tokens)
if backup_state:
state = self.token_to_kv_pool_allocator.backup_state()
out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(
seq_lens, seq_lens_cpu, last_loc
)
if out_cache_loc is None:
error_msg = (
f"Decode out of memory. Try to lower your batch size.\n"
f"Try to allocate {len(seq_lens)} tokens.\n"
f"{self._available_and_evictable_str()}"
)
logger.error(error_msg)
raise RuntimeError(error_msg)
if backup_state:
return out_cache_loc, state
else:
return out_cache_loc
def write_cache_indices(
self,
req_pool_indices: List[int],
prefix_lens: List[int],
seq_lens: List[int],
extend_lens: List[int],
out_cache_loc: torch.Tensor,
req_pool_indices_tensor: torch.Tensor,
prefix_lens_tensor: torch.Tensor,
seq_lens_tensor: torch.Tensor,
extend_lens_tensor: torch.Tensor,
prefix_tensors: list[torch.Tensor],
):
if support_triton(global_server_args_dict.get("attention_backend")):
prefix_pointers = torch.tensor(
[t.data_ptr() for t in prefix_tensors], device=self.device
)
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
write_req_to_token_pool_triton[(len(req_pool_indices),)](
self.req_to_token_pool.req_to_token,
req_pool_indices_tensor,
prefix_pointers,
prefix_lens_tensor,
seq_lens_tensor,
extend_lens_tensor,
out_cache_loc,
self.req_to_token_pool.req_to_token.shape[1],
)
else:
pt = 0
for i in range(len(req_pool_indices)):
self.req_to_token_pool.write(
(req_pool_indices[i], slice(0, prefix_lens[i])),
prefix_tensors[i],
)
self.req_to_token_pool.write(
(req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])),
out_cache_loc[pt : pt + extend_lens[i]],
)
pt += extend_lens[i]
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
self.encoder_lens_cpu = []
self.encoder_cached = []
@@ -1253,10 +1100,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
self.device, non_blocking=True
)
prefix_lens_tensor = torch.tensor(
prefix_lens, dtype=torch.int64, device=self.device
)
prefix_lens_cpu_tensor = torch.tensor(prefix_lens, dtype=torch.int64)
token_type_ids_tensor = None
if len(token_type_ids) > 0:
@@ -1264,48 +1107,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
sum(token_type_ids, []), dtype=torch.int64
).to(self.device, non_blocking=True)
extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor
# Allocate req slots
bs = len(self.reqs)
req_pool_indices = self.alloc_req_slots(bs, self.reqs)
req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
self.device, non_blocking=True
)
# Set batch fields needed by alloc_for_extend
self.prefix_lens = prefix_lens
self.extend_lens = extend_lens
self.seq_lens = seq_lens_tensor
self.seq_lens_cpu = seq_lens_cpu
self.extend_num_tokens = extend_num_tokens
# Allocate memory
if self.token_to_kv_pool_allocator.page_size == 1:
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
else:
last_loc = [
(
r.prefix_indices[-1:]
if len(r.prefix_indices) > 0
else torch.tensor([-1], device=self.device)
)
for r in self.reqs
]
out_cache_loc = self.alloc_paged_token_slots_extend(
prefix_lens_tensor,
prefix_lens_cpu_tensor,
seq_lens_tensor,
seq_lens_cpu,
torch.cat(last_loc),
extend_num_tokens,
)
# Write allocated tokens to req_to_token_pool
self.write_cache_indices(
req_pool_indices,
prefix_lens,
seq_lens,
extend_lens,
out_cache_loc,
req_pool_indices_tensor,
prefix_lens_tensor,
seq_lens_tensor,
extend_lens_tensor,
[r.prefix_indices for r in reqs],
out_cache_loc, req_pool_indices_tensor, req_pool_indices = alloc_for_extend(
self
)
# Set fields
@@ -1317,12 +1128,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
req.req_pool_idx = req_pool_indices[i]
assert seq_len - pre_len == req.extend_input_len
if pre_len > 0:
if isinstance(self.tree_cache, SWAChunkCache):
self.tree_cache.evict_swa(
req, pre_len, self.model_config.attention_chunk_size
)
# If input_embeds are available, store them
if req.input_embeds is not None:
# If req.input_embeds is already a list, append its content directly
@@ -1414,8 +1219,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.input_ids = input_ids_tensor
self.req_pool_indices = req_pool_indices_tensor
self.seq_lens = seq_lens_tensor
self.seq_lens_cpu = seq_lens_cpu
self.orig_seq_lens = orig_seq_lens_tensor
self.out_cache_loc = out_cache_loc
self.input_embeds = (
@@ -1439,9 +1242,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
self.extend_num_tokens = extend_num_tokens
self.prefix_lens = prefix_lens
self.extend_lens = extend_lens
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
if self.model_config.is_encoder_decoder:
@@ -1681,11 +1481,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.output_ids = None
if self.model_config.is_encoder_decoder:
locs = self.encoder_lens + self.seq_lens
self.prepare_encoder_info_decode()
else:
locs = self.seq_lens.clone()
# Allocate memory
self.out_cache_loc = alloc_for_decode(self, token_per_req=1)
# Update seq_lens after allocation
if self.enable_overlap:
# Do not use in-place operations in the overlap mode
self.seq_lens = self.seq_lens + 1
@@ -1698,28 +1499,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.orig_seq_lens.add_(1)
self.seq_lens_sum += bs
# free memory
if isinstance(self.tree_cache, SWAChunkCache):
for req in self.reqs:
self.tree_cache.evict_swa(
req, req.seqlen - 1, self.model_config.attention_chunk_size
)
# Allocate memory
if self.token_to_kv_pool_allocator.page_size == 1:
self.out_cache_loc = self.alloc_token_slots(bs)
else:
last_loc = self.req_to_token_pool.req_to_token[
self.req_pool_indices, self.seq_lens - 2
]
self.out_cache_loc = self.alloc_paged_token_slots_decode(
self.seq_lens, self.seq_lens_cpu, last_loc
)
self.req_to_token_pool.write(
(self.req_pool_indices, locs), self.out_cache_loc.to(torch.int32)
)
def filter_batch(
self,
chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None,
@@ -1940,23 +1719,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
else:
return self.token_to_kv_pool_allocator.available_size() >= num_tokens
def _available_and_evictable_str(self) -> str:
if self.is_hybrid:
full_available_size = self.token_to_kv_pool_allocator.full_available_size()
swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
full_evictable_size = self.tree_cache.full_evictable_size()
swa_evictable_size = self.tree_cache.swa_evictable_size()
return (
f"Available full tokens: {full_available_size + full_evictable_size} ({full_available_size=} + {full_evictable_size=})\n"
f"Available swa tokens: {swa_available_size + swa_evictable_size} ({swa_available_size=} + {swa_evictable_size=})\n"
f"Full LRU list evictable size: {self.tree_cache.full_lru_list_evictable_size()}\n"
f"SWA LRU list evictable size: {self.tree_cache.swa_lru_list_evictable_size()}\n"
)
else:
available_size = self.token_to_kv_pool_allocator.available_size()
evictable_size = self.tree_cache.evictable_size()
return f"Available tokens: {available_size + evictable_size} ({available_size=} + {evictable_size=})\n"
def __str__(self):
return (
f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
@@ -2038,128 +1800,3 @@ class ModelWorkerBatch:
# Whether this batch is prefill-only (no token generation needed)
is_prefill_only: bool = False
@triton.jit
def write_req_to_token_pool_triton(
req_to_token_ptr, # [max_batch, max_context_len]
req_pool_indices,
prefix_tensors,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
req_to_token_ptr_stride: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(0)
req_pool_index = tl.load(req_pool_indices + pid)
pre_len = tl.load(pre_lens + pid)
seq_len = tl.load(seq_lens + pid)
prefix_tensor = tl.load(prefix_tensors + pid).to(tl.pointer_type(tl.int64))
# write prefix
num_loop = tl.cdiv(pre_len, BLOCK_SIZE)
for i in range(num_loop):
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = offset < pre_len
value = tl.load(prefix_tensor + offset, mask=mask)
tl.store(
req_to_token_ptr + req_pool_index * req_to_token_ptr_stride + offset,
value,
mask=mask,
)
# NOTE: This can be slow for large bs
cumsum_start = tl.cast(0, tl.int64)
for i in range(pid):
cumsum_start += tl.load(extend_lens + i)
num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
for i in range(num_loop):
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = offset < (seq_len - pre_len)
value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
tl.store(
req_to_token_ptr
+ req_pool_index * req_to_token_ptr_stride
+ offset
+ pre_len,
value,
mask=mask,
)
def get_last_loc(
req_to_token: torch.Tensor,
req_pool_indices_tensor: torch.Tensor,
prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
if (
global_server_args_dict["attention_backend"] != "ascend"
and global_server_args_dict["attention_backend"] != "torch_native"
):
impl = get_last_loc_triton
else:
impl = get_last_loc_torch
return impl(req_to_token, req_pool_indices_tensor, prefix_lens_tensor)
def get_last_loc_torch(
req_to_token: torch.Tensor,
req_pool_indices_tensor: torch.Tensor,
prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
return torch.where(
prefix_lens_tensor > 0,
req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],
torch.full_like(prefix_lens_tensor, -1),
)
@triton.jit
def get_last_loc_kernel(
req_to_token,
req_pool_indices_tensor,
prefix_lens_tensor,
result,
num_tokens,
req_to_token_stride,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
mask = offset < num_tokens
prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0)
req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0)
token_mask = prefix_lens > 0
token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1)
tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1)
tl.store(result + offset, tokens, mask=mask)
def get_last_loc_triton(
req_to_token: torch.Tensor,
req_pool_indices_tensor: torch.Tensor,
prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
BLOCK_SIZE = 256
num_tokens = prefix_lens_tensor.shape[0]
result = torch.empty_like(prefix_lens_tensor)
grid = (triton.cdiv(num_tokens, BLOCK_SIZE),)
get_last_loc_kernel[grid](
req_to_token,
req_pool_indices_tensor,
prefix_lens_tensor,
result,
num_tokens,
req_to_token.stride(0),
BLOCK_SIZE,
)
return result