[fix]enable flashmla when using draft model P/D attention select (#11012)

This commit is contained in:
Hank Han
2025-10-04 20:59:34 +08:00
committed by GitHub
parent d01b921482
commit 666da3d59f
3 changed files with 14 additions and 5 deletions

View File

@@ -201,9 +201,10 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0),
)
num_q_heads = self.num_q_heads * (self.num_draft_tokens or 1)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32),
self.num_q_heads,
num_q_heads,
1,
)
self.cuda_graph_mla_metadata.copy_(mla_metadata)
@@ -275,9 +276,10 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0),
)
num_q_heads = self.num_q_heads * (self.num_draft_tokens or 1)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32),
self.num_q_heads,
num_q_heads,
1,
)
self.cuda_graph_mla_metadata.copy_(mla_metadata)