Support FlashMLA backend (#4472)

Co-authored-by: yinfan98 <1106310035@qq.com>
This commit is contained in:
lukec
2025-03-17 00:07:06 +08:00
committed by GitHub
parent 1b859295f4
commit a53fe428f9
6 changed files with 209 additions and 1 deletions

View File

@@ -173,6 +173,7 @@ class ServerArgs:
tool_call_parser: str = None
enable_hierarchical_cache: bool = False
enable_flashinfer_mla: bool = False
enable_flashmla: bool = False
flashinfer_mla_disable_ragged: bool = False
warmups: Optional[str] = None
@@ -227,6 +228,8 @@ class ServerArgs:
assert self.chunked_prefill_size % self.page_size == 0
if self.enable_flashmla is True:
assert self.page_size == 64, "FlashMLA only support page_size=64"
# Set cuda graph max batch size
if self.cuda_graph_max_bs is None:
# Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
@@ -753,6 +756,11 @@ class ServerArgs:
action="store_true",
help="Enable FlashInfer MLA optimization",
)
parser.add_argument(
"--enable-flashmla",
action="store_true",
help="Enable FlashMLA decode optimization",
)
parser.add_argument(
"--flashinfer-mla-disable-ragged",
action="store_true",