From 38af4f68a90cd0df6f77e81fae24829cced36f88 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 14 Jun 2025 22:49:41 -0700 Subject: [PATCH] Fix grammar abort & Minor style fixes (#7204) --- .../srt/layers/attention/flashinfer_mla_backend.py | 11 +++++------ .../sglang/srt/layers/attention/flashmla_backend.py | 4 +--- python/sglang/srt/layers/attention/triton_backend.py | 9 +++++---- .../layers/attention/triton_ops/decode_attention.py | 4 ++-- python/sglang/srt/layers/radix_attention.py | 5 ++--- python/sglang/srt/managers/scheduler.py | 3 ++- python/sglang/srt/mem_cache/memory_pool.py | 3 --- .../eagle_draft_extend_cuda_graph_runner.py | 4 ++-- 8 files changed, 19 insertions(+), 24 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py index 275518b6c..c4192a715 100644 --- a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py @@ -15,7 +15,6 @@ from functools import partial from typing import TYPE_CHECKING, Callable, Optional, Union import torch -import triton if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1": import logging @@ -33,7 +32,7 @@ from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput -from sglang.srt.utils import is_flashinfer_available +from sglang.srt.utils import is_flashinfer_available, next_power_of_2 if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention @@ -756,7 +755,7 @@ class FlashInferMLAMultiStepDraftBackend: if topk > 1: raise ValueError( - f"Currently Flashinfer MLA only supports topk=1 for speculative decoding" + "Currently Flashinfer MLA only supports topk=1 for speculative decoding" ) self.topk = topk self.speculative_num_steps = speculative_num_steps @@ -815,9 +814,9 @@ class FlashInferMLAMultiStepDraftBackend: self.pool_len, kv_indices_buffer.shape[1], self.kv_indptr.shape[1], - triton.next_power_of_2(num_seqs), - triton.next_power_of_2(self.speculative_num_steps), - triton.next_power_of_2(bs), + next_power_of_2(num_seqs), + next_power_of_2(self.speculative_num_steps), + next_power_of_2(bs), ) assert forward_batch.spec_info is not None diff --git a/python/sglang/srt/layers/attention/flashmla_backend.py b/python/sglang/srt/layers/attention/flashmla_backend.py index 1198ddda4..b74e03c6e 100644 --- a/python/sglang/srt/layers/attention/flashmla_backend.py +++ b/python/sglang/srt/layers/attention/flashmla_backend.py @@ -464,11 +464,9 @@ class FlashMLAMultiStepDraftBackend: topk: int, speculative_num_steps: int, ): - from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices - if topk > 1: raise ValueError( - f"Currently FlashMLA only supports topk=1 for speculative decoding" + "Currently FlashMLA only supports topk=1 for speculative decoding" ) self.topk = topk self.speculative_num_steps = speculative_num_steps diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 38c239177..0a3aef9c3 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -12,7 +12,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.radix_attention import AttentionType from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode -from sglang.srt.utils import get_bool_env_var, get_device_core_count +from sglang.srt.utils import get_bool_env_var, get_device_core_count, next_power_of_2 if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention @@ -766,6 +766,7 @@ class TritonMultiStepDraftBackend: self.device = model_runner.device # Cached variables for generate_draft_decode_kv_indices self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1] + self.page_size = model_runner.server_args.page_size def common_template( self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int @@ -788,9 +789,9 @@ class TritonMultiStepDraftBackend: self.pool_len, kv_indices_buffer.shape[1], self.kv_indptr.shape[1], - triton.next_power_of_2(num_seqs), - triton.next_power_of_2(self.speculative_num_steps), - triton.next_power_of_2(bs), + next_power_of_2(num_seqs), + next_power_of_2(self.speculative_num_steps), + next_power_of_2(bs), ) for i in range(self.speculative_num_steps): diff --git a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py index 1f9212834..b334d851f 100644 --- a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py @@ -708,7 +708,7 @@ def decode_attention_fwd( num_kv_splits, max_kv_splits, sm_scale, - logit_cap, + logit_cap=logit_cap, ) else: # GQA/MQA/MLA @@ -724,5 +724,5 @@ def decode_attention_fwd( num_kv_splits, max_kv_splits, sm_scale, - logit_cap, + logit_cap=logit_cap, ) diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index c57ae2aae..322704ca9 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -18,7 +18,6 @@ from typing import Optional from torch import nn -from sglang.srt.layers.linear import UnquantizedLinearMethod from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -52,9 +51,9 @@ class RadixAttention(nn.Module): sliding_window_size: int = -1, is_cross_attention: bool = False, quant_config: Optional[QuantizationConfig] = None, - attn_type=AttentionType.DECODER, - prefix: str = "", + attn_type: AttentionType = AttentionType.DECODER, use_irope: bool = False, + prefix: str = "", ): super().__init__() self.tp_q_head_num = num_heads diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index e327f7bae..c8852f0be 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -2108,7 +2108,8 @@ class Scheduler( # In this case, we change the input_ids to be only one token to make this prefill cheap. if req.rid.startswith(recv_req.rid): logger.debug(f"Abort grammar queue request. {req.rid=}") - req.grammar.cancel() + if req.grammar: + req.grammar.cancel() req.set_finish_with_abort("Aborted by AbortReq.") # Delete requests in the running batch diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index b6dd8dcdd..bac310da3 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -141,15 +141,12 @@ class KVCache(abc.ABC): ) -> None: raise NotImplementedError() - @abc.abstractmethod def get_flat_data(self, indices): raise NotImplementedError() - @abc.abstractmethod def transfer(self, indices, flat_data): raise NotImplementedError() - @abc.abstractmethod def transfer_per_layer(self, indices, flat_data, layer_id): raise NotImplementedError() diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index ee1d2098d..b8fb11974 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -86,8 +86,8 @@ class EAGLEDraftExtendCudaGraphRunner: self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32) self.extend_seq_lens = torch.ones((self.max_bs,), dtype=torch.int32) - self.accept_length = ( - torch.ones((self.max_bs,), dtype=torch.int32) * self.num_tokens_per_bs + self.accept_length = torch.full( + (self.max_bs,), self.num_tokens_per_bs, dtype=torch.int32 ) # Capture