diff --git a/docs/developer_guide/contribution_guide.md b/docs/developer_guide/contribution_guide.md index 55de73a0b..db406a544 100644 --- a/docs/developer_guide/contribution_guide.md +++ b/docs/developer_guide/contribution_guide.md @@ -63,12 +63,12 @@ You can find additional accuracy eval examples in: ## Benchmark the speed Refer to [Benchmark and Profiling](../developer_guide/benchmark_and_profiling.md). -## Request a Review +## Request a review You can identify potential reviewers for your code by checking the [code owners](https://github.com/sgl-project/sglang/blob/main/.github/CODEOWNERS) and [reviewers](https://github.com/sgl-project/sglang/blob/main/.github/REVIEWERS.md) files. Another effective strategy is to review the file modification history and contact individuals who have frequently edited the files. If you modify files protected by code owners, their approval is required to merge the code. -## General Code Style +## General code style - Avoid code duplication. If the same code snippet (more than five lines) appears multiple times, extract it into a shared function. - Minimize device synchronization. Reduce expensive CPU-GPU synchronization operations, such as `tensor.item()` or `tensor.cpu()`, whenever possible. Use vectorized code. - Keep files concise. If a file exceeds 2,000 lines of code, split it into multiple smaller files. diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index 36530445a..8401e4708 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -267,7 +267,6 @@ def extend(reqs, model_runner): model_config=model_runner.model_config, enable_overlap=False, spec_algorithm=SpeculativeAlgorithm.NONE, - enable_custom_logit_processor=False, ) batch.prepare_for_extend() _maybe_prepare_mlp_sync_batch(batch, model_runner) diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 02f297d6a..1570b8b32 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -864,7 +864,6 @@ class SchedulerDisaggregationDecodeMixin: self.model_config, self.enable_overlap, self.spec_algorithm, - self.server_args.enable_custom_logit_processor, ) # construct fake completed prefill diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 73cf574dd..f50676c3b 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -870,6 +870,8 @@ class FlashInferIndicesUpdaterPrefill: spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], ): if use_ragged: + # TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu + # and forward_batch.extend_seq_lens_cpu paged_kernel_lens = prefix_lens paged_kernel_lens_sum = paged_kernel_lens.sum().item() else: diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 10d242ebe..a3d8f88eb 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -57,16 +57,36 @@ class TritonAttnBackend(AttentionBackend): self.decode_attention_fwd = torch.compiler.disable(decode_attention_fwd) self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd) + # Parse args self.skip_prefill = skip_prefill - max_bs = model_runner.req_to_token_pool.size + self.sliding_window_size = model_runner.sliding_window_size + self.req_to_token = model_runner.req_to_token_pool.req_to_token + self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator + self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens + self.speculative_num_steps = model_runner.server_args.speculative_num_steps + self.num_head = ( + model_runner.model_config.num_attention_heads // get_attention_tp_size() + ) + self.num_kv_head = model_runner.model_config.get_num_kv_heads( + get_attention_tp_size() + ) + self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1] + self.max_context_len = model_runner.model_config.context_len + self.device = model_runner.device + self.device_core_count = get_device_core_count(model_runner.gpu_id) + self.static_kv_splits = get_bool_env_var( + "SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false" + ) + self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits + # Check arguments assert not ( model_runner.sliding_window_size is not None and model_runner.model_config.is_encoder_decoder ), "Sliding window and cross attention are not supported together" - self.sliding_window_size = model_runner.sliding_window_size + # Initialize buffers # TODO(Jianan Ji): Make sure it behaves as expected when kv_indptr_buf is provided and sliding window is enabled if kv_indptr_buf is None: self.kv_indptr = torch.zeros( @@ -87,9 +107,6 @@ class TritonAttnBackend(AttentionBackend): # When provided a buffer, create a clone for the second buffer self.window_kv_indptr = torch.zeros_like(kv_indptr_buf) - self.req_to_token = model_runner.req_to_token_pool.req_to_token - self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator - if not self.skip_prefill: self.qo_indptr = torch.zeros( (max_bs + 1,), dtype=torch.int32, device=model_runner.device @@ -99,29 +116,9 @@ class TritonAttnBackend(AttentionBackend): (max_bs + 1,), dtype=torch.int64, device=model_runner.device ) - self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens - self.speculative_num_steps = model_runner.server_args.speculative_num_steps - - self.num_head = ( - model_runner.model_config.num_attention_heads // get_attention_tp_size() - ) - self.num_kv_head = model_runner.model_config.get_num_kv_heads( - get_attention_tp_size() - ) - - self.static_kv_splits = get_bool_env_var( - "SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false" - ) - self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits - self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1] - + # Initialize forward metadata self.forward_metadata: ForwardMetadata = None - self.max_context_len = model_runner.model_config.context_len - - self.device = model_runner.device - self.device_core_count = get_device_core_count(model_runner.gpu_id) - def get_num_kv_splits( self, num_kv_splits: torch.Tensor, @@ -333,7 +330,7 @@ class TritonAttnBackend(AttentionBackend): mask_indptr = None attn_logits = None attn_lse = None - max_extend_len = torch.max(forward_batch.extend_seq_lens).item() + max_extend_len = max(forward_batch.extend_seq_lens_cpu) num_kv_splits = None self.forward_metadata = ForwardMetadata( diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index faa8a9b93..7628ec2dd 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -113,6 +113,7 @@ GLOBAL_SERVER_ARGS_KEYS = [ "enable_multimodal", "enable_symm_mem", "quantization", + "enable_custom_logit_processor", ] # Put some global args for easy access @@ -909,9 +910,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): spec_algorithm: SpeculativeAlgorithm = None spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None - # Enable custom logit processor - enable_custom_logit_processor: bool = False - # Whether to return hidden states return_hidden_states: bool = False @@ -928,7 +926,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): model_config: ModelConfig, enable_overlap: bool, spec_algorithm: SpeculativeAlgorithm, - enable_custom_logit_processor: bool, chunked_req: Optional[Req] = None, ): return_logprob = any(req.return_logprob for req in reqs) @@ -955,7 +952,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): has_grammar=any(req.grammar for req in reqs), device=req_to_token_pool.device, spec_algorithm=spec_algorithm, - enable_custom_logit_processor=enable_custom_logit_processor, return_hidden_states=any(req.return_hidden_states for req in reqs), chunked_req=chunked_req, ) @@ -1009,6 +1005,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): extend_num_tokens: int, backup_state: bool = False, ): + # Over estimate the number of tokens: assume each request needs a new page. num_tokens = ( extend_num_tokens + len(seq_lens) * self.token_to_kv_pool_allocator.page_size @@ -1041,8 +1038,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): last_loc: torch.Tensor, backup_state: bool = False, ): + # Over estimate the number of tokens: assume each request needs a new page. num_tokens = len(seq_lens) * self.token_to_kv_pool_allocator.page_size - self._evict_tree_cache_if_needed(num_tokens) if backup_state: @@ -1721,38 +1718,18 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): extend_prefix_lens = self.prefix_lens extend_logprob_start_lens = self.extend_logprob_start_lens - if self.forward_mode.is_decode_or_idle(): - attention_backend_str = global_server_args_dict["decode_attention_backend"] - else: - attention_backend_str = global_server_args_dict["prefill_attention_backend"] - # Create seq_lens_cpu when needed - if ( - attention_backend_str - in [ - "fa3", - "flashinfer", - "flashmla", - "cutlass_mla", - "ascend", - "trtllm_mha", - "aiter", - ] - or global_server_args_dict["enable_two_batch_overlap"] - ): - seq_lens_cpu = ( - seq_lens_cpu_cache - if seq_lens_cpu_cache is not None - else self.seq_lens.cpu() - ) - else: - seq_lens_cpu = None - if self.sampling_info: if self.has_grammar: self.sampling_info.grammars = [req.grammar for req in self.reqs] else: self.sampling_info.grammars = None + seq_lens_cpu = ( + seq_lens_cpu_cache + if seq_lens_cpu_cache is not None + else self.seq_lens.cpu() + ) + global bid bid += 1 return ModelWorkerBatch( @@ -1815,18 +1792,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): return_logprob=self.return_logprob, decoding_reqs=self.decoding_reqs, spec_algorithm=self.spec_algorithm, - enable_custom_logit_processor=self.enable_custom_logit_processor, global_num_tokens=self.global_num_tokens, global_num_tokens_for_logprob=self.global_num_tokens_for_logprob, can_run_dp_cuda_graph=self.can_run_dp_cuda_graph, is_extend_in_batch=self.is_extend_in_batch, ) - def _evict_tree_cache_if_needed( - self, - num_tokens: int, - ) -> None: - if isinstance(self.tree_cache, SWAChunkCache): + def _evict_tree_cache_if_needed(self, num_tokens: int): + if isinstance(self.tree_cache, (SWAChunkCache, ChunkCache)): return if self.is_hybrid: diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 5b10eef59..fc0055b2b 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1634,7 +1634,6 @@ class Scheduler( self.model_config, self.enable_overlap, self.spec_algorithm, - self.server_args.enable_custom_logit_processor, chunked_req=self.chunked_req, ) if self.enable_hierarchical_cache: @@ -2031,7 +2030,6 @@ class Scheduler( self.model_config, self.enable_overlap, self.spec_algorithm, - self.server_args.enable_custom_logit_processor, ) idle_batch.prepare_for_idle() return idle_batch diff --git a/python/sglang/srt/mem_cache/allocator.py b/python/sglang/srt/mem_cache/allocator.py index 0bf8cc2e1..64c2fe318 100644 --- a/python/sglang/srt/mem_cache/allocator.py +++ b/python/sglang/srt/mem_cache/allocator.py @@ -20,7 +20,6 @@ Page-aligned memory pool. """ import abc -import weakref from typing import TYPE_CHECKING import torch @@ -81,9 +80,6 @@ class BaseTokenToKVPoolAllocator(abc.ABC): if self.free_group: self.free(torch.cat(self.free_group)) - def estimated_num_new_pages(self, bs, extend_num_tokens): - return bs * ((extend_num_tokens + self.page_size - 1) // self.page_size) - def merge_and_sort_free(self): if len(self.release_pages) > 0: self.free_pages = torch.cat((self.free_pages, self.release_pages)) @@ -149,6 +145,7 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): def alloc(self, need_size: int): if self.need_sort and need_size > len(self.free_pages): self.merge_and_sort_free() + if need_size > len(self.free_pages): return None @@ -437,9 +434,13 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): device: str, kvcache: KVCache, need_sort: bool, + max_num_extend_tokens: int, ): super().__init__(size, page_size, dtype, device, kvcache, need_sort) self.num_pages = size // page_size + self.max_num_extend_tokens_next_power_of_2 = next_power_of_2( + max_num_extend_tokens + ) self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL") self.ret_values = torch.empty((), dtype=torch.int64, device=self.device) self.clear() @@ -480,7 +481,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ) bs = len(prefix_lens) - if self.need_sort and self.estimated_num_new_pages(bs, extend_num_tokens) > len( + if self.need_sort and extend_num_tokens // self.page_size + bs + 1 > len( self.free_pages ): self.merge_and_sort_free() @@ -497,7 +498,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): self.ret_values, next_power_of_2(bs), self.page_size, - next_power_of_2(extend_num_tokens), + self.max_num_extend_tokens_next_power_of_2, ) if self.debug_mode: @@ -522,9 +523,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ) bs = len(seq_lens) - if self.need_sort and self.estimated_num_new_pages(bs, 1) > len( - self.free_pages - ): + if self.need_sort and bs > len(self.free_pages): self.merge_and_sort_free() out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device) @@ -578,151 +577,3 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): def load_cpu_copy(self, kv_cache_cpu, indices): return self._kvcache.load_cpu_copy(kv_cache_cpu, indices) - - -def alloc_extend_kernel_ascend( - prefix_lens, - seq_lens, - last_loc, - free_pages, - out_indices, - page_size, - device, -): - extend_lens = seq_lens - prefix_lens - end_pos = torch.cumsum(extend_lens, 0) - start_pos = end_pos - extend_lens - num_new_pages = (seq_lens + page_size - 1) // page_size - ( - prefix_lens + page_size - 1 - ) // page_size - num_full_new_pages = (seq_lens) // page_size - ( - prefix_lens + page_size - 1 - ) // page_size - need_page = num_new_pages - num_full_new_pages - end_new_pages = torch.cumsum(num_new_pages, 0) - start_new_pages = end_new_pages - num_new_pages - pos_in_page = torch.arange(page_size, device=device, dtype=torch.int32) - for i in range(len(prefix_lens)): - num1 = ( - min( - seq_lens[i], - (prefix_lens[i] + page_size - 1) // page_size * page_size, - ) - - prefix_lens[i] - ) - if num1: - out_indices[start_pos[i] : start_pos[i] + num1] = ( - last_loc[i] + 1 + pos_in_page[:num1].view(-1) - ) - - num2 = ( - seq_lens[i] // page_size - (prefix_lens[i] + page_size - 1) // page_size - ) * page_size - if num2: - pages = ( - free_pages[start_new_pages[i] : end_new_pages[i] - need_page[i]] - * page_size - ) - out_indices[start_pos[i] + num1 : start_pos[i] + num1 + num2] = ( - pages.view(-1, 1) + pos_in_page.view(1, -1) - ).view(-1) - - num3 = seq_lens[i] - seq_lens[i] // page_size * page_size - if num3: - out_indices[end_pos[i] - num3 : end_pos[i]] = ( - free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3] - ).view(-1) - - -class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): - - def __init__( - self, - size: int, - page_size: int, - dtype: torch.dtype, - device: str, - kvcache: KVCache, - need_sort: bool, - ): - super().__init__(size, page_size, dtype, device, kvcache, need_sort) - - def alloc_extend( - self, - prefix_lens: torch.Tensor, - seq_lens: torch.Tensor, - last_loc: torch.Tensor, - extend_num_tokens: int, - ): - if self.debug_mode: - assert torch.all( - (last_loc + 1) % self.page_size == prefix_lens % self.page_size - ) - - estimated_num_new_pages = ( - ( - (seq_lens + self.page_size - 1) // self.page_size - - (prefix_lens + self.page_size - 1) // self.page_size - ) - .sum() - .item() - ) - if self.need_sort and estimated_num_new_pages > len(self.free_pages): - self.merge_and_sort_free() - - if estimated_num_new_pages > len(self.free_pages): - return None - - out_indices = torch.empty( - (extend_num_tokens,), dtype=torch.int32, device=self.device - ) - - alloc_extend_kernel_ascend( - prefix_lens, - seq_lens, - last_loc, - self.free_pages, - out_indices, - self.page_size, - self.device, - ) - - if self.debug_mode: - assert len(torch.unique(out_indices)) == len(out_indices) - - self.free_pages = self.free_pages[estimated_num_new_pages:] - return out_indices - - def alloc_decode( - self, - seq_lens: torch.Tensor, - last_loc: torch.Tensor, - ): - if self.debug_mode: - assert torch.all( - (last_loc + 2) % self.page_size == seq_lens % self.page_size - ) - - need_new_pages = (seq_lens % self.page_size == 1).int() - num_new_pages = need_new_pages.sum().item() - - if num_new_pages > len(self.free_pages): - self.merge_and_sort_free() - - if num_new_pages > len(self.free_pages): - return None - - end_new_pages = torch.cumsum(need_new_pages, 0) - start_new_pages = end_new_pages - need_new_pages - if num_new_pages == 0: - out_indices = last_loc + 1 - else: - out_indices = (last_loc + 1) * (1 - need_new_pages) + self.free_pages[ - start_new_pages - ] * self.page_size * need_new_pages - - if self.debug_mode: - assert len(torch.unique(out_indices)) == len(out_indices) - - self.free_pages = self.free_pages[num_new_pages:] - return out_indices.int() diff --git a/python/sglang/srt/mem_cache/allocator_ascend.py b/python/sglang/srt/mem_cache/allocator_ascend.py new file mode 100644 index 000000000..94bbaafeb --- /dev/null +++ b/python/sglang/srt/mem_cache/allocator_ascend.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + +from sglang.srt.mem_cache.allocator import PagedTokenToKVPoolAllocator + +if TYPE_CHECKING: + from sglang.srt.mem_cache.memory_pool import KVCache + + +def alloc_extend_kernel_ascend( + prefix_lens, + seq_lens, + last_loc, + free_pages, + out_indices, + page_size, + device, +): + extend_lens = seq_lens - prefix_lens + end_pos = torch.cumsum(extend_lens, 0) + start_pos = end_pos - extend_lens + num_new_pages = (seq_lens + page_size - 1) // page_size - ( + prefix_lens + page_size - 1 + ) // page_size + num_full_new_pages = (seq_lens) // page_size - ( + prefix_lens + page_size - 1 + ) // page_size + need_page = num_new_pages - num_full_new_pages + end_new_pages = torch.cumsum(num_new_pages, 0) + start_new_pages = end_new_pages - num_new_pages + pos_in_page = torch.arange(page_size, device=device, dtype=torch.int32) + for i in range(len(prefix_lens)): + num1 = ( + min( + seq_lens[i], + (prefix_lens[i] + page_size - 1) // page_size * page_size, + ) + - prefix_lens[i] + ) + if num1: + out_indices[start_pos[i] : start_pos[i] + num1] = ( + last_loc[i] + 1 + pos_in_page[:num1].view(-1) + ) + + num2 = ( + seq_lens[i] // page_size - (prefix_lens[i] + page_size - 1) // page_size + ) * page_size + if num2: + pages = ( + free_pages[start_new_pages[i] : end_new_pages[i] - need_page[i]] + * page_size + ) + out_indices[start_pos[i] + num1 : start_pos[i] + num1 + num2] = ( + pages.view(-1, 1) + pos_in_page.view(1, -1) + ).view(-1) + + num3 = seq_lens[i] - seq_lens[i] // page_size * page_size + if num3: + out_indices[end_pos[i] - num3 : end_pos[i]] = ( + free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3] + ).view(-1) + + +class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): + + def __init__( + self, + size: int, + page_size: int, + dtype: torch.dtype, + device: str, + kvcache: KVCache, + need_sort: bool, + ): + super().__init__(size, page_size, dtype, device, kvcache, need_sort, 1) + + def alloc_extend( + self, + prefix_lens: torch.Tensor, + seq_lens: torch.Tensor, + last_loc: torch.Tensor, + extend_num_tokens: int, + ): + if self.debug_mode: + assert torch.all( + (last_loc + 1) % self.page_size == prefix_lens % self.page_size + ) + + num_new_pages = ( + ( + (seq_lens + self.page_size - 1) // self.page_size + - (prefix_lens + self.page_size - 1) // self.page_size + ) + .sum() + .item() + ) + if self.need_sort and num_new_pages > len(self.free_pages): + self.merge_and_sort_free() + + if num_new_pages > len(self.free_pages): + return None + + out_indices = torch.empty( + (extend_num_tokens,), dtype=torch.int32, device=self.device + ) + + alloc_extend_kernel_ascend( + prefix_lens, + seq_lens, + last_loc, + self.free_pages, + out_indices, + self.page_size, + self.device, + ) + + if self.debug_mode: + assert len(torch.unique(out_indices)) == len(out_indices) + + self.free_pages = self.free_pages[num_new_pages:] + return out_indices + + def alloc_decode( + self, + seq_lens: torch.Tensor, + last_loc: torch.Tensor, + ): + if self.debug_mode: + assert torch.all( + (last_loc + 2) % self.page_size == seq_lens % self.page_size + ) + + need_new_pages = (seq_lens % self.page_size == 1).int() + num_new_pages = need_new_pages.sum().item() + + if num_new_pages > len(self.free_pages): + self.merge_and_sort_free() + + if num_new_pages > len(self.free_pages): + return None + + end_new_pages = torch.cumsum(need_new_pages, 0) + start_new_pages = end_new_pages - need_new_pages + if num_new_pages == 0: + out_indices = last_loc + 1 + else: + out_indices = (last_loc + 1) * (1 - need_new_pages) + self.free_pages[ + start_new_pages + ] * self.page_size * need_new_pages + + if self.debug_mode: + assert len(torch.unique(out_indices)) == len(out_indices) + + self.free_pages = self.free_pages[num_new_pages:] + return out_indices.int() diff --git a/python/sglang/srt/mem_cache/chunk_cache.py b/python/sglang/srt/mem_cache/chunk_cache.py index 1cec3d21b..88d923b46 100644 --- a/python/sglang/srt/mem_cache/chunk_cache.py +++ b/python/sglang/srt/mem_cache/chunk_cache.py @@ -2,7 +2,7 @@ from __future__ import annotations """Cache for chunked prefill, used when RadixCache is disabled.""" -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Optional import torch diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index fe56a208b..bb67c79f3 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -75,12 +75,12 @@ from sglang.srt.managers.schedule_batch import ( global_server_args_dict, ) from sglang.srt.mem_cache.allocator import ( - AscendPagedTokenToKVPoolAllocator, BaseTokenToKVPoolAllocator, PagedTokenToKVPoolAllocator, SWATokenToKVPoolAllocator, TokenToKVPoolAllocator, ) +from sglang.srt.mem_cache.allocator_ascend import AscendPagedTokenToKVPoolAllocator from sglang.srt.mem_cache.memory_pool import ( AscendMLAPagedTokenToKVPool, AscendTokenToKVPool, @@ -176,10 +176,6 @@ class ModelRunner: self.mem_fraction_static = mem_fraction_static self.device = server_args.device self.gpu_id = gpu_id - - # Apply the rank zero filter to logger - if not any(isinstance(f, RankZeroFilter) for f in logger.filters): - logger.addFilter(RankZeroFilter(tp_rank == 0)) self.tp_rank = tp_rank self.tp_size = tp_size self.moe_ep_rank = moe_ep_rank @@ -205,15 +201,17 @@ class ModelRunner: self.is_hybrid = model_config.is_hybrid self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA self.attention_chunk_size = model_config.attention_chunk_size - self.forward_pass_id = 0 + # Apply the rank zero filter to logger + if not any(isinstance(f, RankZeroFilter) for f in logger.filters): + logger.addFilter(RankZeroFilter(tp_rank == 0)) + if server_args.show_time_cost: + enable_show_time_cost() + # Model-specific adjustment self.model_specific_adjustment() - if server_args.show_time_cost: - enable_show_time_cost() - # Global vars global_server_args_dict.update( {k: getattr(server_args, k) for k in GLOBAL_SERVER_ARGS_KEYS} @@ -221,8 +219,6 @@ class ModelRunner: # TODO it is indeed not a "server args" "use_mla_backend": self.use_mla_backend, "speculative_algorithm": self.spec_algorithm, - } - | { "moe_a2a_backend": MoeA2ABackend(server_args.moe_a2a_backend), "deepep_mode": DeepEPMode(server_args.deepep_mode), } @@ -242,13 +238,15 @@ class ModelRunner: if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM: deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args) - # If it is a draft model, tp_group can be different + # Initialize the model runner self.initialize(min_per_gpu_memory) - # temporary cached values + # Temporary cached values self.support_pp = ( "pp_proxy_tensors" in inspect.signature(self.model.forward).parameters ) + + # For weight updates self._model_update_group = {} def initialize(self, min_per_gpu_memory: float): @@ -277,6 +275,7 @@ class ModelRunner: ) ) + # Expert parallelism self.eplb_manager = ( EPLBManager(self) if self.server_args.enable_eplb and (not self.is_draft_worker) @@ -1160,6 +1159,7 @@ class ModelRunner: max_num_reqs: Optional[int] = None, max_total_tokens: Optional[int] = None, ): + # Determine the kv cache dtype if self.server_args.kv_cache_dtype == "auto": self.kv_cache_dtype = self.dtype elif self.server_args.kv_cache_dtype == "fp8_e5m2": @@ -1178,6 +1178,8 @@ class ModelRunner: ) self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory) + if SGLANG_CI_SMALL_KV_SIZE: + self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE) if max_num_reqs is None: max_num_reqs = min( @@ -1190,9 +1192,6 @@ class ModelRunner: 4096, ) - if SGLANG_CI_SMALL_KV_SIZE: - self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE) - if not self.spec_algorithm.is_none(): if self.is_draft_worker: self.max_total_num_tokens = self.server_args.draft_runner_cache_size @@ -1239,6 +1238,7 @@ class ModelRunner: "Not enough memory. Please try to increase --mem-fraction-static." ) + # Initialize req_to_token_pool if self.req_to_token_pool is None: if self.server_args.disaggregation_mode == "decode": from sglang.srt.disaggregation.decode import DecodeReqToTokenPool @@ -1264,6 +1264,7 @@ class ModelRunner: # Draft worker shares req_to_token_pool with the target worker. assert self.is_draft_worker + # Initialize token_to_kv_pool if self.server_args.attention_backend == "ascend": if self.use_mla_backend: self.token_to_kv_pool = AscendMLAPagedTokenToKVPool( @@ -1349,28 +1350,44 @@ class ModelRunner: end_layer=self.end_layer, ) + # Initialize token_to_kv_pool_allocator need_sort = self.server_args.disaggregation_mode in ("decode", "prefill") + max_num_extend_tokens = ( + self.server_args.chunked_prefill_size + if self.server_args.chunked_prefill_size > 0 + else self.server_args.max_prefill_tokens + ) if self.token_to_kv_pool_allocator is None: - if self.page_size == 1: - if self.is_hybrid: - self.token_to_kv_pool_allocator = SWATokenToKVPoolAllocator( - self.full_max_total_num_tokens, - self.swa_max_total_num_tokens, - dtype=self.kv_cache_dtype, - device=self.device, - kvcache=self.token_to_kv_pool, - need_sort=need_sort, - ) - else: - self.token_to_kv_pool_allocator = TokenToKVPoolAllocator( - self.max_total_num_tokens, - dtype=self.kv_cache_dtype, - device=self.device, - kvcache=self.token_to_kv_pool, - need_sort=need_sort, - ) + if self.server_args.attention_backend == "ascend": + self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator( + self.max_total_num_tokens, + page_size=self.page_size, + dtype=self.kv_cache_dtype, + device=self.device, + kvcache=self.token_to_kv_pool, + need_sort=need_sort, + ) else: - if not _is_npu: + if self.page_size == 1: + if self.is_hybrid: + self.token_to_kv_pool_allocator = SWATokenToKVPoolAllocator( + self.full_max_total_num_tokens, + self.swa_max_total_num_tokens, + dtype=self.kv_cache_dtype, + device=self.device, + kvcache=self.token_to_kv_pool, + need_sort=need_sort, + ) + else: + self.token_to_kv_pool_allocator = TokenToKVPoolAllocator( + self.max_total_num_tokens, + dtype=self.kv_cache_dtype, + device=self.device, + kvcache=self.token_to_kv_pool, + need_sort=need_sort, + ) + else: + assert not self.is_hybrid self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator( self.max_total_num_tokens, page_size=self.page_size, @@ -1378,15 +1395,7 @@ class ModelRunner: device=self.device, kvcache=self.token_to_kv_pool, need_sort=need_sort, - ) - else: - self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator( - self.max_total_num_tokens, - page_size=self.page_size, - dtype=self.kv_cache_dtype, - device=self.device, - kvcache=self.token_to_kv_pool, - need_sort=need_sort, + max_num_extend_tokens=max_num_extend_tokens, ) else: assert self.is_draft_worker @@ -1554,15 +1563,13 @@ class ModelRunner: ) return TRTLLMHAAttnBackend(self) - elif backend_str == "intel_amx": from sglang.srt.layers.attention.intel_amx_backend import ( IntelAMXAttnBackend, ) - logger.info(f"Intel AMX attention backend is enabled.") return IntelAMXAttnBackend(self) - elif self.server_args.attention_backend == "dual_chunk_flash_attn": + elif backend_str == "dual_chunk_flash_attn": from sglang.srt.layers.attention.dual_chunk_flashattention_backend import ( DualChunkFlashAttentionBackend, ) @@ -1606,6 +1613,7 @@ class ModelRunner: f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB" ) self.cuda_graph_runner = CudaGraphRunner(self) + after_mem = get_available_gpu_memory(self.device, self.gpu_id) self.cuda_graph_mem_usage = before_mem - after_mem logger.info( diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index bcdadbe11..ec649f479 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -68,6 +68,8 @@ class SamplingBatchInfo: @classmethod def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): + from sglang.srt.managers.schedule_batch import global_server_args_dict + reqs = batch.reqs device = batch.device temperatures = ( @@ -97,10 +99,11 @@ class SamplingBatchInfo: logit_bias[i, int(key)] = value # Check if any request has custom logit processor - has_custom_logit_processor = ( - batch.enable_custom_logit_processor # check the flag first. - and any(r.custom_logit_processor for r in reqs) # then check the requests. - ) + has_custom_logit_processor = global_server_args_dict[ + "enable_custom_logit_processor" + ] and any( # check the flag first. + r.custom_logit_processor for r in reqs + ) # then check the requests. if has_custom_logit_processor: # Merge the same type of custom logit processors together diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index b7e053fd9..0f2879fde 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -575,6 +575,7 @@ class ServerArgs: "Pipeline parallelism is incompatible with overlap schedule." ) + # Hicache if self.hicache_storage_backend == "mooncake": # to use mooncake storage backend, the following conditions must be met: self.hicache_io_backend = "kernel" @@ -1316,19 +1317,23 @@ class ServerArgs: # Kernel backend ATTN_BACKENDS = [ - "aiter", + # Common + "triton", + "torch_native", + # NVIDIA specific "cutlass_mla", "fa3", "flashinfer", "flashmla", - "intel_amx", - "torch_native", - "ascend", - "triton", "trtllm_mla", "trtllm_mha", "dual_chunk_flash_attn", + # AMD specific + "aiter", "wave", + # Other platforms + "intel_amx", + "ascend", ] parser.add_argument( "--attention-backend", diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 093d1f739..90f9b843c 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -21,7 +21,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { /* * From csrc/allreduce */ - m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta); m.def("register_graph_buffers", ®ister_graph_buffers); m.def("dispose", &dispose); @@ -46,6 +45,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.def("mscclpp_allreduce(int context, Tensor inp, Tensor! out, int nthreads, int nblocks) -> ()"); m.impl("mscclpp_allreduce", torch::kCUDA, &mscclpp_allreduce); + /* * From csrc/attention */ @@ -284,6 +284,12 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "page_size) -> ()"); m.impl("transfer_kv_direct", torch::kCUDA, &transfer_kv_direct); + /* + * From csrc/memory + */ + m.def("store_kv_cache(Tensor k_cache, Tensor v_cache, Tensor out_loc, Tensor k, Tensor v) -> ()"); + m.impl("store_kv_cache", &store_kv_cache); + /* * From csrc/moe/cutlass_moe/w4a8 */ @@ -390,13 +396,13 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.impl("convert_vertical_slash_indexes_mergehead", torch::kCUDA, &convert_vertical_slash_indexes_mergehead); /* - * From XGrammar + * From csrc/grammar */ m.def("apply_token_bitmask_inplace_cuda(Tensor logits, Tensor bitmask, Tensor? indices=None) -> ()"); m.impl("apply_token_bitmask_inplace_cuda", &ApplyTokenBitmaskInplace); /* - * From QServe + * From csrc/gemm (QServe) */ m.def( "qserve_w4a8_per_chn_gemm(Tensor _in_feats, Tensor _kernel, Tensor _wscales, Tensor _ascales, Tensor _w_szs, " @@ -413,12 +419,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { */ m.def("create_greenctx_stream_by_value(int smA, int smB, int device) -> int[]"); m.impl("create_greenctx_stream_by_value", &create_greenctx_stream_by_value); - - /* - * From csrc/memory - */ - m.def("store_kv_cache(Tensor k_cache, Tensor v_cache, Tensor out_loc, Tensor k, Tensor v) -> ()"); - m.impl("store_kv_cache", &store_kv_cache); } REGISTER_EXTENSION(common_ops) diff --git a/sgl-kernel/csrc/torch_extension_rocm.cc b/sgl-kernel/csrc/common_extension_rocm.cc similarity index 100% rename from sgl-kernel/csrc/torch_extension_rocm.cc rename to sgl-kernel/csrc/common_extension_rocm.cc diff --git a/sgl-kernel/setup_rocm.py b/sgl-kernel/setup_rocm.py index 47f59071f..a919d8f3b 100644 --- a/sgl-kernel/setup_rocm.py +++ b/sgl-kernel/setup_rocm.py @@ -47,7 +47,7 @@ sources = [ "csrc/moe/moe_align_kernel.cu", "csrc/moe/moe_topk_softmax_kernels.cu", "csrc/speculative/eagle_utils.cu", - "csrc/torch_extension_rocm.cc", + "csrc/common_extension_rocm.cc", ] cxx_flags = ["-O3"]