From 21615cc3fe7cb61a9030e302548e8b2835fcebd5 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 16 Jun 2025 01:03:13 -0700 Subject: [PATCH] Minor style and doc fix (#7228) --- docs/backend/attention_backend.md | 15 +++++++++------ .../srt/layers/attention/cutlass_mla_backend.py | 3 --- .../layers/attention/flashattention_backend.py | 1 - .../srt/layers/attention/flashmla_backend.py | 5 +---- 4 files changed, 10 insertions(+), 14 deletions(-) diff --git a/docs/backend/attention_backend.md b/docs/backend/attention_backend.md index 48e08d847..ad5ddfde9 100644 --- a/docs/backend/attention_backend.md +++ b/docs/backend/attention_backend.md @@ -3,13 +3,16 @@ ## Supporting matrix for different attention backends | **Backend** | **Page Size > 1** | **Spec Decoding** | **MLA** | **Sliding Window** | **MultiModal** | -|--------------------------|-------------------|-------------------|--------|--------------------|------------| -| **FlashInfer** | ✅ | ✅ | ✅ | ✅ | ✅ | -| **FA3** | ✅ | ✅ | ✅ | ✅ | ✅ | -| **Triton** | ❌ | ✅ | ✅ | ❌ | ❌ | -| **Torch Native** | ❌ | ❌ | ❌ | ❌ | ❌ | -| **FlashMLA** | ✅ | ✅ | ✅ | ❌ | ❌ | +|--------------------------|-------------------|-------------------|---------|--------------------|----------------| +| **FlashInfer** | ❌ | ✅ | ✅ | ✅ | ✅ | +| **FA3** | ✅ | ✅ | ✅ | ✅ | ✅ | +| **Triton** | ❌ | ✅ | ✅ | ✅ | ❌ | +| **Torch Native** | ❌ | ❌ | ❌ | ❌ | ❌ | +| **FlashMLA** | ✅ | ✅ | ✅ | ❌ | ❌ | +Note: Every kernel backend is compatible with a page size > 1 by specifying an argument such as `--page-size 16`. +This is because a page size of 16 can be converted to a page size of 1 in the kernel backend. +The "❌" and "✅" symbols in the table above under "Page Size > 1" indicate whether the kernel actually operates with a page size greater than 1, rather than treating a page size of 16 as a page size of 1. ## User guide diff --git a/python/sglang/srt/layers/attention/cutlass_mla_backend.py b/python/sglang/srt/layers/attention/cutlass_mla_backend.py index 8b3d18602..65fff548e 100644 --- a/python/sglang/srt/layers/attention/cutlass_mla_backend.py +++ b/python/sglang/srt/layers/attention/cutlass_mla_backend.py @@ -11,8 +11,6 @@ from typing import TYPE_CHECKING, Optional, Union import torch import triton -from sglang.global_config import global_config -from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton from sglang.srt.layers.dp_attention import get_attention_tp_size @@ -22,7 +20,6 @@ from sglang.srt.utils import is_cuda if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner - from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.spec_info import SpecInfo _is_cuda = is_cuda() diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 8fec69f12..97eead3af 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -11,7 +11,6 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput -from sglang.srt.utils import get_compiler_backend if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention diff --git a/python/sglang/srt/layers/attention/flashmla_backend.py b/python/sglang/srt/layers/attention/flashmla_backend.py index b74e03c6e..cad4c1950 100644 --- a/python/sglang/srt/layers/attention/flashmla_backend.py +++ b/python/sglang/srt/layers/attention/flashmla_backend.py @@ -14,8 +14,6 @@ import torch import triton from flash_mla import flash_mla_with_kvcache, get_mla_metadata -from sglang.global_config import global_config -from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton from sglang.srt.layers.dp_attention import get_attention_tp_size @@ -24,7 +22,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner - from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.spec_info import SpecInfo @@ -330,7 +327,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ) def get_cuda_graph_seq_len_fill_value(self): - return 1024 + return 1 def forward_decode( self,