Clean up allocators (#9134)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Lianmin Zheng
2025-08-13 13:56:04 -07:00
committed by GitHub
parent 2f20f43026
commit 9e426466af
16 changed files with 288 additions and 295 deletions

View File

@@ -267,7 +267,6 @@ def extend(reqs, model_runner):
model_config=model_runner.model_config,
enable_overlap=False,
spec_algorithm=SpeculativeAlgorithm.NONE,
enable_custom_logit_processor=False,
)
batch.prepare_for_extend()
_maybe_prepare_mlp_sync_batch(batch, model_runner)

View File

@@ -864,7 +864,6 @@ class SchedulerDisaggregationDecodeMixin:
self.model_config,
self.enable_overlap,
self.spec_algorithm,
self.server_args.enable_custom_logit_processor,
)
# construct fake completed prefill

View File

@@ -870,6 +870,8 @@ class FlashInferIndicesUpdaterPrefill:
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
if use_ragged:
# TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu
# and forward_batch.extend_seq_lens_cpu
paged_kernel_lens = prefix_lens
paged_kernel_lens_sum = paged_kernel_lens.sum().item()
else:

View File

@@ -57,16 +57,36 @@ class TritonAttnBackend(AttentionBackend):
self.decode_attention_fwd = torch.compiler.disable(decode_attention_fwd)
self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd)
# Parse args
self.skip_prefill = skip_prefill
max_bs = model_runner.req_to_token_pool.size
self.sliding_window_size = model_runner.sliding_window_size
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
self.num_head = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
get_attention_tp_size()
)
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
self.max_context_len = model_runner.model_config.context_len
self.device = model_runner.device
self.device_core_count = get_device_core_count(model_runner.gpu_id)
self.static_kv_splits = get_bool_env_var(
"SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
)
self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
# Check arguments
assert not (
model_runner.sliding_window_size is not None
and model_runner.model_config.is_encoder_decoder
), "Sliding window and cross attention are not supported together"
self.sliding_window_size = model_runner.sliding_window_size
# Initialize buffers
# TODO(Jianan Ji): Make sure it behaves as expected when kv_indptr_buf is provided and sliding window is enabled
if kv_indptr_buf is None:
self.kv_indptr = torch.zeros(
@@ -87,9 +107,6 @@ class TritonAttnBackend(AttentionBackend):
# When provided a buffer, create a clone for the second buffer
self.window_kv_indptr = torch.zeros_like(kv_indptr_buf)
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
if not self.skip_prefill:
self.qo_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
@@ -99,29 +116,9 @@ class TritonAttnBackend(AttentionBackend):
(max_bs + 1,), dtype=torch.int64, device=model_runner.device
)
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
self.num_head = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
get_attention_tp_size()
)
self.static_kv_splits = get_bool_env_var(
"SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
)
self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
# Initialize forward metadata
self.forward_metadata: ForwardMetadata = None
self.max_context_len = model_runner.model_config.context_len
self.device = model_runner.device
self.device_core_count = get_device_core_count(model_runner.gpu_id)
def get_num_kv_splits(
self,
num_kv_splits: torch.Tensor,
@@ -333,7 +330,7 @@ class TritonAttnBackend(AttentionBackend):
mask_indptr = None
attn_logits = None
attn_lse = None
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
max_extend_len = max(forward_batch.extend_seq_lens_cpu)
num_kv_splits = None
self.forward_metadata = ForwardMetadata(

View File

@@ -113,6 +113,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
"enable_multimodal",
"enable_symm_mem",
"quantization",
"enable_custom_logit_processor",
]
# Put some global args for easy access
@@ -909,9 +910,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
# Enable custom logit processor
enable_custom_logit_processor: bool = False
# Whether to return hidden states
return_hidden_states: bool = False
@@ -928,7 +926,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
model_config: ModelConfig,
enable_overlap: bool,
spec_algorithm: SpeculativeAlgorithm,
enable_custom_logit_processor: bool,
chunked_req: Optional[Req] = None,
):
return_logprob = any(req.return_logprob for req in reqs)
@@ -955,7 +952,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
has_grammar=any(req.grammar for req in reqs),
device=req_to_token_pool.device,
spec_algorithm=spec_algorithm,
enable_custom_logit_processor=enable_custom_logit_processor,
return_hidden_states=any(req.return_hidden_states for req in reqs),
chunked_req=chunked_req,
)
@@ -1009,6 +1005,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
extend_num_tokens: int,
backup_state: bool = False,
):
# Over estimate the number of tokens: assume each request needs a new page.
num_tokens = (
extend_num_tokens
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size
@@ -1041,8 +1038,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
last_loc: torch.Tensor,
backup_state: bool = False,
):
# Over estimate the number of tokens: assume each request needs a new page.
num_tokens = len(seq_lens) * self.token_to_kv_pool_allocator.page_size
self._evict_tree_cache_if_needed(num_tokens)
if backup_state:
@@ -1721,38 +1718,18 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
extend_prefix_lens = self.prefix_lens
extend_logprob_start_lens = self.extend_logprob_start_lens
if self.forward_mode.is_decode_or_idle():
attention_backend_str = global_server_args_dict["decode_attention_backend"]
else:
attention_backend_str = global_server_args_dict["prefill_attention_backend"]
# Create seq_lens_cpu when needed
if (
attention_backend_str
in [
"fa3",
"flashinfer",
"flashmla",
"cutlass_mla",
"ascend",
"trtllm_mha",
"aiter",
]
or global_server_args_dict["enable_two_batch_overlap"]
):
seq_lens_cpu = (
seq_lens_cpu_cache
if seq_lens_cpu_cache is not None
else self.seq_lens.cpu()
)
else:
seq_lens_cpu = None
if self.sampling_info:
if self.has_grammar:
self.sampling_info.grammars = [req.grammar for req in self.reqs]
else:
self.sampling_info.grammars = None
seq_lens_cpu = (
seq_lens_cpu_cache
if seq_lens_cpu_cache is not None
else self.seq_lens.cpu()
)
global bid
bid += 1
return ModelWorkerBatch(
@@ -1815,18 +1792,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
return_logprob=self.return_logprob,
decoding_reqs=self.decoding_reqs,
spec_algorithm=self.spec_algorithm,
enable_custom_logit_processor=self.enable_custom_logit_processor,
global_num_tokens=self.global_num_tokens,
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
is_extend_in_batch=self.is_extend_in_batch,
)
def _evict_tree_cache_if_needed(
self,
num_tokens: int,
) -> None:
if isinstance(self.tree_cache, SWAChunkCache):
def _evict_tree_cache_if_needed(self, num_tokens: int):
if isinstance(self.tree_cache, (SWAChunkCache, ChunkCache)):
return
if self.is_hybrid:

View File

@@ -1634,7 +1634,6 @@ class Scheduler(
self.model_config,
self.enable_overlap,
self.spec_algorithm,
self.server_args.enable_custom_logit_processor,
chunked_req=self.chunked_req,
)
if self.enable_hierarchical_cache:
@@ -2031,7 +2030,6 @@ class Scheduler(
self.model_config,
self.enable_overlap,
self.spec_algorithm,
self.server_args.enable_custom_logit_processor,
)
idle_batch.prepare_for_idle()
return idle_batch

View File

@@ -20,7 +20,6 @@ Page-aligned memory pool.
"""
import abc
import weakref
from typing import TYPE_CHECKING
import torch
@@ -81,9 +80,6 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
if self.free_group:
self.free(torch.cat(self.free_group))
def estimated_num_new_pages(self, bs, extend_num_tokens):
return bs * ((extend_num_tokens + self.page_size - 1) // self.page_size)
def merge_and_sort_free(self):
if len(self.release_pages) > 0:
self.free_pages = torch.cat((self.free_pages, self.release_pages))
@@ -149,6 +145,7 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
def alloc(self, need_size: int):
if self.need_sort and need_size > len(self.free_pages):
self.merge_and_sort_free()
if need_size > len(self.free_pages):
return None
@@ -437,9 +434,13 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
device: str,
kvcache: KVCache,
need_sort: bool,
max_num_extend_tokens: int,
):
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
self.num_pages = size // page_size
self.max_num_extend_tokens_next_power_of_2 = next_power_of_2(
max_num_extend_tokens
)
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
self.clear()
@@ -480,7 +481,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
)
bs = len(prefix_lens)
if self.need_sort and self.estimated_num_new_pages(bs, extend_num_tokens) > len(
if self.need_sort and extend_num_tokens // self.page_size + bs + 1 > len(
self.free_pages
):
self.merge_and_sort_free()
@@ -497,7 +498,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
self.ret_values,
next_power_of_2(bs),
self.page_size,
next_power_of_2(extend_num_tokens),
self.max_num_extend_tokens_next_power_of_2,
)
if self.debug_mode:
@@ -522,9 +523,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
)
bs = len(seq_lens)
if self.need_sort and self.estimated_num_new_pages(bs, 1) > len(
self.free_pages
):
if self.need_sort and bs > len(self.free_pages):
self.merge_and_sort_free()
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
@@ -578,151 +577,3 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
def load_cpu_copy(self, kv_cache_cpu, indices):
return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
def alloc_extend_kernel_ascend(
prefix_lens,
seq_lens,
last_loc,
free_pages,
out_indices,
page_size,
device,
):
extend_lens = seq_lens - prefix_lens
end_pos = torch.cumsum(extend_lens, 0)
start_pos = end_pos - extend_lens
num_new_pages = (seq_lens + page_size - 1) // page_size - (
prefix_lens + page_size - 1
) // page_size
num_full_new_pages = (seq_lens) // page_size - (
prefix_lens + page_size - 1
) // page_size
need_page = num_new_pages - num_full_new_pages
end_new_pages = torch.cumsum(num_new_pages, 0)
start_new_pages = end_new_pages - num_new_pages
pos_in_page = torch.arange(page_size, device=device, dtype=torch.int32)
for i in range(len(prefix_lens)):
num1 = (
min(
seq_lens[i],
(prefix_lens[i] + page_size - 1) // page_size * page_size,
)
- prefix_lens[i]
)
if num1:
out_indices[start_pos[i] : start_pos[i] + num1] = (
last_loc[i] + 1 + pos_in_page[:num1].view(-1)
)
num2 = (
seq_lens[i] // page_size - (prefix_lens[i] + page_size - 1) // page_size
) * page_size
if num2:
pages = (
free_pages[start_new_pages[i] : end_new_pages[i] - need_page[i]]
* page_size
)
out_indices[start_pos[i] + num1 : start_pos[i] + num1 + num2] = (
pages.view(-1, 1) + pos_in_page.view(1, -1)
).view(-1)
num3 = seq_lens[i] - seq_lens[i] // page_size * page_size
if num3:
out_indices[end_pos[i] - num3 : end_pos[i]] = (
free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3]
).view(-1)
class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
def __init__(
self,
size: int,
page_size: int,
dtype: torch.dtype,
device: str,
kvcache: KVCache,
need_sort: bool,
):
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
def alloc_extend(
self,
prefix_lens: torch.Tensor,
seq_lens: torch.Tensor,
last_loc: torch.Tensor,
extend_num_tokens: int,
):
if self.debug_mode:
assert torch.all(
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
)
estimated_num_new_pages = (
(
(seq_lens + self.page_size - 1) // self.page_size
- (prefix_lens + self.page_size - 1) // self.page_size
)
.sum()
.item()
)
if self.need_sort and estimated_num_new_pages > len(self.free_pages):
self.merge_and_sort_free()
if estimated_num_new_pages > len(self.free_pages):
return None
out_indices = torch.empty(
(extend_num_tokens,), dtype=torch.int32, device=self.device
)
alloc_extend_kernel_ascend(
prefix_lens,
seq_lens,
last_loc,
self.free_pages,
out_indices,
self.page_size,
self.device,
)
if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices)
self.free_pages = self.free_pages[estimated_num_new_pages:]
return out_indices
def alloc_decode(
self,
seq_lens: torch.Tensor,
last_loc: torch.Tensor,
):
if self.debug_mode:
assert torch.all(
(last_loc + 2) % self.page_size == seq_lens % self.page_size
)
need_new_pages = (seq_lens % self.page_size == 1).int()
num_new_pages = need_new_pages.sum().item()
if num_new_pages > len(self.free_pages):
self.merge_and_sort_free()
if num_new_pages > len(self.free_pages):
return None
end_new_pages = torch.cumsum(need_new_pages, 0)
start_new_pages = end_new_pages - need_new_pages
if num_new_pages == 0:
out_indices = last_loc + 1
else:
out_indices = (last_loc + 1) * (1 - need_new_pages) + self.free_pages[
start_new_pages
] * self.page_size * need_new_pages
if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices)
self.free_pages = self.free_pages[num_new_pages:]
return out_indices.int()

View File

@@ -0,0 +1,158 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
from sglang.srt.mem_cache.allocator import PagedTokenToKVPoolAllocator
if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import KVCache
def alloc_extend_kernel_ascend(
prefix_lens,
seq_lens,
last_loc,
free_pages,
out_indices,
page_size,
device,
):
extend_lens = seq_lens - prefix_lens
end_pos = torch.cumsum(extend_lens, 0)
start_pos = end_pos - extend_lens
num_new_pages = (seq_lens + page_size - 1) // page_size - (
prefix_lens + page_size - 1
) // page_size
num_full_new_pages = (seq_lens) // page_size - (
prefix_lens + page_size - 1
) // page_size
need_page = num_new_pages - num_full_new_pages
end_new_pages = torch.cumsum(num_new_pages, 0)
start_new_pages = end_new_pages - num_new_pages
pos_in_page = torch.arange(page_size, device=device, dtype=torch.int32)
for i in range(len(prefix_lens)):
num1 = (
min(
seq_lens[i],
(prefix_lens[i] + page_size - 1) // page_size * page_size,
)
- prefix_lens[i]
)
if num1:
out_indices[start_pos[i] : start_pos[i] + num1] = (
last_loc[i] + 1 + pos_in_page[:num1].view(-1)
)
num2 = (
seq_lens[i] // page_size - (prefix_lens[i] + page_size - 1) // page_size
) * page_size
if num2:
pages = (
free_pages[start_new_pages[i] : end_new_pages[i] - need_page[i]]
* page_size
)
out_indices[start_pos[i] + num1 : start_pos[i] + num1 + num2] = (
pages.view(-1, 1) + pos_in_page.view(1, -1)
).view(-1)
num3 = seq_lens[i] - seq_lens[i] // page_size * page_size
if num3:
out_indices[end_pos[i] - num3 : end_pos[i]] = (
free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3]
).view(-1)
class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
def __init__(
self,
size: int,
page_size: int,
dtype: torch.dtype,
device: str,
kvcache: KVCache,
need_sort: bool,
):
super().__init__(size, page_size, dtype, device, kvcache, need_sort, 1)
def alloc_extend(
self,
prefix_lens: torch.Tensor,
seq_lens: torch.Tensor,
last_loc: torch.Tensor,
extend_num_tokens: int,
):
if self.debug_mode:
assert torch.all(
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
)
num_new_pages = (
(
(seq_lens + self.page_size - 1) // self.page_size
- (prefix_lens + self.page_size - 1) // self.page_size
)
.sum()
.item()
)
if self.need_sort and num_new_pages > len(self.free_pages):
self.merge_and_sort_free()
if num_new_pages > len(self.free_pages):
return None
out_indices = torch.empty(
(extend_num_tokens,), dtype=torch.int32, device=self.device
)
alloc_extend_kernel_ascend(
prefix_lens,
seq_lens,
last_loc,
self.free_pages,
out_indices,
self.page_size,
self.device,
)
if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices)
self.free_pages = self.free_pages[num_new_pages:]
return out_indices
def alloc_decode(
self,
seq_lens: torch.Tensor,
last_loc: torch.Tensor,
):
if self.debug_mode:
assert torch.all(
(last_loc + 2) % self.page_size == seq_lens % self.page_size
)
need_new_pages = (seq_lens % self.page_size == 1).int()
num_new_pages = need_new_pages.sum().item()
if num_new_pages > len(self.free_pages):
self.merge_and_sort_free()
if num_new_pages > len(self.free_pages):
return None
end_new_pages = torch.cumsum(need_new_pages, 0)
start_new_pages = end_new_pages - need_new_pages
if num_new_pages == 0:
out_indices = last_loc + 1
else:
out_indices = (last_loc + 1) * (1 - need_new_pages) + self.free_pages[
start_new_pages
] * self.page_size * need_new_pages
if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices)
self.free_pages = self.free_pages[num_new_pages:]
return out_indices.int()

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
"""Cache for chunked prefill, used when RadixCache is disabled."""
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Optional
import torch

View File

@@ -75,12 +75,12 @@ from sglang.srt.managers.schedule_batch import (
global_server_args_dict,
)
from sglang.srt.mem_cache.allocator import (
AscendPagedTokenToKVPoolAllocator,
BaseTokenToKVPoolAllocator,
PagedTokenToKVPoolAllocator,
SWATokenToKVPoolAllocator,
TokenToKVPoolAllocator,
)
from sglang.srt.mem_cache.allocator_ascend import AscendPagedTokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import (
AscendMLAPagedTokenToKVPool,
AscendTokenToKVPool,
@@ -176,10 +176,6 @@ class ModelRunner:
self.mem_fraction_static = mem_fraction_static
self.device = server_args.device
self.gpu_id = gpu_id
# Apply the rank zero filter to logger
if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
logger.addFilter(RankZeroFilter(tp_rank == 0))
self.tp_rank = tp_rank
self.tp_size = tp_size
self.moe_ep_rank = moe_ep_rank
@@ -205,15 +201,17 @@ class ModelRunner:
self.is_hybrid = model_config.is_hybrid
self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
self.attention_chunk_size = model_config.attention_chunk_size
self.forward_pass_id = 0
# Apply the rank zero filter to logger
if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
logger.addFilter(RankZeroFilter(tp_rank == 0))
if server_args.show_time_cost:
enable_show_time_cost()
# Model-specific adjustment
self.model_specific_adjustment()
if server_args.show_time_cost:
enable_show_time_cost()
# Global vars
global_server_args_dict.update(
{k: getattr(server_args, k) for k in GLOBAL_SERVER_ARGS_KEYS}
@@ -221,8 +219,6 @@ class ModelRunner:
# TODO it is indeed not a "server args"
"use_mla_backend": self.use_mla_backend,
"speculative_algorithm": self.spec_algorithm,
}
| {
"moe_a2a_backend": MoeA2ABackend(server_args.moe_a2a_backend),
"deepep_mode": DeepEPMode(server_args.deepep_mode),
}
@@ -242,13 +238,15 @@ class ModelRunner:
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
# If it is a draft model, tp_group can be different
# Initialize the model runner
self.initialize(min_per_gpu_memory)
# temporary cached values
# Temporary cached values
self.support_pp = (
"pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
)
# For weight updates
self._model_update_group = {}
def initialize(self, min_per_gpu_memory: float):
@@ -277,6 +275,7 @@ class ModelRunner:
)
)
# Expert parallelism
self.eplb_manager = (
EPLBManager(self)
if self.server_args.enable_eplb and (not self.is_draft_worker)
@@ -1160,6 +1159,7 @@ class ModelRunner:
max_num_reqs: Optional[int] = None,
max_total_tokens: Optional[int] = None,
):
# Determine the kv cache dtype
if self.server_args.kv_cache_dtype == "auto":
self.kv_cache_dtype = self.dtype
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
@@ -1178,6 +1178,8 @@ class ModelRunner:
)
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
if SGLANG_CI_SMALL_KV_SIZE:
self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
if max_num_reqs is None:
max_num_reqs = min(
@@ -1190,9 +1192,6 @@ class ModelRunner:
4096,
)
if SGLANG_CI_SMALL_KV_SIZE:
self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
if not self.spec_algorithm.is_none():
if self.is_draft_worker:
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
@@ -1239,6 +1238,7 @@ class ModelRunner:
"Not enough memory. Please try to increase --mem-fraction-static."
)
# Initialize req_to_token_pool
if self.req_to_token_pool is None:
if self.server_args.disaggregation_mode == "decode":
from sglang.srt.disaggregation.decode import DecodeReqToTokenPool
@@ -1264,6 +1264,7 @@ class ModelRunner:
# Draft worker shares req_to_token_pool with the target worker.
assert self.is_draft_worker
# Initialize token_to_kv_pool
if self.server_args.attention_backend == "ascend":
if self.use_mla_backend:
self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
@@ -1349,28 +1350,44 @@ class ModelRunner:
end_layer=self.end_layer,
)
# Initialize token_to_kv_pool_allocator
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
max_num_extend_tokens = (
self.server_args.chunked_prefill_size
if self.server_args.chunked_prefill_size > 0
else self.server_args.max_prefill_tokens
)
if self.token_to_kv_pool_allocator is None:
if self.page_size == 1:
if self.is_hybrid:
self.token_to_kv_pool_allocator = SWATokenToKVPoolAllocator(
self.full_max_total_num_tokens,
self.swa_max_total_num_tokens,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
need_sort=need_sort,
)
else:
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
self.max_total_num_tokens,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
need_sort=need_sort,
)
if self.server_args.attention_backend == "ascend":
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
need_sort=need_sort,
)
else:
if not _is_npu:
if self.page_size == 1:
if self.is_hybrid:
self.token_to_kv_pool_allocator = SWATokenToKVPoolAllocator(
self.full_max_total_num_tokens,
self.swa_max_total_num_tokens,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
need_sort=need_sort,
)
else:
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
self.max_total_num_tokens,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
need_sort=need_sort,
)
else:
assert not self.is_hybrid
self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
self.max_total_num_tokens,
page_size=self.page_size,
@@ -1378,15 +1395,7 @@ class ModelRunner:
device=self.device,
kvcache=self.token_to_kv_pool,
need_sort=need_sort,
)
else:
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
need_sort=need_sort,
max_num_extend_tokens=max_num_extend_tokens,
)
else:
assert self.is_draft_worker
@@ -1554,15 +1563,13 @@ class ModelRunner:
)
return TRTLLMHAAttnBackend(self)
elif backend_str == "intel_amx":
from sglang.srt.layers.attention.intel_amx_backend import (
IntelAMXAttnBackend,
)
logger.info(f"Intel AMX attention backend is enabled.")
return IntelAMXAttnBackend(self)
elif self.server_args.attention_backend == "dual_chunk_flash_attn":
elif backend_str == "dual_chunk_flash_attn":
from sglang.srt.layers.attention.dual_chunk_flashattention_backend import (
DualChunkFlashAttentionBackend,
)
@@ -1606,6 +1613,7 @@ class ModelRunner:
f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
)
self.cuda_graph_runner = CudaGraphRunner(self)
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
self.cuda_graph_mem_usage = before_mem - after_mem
logger.info(

View File

@@ -68,6 +68,8 @@ class SamplingBatchInfo:
@classmethod
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
from sglang.srt.managers.schedule_batch import global_server_args_dict
reqs = batch.reqs
device = batch.device
temperatures = (
@@ -97,10 +99,11 @@ class SamplingBatchInfo:
logit_bias[i, int(key)] = value
# Check if any request has custom logit processor
has_custom_logit_processor = (
batch.enable_custom_logit_processor # check the flag first.
and any(r.custom_logit_processor for r in reqs) # then check the requests.
)
has_custom_logit_processor = global_server_args_dict[
"enable_custom_logit_processor"
] and any( # check the flag first.
r.custom_logit_processor for r in reqs
) # then check the requests.
if has_custom_logit_processor:
# Merge the same type of custom logit processors together

View File

@@ -575,6 +575,7 @@ class ServerArgs:
"Pipeline parallelism is incompatible with overlap schedule."
)
# Hicache
if self.hicache_storage_backend == "mooncake":
# to use mooncake storage backend, the following conditions must be met:
self.hicache_io_backend = "kernel"
@@ -1316,19 +1317,23 @@ class ServerArgs:
# Kernel backend
ATTN_BACKENDS = [
"aiter",
# Common
"triton",
"torch_native",
# NVIDIA specific
"cutlass_mla",
"fa3",
"flashinfer",
"flashmla",
"intel_amx",
"torch_native",
"ascend",
"triton",
"trtllm_mla",
"trtllm_mha",
"dual_chunk_flash_attn",
# AMD specific
"aiter",
"wave",
# Other platforms
"intel_amx",
"ascend",
]
parser.add_argument(
"--attention-backend",