Clean up allocators (#9134)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -267,7 +267,6 @@ def extend(reqs, model_runner):
|
||||
model_config=model_runner.model_config,
|
||||
enable_overlap=False,
|
||||
spec_algorithm=SpeculativeAlgorithm.NONE,
|
||||
enable_custom_logit_processor=False,
|
||||
)
|
||||
batch.prepare_for_extend()
|
||||
_maybe_prepare_mlp_sync_batch(batch, model_runner)
|
||||
|
||||
@@ -864,7 +864,6 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
self.model_config,
|
||||
self.enable_overlap,
|
||||
self.spec_algorithm,
|
||||
self.server_args.enable_custom_logit_processor,
|
||||
)
|
||||
|
||||
# construct fake completed prefill
|
||||
|
||||
@@ -870,6 +870,8 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
if use_ragged:
|
||||
# TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu
|
||||
# and forward_batch.extend_seq_lens_cpu
|
||||
paged_kernel_lens = prefix_lens
|
||||
paged_kernel_lens_sum = paged_kernel_lens.sum().item()
|
||||
else:
|
||||
|
||||
@@ -57,16 +57,36 @@ class TritonAttnBackend(AttentionBackend):
|
||||
self.decode_attention_fwd = torch.compiler.disable(decode_attention_fwd)
|
||||
self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd)
|
||||
|
||||
# Parse args
|
||||
self.skip_prefill = skip_prefill
|
||||
|
||||
max_bs = model_runner.req_to_token_pool.size
|
||||
self.sliding_window_size = model_runner.sliding_window_size
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
|
||||
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
|
||||
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
|
||||
self.num_head = (
|
||||
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||
)
|
||||
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
|
||||
get_attention_tp_size()
|
||||
)
|
||||
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
|
||||
self.max_context_len = model_runner.model_config.context_len
|
||||
self.device = model_runner.device
|
||||
self.device_core_count = get_device_core_count(model_runner.gpu_id)
|
||||
self.static_kv_splits = get_bool_env_var(
|
||||
"SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
|
||||
)
|
||||
self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
|
||||
|
||||
# Check arguments
|
||||
assert not (
|
||||
model_runner.sliding_window_size is not None
|
||||
and model_runner.model_config.is_encoder_decoder
|
||||
), "Sliding window and cross attention are not supported together"
|
||||
self.sliding_window_size = model_runner.sliding_window_size
|
||||
|
||||
# Initialize buffers
|
||||
# TODO(Jianan Ji): Make sure it behaves as expected when kv_indptr_buf is provided and sliding window is enabled
|
||||
if kv_indptr_buf is None:
|
||||
self.kv_indptr = torch.zeros(
|
||||
@@ -87,9 +107,6 @@ class TritonAttnBackend(AttentionBackend):
|
||||
# When provided a buffer, create a clone for the second buffer
|
||||
self.window_kv_indptr = torch.zeros_like(kv_indptr_buf)
|
||||
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
|
||||
|
||||
if not self.skip_prefill:
|
||||
self.qo_indptr = torch.zeros(
|
||||
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||
@@ -99,29 +116,9 @@ class TritonAttnBackend(AttentionBackend):
|
||||
(max_bs + 1,), dtype=torch.int64, device=model_runner.device
|
||||
)
|
||||
|
||||
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
|
||||
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
|
||||
|
||||
self.num_head = (
|
||||
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||
)
|
||||
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
|
||||
get_attention_tp_size()
|
||||
)
|
||||
|
||||
self.static_kv_splits = get_bool_env_var(
|
||||
"SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
|
||||
)
|
||||
self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
|
||||
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
|
||||
|
||||
# Initialize forward metadata
|
||||
self.forward_metadata: ForwardMetadata = None
|
||||
|
||||
self.max_context_len = model_runner.model_config.context_len
|
||||
|
||||
self.device = model_runner.device
|
||||
self.device_core_count = get_device_core_count(model_runner.gpu_id)
|
||||
|
||||
def get_num_kv_splits(
|
||||
self,
|
||||
num_kv_splits: torch.Tensor,
|
||||
@@ -333,7 +330,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
mask_indptr = None
|
||||
attn_logits = None
|
||||
attn_lse = None
|
||||
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
|
||||
max_extend_len = max(forward_batch.extend_seq_lens_cpu)
|
||||
num_kv_splits = None
|
||||
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
|
||||
@@ -113,6 +113,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
||||
"enable_multimodal",
|
||||
"enable_symm_mem",
|
||||
"quantization",
|
||||
"enable_custom_logit_processor",
|
||||
]
|
||||
|
||||
# Put some global args for easy access
|
||||
@@ -909,9 +910,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
spec_algorithm: SpeculativeAlgorithm = None
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
|
||||
|
||||
# Enable custom logit processor
|
||||
enable_custom_logit_processor: bool = False
|
||||
|
||||
# Whether to return hidden states
|
||||
return_hidden_states: bool = False
|
||||
|
||||
@@ -928,7 +926,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
model_config: ModelConfig,
|
||||
enable_overlap: bool,
|
||||
spec_algorithm: SpeculativeAlgorithm,
|
||||
enable_custom_logit_processor: bool,
|
||||
chunked_req: Optional[Req] = None,
|
||||
):
|
||||
return_logprob = any(req.return_logprob for req in reqs)
|
||||
@@ -955,7 +952,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
has_grammar=any(req.grammar for req in reqs),
|
||||
device=req_to_token_pool.device,
|
||||
spec_algorithm=spec_algorithm,
|
||||
enable_custom_logit_processor=enable_custom_logit_processor,
|
||||
return_hidden_states=any(req.return_hidden_states for req in reqs),
|
||||
chunked_req=chunked_req,
|
||||
)
|
||||
@@ -1009,6 +1005,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
extend_num_tokens: int,
|
||||
backup_state: bool = False,
|
||||
):
|
||||
# Over estimate the number of tokens: assume each request needs a new page.
|
||||
num_tokens = (
|
||||
extend_num_tokens
|
||||
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size
|
||||
@@ -1041,8 +1038,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
last_loc: torch.Tensor,
|
||||
backup_state: bool = False,
|
||||
):
|
||||
# Over estimate the number of tokens: assume each request needs a new page.
|
||||
num_tokens = len(seq_lens) * self.token_to_kv_pool_allocator.page_size
|
||||
|
||||
self._evict_tree_cache_if_needed(num_tokens)
|
||||
|
||||
if backup_state:
|
||||
@@ -1721,38 +1718,18 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
extend_prefix_lens = self.prefix_lens
|
||||
extend_logprob_start_lens = self.extend_logprob_start_lens
|
||||
|
||||
if self.forward_mode.is_decode_or_idle():
|
||||
attention_backend_str = global_server_args_dict["decode_attention_backend"]
|
||||
else:
|
||||
attention_backend_str = global_server_args_dict["prefill_attention_backend"]
|
||||
# Create seq_lens_cpu when needed
|
||||
if (
|
||||
attention_backend_str
|
||||
in [
|
||||
"fa3",
|
||||
"flashinfer",
|
||||
"flashmla",
|
||||
"cutlass_mla",
|
||||
"ascend",
|
||||
"trtllm_mha",
|
||||
"aiter",
|
||||
]
|
||||
or global_server_args_dict["enable_two_batch_overlap"]
|
||||
):
|
||||
seq_lens_cpu = (
|
||||
seq_lens_cpu_cache
|
||||
if seq_lens_cpu_cache is not None
|
||||
else self.seq_lens.cpu()
|
||||
)
|
||||
else:
|
||||
seq_lens_cpu = None
|
||||
|
||||
if self.sampling_info:
|
||||
if self.has_grammar:
|
||||
self.sampling_info.grammars = [req.grammar for req in self.reqs]
|
||||
else:
|
||||
self.sampling_info.grammars = None
|
||||
|
||||
seq_lens_cpu = (
|
||||
seq_lens_cpu_cache
|
||||
if seq_lens_cpu_cache is not None
|
||||
else self.seq_lens.cpu()
|
||||
)
|
||||
|
||||
global bid
|
||||
bid += 1
|
||||
return ModelWorkerBatch(
|
||||
@@ -1815,18 +1792,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
return_logprob=self.return_logprob,
|
||||
decoding_reqs=self.decoding_reqs,
|
||||
spec_algorithm=self.spec_algorithm,
|
||||
enable_custom_logit_processor=self.enable_custom_logit_processor,
|
||||
global_num_tokens=self.global_num_tokens,
|
||||
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
|
||||
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
||||
is_extend_in_batch=self.is_extend_in_batch,
|
||||
)
|
||||
|
||||
def _evict_tree_cache_if_needed(
|
||||
self,
|
||||
num_tokens: int,
|
||||
) -> None:
|
||||
if isinstance(self.tree_cache, SWAChunkCache):
|
||||
def _evict_tree_cache_if_needed(self, num_tokens: int):
|
||||
if isinstance(self.tree_cache, (SWAChunkCache, ChunkCache)):
|
||||
return
|
||||
|
||||
if self.is_hybrid:
|
||||
|
||||
@@ -1634,7 +1634,6 @@ class Scheduler(
|
||||
self.model_config,
|
||||
self.enable_overlap,
|
||||
self.spec_algorithm,
|
||||
self.server_args.enable_custom_logit_processor,
|
||||
chunked_req=self.chunked_req,
|
||||
)
|
||||
if self.enable_hierarchical_cache:
|
||||
@@ -2031,7 +2030,6 @@ class Scheduler(
|
||||
self.model_config,
|
||||
self.enable_overlap,
|
||||
self.spec_algorithm,
|
||||
self.server_args.enable_custom_logit_processor,
|
||||
)
|
||||
idle_batch.prepare_for_idle()
|
||||
return idle_batch
|
||||
|
||||
@@ -20,7 +20,6 @@ Page-aligned memory pool.
|
||||
"""
|
||||
|
||||
import abc
|
||||
import weakref
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
@@ -81,9 +80,6 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
|
||||
if self.free_group:
|
||||
self.free(torch.cat(self.free_group))
|
||||
|
||||
def estimated_num_new_pages(self, bs, extend_num_tokens):
|
||||
return bs * ((extend_num_tokens + self.page_size - 1) // self.page_size)
|
||||
|
||||
def merge_and_sort_free(self):
|
||||
if len(self.release_pages) > 0:
|
||||
self.free_pages = torch.cat((self.free_pages, self.release_pages))
|
||||
@@ -149,6 +145,7 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
def alloc(self, need_size: int):
|
||||
if self.need_sort and need_size > len(self.free_pages):
|
||||
self.merge_and_sort_free()
|
||||
|
||||
if need_size > len(self.free_pages):
|
||||
return None
|
||||
|
||||
@@ -437,9 +434,13 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
device: str,
|
||||
kvcache: KVCache,
|
||||
need_sort: bool,
|
||||
max_num_extend_tokens: int,
|
||||
):
|
||||
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
|
||||
self.num_pages = size // page_size
|
||||
self.max_num_extend_tokens_next_power_of_2 = next_power_of_2(
|
||||
max_num_extend_tokens
|
||||
)
|
||||
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
|
||||
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
|
||||
self.clear()
|
||||
@@ -480,7 +481,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
)
|
||||
|
||||
bs = len(prefix_lens)
|
||||
if self.need_sort and self.estimated_num_new_pages(bs, extend_num_tokens) > len(
|
||||
if self.need_sort and extend_num_tokens // self.page_size + bs + 1 > len(
|
||||
self.free_pages
|
||||
):
|
||||
self.merge_and_sort_free()
|
||||
@@ -497,7 +498,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
self.ret_values,
|
||||
next_power_of_2(bs),
|
||||
self.page_size,
|
||||
next_power_of_2(extend_num_tokens),
|
||||
self.max_num_extend_tokens_next_power_of_2,
|
||||
)
|
||||
|
||||
if self.debug_mode:
|
||||
@@ -522,9 +523,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
)
|
||||
|
||||
bs = len(seq_lens)
|
||||
if self.need_sort and self.estimated_num_new_pages(bs, 1) > len(
|
||||
self.free_pages
|
||||
):
|
||||
if self.need_sort and bs > len(self.free_pages):
|
||||
self.merge_and_sort_free()
|
||||
|
||||
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
||||
@@ -578,151 +577,3 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
|
||||
def load_cpu_copy(self, kv_cache_cpu, indices):
|
||||
return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
|
||||
|
||||
|
||||
def alloc_extend_kernel_ascend(
|
||||
prefix_lens,
|
||||
seq_lens,
|
||||
last_loc,
|
||||
free_pages,
|
||||
out_indices,
|
||||
page_size,
|
||||
device,
|
||||
):
|
||||
extend_lens = seq_lens - prefix_lens
|
||||
end_pos = torch.cumsum(extend_lens, 0)
|
||||
start_pos = end_pos - extend_lens
|
||||
num_new_pages = (seq_lens + page_size - 1) // page_size - (
|
||||
prefix_lens + page_size - 1
|
||||
) // page_size
|
||||
num_full_new_pages = (seq_lens) // page_size - (
|
||||
prefix_lens + page_size - 1
|
||||
) // page_size
|
||||
need_page = num_new_pages - num_full_new_pages
|
||||
end_new_pages = torch.cumsum(num_new_pages, 0)
|
||||
start_new_pages = end_new_pages - num_new_pages
|
||||
pos_in_page = torch.arange(page_size, device=device, dtype=torch.int32)
|
||||
for i in range(len(prefix_lens)):
|
||||
num1 = (
|
||||
min(
|
||||
seq_lens[i],
|
||||
(prefix_lens[i] + page_size - 1) // page_size * page_size,
|
||||
)
|
||||
- prefix_lens[i]
|
||||
)
|
||||
if num1:
|
||||
out_indices[start_pos[i] : start_pos[i] + num1] = (
|
||||
last_loc[i] + 1 + pos_in_page[:num1].view(-1)
|
||||
)
|
||||
|
||||
num2 = (
|
||||
seq_lens[i] // page_size - (prefix_lens[i] + page_size - 1) // page_size
|
||||
) * page_size
|
||||
if num2:
|
||||
pages = (
|
||||
free_pages[start_new_pages[i] : end_new_pages[i] - need_page[i]]
|
||||
* page_size
|
||||
)
|
||||
out_indices[start_pos[i] + num1 : start_pos[i] + num1 + num2] = (
|
||||
pages.view(-1, 1) + pos_in_page.view(1, -1)
|
||||
).view(-1)
|
||||
|
||||
num3 = seq_lens[i] - seq_lens[i] // page_size * page_size
|
||||
if num3:
|
||||
out_indices[end_pos[i] - num3 : end_pos[i]] = (
|
||||
free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3]
|
||||
).view(-1)
|
||||
|
||||
|
||||
class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
page_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
kvcache: KVCache,
|
||||
need_sort: bool,
|
||||
):
|
||||
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
|
||||
|
||||
def alloc_extend(
|
||||
self,
|
||||
prefix_lens: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
last_loc: torch.Tensor,
|
||||
extend_num_tokens: int,
|
||||
):
|
||||
if self.debug_mode:
|
||||
assert torch.all(
|
||||
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
||||
)
|
||||
|
||||
estimated_num_new_pages = (
|
||||
(
|
||||
(seq_lens + self.page_size - 1) // self.page_size
|
||||
- (prefix_lens + self.page_size - 1) // self.page_size
|
||||
)
|
||||
.sum()
|
||||
.item()
|
||||
)
|
||||
if self.need_sort and estimated_num_new_pages > len(self.free_pages):
|
||||
self.merge_and_sort_free()
|
||||
|
||||
if estimated_num_new_pages > len(self.free_pages):
|
||||
return None
|
||||
|
||||
out_indices = torch.empty(
|
||||
(extend_num_tokens,), dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
alloc_extend_kernel_ascend(
|
||||
prefix_lens,
|
||||
seq_lens,
|
||||
last_loc,
|
||||
self.free_pages,
|
||||
out_indices,
|
||||
self.page_size,
|
||||
self.device,
|
||||
)
|
||||
|
||||
if self.debug_mode:
|
||||
assert len(torch.unique(out_indices)) == len(out_indices)
|
||||
|
||||
self.free_pages = self.free_pages[estimated_num_new_pages:]
|
||||
return out_indices
|
||||
|
||||
def alloc_decode(
|
||||
self,
|
||||
seq_lens: torch.Tensor,
|
||||
last_loc: torch.Tensor,
|
||||
):
|
||||
if self.debug_mode:
|
||||
assert torch.all(
|
||||
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
||||
)
|
||||
|
||||
need_new_pages = (seq_lens % self.page_size == 1).int()
|
||||
num_new_pages = need_new_pages.sum().item()
|
||||
|
||||
if num_new_pages > len(self.free_pages):
|
||||
self.merge_and_sort_free()
|
||||
|
||||
if num_new_pages > len(self.free_pages):
|
||||
return None
|
||||
|
||||
end_new_pages = torch.cumsum(need_new_pages, 0)
|
||||
start_new_pages = end_new_pages - need_new_pages
|
||||
if num_new_pages == 0:
|
||||
out_indices = last_loc + 1
|
||||
else:
|
||||
out_indices = (last_loc + 1) * (1 - need_new_pages) + self.free_pages[
|
||||
start_new_pages
|
||||
] * self.page_size * need_new_pages
|
||||
|
||||
if self.debug_mode:
|
||||
assert len(torch.unique(out_indices)) == len(out_indices)
|
||||
|
||||
self.free_pages = self.free_pages[num_new_pages:]
|
||||
return out_indices.int()
|
||||
|
||||
158
python/sglang/srt/mem_cache/allocator_ascend.py
Normal file
158
python/sglang/srt/mem_cache/allocator_ascend.py
Normal file
@@ -0,0 +1,158 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.allocator import PagedTokenToKVPoolAllocator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.mem_cache.memory_pool import KVCache
|
||||
|
||||
|
||||
def alloc_extend_kernel_ascend(
|
||||
prefix_lens,
|
||||
seq_lens,
|
||||
last_loc,
|
||||
free_pages,
|
||||
out_indices,
|
||||
page_size,
|
||||
device,
|
||||
):
|
||||
extend_lens = seq_lens - prefix_lens
|
||||
end_pos = torch.cumsum(extend_lens, 0)
|
||||
start_pos = end_pos - extend_lens
|
||||
num_new_pages = (seq_lens + page_size - 1) // page_size - (
|
||||
prefix_lens + page_size - 1
|
||||
) // page_size
|
||||
num_full_new_pages = (seq_lens) // page_size - (
|
||||
prefix_lens + page_size - 1
|
||||
) // page_size
|
||||
need_page = num_new_pages - num_full_new_pages
|
||||
end_new_pages = torch.cumsum(num_new_pages, 0)
|
||||
start_new_pages = end_new_pages - num_new_pages
|
||||
pos_in_page = torch.arange(page_size, device=device, dtype=torch.int32)
|
||||
for i in range(len(prefix_lens)):
|
||||
num1 = (
|
||||
min(
|
||||
seq_lens[i],
|
||||
(prefix_lens[i] + page_size - 1) // page_size * page_size,
|
||||
)
|
||||
- prefix_lens[i]
|
||||
)
|
||||
if num1:
|
||||
out_indices[start_pos[i] : start_pos[i] + num1] = (
|
||||
last_loc[i] + 1 + pos_in_page[:num1].view(-1)
|
||||
)
|
||||
|
||||
num2 = (
|
||||
seq_lens[i] // page_size - (prefix_lens[i] + page_size - 1) // page_size
|
||||
) * page_size
|
||||
if num2:
|
||||
pages = (
|
||||
free_pages[start_new_pages[i] : end_new_pages[i] - need_page[i]]
|
||||
* page_size
|
||||
)
|
||||
out_indices[start_pos[i] + num1 : start_pos[i] + num1 + num2] = (
|
||||
pages.view(-1, 1) + pos_in_page.view(1, -1)
|
||||
).view(-1)
|
||||
|
||||
num3 = seq_lens[i] - seq_lens[i] // page_size * page_size
|
||||
if num3:
|
||||
out_indices[end_pos[i] - num3 : end_pos[i]] = (
|
||||
free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3]
|
||||
).view(-1)
|
||||
|
||||
|
||||
class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
page_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
kvcache: KVCache,
|
||||
need_sort: bool,
|
||||
):
|
||||
super().__init__(size, page_size, dtype, device, kvcache, need_sort, 1)
|
||||
|
||||
def alloc_extend(
|
||||
self,
|
||||
prefix_lens: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
last_loc: torch.Tensor,
|
||||
extend_num_tokens: int,
|
||||
):
|
||||
if self.debug_mode:
|
||||
assert torch.all(
|
||||
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
||||
)
|
||||
|
||||
num_new_pages = (
|
||||
(
|
||||
(seq_lens + self.page_size - 1) // self.page_size
|
||||
- (prefix_lens + self.page_size - 1) // self.page_size
|
||||
)
|
||||
.sum()
|
||||
.item()
|
||||
)
|
||||
if self.need_sort and num_new_pages > len(self.free_pages):
|
||||
self.merge_and_sort_free()
|
||||
|
||||
if num_new_pages > len(self.free_pages):
|
||||
return None
|
||||
|
||||
out_indices = torch.empty(
|
||||
(extend_num_tokens,), dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
alloc_extend_kernel_ascend(
|
||||
prefix_lens,
|
||||
seq_lens,
|
||||
last_loc,
|
||||
self.free_pages,
|
||||
out_indices,
|
||||
self.page_size,
|
||||
self.device,
|
||||
)
|
||||
|
||||
if self.debug_mode:
|
||||
assert len(torch.unique(out_indices)) == len(out_indices)
|
||||
|
||||
self.free_pages = self.free_pages[num_new_pages:]
|
||||
return out_indices
|
||||
|
||||
def alloc_decode(
|
||||
self,
|
||||
seq_lens: torch.Tensor,
|
||||
last_loc: torch.Tensor,
|
||||
):
|
||||
if self.debug_mode:
|
||||
assert torch.all(
|
||||
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
||||
)
|
||||
|
||||
need_new_pages = (seq_lens % self.page_size == 1).int()
|
||||
num_new_pages = need_new_pages.sum().item()
|
||||
|
||||
if num_new_pages > len(self.free_pages):
|
||||
self.merge_and_sort_free()
|
||||
|
||||
if num_new_pages > len(self.free_pages):
|
||||
return None
|
||||
|
||||
end_new_pages = torch.cumsum(need_new_pages, 0)
|
||||
start_new_pages = end_new_pages - need_new_pages
|
||||
if num_new_pages == 0:
|
||||
out_indices = last_loc + 1
|
||||
else:
|
||||
out_indices = (last_loc + 1) * (1 - need_new_pages) + self.free_pages[
|
||||
start_new_pages
|
||||
] * self.page_size * need_new_pages
|
||||
|
||||
if self.debug_mode:
|
||||
assert len(torch.unique(out_indices)) == len(out_indices)
|
||||
|
||||
self.free_pages = self.free_pages[num_new_pages:]
|
||||
return out_indices.int()
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@@ -75,12 +75,12 @@ from sglang.srt.managers.schedule_batch import (
|
||||
global_server_args_dict,
|
||||
)
|
||||
from sglang.srt.mem_cache.allocator import (
|
||||
AscendPagedTokenToKVPoolAllocator,
|
||||
BaseTokenToKVPoolAllocator,
|
||||
PagedTokenToKVPoolAllocator,
|
||||
SWATokenToKVPoolAllocator,
|
||||
TokenToKVPoolAllocator,
|
||||
)
|
||||
from sglang.srt.mem_cache.allocator_ascend import AscendPagedTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
AscendMLAPagedTokenToKVPool,
|
||||
AscendTokenToKVPool,
|
||||
@@ -176,10 +176,6 @@ class ModelRunner:
|
||||
self.mem_fraction_static = mem_fraction_static
|
||||
self.device = server_args.device
|
||||
self.gpu_id = gpu_id
|
||||
|
||||
# Apply the rank zero filter to logger
|
||||
if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
|
||||
logger.addFilter(RankZeroFilter(tp_rank == 0))
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_size = tp_size
|
||||
self.moe_ep_rank = moe_ep_rank
|
||||
@@ -205,15 +201,17 @@ class ModelRunner:
|
||||
self.is_hybrid = model_config.is_hybrid
|
||||
self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
|
||||
self.attention_chunk_size = model_config.attention_chunk_size
|
||||
|
||||
self.forward_pass_id = 0
|
||||
|
||||
# Apply the rank zero filter to logger
|
||||
if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
|
||||
logger.addFilter(RankZeroFilter(tp_rank == 0))
|
||||
if server_args.show_time_cost:
|
||||
enable_show_time_cost()
|
||||
|
||||
# Model-specific adjustment
|
||||
self.model_specific_adjustment()
|
||||
|
||||
if server_args.show_time_cost:
|
||||
enable_show_time_cost()
|
||||
|
||||
# Global vars
|
||||
global_server_args_dict.update(
|
||||
{k: getattr(server_args, k) for k in GLOBAL_SERVER_ARGS_KEYS}
|
||||
@@ -221,8 +219,6 @@ class ModelRunner:
|
||||
# TODO it is indeed not a "server args"
|
||||
"use_mla_backend": self.use_mla_backend,
|
||||
"speculative_algorithm": self.spec_algorithm,
|
||||
}
|
||||
| {
|
||||
"moe_a2a_backend": MoeA2ABackend(server_args.moe_a2a_backend),
|
||||
"deepep_mode": DeepEPMode(server_args.deepep_mode),
|
||||
}
|
||||
@@ -242,13 +238,15 @@ class ModelRunner:
|
||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
||||
deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
|
||||
|
||||
# If it is a draft model, tp_group can be different
|
||||
# Initialize the model runner
|
||||
self.initialize(min_per_gpu_memory)
|
||||
|
||||
# temporary cached values
|
||||
# Temporary cached values
|
||||
self.support_pp = (
|
||||
"pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
|
||||
)
|
||||
|
||||
# For weight updates
|
||||
self._model_update_group = {}
|
||||
|
||||
def initialize(self, min_per_gpu_memory: float):
|
||||
@@ -277,6 +275,7 @@ class ModelRunner:
|
||||
)
|
||||
)
|
||||
|
||||
# Expert parallelism
|
||||
self.eplb_manager = (
|
||||
EPLBManager(self)
|
||||
if self.server_args.enable_eplb and (not self.is_draft_worker)
|
||||
@@ -1160,6 +1159,7 @@ class ModelRunner:
|
||||
max_num_reqs: Optional[int] = None,
|
||||
max_total_tokens: Optional[int] = None,
|
||||
):
|
||||
# Determine the kv cache dtype
|
||||
if self.server_args.kv_cache_dtype == "auto":
|
||||
self.kv_cache_dtype = self.dtype
|
||||
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
|
||||
@@ -1178,6 +1178,8 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
|
||||
if SGLANG_CI_SMALL_KV_SIZE:
|
||||
self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
|
||||
|
||||
if max_num_reqs is None:
|
||||
max_num_reqs = min(
|
||||
@@ -1190,9 +1192,6 @@ class ModelRunner:
|
||||
4096,
|
||||
)
|
||||
|
||||
if SGLANG_CI_SMALL_KV_SIZE:
|
||||
self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
|
||||
|
||||
if not self.spec_algorithm.is_none():
|
||||
if self.is_draft_worker:
|
||||
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
|
||||
@@ -1239,6 +1238,7 @@ class ModelRunner:
|
||||
"Not enough memory. Please try to increase --mem-fraction-static."
|
||||
)
|
||||
|
||||
# Initialize req_to_token_pool
|
||||
if self.req_to_token_pool is None:
|
||||
if self.server_args.disaggregation_mode == "decode":
|
||||
from sglang.srt.disaggregation.decode import DecodeReqToTokenPool
|
||||
@@ -1264,6 +1264,7 @@ class ModelRunner:
|
||||
# Draft worker shares req_to_token_pool with the target worker.
|
||||
assert self.is_draft_worker
|
||||
|
||||
# Initialize token_to_kv_pool
|
||||
if self.server_args.attention_backend == "ascend":
|
||||
if self.use_mla_backend:
|
||||
self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
|
||||
@@ -1349,28 +1350,44 @@ class ModelRunner:
|
||||
end_layer=self.end_layer,
|
||||
)
|
||||
|
||||
# Initialize token_to_kv_pool_allocator
|
||||
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
|
||||
max_num_extend_tokens = (
|
||||
self.server_args.chunked_prefill_size
|
||||
if self.server_args.chunked_prefill_size > 0
|
||||
else self.server_args.max_prefill_tokens
|
||||
)
|
||||
if self.token_to_kv_pool_allocator is None:
|
||||
if self.page_size == 1:
|
||||
if self.is_hybrid:
|
||||
self.token_to_kv_pool_allocator = SWATokenToKVPoolAllocator(
|
||||
self.full_max_total_num_tokens,
|
||||
self.swa_max_total_num_tokens,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
kvcache=self.token_to_kv_pool,
|
||||
need_sort=need_sort,
|
||||
)
|
||||
else:
|
||||
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
|
||||
self.max_total_num_tokens,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
kvcache=self.token_to_kv_pool,
|
||||
need_sort=need_sort,
|
||||
)
|
||||
if self.server_args.attention_backend == "ascend":
|
||||
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
|
||||
self.max_total_num_tokens,
|
||||
page_size=self.page_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
kvcache=self.token_to_kv_pool,
|
||||
need_sort=need_sort,
|
||||
)
|
||||
else:
|
||||
if not _is_npu:
|
||||
if self.page_size == 1:
|
||||
if self.is_hybrid:
|
||||
self.token_to_kv_pool_allocator = SWATokenToKVPoolAllocator(
|
||||
self.full_max_total_num_tokens,
|
||||
self.swa_max_total_num_tokens,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
kvcache=self.token_to_kv_pool,
|
||||
need_sort=need_sort,
|
||||
)
|
||||
else:
|
||||
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
|
||||
self.max_total_num_tokens,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
kvcache=self.token_to_kv_pool,
|
||||
need_sort=need_sort,
|
||||
)
|
||||
else:
|
||||
assert not self.is_hybrid
|
||||
self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
|
||||
self.max_total_num_tokens,
|
||||
page_size=self.page_size,
|
||||
@@ -1378,15 +1395,7 @@ class ModelRunner:
|
||||
device=self.device,
|
||||
kvcache=self.token_to_kv_pool,
|
||||
need_sort=need_sort,
|
||||
)
|
||||
else:
|
||||
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
|
||||
self.max_total_num_tokens,
|
||||
page_size=self.page_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
kvcache=self.token_to_kv_pool,
|
||||
need_sort=need_sort,
|
||||
max_num_extend_tokens=max_num_extend_tokens,
|
||||
)
|
||||
else:
|
||||
assert self.is_draft_worker
|
||||
@@ -1554,15 +1563,13 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
return TRTLLMHAAttnBackend(self)
|
||||
|
||||
elif backend_str == "intel_amx":
|
||||
from sglang.srt.layers.attention.intel_amx_backend import (
|
||||
IntelAMXAttnBackend,
|
||||
)
|
||||
|
||||
logger.info(f"Intel AMX attention backend is enabled.")
|
||||
return IntelAMXAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "dual_chunk_flash_attn":
|
||||
elif backend_str == "dual_chunk_flash_attn":
|
||||
from sglang.srt.layers.attention.dual_chunk_flashattention_backend import (
|
||||
DualChunkFlashAttentionBackend,
|
||||
)
|
||||
@@ -1606,6 +1613,7 @@ class ModelRunner:
|
||||
f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
||||
)
|
||||
self.cuda_graph_runner = CudaGraphRunner(self)
|
||||
|
||||
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
||||
self.cuda_graph_mem_usage = before_mem - after_mem
|
||||
logger.info(
|
||||
|
||||
@@ -68,6 +68,8 @@ class SamplingBatchInfo:
|
||||
|
||||
@classmethod
|
||||
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
|
||||
reqs = batch.reqs
|
||||
device = batch.device
|
||||
temperatures = (
|
||||
@@ -97,10 +99,11 @@ class SamplingBatchInfo:
|
||||
logit_bias[i, int(key)] = value
|
||||
|
||||
# Check if any request has custom logit processor
|
||||
has_custom_logit_processor = (
|
||||
batch.enable_custom_logit_processor # check the flag first.
|
||||
and any(r.custom_logit_processor for r in reqs) # then check the requests.
|
||||
)
|
||||
has_custom_logit_processor = global_server_args_dict[
|
||||
"enable_custom_logit_processor"
|
||||
] and any( # check the flag first.
|
||||
r.custom_logit_processor for r in reqs
|
||||
) # then check the requests.
|
||||
|
||||
if has_custom_logit_processor:
|
||||
# Merge the same type of custom logit processors together
|
||||
|
||||
@@ -575,6 +575,7 @@ class ServerArgs:
|
||||
"Pipeline parallelism is incompatible with overlap schedule."
|
||||
)
|
||||
|
||||
# Hicache
|
||||
if self.hicache_storage_backend == "mooncake":
|
||||
# to use mooncake storage backend, the following conditions must be met:
|
||||
self.hicache_io_backend = "kernel"
|
||||
@@ -1316,19 +1317,23 @@ class ServerArgs:
|
||||
|
||||
# Kernel backend
|
||||
ATTN_BACKENDS = [
|
||||
"aiter",
|
||||
# Common
|
||||
"triton",
|
||||
"torch_native",
|
||||
# NVIDIA specific
|
||||
"cutlass_mla",
|
||||
"fa3",
|
||||
"flashinfer",
|
||||
"flashmla",
|
||||
"intel_amx",
|
||||
"torch_native",
|
||||
"ascend",
|
||||
"triton",
|
||||
"trtllm_mla",
|
||||
"trtllm_mha",
|
||||
"dual_chunk_flash_attn",
|
||||
# AMD specific
|
||||
"aiter",
|
||||
"wave",
|
||||
# Other platforms
|
||||
"intel_amx",
|
||||
"ascend",
|
||||
]
|
||||
parser.add_argument(
|
||||
"--attention-backend",
|
||||
|
||||
Reference in New Issue
Block a user