[feat] Support different attention backends for prefill and decode (#6338)

Co-authored-by: tianqilin.99 <tianqilin.99@bytedance.com>
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
This commit is contained in:
Qiaolin Yu
2025-07-27 20:42:29 -07:00
committed by GitHub
parent fe6a445d1e
commit 2810338401
9 changed files with 350 additions and 29 deletions

View File

@@ -151,6 +151,8 @@ class ServerArgs:
# Kernel backend
attention_backend: Optional[str] = None
decode_attention_backend: Optional[str] = None
prefill_attention_backend: Optional[str] = None
sampling_backend: Optional[str] = None
grammar_backend: Optional[str] = None
mm_attention_backend: Optional[str] = None
@@ -387,13 +389,19 @@ class ServerArgs:
)
self.page_size = 128
if self.attention_backend == "flashmla":
if (
self.attention_backend == "flashmla"
or self.decode_attention_backend == "flashmla"
):
logger.warning(
"FlashMLA only supports a page_size of 64, change page_size to 64."
)
self.page_size = 64
if self.attention_backend == "cutlass_mla":
if (
self.attention_backend == "cutlass_mla"
or self.decode_attention_backend == "cutlass_mla"
):
logger.warning(
"Cutlass MLA only supports a page_size of 128, change page_size to 128."
)
@@ -1213,6 +1221,35 @@ class ServerArgs:
default=ServerArgs.attention_backend,
help="Choose the kernels for attention layers.",
)
parser.add_argument(
"--decode-attention-backend",
type=str,
choices=[
"flashinfer",
"triton",
"torch_native",
"fa3",
"flashmla",
"cutlass_mla",
],
default=ServerArgs.decode_attention_backend,
help="Choose the kernels for decode attention layers (have priority over --attention-backend).",
)
parser.add_argument(
"--prefill-attention-backend",
type=str,
choices=[
"flashinfer",
"triton",
"torch_native",
"fa3",
"flashmla",
"cutlass_mla",
],
default=ServerArgs.prefill_attention_backend,
help="Choose the kernels for prefill attention layers (have priority over --attention-backend).",
)
parser.add_argument(
"--sampling-backend",
type=str,