[fix]enable flashmla when using draft model P/D attention select (#11012)

This commit is contained in:
Hank Han
2025-10-04 20:59:34 +08:00
committed by GitHub
parent d01b921482
commit 666da3d59f
3 changed files with 14 additions and 5 deletions

View File

@@ -201,9 +201,10 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0),
)
num_q_heads = self.num_q_heads * (self.num_draft_tokens or 1)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32),
self.num_q_heads,
num_q_heads,
1,
)
self.cuda_graph_mla_metadata.copy_(mla_metadata)
@@ -275,9 +276,10 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0),
)
num_q_heads = self.num_q_heads * (self.num_draft_tokens or 1)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32),
self.num_q_heads,
num_q_heads,
1,
)
self.cuda_graph_mla_metadata.copy_(mla_metadata)

View File

@@ -244,6 +244,7 @@ class EAGLEWorker(TpModelWorker):
if not is_blackwell()
else self._create_triton_prefill_backend
),
"flashmla": self._create_flashmla_prefill_backend,
"trtllm_mha": self._create_trtllm_mha_prefill_backend,
"trtllm_mla": self._create_trtllm_mla_prefill_backend,
}
@@ -383,6 +384,12 @@ class EAGLEWorker(TpModelWorker):
return TRTLLMMLABackend(self.draft_model_runner, skip_prefill=False)
def _create_flashmla_prefill_backend(self):
logger.warning(
"flashmla prefill backend is not yet supported for draft extend."
)
return None
def init_cuda_graphs(self):
"""Capture cuda graphs."""
self.cuda_graph_runner = None