Clean up allocators (#9134)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -63,12 +63,12 @@ You can find additional accuracy eval examples in:
|
|||||||
## Benchmark the speed
|
## Benchmark the speed
|
||||||
Refer to [Benchmark and Profiling](../developer_guide/benchmark_and_profiling.md).
|
Refer to [Benchmark and Profiling](../developer_guide/benchmark_and_profiling.md).
|
||||||
|
|
||||||
## Request a Review
|
## Request a review
|
||||||
You can identify potential reviewers for your code by checking the [code owners](https://github.com/sgl-project/sglang/blob/main/.github/CODEOWNERS) and [reviewers](https://github.com/sgl-project/sglang/blob/main/.github/REVIEWERS.md) files.
|
You can identify potential reviewers for your code by checking the [code owners](https://github.com/sgl-project/sglang/blob/main/.github/CODEOWNERS) and [reviewers](https://github.com/sgl-project/sglang/blob/main/.github/REVIEWERS.md) files.
|
||||||
Another effective strategy is to review the file modification history and contact individuals who have frequently edited the files.
|
Another effective strategy is to review the file modification history and contact individuals who have frequently edited the files.
|
||||||
If you modify files protected by code owners, their approval is required to merge the code.
|
If you modify files protected by code owners, their approval is required to merge the code.
|
||||||
|
|
||||||
## General Code Style
|
## General code style
|
||||||
- Avoid code duplication. If the same code snippet (more than five lines) appears multiple times, extract it into a shared function.
|
- Avoid code duplication. If the same code snippet (more than five lines) appears multiple times, extract it into a shared function.
|
||||||
- Minimize device synchronization. Reduce expensive CPU-GPU synchronization operations, such as `tensor.item()` or `tensor.cpu()`, whenever possible. Use vectorized code.
|
- Minimize device synchronization. Reduce expensive CPU-GPU synchronization operations, such as `tensor.item()` or `tensor.cpu()`, whenever possible. Use vectorized code.
|
||||||
- Keep files concise. If a file exceeds 2,000 lines of code, split it into multiple smaller files.
|
- Keep files concise. If a file exceeds 2,000 lines of code, split it into multiple smaller files.
|
||||||
|
|||||||
@@ -267,7 +267,6 @@ def extend(reqs, model_runner):
|
|||||||
model_config=model_runner.model_config,
|
model_config=model_runner.model_config,
|
||||||
enable_overlap=False,
|
enable_overlap=False,
|
||||||
spec_algorithm=SpeculativeAlgorithm.NONE,
|
spec_algorithm=SpeculativeAlgorithm.NONE,
|
||||||
enable_custom_logit_processor=False,
|
|
||||||
)
|
)
|
||||||
batch.prepare_for_extend()
|
batch.prepare_for_extend()
|
||||||
_maybe_prepare_mlp_sync_batch(batch, model_runner)
|
_maybe_prepare_mlp_sync_batch(batch, model_runner)
|
||||||
|
|||||||
@@ -864,7 +864,6 @@ class SchedulerDisaggregationDecodeMixin:
|
|||||||
self.model_config,
|
self.model_config,
|
||||||
self.enable_overlap,
|
self.enable_overlap,
|
||||||
self.spec_algorithm,
|
self.spec_algorithm,
|
||||||
self.server_args.enable_custom_logit_processor,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# construct fake completed prefill
|
# construct fake completed prefill
|
||||||
|
|||||||
@@ -870,6 +870,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||||
):
|
):
|
||||||
if use_ragged:
|
if use_ragged:
|
||||||
|
# TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu
|
||||||
|
# and forward_batch.extend_seq_lens_cpu
|
||||||
paged_kernel_lens = prefix_lens
|
paged_kernel_lens = prefix_lens
|
||||||
paged_kernel_lens_sum = paged_kernel_lens.sum().item()
|
paged_kernel_lens_sum = paged_kernel_lens.sum().item()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -57,16 +57,36 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
self.decode_attention_fwd = torch.compiler.disable(decode_attention_fwd)
|
self.decode_attention_fwd = torch.compiler.disable(decode_attention_fwd)
|
||||||
self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd)
|
self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd)
|
||||||
|
|
||||||
|
# Parse args
|
||||||
self.skip_prefill = skip_prefill
|
self.skip_prefill = skip_prefill
|
||||||
|
|
||||||
max_bs = model_runner.req_to_token_pool.size
|
max_bs = model_runner.req_to_token_pool.size
|
||||||
|
self.sliding_window_size = model_runner.sliding_window_size
|
||||||
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||||
|
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
|
||||||
|
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
|
||||||
|
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
|
||||||
|
self.num_head = (
|
||||||
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||||
|
)
|
||||||
|
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
|
||||||
|
get_attention_tp_size()
|
||||||
|
)
|
||||||
|
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
|
||||||
|
self.max_context_len = model_runner.model_config.context_len
|
||||||
|
self.device = model_runner.device
|
||||||
|
self.device_core_count = get_device_core_count(model_runner.gpu_id)
|
||||||
|
self.static_kv_splits = get_bool_env_var(
|
||||||
|
"SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
|
||||||
|
)
|
||||||
|
self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
|
||||||
|
|
||||||
|
# Check arguments
|
||||||
assert not (
|
assert not (
|
||||||
model_runner.sliding_window_size is not None
|
model_runner.sliding_window_size is not None
|
||||||
and model_runner.model_config.is_encoder_decoder
|
and model_runner.model_config.is_encoder_decoder
|
||||||
), "Sliding window and cross attention are not supported together"
|
), "Sliding window and cross attention are not supported together"
|
||||||
self.sliding_window_size = model_runner.sliding_window_size
|
|
||||||
|
|
||||||
|
# Initialize buffers
|
||||||
# TODO(Jianan Ji): Make sure it behaves as expected when kv_indptr_buf is provided and sliding window is enabled
|
# TODO(Jianan Ji): Make sure it behaves as expected when kv_indptr_buf is provided and sliding window is enabled
|
||||||
if kv_indptr_buf is None:
|
if kv_indptr_buf is None:
|
||||||
self.kv_indptr = torch.zeros(
|
self.kv_indptr = torch.zeros(
|
||||||
@@ -87,9 +107,6 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
# When provided a buffer, create a clone for the second buffer
|
# When provided a buffer, create a clone for the second buffer
|
||||||
self.window_kv_indptr = torch.zeros_like(kv_indptr_buf)
|
self.window_kv_indptr = torch.zeros_like(kv_indptr_buf)
|
||||||
|
|
||||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
|
||||||
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
|
|
||||||
|
|
||||||
if not self.skip_prefill:
|
if not self.skip_prefill:
|
||||||
self.qo_indptr = torch.zeros(
|
self.qo_indptr = torch.zeros(
|
||||||
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||||
@@ -99,29 +116,9 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
(max_bs + 1,), dtype=torch.int64, device=model_runner.device
|
(max_bs + 1,), dtype=torch.int64, device=model_runner.device
|
||||||
)
|
)
|
||||||
|
|
||||||
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
|
# Initialize forward metadata
|
||||||
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
|
|
||||||
|
|
||||||
self.num_head = (
|
|
||||||
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
|
||||||
)
|
|
||||||
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
|
|
||||||
get_attention_tp_size()
|
|
||||||
)
|
|
||||||
|
|
||||||
self.static_kv_splits = get_bool_env_var(
|
|
||||||
"SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
|
|
||||||
)
|
|
||||||
self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
|
|
||||||
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
|
|
||||||
|
|
||||||
self.forward_metadata: ForwardMetadata = None
|
self.forward_metadata: ForwardMetadata = None
|
||||||
|
|
||||||
self.max_context_len = model_runner.model_config.context_len
|
|
||||||
|
|
||||||
self.device = model_runner.device
|
|
||||||
self.device_core_count = get_device_core_count(model_runner.gpu_id)
|
|
||||||
|
|
||||||
def get_num_kv_splits(
|
def get_num_kv_splits(
|
||||||
self,
|
self,
|
||||||
num_kv_splits: torch.Tensor,
|
num_kv_splits: torch.Tensor,
|
||||||
@@ -333,7 +330,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
mask_indptr = None
|
mask_indptr = None
|
||||||
attn_logits = None
|
attn_logits = None
|
||||||
attn_lse = None
|
attn_lse = None
|
||||||
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
|
max_extend_len = max(forward_batch.extend_seq_lens_cpu)
|
||||||
num_kv_splits = None
|
num_kv_splits = None
|
||||||
|
|
||||||
self.forward_metadata = ForwardMetadata(
|
self.forward_metadata = ForwardMetadata(
|
||||||
|
|||||||
@@ -113,6 +113,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|||||||
"enable_multimodal",
|
"enable_multimodal",
|
||||||
"enable_symm_mem",
|
"enable_symm_mem",
|
||||||
"quantization",
|
"quantization",
|
||||||
|
"enable_custom_logit_processor",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Put some global args for easy access
|
# Put some global args for easy access
|
||||||
@@ -909,9 +910,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
spec_algorithm: SpeculativeAlgorithm = None
|
spec_algorithm: SpeculativeAlgorithm = None
|
||||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
|
||||||
|
|
||||||
# Enable custom logit processor
|
|
||||||
enable_custom_logit_processor: bool = False
|
|
||||||
|
|
||||||
# Whether to return hidden states
|
# Whether to return hidden states
|
||||||
return_hidden_states: bool = False
|
return_hidden_states: bool = False
|
||||||
|
|
||||||
@@ -928,7 +926,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
enable_overlap: bool,
|
enable_overlap: bool,
|
||||||
spec_algorithm: SpeculativeAlgorithm,
|
spec_algorithm: SpeculativeAlgorithm,
|
||||||
enable_custom_logit_processor: bool,
|
|
||||||
chunked_req: Optional[Req] = None,
|
chunked_req: Optional[Req] = None,
|
||||||
):
|
):
|
||||||
return_logprob = any(req.return_logprob for req in reqs)
|
return_logprob = any(req.return_logprob for req in reqs)
|
||||||
@@ -955,7 +952,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
has_grammar=any(req.grammar for req in reqs),
|
has_grammar=any(req.grammar for req in reqs),
|
||||||
device=req_to_token_pool.device,
|
device=req_to_token_pool.device,
|
||||||
spec_algorithm=spec_algorithm,
|
spec_algorithm=spec_algorithm,
|
||||||
enable_custom_logit_processor=enable_custom_logit_processor,
|
|
||||||
return_hidden_states=any(req.return_hidden_states for req in reqs),
|
return_hidden_states=any(req.return_hidden_states for req in reqs),
|
||||||
chunked_req=chunked_req,
|
chunked_req=chunked_req,
|
||||||
)
|
)
|
||||||
@@ -1009,6 +1005,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
extend_num_tokens: int,
|
extend_num_tokens: int,
|
||||||
backup_state: bool = False,
|
backup_state: bool = False,
|
||||||
):
|
):
|
||||||
|
# Over estimate the number of tokens: assume each request needs a new page.
|
||||||
num_tokens = (
|
num_tokens = (
|
||||||
extend_num_tokens
|
extend_num_tokens
|
||||||
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size
|
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size
|
||||||
@@ -1041,8 +1038,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
last_loc: torch.Tensor,
|
last_loc: torch.Tensor,
|
||||||
backup_state: bool = False,
|
backup_state: bool = False,
|
||||||
):
|
):
|
||||||
|
# Over estimate the number of tokens: assume each request needs a new page.
|
||||||
num_tokens = len(seq_lens) * self.token_to_kv_pool_allocator.page_size
|
num_tokens = len(seq_lens) * self.token_to_kv_pool_allocator.page_size
|
||||||
|
|
||||||
self._evict_tree_cache_if_needed(num_tokens)
|
self._evict_tree_cache_if_needed(num_tokens)
|
||||||
|
|
||||||
if backup_state:
|
if backup_state:
|
||||||
@@ -1721,38 +1718,18 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
extend_prefix_lens = self.prefix_lens
|
extend_prefix_lens = self.prefix_lens
|
||||||
extend_logprob_start_lens = self.extend_logprob_start_lens
|
extend_logprob_start_lens = self.extend_logprob_start_lens
|
||||||
|
|
||||||
if self.forward_mode.is_decode_or_idle():
|
|
||||||
attention_backend_str = global_server_args_dict["decode_attention_backend"]
|
|
||||||
else:
|
|
||||||
attention_backend_str = global_server_args_dict["prefill_attention_backend"]
|
|
||||||
# Create seq_lens_cpu when needed
|
|
||||||
if (
|
|
||||||
attention_backend_str
|
|
||||||
in [
|
|
||||||
"fa3",
|
|
||||||
"flashinfer",
|
|
||||||
"flashmla",
|
|
||||||
"cutlass_mla",
|
|
||||||
"ascend",
|
|
||||||
"trtllm_mha",
|
|
||||||
"aiter",
|
|
||||||
]
|
|
||||||
or global_server_args_dict["enable_two_batch_overlap"]
|
|
||||||
):
|
|
||||||
seq_lens_cpu = (
|
|
||||||
seq_lens_cpu_cache
|
|
||||||
if seq_lens_cpu_cache is not None
|
|
||||||
else self.seq_lens.cpu()
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
seq_lens_cpu = None
|
|
||||||
|
|
||||||
if self.sampling_info:
|
if self.sampling_info:
|
||||||
if self.has_grammar:
|
if self.has_grammar:
|
||||||
self.sampling_info.grammars = [req.grammar for req in self.reqs]
|
self.sampling_info.grammars = [req.grammar for req in self.reqs]
|
||||||
else:
|
else:
|
||||||
self.sampling_info.grammars = None
|
self.sampling_info.grammars = None
|
||||||
|
|
||||||
|
seq_lens_cpu = (
|
||||||
|
seq_lens_cpu_cache
|
||||||
|
if seq_lens_cpu_cache is not None
|
||||||
|
else self.seq_lens.cpu()
|
||||||
|
)
|
||||||
|
|
||||||
global bid
|
global bid
|
||||||
bid += 1
|
bid += 1
|
||||||
return ModelWorkerBatch(
|
return ModelWorkerBatch(
|
||||||
@@ -1815,18 +1792,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
return_logprob=self.return_logprob,
|
return_logprob=self.return_logprob,
|
||||||
decoding_reqs=self.decoding_reqs,
|
decoding_reqs=self.decoding_reqs,
|
||||||
spec_algorithm=self.spec_algorithm,
|
spec_algorithm=self.spec_algorithm,
|
||||||
enable_custom_logit_processor=self.enable_custom_logit_processor,
|
|
||||||
global_num_tokens=self.global_num_tokens,
|
global_num_tokens=self.global_num_tokens,
|
||||||
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
|
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
|
||||||
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
||||||
is_extend_in_batch=self.is_extend_in_batch,
|
is_extend_in_batch=self.is_extend_in_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _evict_tree_cache_if_needed(
|
def _evict_tree_cache_if_needed(self, num_tokens: int):
|
||||||
self,
|
if isinstance(self.tree_cache, (SWAChunkCache, ChunkCache)):
|
||||||
num_tokens: int,
|
|
||||||
) -> None:
|
|
||||||
if isinstance(self.tree_cache, SWAChunkCache):
|
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.is_hybrid:
|
if self.is_hybrid:
|
||||||
|
|||||||
@@ -1634,7 +1634,6 @@ class Scheduler(
|
|||||||
self.model_config,
|
self.model_config,
|
||||||
self.enable_overlap,
|
self.enable_overlap,
|
||||||
self.spec_algorithm,
|
self.spec_algorithm,
|
||||||
self.server_args.enable_custom_logit_processor,
|
|
||||||
chunked_req=self.chunked_req,
|
chunked_req=self.chunked_req,
|
||||||
)
|
)
|
||||||
if self.enable_hierarchical_cache:
|
if self.enable_hierarchical_cache:
|
||||||
@@ -2031,7 +2030,6 @@ class Scheduler(
|
|||||||
self.model_config,
|
self.model_config,
|
||||||
self.enable_overlap,
|
self.enable_overlap,
|
||||||
self.spec_algorithm,
|
self.spec_algorithm,
|
||||||
self.server_args.enable_custom_logit_processor,
|
|
||||||
)
|
)
|
||||||
idle_batch.prepare_for_idle()
|
idle_batch.prepare_for_idle()
|
||||||
return idle_batch
|
return idle_batch
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ Page-aligned memory pool.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
import weakref
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -81,9 +80,6 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
|
|||||||
if self.free_group:
|
if self.free_group:
|
||||||
self.free(torch.cat(self.free_group))
|
self.free(torch.cat(self.free_group))
|
||||||
|
|
||||||
def estimated_num_new_pages(self, bs, extend_num_tokens):
|
|
||||||
return bs * ((extend_num_tokens + self.page_size - 1) // self.page_size)
|
|
||||||
|
|
||||||
def merge_and_sort_free(self):
|
def merge_and_sort_free(self):
|
||||||
if len(self.release_pages) > 0:
|
if len(self.release_pages) > 0:
|
||||||
self.free_pages = torch.cat((self.free_pages, self.release_pages))
|
self.free_pages = torch.cat((self.free_pages, self.release_pages))
|
||||||
@@ -149,6 +145,7 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|||||||
def alloc(self, need_size: int):
|
def alloc(self, need_size: int):
|
||||||
if self.need_sort and need_size > len(self.free_pages):
|
if self.need_sort and need_size > len(self.free_pages):
|
||||||
self.merge_and_sort_free()
|
self.merge_and_sort_free()
|
||||||
|
|
||||||
if need_size > len(self.free_pages):
|
if need_size > len(self.free_pages):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -437,9 +434,13 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|||||||
device: str,
|
device: str,
|
||||||
kvcache: KVCache,
|
kvcache: KVCache,
|
||||||
need_sort: bool,
|
need_sort: bool,
|
||||||
|
max_num_extend_tokens: int,
|
||||||
):
|
):
|
||||||
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
|
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
|
||||||
self.num_pages = size // page_size
|
self.num_pages = size // page_size
|
||||||
|
self.max_num_extend_tokens_next_power_of_2 = next_power_of_2(
|
||||||
|
max_num_extend_tokens
|
||||||
|
)
|
||||||
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
|
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
|
||||||
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
|
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
|
||||||
self.clear()
|
self.clear()
|
||||||
@@ -480,7 +481,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|||||||
)
|
)
|
||||||
|
|
||||||
bs = len(prefix_lens)
|
bs = len(prefix_lens)
|
||||||
if self.need_sort and self.estimated_num_new_pages(bs, extend_num_tokens) > len(
|
if self.need_sort and extend_num_tokens // self.page_size + bs + 1 > len(
|
||||||
self.free_pages
|
self.free_pages
|
||||||
):
|
):
|
||||||
self.merge_and_sort_free()
|
self.merge_and_sort_free()
|
||||||
@@ -497,7 +498,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|||||||
self.ret_values,
|
self.ret_values,
|
||||||
next_power_of_2(bs),
|
next_power_of_2(bs),
|
||||||
self.page_size,
|
self.page_size,
|
||||||
next_power_of_2(extend_num_tokens),
|
self.max_num_extend_tokens_next_power_of_2,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.debug_mode:
|
if self.debug_mode:
|
||||||
@@ -522,9 +523,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|||||||
)
|
)
|
||||||
|
|
||||||
bs = len(seq_lens)
|
bs = len(seq_lens)
|
||||||
if self.need_sort and self.estimated_num_new_pages(bs, 1) > len(
|
if self.need_sort and bs > len(self.free_pages):
|
||||||
self.free_pages
|
|
||||||
):
|
|
||||||
self.merge_and_sort_free()
|
self.merge_and_sort_free()
|
||||||
|
|
||||||
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
||||||
@@ -578,151 +577,3 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|||||||
|
|
||||||
def load_cpu_copy(self, kv_cache_cpu, indices):
|
def load_cpu_copy(self, kv_cache_cpu, indices):
|
||||||
return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
|
return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
|
||||||
|
|
||||||
|
|
||||||
def alloc_extend_kernel_ascend(
|
|
||||||
prefix_lens,
|
|
||||||
seq_lens,
|
|
||||||
last_loc,
|
|
||||||
free_pages,
|
|
||||||
out_indices,
|
|
||||||
page_size,
|
|
||||||
device,
|
|
||||||
):
|
|
||||||
extend_lens = seq_lens - prefix_lens
|
|
||||||
end_pos = torch.cumsum(extend_lens, 0)
|
|
||||||
start_pos = end_pos - extend_lens
|
|
||||||
num_new_pages = (seq_lens + page_size - 1) // page_size - (
|
|
||||||
prefix_lens + page_size - 1
|
|
||||||
) // page_size
|
|
||||||
num_full_new_pages = (seq_lens) // page_size - (
|
|
||||||
prefix_lens + page_size - 1
|
|
||||||
) // page_size
|
|
||||||
need_page = num_new_pages - num_full_new_pages
|
|
||||||
end_new_pages = torch.cumsum(num_new_pages, 0)
|
|
||||||
start_new_pages = end_new_pages - num_new_pages
|
|
||||||
pos_in_page = torch.arange(page_size, device=device, dtype=torch.int32)
|
|
||||||
for i in range(len(prefix_lens)):
|
|
||||||
num1 = (
|
|
||||||
min(
|
|
||||||
seq_lens[i],
|
|
||||||
(prefix_lens[i] + page_size - 1) // page_size * page_size,
|
|
||||||
)
|
|
||||||
- prefix_lens[i]
|
|
||||||
)
|
|
||||||
if num1:
|
|
||||||
out_indices[start_pos[i] : start_pos[i] + num1] = (
|
|
||||||
last_loc[i] + 1 + pos_in_page[:num1].view(-1)
|
|
||||||
)
|
|
||||||
|
|
||||||
num2 = (
|
|
||||||
seq_lens[i] // page_size - (prefix_lens[i] + page_size - 1) // page_size
|
|
||||||
) * page_size
|
|
||||||
if num2:
|
|
||||||
pages = (
|
|
||||||
free_pages[start_new_pages[i] : end_new_pages[i] - need_page[i]]
|
|
||||||
* page_size
|
|
||||||
)
|
|
||||||
out_indices[start_pos[i] + num1 : start_pos[i] + num1 + num2] = (
|
|
||||||
pages.view(-1, 1) + pos_in_page.view(1, -1)
|
|
||||||
).view(-1)
|
|
||||||
|
|
||||||
num3 = seq_lens[i] - seq_lens[i] // page_size * page_size
|
|
||||||
if num3:
|
|
||||||
out_indices[end_pos[i] - num3 : end_pos[i]] = (
|
|
||||||
free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3]
|
|
||||||
).view(-1)
|
|
||||||
|
|
||||||
|
|
||||||
class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
size: int,
|
|
||||||
page_size: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
device: str,
|
|
||||||
kvcache: KVCache,
|
|
||||||
need_sort: bool,
|
|
||||||
):
|
|
||||||
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
|
|
||||||
|
|
||||||
def alloc_extend(
|
|
||||||
self,
|
|
||||||
prefix_lens: torch.Tensor,
|
|
||||||
seq_lens: torch.Tensor,
|
|
||||||
last_loc: torch.Tensor,
|
|
||||||
extend_num_tokens: int,
|
|
||||||
):
|
|
||||||
if self.debug_mode:
|
|
||||||
assert torch.all(
|
|
||||||
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
|
||||||
)
|
|
||||||
|
|
||||||
estimated_num_new_pages = (
|
|
||||||
(
|
|
||||||
(seq_lens + self.page_size - 1) // self.page_size
|
|
||||||
- (prefix_lens + self.page_size - 1) // self.page_size
|
|
||||||
)
|
|
||||||
.sum()
|
|
||||||
.item()
|
|
||||||
)
|
|
||||||
if self.need_sort and estimated_num_new_pages > len(self.free_pages):
|
|
||||||
self.merge_and_sort_free()
|
|
||||||
|
|
||||||
if estimated_num_new_pages > len(self.free_pages):
|
|
||||||
return None
|
|
||||||
|
|
||||||
out_indices = torch.empty(
|
|
||||||
(extend_num_tokens,), dtype=torch.int32, device=self.device
|
|
||||||
)
|
|
||||||
|
|
||||||
alloc_extend_kernel_ascend(
|
|
||||||
prefix_lens,
|
|
||||||
seq_lens,
|
|
||||||
last_loc,
|
|
||||||
self.free_pages,
|
|
||||||
out_indices,
|
|
||||||
self.page_size,
|
|
||||||
self.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.debug_mode:
|
|
||||||
assert len(torch.unique(out_indices)) == len(out_indices)
|
|
||||||
|
|
||||||
self.free_pages = self.free_pages[estimated_num_new_pages:]
|
|
||||||
return out_indices
|
|
||||||
|
|
||||||
def alloc_decode(
|
|
||||||
self,
|
|
||||||
seq_lens: torch.Tensor,
|
|
||||||
last_loc: torch.Tensor,
|
|
||||||
):
|
|
||||||
if self.debug_mode:
|
|
||||||
assert torch.all(
|
|
||||||
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
|
||||||
)
|
|
||||||
|
|
||||||
need_new_pages = (seq_lens % self.page_size == 1).int()
|
|
||||||
num_new_pages = need_new_pages.sum().item()
|
|
||||||
|
|
||||||
if num_new_pages > len(self.free_pages):
|
|
||||||
self.merge_and_sort_free()
|
|
||||||
|
|
||||||
if num_new_pages > len(self.free_pages):
|
|
||||||
return None
|
|
||||||
|
|
||||||
end_new_pages = torch.cumsum(need_new_pages, 0)
|
|
||||||
start_new_pages = end_new_pages - need_new_pages
|
|
||||||
if num_new_pages == 0:
|
|
||||||
out_indices = last_loc + 1
|
|
||||||
else:
|
|
||||||
out_indices = (last_loc + 1) * (1 - need_new_pages) + self.free_pages[
|
|
||||||
start_new_pages
|
|
||||||
] * self.page_size * need_new_pages
|
|
||||||
|
|
||||||
if self.debug_mode:
|
|
||||||
assert len(torch.unique(out_indices)) == len(out_indices)
|
|
||||||
|
|
||||||
self.free_pages = self.free_pages[num_new_pages:]
|
|
||||||
return out_indices.int()
|
|
||||||
|
|||||||
158
python/sglang/srt/mem_cache/allocator_ascend.py
Normal file
158
python/sglang/srt/mem_cache/allocator_ascend.py
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.mem_cache.allocator import PagedTokenToKVPoolAllocator
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.mem_cache.memory_pool import KVCache
|
||||||
|
|
||||||
|
|
||||||
|
def alloc_extend_kernel_ascend(
|
||||||
|
prefix_lens,
|
||||||
|
seq_lens,
|
||||||
|
last_loc,
|
||||||
|
free_pages,
|
||||||
|
out_indices,
|
||||||
|
page_size,
|
||||||
|
device,
|
||||||
|
):
|
||||||
|
extend_lens = seq_lens - prefix_lens
|
||||||
|
end_pos = torch.cumsum(extend_lens, 0)
|
||||||
|
start_pos = end_pos - extend_lens
|
||||||
|
num_new_pages = (seq_lens + page_size - 1) // page_size - (
|
||||||
|
prefix_lens + page_size - 1
|
||||||
|
) // page_size
|
||||||
|
num_full_new_pages = (seq_lens) // page_size - (
|
||||||
|
prefix_lens + page_size - 1
|
||||||
|
) // page_size
|
||||||
|
need_page = num_new_pages - num_full_new_pages
|
||||||
|
end_new_pages = torch.cumsum(num_new_pages, 0)
|
||||||
|
start_new_pages = end_new_pages - num_new_pages
|
||||||
|
pos_in_page = torch.arange(page_size, device=device, dtype=torch.int32)
|
||||||
|
for i in range(len(prefix_lens)):
|
||||||
|
num1 = (
|
||||||
|
min(
|
||||||
|
seq_lens[i],
|
||||||
|
(prefix_lens[i] + page_size - 1) // page_size * page_size,
|
||||||
|
)
|
||||||
|
- prefix_lens[i]
|
||||||
|
)
|
||||||
|
if num1:
|
||||||
|
out_indices[start_pos[i] : start_pos[i] + num1] = (
|
||||||
|
last_loc[i] + 1 + pos_in_page[:num1].view(-1)
|
||||||
|
)
|
||||||
|
|
||||||
|
num2 = (
|
||||||
|
seq_lens[i] // page_size - (prefix_lens[i] + page_size - 1) // page_size
|
||||||
|
) * page_size
|
||||||
|
if num2:
|
||||||
|
pages = (
|
||||||
|
free_pages[start_new_pages[i] : end_new_pages[i] - need_page[i]]
|
||||||
|
* page_size
|
||||||
|
)
|
||||||
|
out_indices[start_pos[i] + num1 : start_pos[i] + num1 + num2] = (
|
||||||
|
pages.view(-1, 1) + pos_in_page.view(1, -1)
|
||||||
|
).view(-1)
|
||||||
|
|
||||||
|
num3 = seq_lens[i] - seq_lens[i] // page_size * page_size
|
||||||
|
if num3:
|
||||||
|
out_indices[end_pos[i] - num3 : end_pos[i]] = (
|
||||||
|
free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3]
|
||||||
|
).view(-1)
|
||||||
|
|
||||||
|
|
||||||
|
class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
size: int,
|
||||||
|
page_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: str,
|
||||||
|
kvcache: KVCache,
|
||||||
|
need_sort: bool,
|
||||||
|
):
|
||||||
|
super().__init__(size, page_size, dtype, device, kvcache, need_sort, 1)
|
||||||
|
|
||||||
|
def alloc_extend(
|
||||||
|
self,
|
||||||
|
prefix_lens: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
last_loc: torch.Tensor,
|
||||||
|
extend_num_tokens: int,
|
||||||
|
):
|
||||||
|
if self.debug_mode:
|
||||||
|
assert torch.all(
|
||||||
|
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
||||||
|
)
|
||||||
|
|
||||||
|
num_new_pages = (
|
||||||
|
(
|
||||||
|
(seq_lens + self.page_size - 1) // self.page_size
|
||||||
|
- (prefix_lens + self.page_size - 1) // self.page_size
|
||||||
|
)
|
||||||
|
.sum()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
if self.need_sort and num_new_pages > len(self.free_pages):
|
||||||
|
self.merge_and_sort_free()
|
||||||
|
|
||||||
|
if num_new_pages > len(self.free_pages):
|
||||||
|
return None
|
||||||
|
|
||||||
|
out_indices = torch.empty(
|
||||||
|
(extend_num_tokens,), dtype=torch.int32, device=self.device
|
||||||
|
)
|
||||||
|
|
||||||
|
alloc_extend_kernel_ascend(
|
||||||
|
prefix_lens,
|
||||||
|
seq_lens,
|
||||||
|
last_loc,
|
||||||
|
self.free_pages,
|
||||||
|
out_indices,
|
||||||
|
self.page_size,
|
||||||
|
self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.debug_mode:
|
||||||
|
assert len(torch.unique(out_indices)) == len(out_indices)
|
||||||
|
|
||||||
|
self.free_pages = self.free_pages[num_new_pages:]
|
||||||
|
return out_indices
|
||||||
|
|
||||||
|
def alloc_decode(
|
||||||
|
self,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
last_loc: torch.Tensor,
|
||||||
|
):
|
||||||
|
if self.debug_mode:
|
||||||
|
assert torch.all(
|
||||||
|
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
||||||
|
)
|
||||||
|
|
||||||
|
need_new_pages = (seq_lens % self.page_size == 1).int()
|
||||||
|
num_new_pages = need_new_pages.sum().item()
|
||||||
|
|
||||||
|
if num_new_pages > len(self.free_pages):
|
||||||
|
self.merge_and_sort_free()
|
||||||
|
|
||||||
|
if num_new_pages > len(self.free_pages):
|
||||||
|
return None
|
||||||
|
|
||||||
|
end_new_pages = torch.cumsum(need_new_pages, 0)
|
||||||
|
start_new_pages = end_new_pages - need_new_pages
|
||||||
|
if num_new_pages == 0:
|
||||||
|
out_indices = last_loc + 1
|
||||||
|
else:
|
||||||
|
out_indices = (last_loc + 1) * (1 - need_new_pages) + self.free_pages[
|
||||||
|
start_new_pages
|
||||||
|
] * self.page_size * need_new_pages
|
||||||
|
|
||||||
|
if self.debug_mode:
|
||||||
|
assert len(torch.unique(out_indices)) == len(out_indices)
|
||||||
|
|
||||||
|
self.free_pages = self.free_pages[num_new_pages:]
|
||||||
|
return out_indices.int()
|
||||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|||||||
@@ -75,12 +75,12 @@ from sglang.srt.managers.schedule_batch import (
|
|||||||
global_server_args_dict,
|
global_server_args_dict,
|
||||||
)
|
)
|
||||||
from sglang.srt.mem_cache.allocator import (
|
from sglang.srt.mem_cache.allocator import (
|
||||||
AscendPagedTokenToKVPoolAllocator,
|
|
||||||
BaseTokenToKVPoolAllocator,
|
BaseTokenToKVPoolAllocator,
|
||||||
PagedTokenToKVPoolAllocator,
|
PagedTokenToKVPoolAllocator,
|
||||||
SWATokenToKVPoolAllocator,
|
SWATokenToKVPoolAllocator,
|
||||||
TokenToKVPoolAllocator,
|
TokenToKVPoolAllocator,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.mem_cache.allocator_ascend import AscendPagedTokenToKVPoolAllocator
|
||||||
from sglang.srt.mem_cache.memory_pool import (
|
from sglang.srt.mem_cache.memory_pool import (
|
||||||
AscendMLAPagedTokenToKVPool,
|
AscendMLAPagedTokenToKVPool,
|
||||||
AscendTokenToKVPool,
|
AscendTokenToKVPool,
|
||||||
@@ -176,10 +176,6 @@ class ModelRunner:
|
|||||||
self.mem_fraction_static = mem_fraction_static
|
self.mem_fraction_static = mem_fraction_static
|
||||||
self.device = server_args.device
|
self.device = server_args.device
|
||||||
self.gpu_id = gpu_id
|
self.gpu_id = gpu_id
|
||||||
|
|
||||||
# Apply the rank zero filter to logger
|
|
||||||
if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
|
|
||||||
logger.addFilter(RankZeroFilter(tp_rank == 0))
|
|
||||||
self.tp_rank = tp_rank
|
self.tp_rank = tp_rank
|
||||||
self.tp_size = tp_size
|
self.tp_size = tp_size
|
||||||
self.moe_ep_rank = moe_ep_rank
|
self.moe_ep_rank = moe_ep_rank
|
||||||
@@ -205,15 +201,17 @@ class ModelRunner:
|
|||||||
self.is_hybrid = model_config.is_hybrid
|
self.is_hybrid = model_config.is_hybrid
|
||||||
self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
|
self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
|
||||||
self.attention_chunk_size = model_config.attention_chunk_size
|
self.attention_chunk_size = model_config.attention_chunk_size
|
||||||
|
|
||||||
self.forward_pass_id = 0
|
self.forward_pass_id = 0
|
||||||
|
|
||||||
|
# Apply the rank zero filter to logger
|
||||||
|
if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
|
||||||
|
logger.addFilter(RankZeroFilter(tp_rank == 0))
|
||||||
|
if server_args.show_time_cost:
|
||||||
|
enable_show_time_cost()
|
||||||
|
|
||||||
# Model-specific adjustment
|
# Model-specific adjustment
|
||||||
self.model_specific_adjustment()
|
self.model_specific_adjustment()
|
||||||
|
|
||||||
if server_args.show_time_cost:
|
|
||||||
enable_show_time_cost()
|
|
||||||
|
|
||||||
# Global vars
|
# Global vars
|
||||||
global_server_args_dict.update(
|
global_server_args_dict.update(
|
||||||
{k: getattr(server_args, k) for k in GLOBAL_SERVER_ARGS_KEYS}
|
{k: getattr(server_args, k) for k in GLOBAL_SERVER_ARGS_KEYS}
|
||||||
@@ -221,8 +219,6 @@ class ModelRunner:
|
|||||||
# TODO it is indeed not a "server args"
|
# TODO it is indeed not a "server args"
|
||||||
"use_mla_backend": self.use_mla_backend,
|
"use_mla_backend": self.use_mla_backend,
|
||||||
"speculative_algorithm": self.spec_algorithm,
|
"speculative_algorithm": self.spec_algorithm,
|
||||||
}
|
|
||||||
| {
|
|
||||||
"moe_a2a_backend": MoeA2ABackend(server_args.moe_a2a_backend),
|
"moe_a2a_backend": MoeA2ABackend(server_args.moe_a2a_backend),
|
||||||
"deepep_mode": DeepEPMode(server_args.deepep_mode),
|
"deepep_mode": DeepEPMode(server_args.deepep_mode),
|
||||||
}
|
}
|
||||||
@@ -242,13 +238,15 @@ class ModelRunner:
|
|||||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
||||||
deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
|
deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
|
||||||
|
|
||||||
# If it is a draft model, tp_group can be different
|
# Initialize the model runner
|
||||||
self.initialize(min_per_gpu_memory)
|
self.initialize(min_per_gpu_memory)
|
||||||
|
|
||||||
# temporary cached values
|
# Temporary cached values
|
||||||
self.support_pp = (
|
self.support_pp = (
|
||||||
"pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
|
"pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# For weight updates
|
||||||
self._model_update_group = {}
|
self._model_update_group = {}
|
||||||
|
|
||||||
def initialize(self, min_per_gpu_memory: float):
|
def initialize(self, min_per_gpu_memory: float):
|
||||||
@@ -277,6 +275,7 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Expert parallelism
|
||||||
self.eplb_manager = (
|
self.eplb_manager = (
|
||||||
EPLBManager(self)
|
EPLBManager(self)
|
||||||
if self.server_args.enable_eplb and (not self.is_draft_worker)
|
if self.server_args.enable_eplb and (not self.is_draft_worker)
|
||||||
@@ -1160,6 +1159,7 @@ class ModelRunner:
|
|||||||
max_num_reqs: Optional[int] = None,
|
max_num_reqs: Optional[int] = None,
|
||||||
max_total_tokens: Optional[int] = None,
|
max_total_tokens: Optional[int] = None,
|
||||||
):
|
):
|
||||||
|
# Determine the kv cache dtype
|
||||||
if self.server_args.kv_cache_dtype == "auto":
|
if self.server_args.kv_cache_dtype == "auto":
|
||||||
self.kv_cache_dtype = self.dtype
|
self.kv_cache_dtype = self.dtype
|
||||||
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
|
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
|
||||||
@@ -1178,6 +1178,8 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
|
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
|
||||||
|
if SGLANG_CI_SMALL_KV_SIZE:
|
||||||
|
self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
|
||||||
|
|
||||||
if max_num_reqs is None:
|
if max_num_reqs is None:
|
||||||
max_num_reqs = min(
|
max_num_reqs = min(
|
||||||
@@ -1190,9 +1192,6 @@ class ModelRunner:
|
|||||||
4096,
|
4096,
|
||||||
)
|
)
|
||||||
|
|
||||||
if SGLANG_CI_SMALL_KV_SIZE:
|
|
||||||
self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
|
|
||||||
|
|
||||||
if not self.spec_algorithm.is_none():
|
if not self.spec_algorithm.is_none():
|
||||||
if self.is_draft_worker:
|
if self.is_draft_worker:
|
||||||
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
|
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
|
||||||
@@ -1239,6 +1238,7 @@ class ModelRunner:
|
|||||||
"Not enough memory. Please try to increase --mem-fraction-static."
|
"Not enough memory. Please try to increase --mem-fraction-static."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Initialize req_to_token_pool
|
||||||
if self.req_to_token_pool is None:
|
if self.req_to_token_pool is None:
|
||||||
if self.server_args.disaggregation_mode == "decode":
|
if self.server_args.disaggregation_mode == "decode":
|
||||||
from sglang.srt.disaggregation.decode import DecodeReqToTokenPool
|
from sglang.srt.disaggregation.decode import DecodeReqToTokenPool
|
||||||
@@ -1264,6 +1264,7 @@ class ModelRunner:
|
|||||||
# Draft worker shares req_to_token_pool with the target worker.
|
# Draft worker shares req_to_token_pool with the target worker.
|
||||||
assert self.is_draft_worker
|
assert self.is_draft_worker
|
||||||
|
|
||||||
|
# Initialize token_to_kv_pool
|
||||||
if self.server_args.attention_backend == "ascend":
|
if self.server_args.attention_backend == "ascend":
|
||||||
if self.use_mla_backend:
|
if self.use_mla_backend:
|
||||||
self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
|
self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
|
||||||
@@ -1349,28 +1350,44 @@ class ModelRunner:
|
|||||||
end_layer=self.end_layer,
|
end_layer=self.end_layer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Initialize token_to_kv_pool_allocator
|
||||||
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
|
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
|
||||||
|
max_num_extend_tokens = (
|
||||||
|
self.server_args.chunked_prefill_size
|
||||||
|
if self.server_args.chunked_prefill_size > 0
|
||||||
|
else self.server_args.max_prefill_tokens
|
||||||
|
)
|
||||||
if self.token_to_kv_pool_allocator is None:
|
if self.token_to_kv_pool_allocator is None:
|
||||||
if self.page_size == 1:
|
if self.server_args.attention_backend == "ascend":
|
||||||
if self.is_hybrid:
|
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
|
||||||
self.token_to_kv_pool_allocator = SWATokenToKVPoolAllocator(
|
self.max_total_num_tokens,
|
||||||
self.full_max_total_num_tokens,
|
page_size=self.page_size,
|
||||||
self.swa_max_total_num_tokens,
|
dtype=self.kv_cache_dtype,
|
||||||
dtype=self.kv_cache_dtype,
|
device=self.device,
|
||||||
device=self.device,
|
kvcache=self.token_to_kv_pool,
|
||||||
kvcache=self.token_to_kv_pool,
|
need_sort=need_sort,
|
||||||
need_sort=need_sort,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
|
|
||||||
self.max_total_num_tokens,
|
|
||||||
dtype=self.kv_cache_dtype,
|
|
||||||
device=self.device,
|
|
||||||
kvcache=self.token_to_kv_pool,
|
|
||||||
need_sort=need_sort,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
if not _is_npu:
|
if self.page_size == 1:
|
||||||
|
if self.is_hybrid:
|
||||||
|
self.token_to_kv_pool_allocator = SWATokenToKVPoolAllocator(
|
||||||
|
self.full_max_total_num_tokens,
|
||||||
|
self.swa_max_total_num_tokens,
|
||||||
|
dtype=self.kv_cache_dtype,
|
||||||
|
device=self.device,
|
||||||
|
kvcache=self.token_to_kv_pool,
|
||||||
|
need_sort=need_sort,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
|
||||||
|
self.max_total_num_tokens,
|
||||||
|
dtype=self.kv_cache_dtype,
|
||||||
|
device=self.device,
|
||||||
|
kvcache=self.token_to_kv_pool,
|
||||||
|
need_sort=need_sort,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert not self.is_hybrid
|
||||||
self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
|
self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
|
||||||
self.max_total_num_tokens,
|
self.max_total_num_tokens,
|
||||||
page_size=self.page_size,
|
page_size=self.page_size,
|
||||||
@@ -1378,15 +1395,7 @@ class ModelRunner:
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
kvcache=self.token_to_kv_pool,
|
kvcache=self.token_to_kv_pool,
|
||||||
need_sort=need_sort,
|
need_sort=need_sort,
|
||||||
)
|
max_num_extend_tokens=max_num_extend_tokens,
|
||||||
else:
|
|
||||||
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
|
|
||||||
self.max_total_num_tokens,
|
|
||||||
page_size=self.page_size,
|
|
||||||
dtype=self.kv_cache_dtype,
|
|
||||||
device=self.device,
|
|
||||||
kvcache=self.token_to_kv_pool,
|
|
||||||
need_sort=need_sort,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert self.is_draft_worker
|
assert self.is_draft_worker
|
||||||
@@ -1554,15 +1563,13 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return TRTLLMHAAttnBackend(self)
|
return TRTLLMHAAttnBackend(self)
|
||||||
|
|
||||||
elif backend_str == "intel_amx":
|
elif backend_str == "intel_amx":
|
||||||
from sglang.srt.layers.attention.intel_amx_backend import (
|
from sglang.srt.layers.attention.intel_amx_backend import (
|
||||||
IntelAMXAttnBackend,
|
IntelAMXAttnBackend,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Intel AMX attention backend is enabled.")
|
|
||||||
return IntelAMXAttnBackend(self)
|
return IntelAMXAttnBackend(self)
|
||||||
elif self.server_args.attention_backend == "dual_chunk_flash_attn":
|
elif backend_str == "dual_chunk_flash_attn":
|
||||||
from sglang.srt.layers.attention.dual_chunk_flashattention_backend import (
|
from sglang.srt.layers.attention.dual_chunk_flashattention_backend import (
|
||||||
DualChunkFlashAttentionBackend,
|
DualChunkFlashAttentionBackend,
|
||||||
)
|
)
|
||||||
@@ -1606,6 +1613,7 @@ class ModelRunner:
|
|||||||
f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
||||||
)
|
)
|
||||||
self.cuda_graph_runner = CudaGraphRunner(self)
|
self.cuda_graph_runner = CudaGraphRunner(self)
|
||||||
|
|
||||||
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
||||||
self.cuda_graph_mem_usage = before_mem - after_mem
|
self.cuda_graph_mem_usage = before_mem - after_mem
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|||||||
@@ -68,6 +68,8 @@ class SamplingBatchInfo:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
||||||
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
|
|
||||||
reqs = batch.reqs
|
reqs = batch.reqs
|
||||||
device = batch.device
|
device = batch.device
|
||||||
temperatures = (
|
temperatures = (
|
||||||
@@ -97,10 +99,11 @@ class SamplingBatchInfo:
|
|||||||
logit_bias[i, int(key)] = value
|
logit_bias[i, int(key)] = value
|
||||||
|
|
||||||
# Check if any request has custom logit processor
|
# Check if any request has custom logit processor
|
||||||
has_custom_logit_processor = (
|
has_custom_logit_processor = global_server_args_dict[
|
||||||
batch.enable_custom_logit_processor # check the flag first.
|
"enable_custom_logit_processor"
|
||||||
and any(r.custom_logit_processor for r in reqs) # then check the requests.
|
] and any( # check the flag first.
|
||||||
)
|
r.custom_logit_processor for r in reqs
|
||||||
|
) # then check the requests.
|
||||||
|
|
||||||
if has_custom_logit_processor:
|
if has_custom_logit_processor:
|
||||||
# Merge the same type of custom logit processors together
|
# Merge the same type of custom logit processors together
|
||||||
|
|||||||
@@ -575,6 +575,7 @@ class ServerArgs:
|
|||||||
"Pipeline parallelism is incompatible with overlap schedule."
|
"Pipeline parallelism is incompatible with overlap schedule."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Hicache
|
||||||
if self.hicache_storage_backend == "mooncake":
|
if self.hicache_storage_backend == "mooncake":
|
||||||
# to use mooncake storage backend, the following conditions must be met:
|
# to use mooncake storage backend, the following conditions must be met:
|
||||||
self.hicache_io_backend = "kernel"
|
self.hicache_io_backend = "kernel"
|
||||||
@@ -1316,19 +1317,23 @@ class ServerArgs:
|
|||||||
|
|
||||||
# Kernel backend
|
# Kernel backend
|
||||||
ATTN_BACKENDS = [
|
ATTN_BACKENDS = [
|
||||||
"aiter",
|
# Common
|
||||||
|
"triton",
|
||||||
|
"torch_native",
|
||||||
|
# NVIDIA specific
|
||||||
"cutlass_mla",
|
"cutlass_mla",
|
||||||
"fa3",
|
"fa3",
|
||||||
"flashinfer",
|
"flashinfer",
|
||||||
"flashmla",
|
"flashmla",
|
||||||
"intel_amx",
|
|
||||||
"torch_native",
|
|
||||||
"ascend",
|
|
||||||
"triton",
|
|
||||||
"trtllm_mla",
|
"trtllm_mla",
|
||||||
"trtllm_mha",
|
"trtllm_mha",
|
||||||
"dual_chunk_flash_attn",
|
"dual_chunk_flash_attn",
|
||||||
|
# AMD specific
|
||||||
|
"aiter",
|
||||||
"wave",
|
"wave",
|
||||||
|
# Other platforms
|
||||||
|
"intel_amx",
|
||||||
|
"ascend",
|
||||||
]
|
]
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--attention-backend",
|
"--attention-backend",
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
/*
|
/*
|
||||||
* From csrc/allreduce
|
* From csrc/allreduce
|
||||||
*/
|
*/
|
||||||
|
|
||||||
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
|
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
|
||||||
m.def("register_graph_buffers", ®ister_graph_buffers);
|
m.def("register_graph_buffers", ®ister_graph_buffers);
|
||||||
m.def("dispose", &dispose);
|
m.def("dispose", &dispose);
|
||||||
@@ -46,6 +45,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
|
|
||||||
m.def("mscclpp_allreduce(int context, Tensor inp, Tensor! out, int nthreads, int nblocks) -> ()");
|
m.def("mscclpp_allreduce(int context, Tensor inp, Tensor! out, int nthreads, int nblocks) -> ()");
|
||||||
m.impl("mscclpp_allreduce", torch::kCUDA, &mscclpp_allreduce);
|
m.impl("mscclpp_allreduce", torch::kCUDA, &mscclpp_allreduce);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* From csrc/attention
|
* From csrc/attention
|
||||||
*/
|
*/
|
||||||
@@ -284,6 +284,12 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
"page_size) -> ()");
|
"page_size) -> ()");
|
||||||
m.impl("transfer_kv_direct", torch::kCUDA, &transfer_kv_direct);
|
m.impl("transfer_kv_direct", torch::kCUDA, &transfer_kv_direct);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* From csrc/memory
|
||||||
|
*/
|
||||||
|
m.def("store_kv_cache(Tensor k_cache, Tensor v_cache, Tensor out_loc, Tensor k, Tensor v) -> ()");
|
||||||
|
m.impl("store_kv_cache", &store_kv_cache);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* From csrc/moe/cutlass_moe/w4a8
|
* From csrc/moe/cutlass_moe/w4a8
|
||||||
*/
|
*/
|
||||||
@@ -390,13 +396,13 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
m.impl("convert_vertical_slash_indexes_mergehead", torch::kCUDA, &convert_vertical_slash_indexes_mergehead);
|
m.impl("convert_vertical_slash_indexes_mergehead", torch::kCUDA, &convert_vertical_slash_indexes_mergehead);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* From XGrammar
|
* From csrc/grammar
|
||||||
*/
|
*/
|
||||||
m.def("apply_token_bitmask_inplace_cuda(Tensor logits, Tensor bitmask, Tensor? indices=None) -> ()");
|
m.def("apply_token_bitmask_inplace_cuda(Tensor logits, Tensor bitmask, Tensor? indices=None) -> ()");
|
||||||
m.impl("apply_token_bitmask_inplace_cuda", &ApplyTokenBitmaskInplace);
|
m.impl("apply_token_bitmask_inplace_cuda", &ApplyTokenBitmaskInplace);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* From QServe
|
* From csrc/gemm (QServe)
|
||||||
*/
|
*/
|
||||||
m.def(
|
m.def(
|
||||||
"qserve_w4a8_per_chn_gemm(Tensor _in_feats, Tensor _kernel, Tensor _wscales, Tensor _ascales, Tensor _w_szs, "
|
"qserve_w4a8_per_chn_gemm(Tensor _in_feats, Tensor _kernel, Tensor _wscales, Tensor _ascales, Tensor _w_szs, "
|
||||||
@@ -413,12 +419,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
*/
|
*/
|
||||||
m.def("create_greenctx_stream_by_value(int smA, int smB, int device) -> int[]");
|
m.def("create_greenctx_stream_by_value(int smA, int smB, int device) -> int[]");
|
||||||
m.impl("create_greenctx_stream_by_value", &create_greenctx_stream_by_value);
|
m.impl("create_greenctx_stream_by_value", &create_greenctx_stream_by_value);
|
||||||
|
|
||||||
/*
|
|
||||||
* From csrc/memory
|
|
||||||
*/
|
|
||||||
m.def("store_kv_cache(Tensor k_cache, Tensor v_cache, Tensor out_loc, Tensor k, Tensor v) -> ()");
|
|
||||||
m.impl("store_kv_cache", &store_kv_cache);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_EXTENSION(common_ops)
|
REGISTER_EXTENSION(common_ops)
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ sources = [
|
|||||||
"csrc/moe/moe_align_kernel.cu",
|
"csrc/moe/moe_align_kernel.cu",
|
||||||
"csrc/moe/moe_topk_softmax_kernels.cu",
|
"csrc/moe/moe_topk_softmax_kernels.cu",
|
||||||
"csrc/speculative/eagle_utils.cu",
|
"csrc/speculative/eagle_utils.cu",
|
||||||
"csrc/torch_extension_rocm.cc",
|
"csrc/common_extension_rocm.cc",
|
||||||
]
|
]
|
||||||
|
|
||||||
cxx_flags = ["-O3"]
|
cxx_flags = ["-O3"]
|
||||||
|
|||||||
Reference in New Issue
Block a user