From 666da3d59fa6dd78dd4619a9989d7445396d3a90 Mon Sep 17 00:00:00 2001 From: Hank Han <54751605+HanHan009527@users.noreply.github.com> Date: Sat, 4 Oct 2025 20:59:34 +0800 Subject: [PATCH] [fix]enable flashmla when using draft model P/D attention select (#11012) --- python/sglang/srt/layers/attention/flashmla_backend.py | 6 ++++-- python/sglang/srt/speculative/eagle_worker.py | 7 +++++++ test/srt/test_flashmla.py | 6 +++--- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashmla_backend.py b/python/sglang/srt/layers/attention/flashmla_backend.py index 134380f12..d85222806 100644 --- a/python/sglang/srt/layers/attention/flashmla_backend.py +++ b/python/sglang/srt/layers/attention/flashmla_backend.py @@ -201,9 +201,10 @@ class FlashMLABackend(FlashInferMLAAttnBackend): self.req_to_token.stride(0), self.cuda_graph_kv_indices.stride(0), ) + num_q_heads = self.num_q_heads * (self.num_draft_tokens or 1) mla_metadata, num_splits = get_mla_metadata( seq_lens.to(torch.int32), - self.num_q_heads, + num_q_heads, 1, ) self.cuda_graph_mla_metadata.copy_(mla_metadata) @@ -275,9 +276,10 @@ class FlashMLABackend(FlashInferMLAAttnBackend): self.req_to_token.stride(0), self.cuda_graph_kv_indices.stride(0), ) + num_q_heads = self.num_q_heads * (self.num_draft_tokens or 1) mla_metadata, num_splits = get_mla_metadata( seq_lens.to(torch.int32), - self.num_q_heads, + num_q_heads, 1, ) self.cuda_graph_mla_metadata.copy_(mla_metadata) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 82bfaa276..9df1ef973 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -244,6 +244,7 @@ class EAGLEWorker(TpModelWorker): if not is_blackwell() else self._create_triton_prefill_backend ), + "flashmla": self._create_flashmla_prefill_backend, "trtllm_mha": self._create_trtllm_mha_prefill_backend, "trtllm_mla": self._create_trtllm_mla_prefill_backend, } @@ -383,6 +384,12 @@ class EAGLEWorker(TpModelWorker): return TRTLLMMLABackend(self.draft_model_runner, skip_prefill=False) + def _create_flashmla_prefill_backend(self): + logger.warning( + "flashmla prefill backend is not yet supported for draft extend." + ) + return None + def init_cuda_graphs(self): """Capture cuda graphs.""" self.cuda_graph_runner = None diff --git a/test/srt/test_flashmla.py b/test/srt/test_flashmla.py index 681c9b8eb..cfefd9a4a 100644 --- a/test/srt/test_flashmla.py +++ b/test/srt/test_flashmla.py @@ -103,11 +103,11 @@ class TestFlashMLAMTP(CustomTestCase): "--speculative-draft-model-path", "lmsys/sglang-ci-dsv3-test-NextN", "--speculative-num-steps", - "1", + "2", "--speculative-eagle-topk", "1", "--speculative-num-draft-tokens", - "2", + "3", "--attention-backend", "flashmla", ] @@ -146,7 +146,7 @@ class TestFlashMLAMTP(CustomTestCase): "avg_spec_accept_length" ] print(f"{avg_spec_accept_length=}") - self.assertGreater(avg_spec_accept_length, 1.8) + self.assertGreater(avg_spec_accept_length, 2.4) if __name__ == "__main__":