Separate allocation logic from scheduler (#11313)
This commit is contained in:
@@ -51,6 +51,7 @@ import logging
|
|||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
from types import SimpleNamespace
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -257,11 +258,18 @@ def prepare_synthetic_inputs_for_latency_test(
|
|||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
def extend(reqs, model_runner):
|
def extend(reqs, model_runner):
|
||||||
|
# Create dummy tree_cache for benchmarks (no prefix caching, just allocation)
|
||||||
|
dummy_tree_cache = SimpleNamespace(
|
||||||
|
page_size=1,
|
||||||
|
device=model_runner.device,
|
||||||
|
token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,
|
||||||
|
)
|
||||||
|
|
||||||
batch = ScheduleBatch.init_new(
|
batch = ScheduleBatch.init_new(
|
||||||
reqs=reqs,
|
reqs=reqs,
|
||||||
req_to_token_pool=model_runner.req_to_token_pool,
|
req_to_token_pool=model_runner.req_to_token_pool,
|
||||||
token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,
|
token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,
|
||||||
tree_cache=None,
|
tree_cache=dummy_tree_cache,
|
||||||
model_config=model_runner.model_config,
|
model_config=model_runner.model_config,
|
||||||
enable_overlap=False,
|
enable_overlap=False,
|
||||||
spec_algorithm=SpeculativeAlgorithm.NONE,
|
spec_algorithm=SpeculativeAlgorithm.NONE,
|
||||||
|
|||||||
@@ -45,8 +45,6 @@ from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
||||||
@@ -62,6 +60,7 @@ from sglang.srt.mem_cache.allocator import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
||||||
|
from sglang.srt.mem_cache.common import alloc_for_decode, alloc_for_extend
|
||||||
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
|
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
|
||||||
from sglang.srt.mem_cache.radix_cache import RadixKey
|
from sglang.srt.mem_cache.radix_cache import RadixKey
|
||||||
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||||
@@ -70,7 +69,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw
|
|||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import flatten_nested_list, support_triton
|
from sglang.srt.utils import flatten_nested_list
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
@@ -1001,158 +1000,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
def is_empty(self):
|
def is_empty(self):
|
||||||
return len(self.reqs) == 0
|
return len(self.reqs) == 0
|
||||||
|
|
||||||
def alloc_req_slots(self, num_reqs: int, reqs: Optional[List[Req]] = None):
|
|
||||||
if isinstance(self.req_to_token_pool, HybridReqToTokenPool):
|
|
||||||
req_pool_indices = self.req_to_token_pool.alloc(num_reqs, reqs)
|
|
||||||
else:
|
|
||||||
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
|
|
||||||
if req_pool_indices is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
"alloc_req_slots runs out of memory. "
|
|
||||||
"Please set a smaller number for `--max-running-requests`. "
|
|
||||||
f"{self.req_to_token_pool.available_size()=}, "
|
|
||||||
f"{num_reqs=}, "
|
|
||||||
)
|
|
||||||
return req_pool_indices
|
|
||||||
|
|
||||||
def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
|
|
||||||
self._evict_tree_cache_if_needed(num_tokens)
|
|
||||||
|
|
||||||
if backup_state:
|
|
||||||
state = self.token_to_kv_pool_allocator.backup_state()
|
|
||||||
|
|
||||||
out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
|
|
||||||
if out_cache_loc is None:
|
|
||||||
phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
|
|
||||||
error_msg = (
|
|
||||||
f"{phase_str} out of memory. Try to lower your batch size.\n"
|
|
||||||
f"Try to allocate {num_tokens} tokens.\n"
|
|
||||||
f"{self._available_and_evictable_str()}"
|
|
||||||
)
|
|
||||||
logger.error(error_msg)
|
|
||||||
if self.tree_cache is not None:
|
|
||||||
self.tree_cache.pretty_print()
|
|
||||||
raise RuntimeError(error_msg)
|
|
||||||
|
|
||||||
if backup_state:
|
|
||||||
return out_cache_loc, state
|
|
||||||
else:
|
|
||||||
return out_cache_loc
|
|
||||||
|
|
||||||
def alloc_paged_token_slots_extend(
|
|
||||||
self,
|
|
||||||
prefix_lens: torch.Tensor,
|
|
||||||
prefix_lens_cpu: torch.Tensor,
|
|
||||||
seq_lens: torch.Tensor,
|
|
||||||
seq_lens_cpu: torch.Tensor,
|
|
||||||
last_loc: torch.Tensor,
|
|
||||||
extend_num_tokens: int,
|
|
||||||
backup_state: bool = False,
|
|
||||||
):
|
|
||||||
# Over estimate the number of tokens: assume each request needs a new page.
|
|
||||||
num_tokens = (
|
|
||||||
extend_num_tokens
|
|
||||||
+ len(seq_lens_cpu) * self.token_to_kv_pool_allocator.page_size
|
|
||||||
)
|
|
||||||
self._evict_tree_cache_if_needed(num_tokens)
|
|
||||||
|
|
||||||
if backup_state:
|
|
||||||
state = self.token_to_kv_pool_allocator.backup_state()
|
|
||||||
|
|
||||||
out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
|
|
||||||
prefix_lens,
|
|
||||||
prefix_lens_cpu,
|
|
||||||
seq_lens,
|
|
||||||
seq_lens_cpu,
|
|
||||||
last_loc,
|
|
||||||
extend_num_tokens,
|
|
||||||
)
|
|
||||||
if out_cache_loc is None:
|
|
||||||
error_msg = (
|
|
||||||
f"Prefill out of memory. Try to lower your batch size.\n"
|
|
||||||
f"Try to allocate {extend_num_tokens} tokens.\n"
|
|
||||||
f"{self._available_and_evictable_str()}"
|
|
||||||
)
|
|
||||||
logger.error(error_msg)
|
|
||||||
raise RuntimeError(error_msg)
|
|
||||||
|
|
||||||
if backup_state:
|
|
||||||
return out_cache_loc, state
|
|
||||||
else:
|
|
||||||
return out_cache_loc
|
|
||||||
|
|
||||||
def alloc_paged_token_slots_decode(
|
|
||||||
self,
|
|
||||||
seq_lens: torch.Tensor,
|
|
||||||
seq_lens_cpu: torch.Tensor,
|
|
||||||
last_loc: torch.Tensor,
|
|
||||||
backup_state: bool = False,
|
|
||||||
):
|
|
||||||
# Over estimate the number of tokens: assume each request needs a new page.
|
|
||||||
num_tokens = len(seq_lens) * self.token_to_kv_pool_allocator.page_size
|
|
||||||
self._evict_tree_cache_if_needed(num_tokens)
|
|
||||||
|
|
||||||
if backup_state:
|
|
||||||
state = self.token_to_kv_pool_allocator.backup_state()
|
|
||||||
|
|
||||||
out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(
|
|
||||||
seq_lens, seq_lens_cpu, last_loc
|
|
||||||
)
|
|
||||||
if out_cache_loc is None:
|
|
||||||
error_msg = (
|
|
||||||
f"Decode out of memory. Try to lower your batch size.\n"
|
|
||||||
f"Try to allocate {len(seq_lens)} tokens.\n"
|
|
||||||
f"{self._available_and_evictable_str()}"
|
|
||||||
)
|
|
||||||
logger.error(error_msg)
|
|
||||||
raise RuntimeError(error_msg)
|
|
||||||
|
|
||||||
if backup_state:
|
|
||||||
return out_cache_loc, state
|
|
||||||
else:
|
|
||||||
return out_cache_loc
|
|
||||||
|
|
||||||
def write_cache_indices(
|
|
||||||
self,
|
|
||||||
req_pool_indices: List[int],
|
|
||||||
prefix_lens: List[int],
|
|
||||||
seq_lens: List[int],
|
|
||||||
extend_lens: List[int],
|
|
||||||
out_cache_loc: torch.Tensor,
|
|
||||||
req_pool_indices_tensor: torch.Tensor,
|
|
||||||
prefix_lens_tensor: torch.Tensor,
|
|
||||||
seq_lens_tensor: torch.Tensor,
|
|
||||||
extend_lens_tensor: torch.Tensor,
|
|
||||||
prefix_tensors: list[torch.Tensor],
|
|
||||||
):
|
|
||||||
if support_triton(global_server_args_dict.get("attention_backend")):
|
|
||||||
prefix_pointers = torch.tensor(
|
|
||||||
[t.data_ptr() for t in prefix_tensors], device=self.device
|
|
||||||
)
|
|
||||||
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
|
|
||||||
write_req_to_token_pool_triton[(len(req_pool_indices),)](
|
|
||||||
self.req_to_token_pool.req_to_token,
|
|
||||||
req_pool_indices_tensor,
|
|
||||||
prefix_pointers,
|
|
||||||
prefix_lens_tensor,
|
|
||||||
seq_lens_tensor,
|
|
||||||
extend_lens_tensor,
|
|
||||||
out_cache_loc,
|
|
||||||
self.req_to_token_pool.req_to_token.shape[1],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
pt = 0
|
|
||||||
for i in range(len(req_pool_indices)):
|
|
||||||
self.req_to_token_pool.write(
|
|
||||||
(req_pool_indices[i], slice(0, prefix_lens[i])),
|
|
||||||
prefix_tensors[i],
|
|
||||||
)
|
|
||||||
self.req_to_token_pool.write(
|
|
||||||
(req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])),
|
|
||||||
out_cache_loc[pt : pt + extend_lens[i]],
|
|
||||||
)
|
|
||||||
pt += extend_lens[i]
|
|
||||||
|
|
||||||
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
|
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
|
||||||
self.encoder_lens_cpu = []
|
self.encoder_lens_cpu = []
|
||||||
self.encoder_cached = []
|
self.encoder_cached = []
|
||||||
@@ -1253,10 +1100,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
|
orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
|
||||||
self.device, non_blocking=True
|
self.device, non_blocking=True
|
||||||
)
|
)
|
||||||
prefix_lens_tensor = torch.tensor(
|
|
||||||
prefix_lens, dtype=torch.int64, device=self.device
|
|
||||||
)
|
|
||||||
prefix_lens_cpu_tensor = torch.tensor(prefix_lens, dtype=torch.int64)
|
|
||||||
|
|
||||||
token_type_ids_tensor = None
|
token_type_ids_tensor = None
|
||||||
if len(token_type_ids) > 0:
|
if len(token_type_ids) > 0:
|
||||||
@@ -1264,48 +1107,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
sum(token_type_ids, []), dtype=torch.int64
|
sum(token_type_ids, []), dtype=torch.int64
|
||||||
).to(self.device, non_blocking=True)
|
).to(self.device, non_blocking=True)
|
||||||
|
|
||||||
extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor
|
# Set batch fields needed by alloc_for_extend
|
||||||
|
self.prefix_lens = prefix_lens
|
||||||
# Allocate req slots
|
self.extend_lens = extend_lens
|
||||||
bs = len(self.reqs)
|
self.seq_lens = seq_lens_tensor
|
||||||
req_pool_indices = self.alloc_req_slots(bs, self.reqs)
|
self.seq_lens_cpu = seq_lens_cpu
|
||||||
req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
|
self.extend_num_tokens = extend_num_tokens
|
||||||
self.device, non_blocking=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Allocate memory
|
# Allocate memory
|
||||||
if self.token_to_kv_pool_allocator.page_size == 1:
|
out_cache_loc, req_pool_indices_tensor, req_pool_indices = alloc_for_extend(
|
||||||
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
self
|
||||||
else:
|
|
||||||
last_loc = [
|
|
||||||
(
|
|
||||||
r.prefix_indices[-1:]
|
|
||||||
if len(r.prefix_indices) > 0
|
|
||||||
else torch.tensor([-1], device=self.device)
|
|
||||||
)
|
|
||||||
for r in self.reqs
|
|
||||||
]
|
|
||||||
out_cache_loc = self.alloc_paged_token_slots_extend(
|
|
||||||
prefix_lens_tensor,
|
|
||||||
prefix_lens_cpu_tensor,
|
|
||||||
seq_lens_tensor,
|
|
||||||
seq_lens_cpu,
|
|
||||||
torch.cat(last_loc),
|
|
||||||
extend_num_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Write allocated tokens to req_to_token_pool
|
|
||||||
self.write_cache_indices(
|
|
||||||
req_pool_indices,
|
|
||||||
prefix_lens,
|
|
||||||
seq_lens,
|
|
||||||
extend_lens,
|
|
||||||
out_cache_loc,
|
|
||||||
req_pool_indices_tensor,
|
|
||||||
prefix_lens_tensor,
|
|
||||||
seq_lens_tensor,
|
|
||||||
extend_lens_tensor,
|
|
||||||
[r.prefix_indices for r in reqs],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set fields
|
# Set fields
|
||||||
@@ -1317,12 +1128,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
req.req_pool_idx = req_pool_indices[i]
|
req.req_pool_idx = req_pool_indices[i]
|
||||||
assert seq_len - pre_len == req.extend_input_len
|
assert seq_len - pre_len == req.extend_input_len
|
||||||
|
|
||||||
if pre_len > 0:
|
|
||||||
if isinstance(self.tree_cache, SWAChunkCache):
|
|
||||||
self.tree_cache.evict_swa(
|
|
||||||
req, pre_len, self.model_config.attention_chunk_size
|
|
||||||
)
|
|
||||||
|
|
||||||
# If input_embeds are available, store them
|
# If input_embeds are available, store them
|
||||||
if req.input_embeds is not None:
|
if req.input_embeds is not None:
|
||||||
# If req.input_embeds is already a list, append its content directly
|
# If req.input_embeds is already a list, append its content directly
|
||||||
@@ -1414,8 +1219,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
|
|
||||||
self.input_ids = input_ids_tensor
|
self.input_ids = input_ids_tensor
|
||||||
self.req_pool_indices = req_pool_indices_tensor
|
self.req_pool_indices = req_pool_indices_tensor
|
||||||
self.seq_lens = seq_lens_tensor
|
|
||||||
self.seq_lens_cpu = seq_lens_cpu
|
|
||||||
self.orig_seq_lens = orig_seq_lens_tensor
|
self.orig_seq_lens = orig_seq_lens_tensor
|
||||||
self.out_cache_loc = out_cache_loc
|
self.out_cache_loc = out_cache_loc
|
||||||
self.input_embeds = (
|
self.input_embeds = (
|
||||||
@@ -1439,9 +1242,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
|
self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
|
||||||
|
|
||||||
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
||||||
self.extend_num_tokens = extend_num_tokens
|
|
||||||
self.prefix_lens = prefix_lens
|
|
||||||
self.extend_lens = extend_lens
|
|
||||||
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
|
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
|
||||||
|
|
||||||
if self.model_config.is_encoder_decoder:
|
if self.model_config.is_encoder_decoder:
|
||||||
@@ -1681,11 +1481,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
self.output_ids = None
|
self.output_ids = None
|
||||||
|
|
||||||
if self.model_config.is_encoder_decoder:
|
if self.model_config.is_encoder_decoder:
|
||||||
locs = self.encoder_lens + self.seq_lens
|
|
||||||
self.prepare_encoder_info_decode()
|
self.prepare_encoder_info_decode()
|
||||||
else:
|
|
||||||
locs = self.seq_lens.clone()
|
|
||||||
|
|
||||||
|
# Allocate memory
|
||||||
|
self.out_cache_loc = alloc_for_decode(self, token_per_req=1)
|
||||||
|
|
||||||
|
# Update seq_lens after allocation
|
||||||
if self.enable_overlap:
|
if self.enable_overlap:
|
||||||
# Do not use in-place operations in the overlap mode
|
# Do not use in-place operations in the overlap mode
|
||||||
self.seq_lens = self.seq_lens + 1
|
self.seq_lens = self.seq_lens + 1
|
||||||
@@ -1698,28 +1499,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
self.orig_seq_lens.add_(1)
|
self.orig_seq_lens.add_(1)
|
||||||
self.seq_lens_sum += bs
|
self.seq_lens_sum += bs
|
||||||
|
|
||||||
# free memory
|
|
||||||
if isinstance(self.tree_cache, SWAChunkCache):
|
|
||||||
for req in self.reqs:
|
|
||||||
self.tree_cache.evict_swa(
|
|
||||||
req, req.seqlen - 1, self.model_config.attention_chunk_size
|
|
||||||
)
|
|
||||||
|
|
||||||
# Allocate memory
|
|
||||||
if self.token_to_kv_pool_allocator.page_size == 1:
|
|
||||||
self.out_cache_loc = self.alloc_token_slots(bs)
|
|
||||||
else:
|
|
||||||
last_loc = self.req_to_token_pool.req_to_token[
|
|
||||||
self.req_pool_indices, self.seq_lens - 2
|
|
||||||
]
|
|
||||||
self.out_cache_loc = self.alloc_paged_token_slots_decode(
|
|
||||||
self.seq_lens, self.seq_lens_cpu, last_loc
|
|
||||||
)
|
|
||||||
|
|
||||||
self.req_to_token_pool.write(
|
|
||||||
(self.req_pool_indices, locs), self.out_cache_loc.to(torch.int32)
|
|
||||||
)
|
|
||||||
|
|
||||||
def filter_batch(
|
def filter_batch(
|
||||||
self,
|
self,
|
||||||
chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None,
|
chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None,
|
||||||
@@ -1940,23 +1719,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
else:
|
else:
|
||||||
return self.token_to_kv_pool_allocator.available_size() >= num_tokens
|
return self.token_to_kv_pool_allocator.available_size() >= num_tokens
|
||||||
|
|
||||||
def _available_and_evictable_str(self) -> str:
|
|
||||||
if self.is_hybrid:
|
|
||||||
full_available_size = self.token_to_kv_pool_allocator.full_available_size()
|
|
||||||
swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
|
|
||||||
full_evictable_size = self.tree_cache.full_evictable_size()
|
|
||||||
swa_evictable_size = self.tree_cache.swa_evictable_size()
|
|
||||||
return (
|
|
||||||
f"Available full tokens: {full_available_size + full_evictable_size} ({full_available_size=} + {full_evictable_size=})\n"
|
|
||||||
f"Available swa tokens: {swa_available_size + swa_evictable_size} ({swa_available_size=} + {swa_evictable_size=})\n"
|
|
||||||
f"Full LRU list evictable size: {self.tree_cache.full_lru_list_evictable_size()}\n"
|
|
||||||
f"SWA LRU list evictable size: {self.tree_cache.swa_lru_list_evictable_size()}\n"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
available_size = self.token_to_kv_pool_allocator.available_size()
|
|
||||||
evictable_size = self.tree_cache.evictable_size()
|
|
||||||
return f"Available tokens: {available_size + evictable_size} ({available_size=} + {evictable_size=})\n"
|
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return (
|
return (
|
||||||
f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
|
f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
|
||||||
@@ -2038,128 +1800,3 @@ class ModelWorkerBatch:
|
|||||||
|
|
||||||
# Whether this batch is prefill-only (no token generation needed)
|
# Whether this batch is prefill-only (no token generation needed)
|
||||||
is_prefill_only: bool = False
|
is_prefill_only: bool = False
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def write_req_to_token_pool_triton(
|
|
||||||
req_to_token_ptr, # [max_batch, max_context_len]
|
|
||||||
req_pool_indices,
|
|
||||||
prefix_tensors,
|
|
||||||
pre_lens,
|
|
||||||
seq_lens,
|
|
||||||
extend_lens,
|
|
||||||
out_cache_loc,
|
|
||||||
req_to_token_ptr_stride: tl.constexpr,
|
|
||||||
):
|
|
||||||
BLOCK_SIZE: tl.constexpr = 512
|
|
||||||
pid = tl.program_id(0)
|
|
||||||
|
|
||||||
req_pool_index = tl.load(req_pool_indices + pid)
|
|
||||||
pre_len = tl.load(pre_lens + pid)
|
|
||||||
seq_len = tl.load(seq_lens + pid)
|
|
||||||
prefix_tensor = tl.load(prefix_tensors + pid).to(tl.pointer_type(tl.int64))
|
|
||||||
|
|
||||||
# write prefix
|
|
||||||
num_loop = tl.cdiv(pre_len, BLOCK_SIZE)
|
|
||||||
for i in range(num_loop):
|
|
||||||
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
|
||||||
mask = offset < pre_len
|
|
||||||
value = tl.load(prefix_tensor + offset, mask=mask)
|
|
||||||
tl.store(
|
|
||||||
req_to_token_ptr + req_pool_index * req_to_token_ptr_stride + offset,
|
|
||||||
value,
|
|
||||||
mask=mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
# NOTE: This can be slow for large bs
|
|
||||||
cumsum_start = tl.cast(0, tl.int64)
|
|
||||||
for i in range(pid):
|
|
||||||
cumsum_start += tl.load(extend_lens + i)
|
|
||||||
|
|
||||||
num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
|
|
||||||
for i in range(num_loop):
|
|
||||||
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
|
||||||
mask = offset < (seq_len - pre_len)
|
|
||||||
value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
|
|
||||||
tl.store(
|
|
||||||
req_to_token_ptr
|
|
||||||
+ req_pool_index * req_to_token_ptr_stride
|
|
||||||
+ offset
|
|
||||||
+ pre_len,
|
|
||||||
value,
|
|
||||||
mask=mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_last_loc(
|
|
||||||
req_to_token: torch.Tensor,
|
|
||||||
req_pool_indices_tensor: torch.Tensor,
|
|
||||||
prefix_lens_tensor: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
if (
|
|
||||||
global_server_args_dict["attention_backend"] != "ascend"
|
|
||||||
and global_server_args_dict["attention_backend"] != "torch_native"
|
|
||||||
):
|
|
||||||
impl = get_last_loc_triton
|
|
||||||
else:
|
|
||||||
impl = get_last_loc_torch
|
|
||||||
|
|
||||||
return impl(req_to_token, req_pool_indices_tensor, prefix_lens_tensor)
|
|
||||||
|
|
||||||
|
|
||||||
def get_last_loc_torch(
|
|
||||||
req_to_token: torch.Tensor,
|
|
||||||
req_pool_indices_tensor: torch.Tensor,
|
|
||||||
prefix_lens_tensor: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return torch.where(
|
|
||||||
prefix_lens_tensor > 0,
|
|
||||||
req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],
|
|
||||||
torch.full_like(prefix_lens_tensor, -1),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def get_last_loc_kernel(
|
|
||||||
req_to_token,
|
|
||||||
req_pool_indices_tensor,
|
|
||||||
prefix_lens_tensor,
|
|
||||||
result,
|
|
||||||
num_tokens,
|
|
||||||
req_to_token_stride,
|
|
||||||
BLOCK_SIZE: tl.constexpr,
|
|
||||||
):
|
|
||||||
pid = tl.program_id(0)
|
|
||||||
offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
|
|
||||||
mask = offset < num_tokens
|
|
||||||
|
|
||||||
prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0)
|
|
||||||
req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0)
|
|
||||||
|
|
||||||
token_mask = prefix_lens > 0
|
|
||||||
token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1)
|
|
||||||
tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1)
|
|
||||||
|
|
||||||
tl.store(result + offset, tokens, mask=mask)
|
|
||||||
|
|
||||||
|
|
||||||
def get_last_loc_triton(
|
|
||||||
req_to_token: torch.Tensor,
|
|
||||||
req_pool_indices_tensor: torch.Tensor,
|
|
||||||
prefix_lens_tensor: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
BLOCK_SIZE = 256
|
|
||||||
num_tokens = prefix_lens_tensor.shape[0]
|
|
||||||
result = torch.empty_like(prefix_lens_tensor)
|
|
||||||
grid = (triton.cdiv(num_tokens, BLOCK_SIZE),)
|
|
||||||
|
|
||||||
get_last_loc_kernel[grid](
|
|
||||||
req_to_token,
|
|
||||||
req_pool_indices_tensor,
|
|
||||||
prefix_lens_tensor,
|
|
||||||
result,
|
|
||||||
num_tokens,
|
|
||||||
req_to_token.stride(0),
|
|
||||||
BLOCK_SIZE,
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|||||||
479
python/sglang/srt/mem_cache/common.py
Normal file
479
python/sglang/srt/mem_cache/common.py
Normal file
@@ -0,0 +1,479 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
||||||
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
||||||
|
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
|
||||||
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
from sglang.srt.utils import support_triton
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
GLOBAL_SERVER_ARGS_KEYS = ["attention_backend"]
|
||||||
|
|
||||||
|
global_server_args_dict = {k: getattr(ServerArgs, k) for k in GLOBAL_SERVER_ARGS_KEYS}
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def write_req_to_token_pool_triton(
|
||||||
|
req_to_token_ptr, # [max_batch, max_context_len]
|
||||||
|
req_pool_indices,
|
||||||
|
prefix_tensors,
|
||||||
|
pre_lens,
|
||||||
|
seq_lens,
|
||||||
|
extend_lens,
|
||||||
|
out_cache_loc,
|
||||||
|
req_to_token_ptr_stride: tl.constexpr,
|
||||||
|
):
|
||||||
|
BLOCK_SIZE: tl.constexpr = 512
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
|
||||||
|
req_pool_index = tl.load(req_pool_indices + pid)
|
||||||
|
pre_len = tl.load(pre_lens + pid)
|
||||||
|
seq_len = tl.load(seq_lens + pid)
|
||||||
|
prefix_tensor = tl.load(prefix_tensors + pid).to(tl.pointer_type(tl.int64))
|
||||||
|
|
||||||
|
# write prefix
|
||||||
|
num_loop = tl.cdiv(pre_len, BLOCK_SIZE)
|
||||||
|
for i in range(num_loop):
|
||||||
|
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
||||||
|
mask = offset < pre_len
|
||||||
|
value = tl.load(prefix_tensor + offset, mask=mask)
|
||||||
|
tl.store(
|
||||||
|
req_to_token_ptr + req_pool_index * req_to_token_ptr_stride + offset,
|
||||||
|
value,
|
||||||
|
mask=mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE: This can be slow for large bs
|
||||||
|
cumsum_start = tl.cast(0, tl.int64)
|
||||||
|
for i in range(pid):
|
||||||
|
cumsum_start += tl.load(extend_lens + i)
|
||||||
|
|
||||||
|
num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
|
||||||
|
for i in range(num_loop):
|
||||||
|
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
||||||
|
mask = offset < (seq_len - pre_len)
|
||||||
|
value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
|
||||||
|
tl.store(
|
||||||
|
req_to_token_ptr
|
||||||
|
+ req_pool_index * req_to_token_ptr_stride
|
||||||
|
+ offset
|
||||||
|
+ pre_len,
|
||||||
|
value,
|
||||||
|
mask=mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def write_cache_indices(
|
||||||
|
out_cache_loc: torch.Tensor,
|
||||||
|
req_pool_indices_tensor: torch.Tensor,
|
||||||
|
req_pool_indices_cpu: torch.Tensor,
|
||||||
|
prefix_lens_tensor: torch.Tensor,
|
||||||
|
prefix_lens_cpu: torch.Tensor,
|
||||||
|
seq_lens_tensor: torch.Tensor,
|
||||||
|
seq_lens_cpu: torch.Tensor,
|
||||||
|
extend_lens_tensor: torch.Tensor,
|
||||||
|
extend_lens_cpu: torch.Tensor,
|
||||||
|
prefix_tensors: list[torch.Tensor],
|
||||||
|
req_to_token_pool: ReqToTokenPool,
|
||||||
|
):
|
||||||
|
if support_triton(global_server_args_dict.get("attention_backend")):
|
||||||
|
prefix_pointers = torch.tensor(
|
||||||
|
[t.data_ptr() for t in prefix_tensors],
|
||||||
|
device=req_to_token_pool.device,
|
||||||
|
)
|
||||||
|
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
|
||||||
|
write_req_to_token_pool_triton[(req_pool_indices_tensor.shape[0],)](
|
||||||
|
req_to_token_pool.req_to_token,
|
||||||
|
req_pool_indices_tensor,
|
||||||
|
prefix_pointers,
|
||||||
|
prefix_lens_tensor,
|
||||||
|
seq_lens_tensor,
|
||||||
|
extend_lens_tensor,
|
||||||
|
out_cache_loc,
|
||||||
|
req_to_token_pool.req_to_token.shape[1],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pt = 0
|
||||||
|
for i in range(req_pool_indices_cpu.shape[0]):
|
||||||
|
req_idx = req_pool_indices_cpu[i].item()
|
||||||
|
prefix_len = prefix_lens_cpu[i].item()
|
||||||
|
seq_len = seq_lens_cpu[i].item()
|
||||||
|
extend_len = extend_lens_cpu[i].item()
|
||||||
|
|
||||||
|
req_to_token_pool.write(
|
||||||
|
(req_idx, slice(0, prefix_len)),
|
||||||
|
prefix_tensors[i],
|
||||||
|
)
|
||||||
|
req_to_token_pool.write(
|
||||||
|
(req_idx, slice(prefix_len, seq_len)),
|
||||||
|
out_cache_loc[pt : pt + extend_len],
|
||||||
|
)
|
||||||
|
pt += extend_len
|
||||||
|
|
||||||
|
|
||||||
|
def get_last_loc(
|
||||||
|
req_to_token: torch.Tensor,
|
||||||
|
req_pool_indices_tensor: torch.Tensor,
|
||||||
|
prefix_lens_tensor: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if (
|
||||||
|
global_server_args_dict["attention_backend"] != "ascend"
|
||||||
|
and global_server_args_dict["attention_backend"] != "torch_native"
|
||||||
|
):
|
||||||
|
impl = get_last_loc_triton
|
||||||
|
else:
|
||||||
|
impl = get_last_loc_torch
|
||||||
|
|
||||||
|
return impl(req_to_token, req_pool_indices_tensor, prefix_lens_tensor)
|
||||||
|
|
||||||
|
|
||||||
|
def get_last_loc_torch(
|
||||||
|
req_to_token: torch.Tensor,
|
||||||
|
req_pool_indices_tensor: torch.Tensor,
|
||||||
|
prefix_lens_tensor: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.where(
|
||||||
|
prefix_lens_tensor > 0,
|
||||||
|
req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],
|
||||||
|
torch.full_like(prefix_lens_tensor, -1),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def get_last_loc_kernel(
|
||||||
|
req_to_token,
|
||||||
|
req_pool_indices_tensor,
|
||||||
|
prefix_lens_tensor,
|
||||||
|
result,
|
||||||
|
num_tokens,
|
||||||
|
req_to_token_stride,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
|
||||||
|
mask = offset < num_tokens
|
||||||
|
|
||||||
|
prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0)
|
||||||
|
req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0)
|
||||||
|
|
||||||
|
token_mask = prefix_lens > 0
|
||||||
|
token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1)
|
||||||
|
tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1)
|
||||||
|
|
||||||
|
tl.store(result + offset, tokens, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def get_last_loc_triton(
|
||||||
|
req_to_token: torch.Tensor,
|
||||||
|
req_pool_indices_tensor: torch.Tensor,
|
||||||
|
prefix_lens_tensor: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
BLOCK_SIZE = 256
|
||||||
|
num_tokens = prefix_lens_tensor.shape[0]
|
||||||
|
result = torch.empty_like(prefix_lens_tensor)
|
||||||
|
grid = (triton.cdiv(num_tokens, BLOCK_SIZE),)
|
||||||
|
|
||||||
|
get_last_loc_kernel[grid](
|
||||||
|
req_to_token,
|
||||||
|
req_pool_indices_tensor,
|
||||||
|
prefix_lens_tensor,
|
||||||
|
result,
|
||||||
|
num_tokens,
|
||||||
|
req_to_token.stride(0),
|
||||||
|
BLOCK_SIZE,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def alloc_token_slots(
|
||||||
|
tree_cache: BasePrefixCache,
|
||||||
|
num_tokens: int,
|
||||||
|
backup_state: bool = False,
|
||||||
|
):
|
||||||
|
allocator = tree_cache.token_to_kv_pool_allocator
|
||||||
|
evict_from_tree_cache(tree_cache, num_tokens)
|
||||||
|
|
||||||
|
state = None
|
||||||
|
if backup_state:
|
||||||
|
state = allocator.backup_state()
|
||||||
|
|
||||||
|
out_cache_loc = allocator.alloc(num_tokens)
|
||||||
|
|
||||||
|
if out_cache_loc is None:
|
||||||
|
error_msg = (
|
||||||
|
f"Out of memory. Try to lower your batch size.\n"
|
||||||
|
f"Try to allocate {num_tokens} tokens.\n"
|
||||||
|
f"{available_and_evictable_str(tree_cache)}"
|
||||||
|
)
|
||||||
|
logger.error(error_msg)
|
||||||
|
if tree_cache is not None:
|
||||||
|
tree_cache.pretty_print()
|
||||||
|
raise RuntimeError(error_msg)
|
||||||
|
|
||||||
|
return (out_cache_loc, state) if backup_state else out_cache_loc
|
||||||
|
|
||||||
|
|
||||||
|
def evict_from_tree_cache(tree_cache: BasePrefixCache | None, num_tokens: int):
|
||||||
|
if tree_cache is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(tree_cache, (SWAChunkCache, ChunkCache)):
|
||||||
|
return
|
||||||
|
|
||||||
|
allocator = tree_cache.token_to_kv_pool_allocator
|
||||||
|
|
||||||
|
# Check if this is a hybrid allocator
|
||||||
|
if hasattr(allocator, "full_available_size"):
|
||||||
|
# Hybrid allocator
|
||||||
|
full_available_size = allocator.full_available_size()
|
||||||
|
swa_available_size = allocator.swa_available_size()
|
||||||
|
|
||||||
|
if full_available_size < num_tokens or swa_available_size < num_tokens:
|
||||||
|
full_num_tokens = max(0, num_tokens - full_available_size)
|
||||||
|
swa_num_tokens = max(0, num_tokens - swa_available_size)
|
||||||
|
tree_cache.evict(full_num_tokens, swa_num_tokens)
|
||||||
|
else:
|
||||||
|
# Standard allocator
|
||||||
|
if allocator.available_size() < num_tokens:
|
||||||
|
tree_cache.evict(num_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def alloc_paged_token_slots_extend(
|
||||||
|
tree_cache: BasePrefixCache,
|
||||||
|
prefix_lens: torch.Tensor,
|
||||||
|
prefix_lens_cpu: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
seq_lens_cpu: torch.Tensor,
|
||||||
|
last_loc: torch.Tensor,
|
||||||
|
extend_num_tokens: int,
|
||||||
|
backup_state: bool = False,
|
||||||
|
):
|
||||||
|
# Over estimate the number of tokens: assume each request needs a new page.
|
||||||
|
allocator = tree_cache.token_to_kv_pool_allocator
|
||||||
|
num_tokens = extend_num_tokens + len(seq_lens_cpu) * allocator.page_size
|
||||||
|
evict_from_tree_cache(tree_cache, num_tokens)
|
||||||
|
|
||||||
|
state = None
|
||||||
|
if backup_state:
|
||||||
|
state = allocator.backup_state()
|
||||||
|
|
||||||
|
out_cache_loc = allocator.alloc_extend(
|
||||||
|
prefix_lens,
|
||||||
|
prefix_lens_cpu,
|
||||||
|
seq_lens,
|
||||||
|
seq_lens_cpu,
|
||||||
|
last_loc,
|
||||||
|
extend_num_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
if out_cache_loc is None:
|
||||||
|
error_msg = (
|
||||||
|
f"Prefill out of memory. Try to lower your batch size.\n"
|
||||||
|
f"Try to allocate {extend_num_tokens} tokens.\n"
|
||||||
|
f"{available_and_evictable_str(tree_cache)}"
|
||||||
|
)
|
||||||
|
logger.error(error_msg)
|
||||||
|
if tree_cache is not None:
|
||||||
|
tree_cache.pretty_print()
|
||||||
|
raise RuntimeError(error_msg)
|
||||||
|
|
||||||
|
return (out_cache_loc, state) if backup_state else out_cache_loc
|
||||||
|
|
||||||
|
|
||||||
|
def alloc_req_slots(
|
||||||
|
req_to_token_pool: ReqToTokenPool,
|
||||||
|
num_reqs: int,
|
||||||
|
reqs: list[Req] | None,
|
||||||
|
) -> list[int]:
|
||||||
|
"""Allocate request slots from the pool."""
|
||||||
|
if isinstance(req_to_token_pool, HybridReqToTokenPool):
|
||||||
|
req_pool_indices = req_to_token_pool.alloc(num_reqs, reqs)
|
||||||
|
else:
|
||||||
|
req_pool_indices = req_to_token_pool.alloc(num_reqs)
|
||||||
|
|
||||||
|
if req_pool_indices is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"alloc_req_slots runs out of memory. "
|
||||||
|
"Please set a smaller number for `--max-running-requests`. "
|
||||||
|
f"{req_to_token_pool.available_size()=}, "
|
||||||
|
f"{num_reqs=}, "
|
||||||
|
)
|
||||||
|
return req_pool_indices
|
||||||
|
|
||||||
|
|
||||||
|
def alloc_for_extend(
|
||||||
|
batch: ScheduleBatch,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, list[int]]:
|
||||||
|
"""
|
||||||
|
Allocate KV cache for extend batch and write to req_to_token_pool.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
out_cache_loc: allocated cache locations
|
||||||
|
req_pool_indices_device: request pool indices at a device tensor
|
||||||
|
req_pool_indices: request pool indices as list
|
||||||
|
"""
|
||||||
|
# free out-of-window swa tokens
|
||||||
|
if isinstance(batch.tree_cache, SWAChunkCache):
|
||||||
|
for req, pre_len in zip(batch.reqs, batch.prefix_lens):
|
||||||
|
batch.tree_cache.evict_swa(
|
||||||
|
req, pre_len, batch.model_config.attention_chunk_size
|
||||||
|
)
|
||||||
|
|
||||||
|
bs = len(batch.reqs)
|
||||||
|
prefix_tensors = [r.prefix_indices for r in batch.reqs]
|
||||||
|
|
||||||
|
# Create tensors for allocation
|
||||||
|
prefix_lens_cpu = torch.tensor(batch.prefix_lens, dtype=torch.int64)
|
||||||
|
extend_lens_cpu = torch.tensor(batch.extend_lens, dtype=torch.int64)
|
||||||
|
prefix_lens_device = prefix_lens_cpu.to(batch.device, non_blocking=True)
|
||||||
|
extend_lens_device = extend_lens_cpu.to(batch.device, non_blocking=True)
|
||||||
|
|
||||||
|
# Allocate req slots
|
||||||
|
req_pool_indices = alloc_req_slots(batch.req_to_token_pool, bs, batch.reqs)
|
||||||
|
req_pool_indices_cpu = torch.tensor(req_pool_indices, dtype=torch.int64)
|
||||||
|
req_pool_indices_device = req_pool_indices_cpu.to(batch.device, non_blocking=True)
|
||||||
|
|
||||||
|
# Allocate KV cache (throws exception on failure)
|
||||||
|
if batch.tree_cache.page_size == 1:
|
||||||
|
out_cache_loc = alloc_token_slots(batch.tree_cache, batch.extend_num_tokens)
|
||||||
|
else:
|
||||||
|
# Paged allocation - build last_loc
|
||||||
|
last_loc = [
|
||||||
|
(
|
||||||
|
t[-1:]
|
||||||
|
if len(t) > 0
|
||||||
|
else torch.tensor([-1], device=batch.tree_cache.device)
|
||||||
|
)
|
||||||
|
for t in prefix_tensors
|
||||||
|
]
|
||||||
|
out_cache_loc = alloc_paged_token_slots_extend(
|
||||||
|
tree_cache=batch.tree_cache,
|
||||||
|
prefix_lens=prefix_lens_device,
|
||||||
|
prefix_lens_cpu=prefix_lens_cpu,
|
||||||
|
seq_lens=batch.seq_lens,
|
||||||
|
seq_lens_cpu=batch.seq_lens_cpu,
|
||||||
|
last_loc=torch.cat(last_loc),
|
||||||
|
extend_num_tokens=batch.extend_num_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Write to req_to_token_pool
|
||||||
|
write_cache_indices(
|
||||||
|
out_cache_loc,
|
||||||
|
req_pool_indices_device,
|
||||||
|
req_pool_indices_cpu,
|
||||||
|
prefix_lens_device,
|
||||||
|
prefix_lens_cpu,
|
||||||
|
batch.seq_lens,
|
||||||
|
batch.seq_lens_cpu,
|
||||||
|
extend_lens_device,
|
||||||
|
extend_lens_cpu,
|
||||||
|
prefix_tensors,
|
||||||
|
batch.req_to_token_pool,
|
||||||
|
)
|
||||||
|
|
||||||
|
return out_cache_loc, req_pool_indices_device, req_pool_indices
|
||||||
|
|
||||||
|
|
||||||
|
def alloc_paged_token_slots_decode(
|
||||||
|
tree_cache: BasePrefixCache,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
seq_lens_cpu: torch.Tensor,
|
||||||
|
last_loc: torch.Tensor,
|
||||||
|
token_per_req: int = 1,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Allocate paged KV cache for decode batch."""
|
||||||
|
allocator = tree_cache.token_to_kv_pool_allocator
|
||||||
|
# Over estimate the number of tokens: assume each request needs a new page.
|
||||||
|
num_tokens = len(seq_lens) * allocator.page_size
|
||||||
|
evict_from_tree_cache(tree_cache, num_tokens)
|
||||||
|
|
||||||
|
out_cache_loc = allocator.alloc_decode(seq_lens, seq_lens_cpu, last_loc)
|
||||||
|
|
||||||
|
if out_cache_loc is None:
|
||||||
|
error_msg = (
|
||||||
|
f"Decode out of memory. Try to lower your batch size.\n"
|
||||||
|
f"Try to allocate {len(seq_lens) * token_per_req} tokens.\n"
|
||||||
|
f"{available_and_evictable_str(tree_cache)}"
|
||||||
|
)
|
||||||
|
logger.error(error_msg)
|
||||||
|
if tree_cache is not None:
|
||||||
|
tree_cache.pretty_print()
|
||||||
|
raise RuntimeError(error_msg)
|
||||||
|
|
||||||
|
return out_cache_loc
|
||||||
|
|
||||||
|
|
||||||
|
def alloc_for_decode(batch: ScheduleBatch, token_per_req: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Allocate KV cache for decode batch and write to req_to_token_pool.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
out_cache_loc: allocated cache locations
|
||||||
|
"""
|
||||||
|
if isinstance(batch.tree_cache, SWAChunkCache):
|
||||||
|
for req in batch.reqs:
|
||||||
|
batch.tree_cache.evict_swa(
|
||||||
|
req, req.seqlen - 1, batch.model_config.attention_chunk_size
|
||||||
|
)
|
||||||
|
|
||||||
|
bs = batch.seq_lens.shape[0]
|
||||||
|
|
||||||
|
if batch.tree_cache.page_size == 1:
|
||||||
|
# Non-paged allocation
|
||||||
|
out_cache_loc = alloc_token_slots(batch.tree_cache, bs * token_per_req)
|
||||||
|
else:
|
||||||
|
# Paged allocation
|
||||||
|
last_loc = batch.req_to_token_pool.req_to_token[
|
||||||
|
batch.req_pool_indices, batch.seq_lens - 1
|
||||||
|
]
|
||||||
|
seq_lens_next = batch.seq_lens + token_per_req
|
||||||
|
out_cache_loc = alloc_paged_token_slots_decode(
|
||||||
|
tree_cache=batch.tree_cache,
|
||||||
|
seq_lens=seq_lens_next,
|
||||||
|
seq_lens_cpu=batch.seq_lens_cpu + token_per_req,
|
||||||
|
last_loc=last_loc,
|
||||||
|
token_per_req=token_per_req,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Write to req_to_token_pool
|
||||||
|
if batch.model_config.is_encoder_decoder:
|
||||||
|
locs = batch.encoder_lens + batch.seq_lens
|
||||||
|
else:
|
||||||
|
locs = batch.seq_lens.clone()
|
||||||
|
|
||||||
|
batch.req_to_token_pool.write(
|
||||||
|
(batch.req_pool_indices, locs), out_cache_loc.to(torch.int32)
|
||||||
|
)
|
||||||
|
|
||||||
|
return out_cache_loc
|
||||||
|
|
||||||
|
|
||||||
|
def available_and_evictable_str(tree_cache) -> str:
|
||||||
|
token_to_kv_pool_allocator = tree_cache.token_to_kv_pool_allocator
|
||||||
|
if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
|
||||||
|
full_available_size = token_to_kv_pool_allocator.full_available_size()
|
||||||
|
swa_available_size = token_to_kv_pool_allocator.swa_available_size()
|
||||||
|
full_evictable_size = tree_cache.full_evictable_size()
|
||||||
|
swa_evictable_size = tree_cache.swa_evictable_size()
|
||||||
|
return (
|
||||||
|
f"Available full tokens: {full_available_size + full_evictable_size} ({full_available_size=} + {full_evictable_size=})\n"
|
||||||
|
f"Available swa tokens: {swa_available_size + swa_evictable_size} ({swa_available_size=} + {swa_evictable_size=})\n"
|
||||||
|
f"Full LRU list evictable size: {tree_cache.full_lru_list_evictable_size()}\n"
|
||||||
|
f"SWA LRU list evictable size: {tree_cache.swa_lru_list_evictable_size()}\n"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
available_size = token_to_kv_pool_allocator.available_size()
|
||||||
|
evictable_size = tree_cache.evictable_size()
|
||||||
|
return f"Available tokens: {available_size + evictable_size} ({available_size=} + {evictable_size=})\n"
|
||||||
@@ -10,12 +10,13 @@ from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
|||||||
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.layers.sampler import apply_custom_logit_processor
|
from sglang.srt.layers.sampler import apply_custom_logit_processor
|
||||||
from sglang.srt.managers.schedule_batch import (
|
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
||||||
ScheduleBatch,
|
|
||||||
get_last_loc,
|
|
||||||
global_server_args_dict,
|
|
||||||
)
|
|
||||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||||
|
from sglang.srt.mem_cache.common import (
|
||||||
|
alloc_paged_token_slots_extend,
|
||||||
|
alloc_token_slots,
|
||||||
|
get_last_loc,
|
||||||
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
||||||
from sglang.srt.speculative.spec_info import SpecInput, SpecInputType
|
from sglang.srt.speculative.spec_info import SpecInput, SpecInputType
|
||||||
from sglang.srt.speculative.spec_utils import (
|
from sglang.srt.speculative.spec_utils import (
|
||||||
@@ -100,7 +101,10 @@ class EagleVerifyInput(SpecInput):
|
|||||||
batch.input_ids = self.draft_token
|
batch.input_ids = self.draft_token
|
||||||
|
|
||||||
if page_size == 1:
|
if page_size == 1:
|
||||||
batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids))
|
batch.out_cache_loc = alloc_token_slots(
|
||||||
|
batch.tree_cache,
|
||||||
|
len(batch.input_ids),
|
||||||
|
)
|
||||||
end_offset = batch.seq_lens + self.draft_token_num
|
end_offset = batch.seq_lens + self.draft_token_num
|
||||||
else:
|
else:
|
||||||
prefix_lens = batch.seq_lens
|
prefix_lens = batch.seq_lens
|
||||||
@@ -112,7 +116,8 @@ class EagleVerifyInput(SpecInput):
|
|||||||
batch.req_pool_indices,
|
batch.req_pool_indices,
|
||||||
prefix_lens,
|
prefix_lens,
|
||||||
)
|
)
|
||||||
batch.out_cache_loc = batch.alloc_paged_token_slots_extend(
|
batch.out_cache_loc = alloc_paged_token_slots_extend(
|
||||||
|
batch.tree_cache,
|
||||||
prefix_lens,
|
prefix_lens,
|
||||||
prefix_lens_cpu,
|
prefix_lens_cpu,
|
||||||
end_offset,
|
end_offset,
|
||||||
|
|||||||
@@ -14,13 +14,14 @@ from sglang.srt.distributed import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
|
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
|
||||||
from sglang.srt.managers.schedule_batch import (
|
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
||||||
ScheduleBatch,
|
|
||||||
get_last_loc,
|
|
||||||
global_server_args_dict,
|
|
||||||
)
|
|
||||||
from sglang.srt.managers.scheduler import GenerationBatchResult
|
from sglang.srt.managers.scheduler import GenerationBatchResult
|
||||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||||
|
from sglang.srt.mem_cache.common import (
|
||||||
|
alloc_paged_token_slots_extend,
|
||||||
|
alloc_token_slots,
|
||||||
|
get_last_loc,
|
||||||
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import (
|
from sglang.srt.model_executor.forward_batch_info import (
|
||||||
CaptureHiddenMode,
|
CaptureHiddenMode,
|
||||||
ForwardBatch,
|
ForwardBatch,
|
||||||
@@ -541,8 +542,10 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
# [ topk 0 ] [ topk 1 ]
|
# [ topk 0 ] [ topk 1 ]
|
||||||
# [iter=0, iter=1, iter=2] [iter=0, iter=1, iter=2]
|
# [iter=0, iter=1, iter=2] [iter=0, iter=1, iter=2]
|
||||||
if self.page_size == 1:
|
if self.page_size == 1:
|
||||||
out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots(
|
out_cache_loc, token_to_kv_pool_state_backup = alloc_token_slots(
|
||||||
num_seqs * self.speculative_num_steps * self.topk, backup_state=True
|
batch.tree_cache,
|
||||||
|
num_seqs * self.speculative_num_steps * self.topk,
|
||||||
|
backup_state=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if self.topk == 1:
|
if self.topk == 1:
|
||||||
@@ -601,7 +604,8 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
extend_num_tokens = torch.sum((seq_lens_cpu - prefix_lens_cpu)).item()
|
extend_num_tokens = torch.sum((seq_lens_cpu - prefix_lens_cpu)).item()
|
||||||
|
|
||||||
out_cache_loc, token_to_kv_pool_state_backup = (
|
out_cache_loc, token_to_kv_pool_state_backup = (
|
||||||
batch.alloc_paged_token_slots_extend(
|
alloc_paged_token_slots_extend(
|
||||||
|
batch.tree_cache,
|
||||||
prefix_lens,
|
prefix_lens,
|
||||||
prefix_lens_cpu,
|
prefix_lens_cpu,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
|
|||||||
@@ -16,10 +16,11 @@ import torch.nn.functional as F
|
|||||||
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.layers.sampler import apply_custom_logit_processor
|
from sglang.srt.layers.sampler import apply_custom_logit_processor
|
||||||
from sglang.srt.managers.schedule_batch import (
|
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
||||||
ScheduleBatch,
|
from sglang.srt.mem_cache.common import (
|
||||||
|
alloc_paged_token_slots_extend,
|
||||||
|
alloc_token_slots,
|
||||||
get_last_loc,
|
get_last_loc,
|
||||||
global_server_args_dict,
|
|
||||||
)
|
)
|
||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
from sglang.srt.speculative.spec_info import SpecInput, SpecInputType
|
from sglang.srt.speculative.spec_info import SpecInput, SpecInputType
|
||||||
@@ -74,7 +75,10 @@ class NgramVerifyInput(SpecInput):
|
|||||||
batch.input_ids = self.draft_token
|
batch.input_ids = self.draft_token
|
||||||
|
|
||||||
if page_size == 1:
|
if page_size == 1:
|
||||||
batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids))
|
batch.out_cache_loc = alloc_token_slots(
|
||||||
|
batch.tree_cache,
|
||||||
|
len(batch.input_ids),
|
||||||
|
)
|
||||||
end_offset = batch.seq_lens + self.draft_token_num
|
end_offset = batch.seq_lens + self.draft_token_num
|
||||||
else:
|
else:
|
||||||
# TODO(lsyin): add prefix lens cpu here to support page size > 1
|
# TODO(lsyin): add prefix lens cpu here to support page size > 1
|
||||||
@@ -87,7 +91,8 @@ class NgramVerifyInput(SpecInput):
|
|||||||
batch.req_pool_indices,
|
batch.req_pool_indices,
|
||||||
prefix_lens,
|
prefix_lens,
|
||||||
)
|
)
|
||||||
batch.out_cache_loc = batch.alloc_paged_token_slots_extend(
|
batch.out_cache_loc = alloc_paged_token_slots_extend(
|
||||||
|
batch.tree_cache,
|
||||||
prefix_lens,
|
prefix_lens,
|
||||||
prefix_lens_cpu,
|
prefix_lens_cpu,
|
||||||
end_offset,
|
end_offset,
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ python3 test_forward_split_prefill.py
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -95,11 +96,18 @@ class TestForwardSplitPrefill(CustomTestCase):
|
|||||||
req.logprob_start_len = len(req.origin_input_ids) - 1
|
req.logprob_start_len = len(req.origin_input_ids) - 1
|
||||||
reqs.append(req)
|
reqs.append(req)
|
||||||
|
|
||||||
|
# Create dummy tree_cache for tests (no prefix caching, just allocation)
|
||||||
|
dummy_tree_cache = SimpleNamespace(
|
||||||
|
page_size=1,
|
||||||
|
device=self.model_runner.device,
|
||||||
|
token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator,
|
||||||
|
)
|
||||||
|
|
||||||
batch = ScheduleBatch.init_new(
|
batch = ScheduleBatch.init_new(
|
||||||
reqs=reqs,
|
reqs=reqs,
|
||||||
req_to_token_pool=self.model_runner.req_to_token_pool,
|
req_to_token_pool=self.model_runner.req_to_token_pool,
|
||||||
token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator,
|
token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator,
|
||||||
tree_cache=None,
|
tree_cache=dummy_tree_cache,
|
||||||
model_config=self.model_config,
|
model_config=self.model_config,
|
||||||
enable_overlap=False,
|
enable_overlap=False,
|
||||||
spec_algorithm=SpeculativeAlgorithm.NONE,
|
spec_algorithm=SpeculativeAlgorithm.NONE,
|
||||||
|
|||||||
Reference in New Issue
Block a user