[fix]enable flashmla when using draft model P/D attention select (#11012)
This commit is contained in:
@@ -201,9 +201,10 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|||||||
self.req_to_token.stride(0),
|
self.req_to_token.stride(0),
|
||||||
self.cuda_graph_kv_indices.stride(0),
|
self.cuda_graph_kv_indices.stride(0),
|
||||||
)
|
)
|
||||||
|
num_q_heads = self.num_q_heads * (self.num_draft_tokens or 1)
|
||||||
mla_metadata, num_splits = get_mla_metadata(
|
mla_metadata, num_splits = get_mla_metadata(
|
||||||
seq_lens.to(torch.int32),
|
seq_lens.to(torch.int32),
|
||||||
self.num_q_heads,
|
num_q_heads,
|
||||||
1,
|
1,
|
||||||
)
|
)
|
||||||
self.cuda_graph_mla_metadata.copy_(mla_metadata)
|
self.cuda_graph_mla_metadata.copy_(mla_metadata)
|
||||||
@@ -275,9 +276,10 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|||||||
self.req_to_token.stride(0),
|
self.req_to_token.stride(0),
|
||||||
self.cuda_graph_kv_indices.stride(0),
|
self.cuda_graph_kv_indices.stride(0),
|
||||||
)
|
)
|
||||||
|
num_q_heads = self.num_q_heads * (self.num_draft_tokens or 1)
|
||||||
mla_metadata, num_splits = get_mla_metadata(
|
mla_metadata, num_splits = get_mla_metadata(
|
||||||
seq_lens.to(torch.int32),
|
seq_lens.to(torch.int32),
|
||||||
self.num_q_heads,
|
num_q_heads,
|
||||||
1,
|
1,
|
||||||
)
|
)
|
||||||
self.cuda_graph_mla_metadata.copy_(mla_metadata)
|
self.cuda_graph_mla_metadata.copy_(mla_metadata)
|
||||||
|
|||||||
@@ -244,6 +244,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
if not is_blackwell()
|
if not is_blackwell()
|
||||||
else self._create_triton_prefill_backend
|
else self._create_triton_prefill_backend
|
||||||
),
|
),
|
||||||
|
"flashmla": self._create_flashmla_prefill_backend,
|
||||||
"trtllm_mha": self._create_trtllm_mha_prefill_backend,
|
"trtllm_mha": self._create_trtllm_mha_prefill_backend,
|
||||||
"trtllm_mla": self._create_trtllm_mla_prefill_backend,
|
"trtllm_mla": self._create_trtllm_mla_prefill_backend,
|
||||||
}
|
}
|
||||||
@@ -383,6 +384,12 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
|
|
||||||
return TRTLLMMLABackend(self.draft_model_runner, skip_prefill=False)
|
return TRTLLMMLABackend(self.draft_model_runner, skip_prefill=False)
|
||||||
|
|
||||||
|
def _create_flashmla_prefill_backend(self):
|
||||||
|
logger.warning(
|
||||||
|
"flashmla prefill backend is not yet supported for draft extend."
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
def init_cuda_graphs(self):
|
def init_cuda_graphs(self):
|
||||||
"""Capture cuda graphs."""
|
"""Capture cuda graphs."""
|
||||||
self.cuda_graph_runner = None
|
self.cuda_graph_runner = None
|
||||||
|
|||||||
@@ -103,11 +103,11 @@ class TestFlashMLAMTP(CustomTestCase):
|
|||||||
"--speculative-draft-model-path",
|
"--speculative-draft-model-path",
|
||||||
"lmsys/sglang-ci-dsv3-test-NextN",
|
"lmsys/sglang-ci-dsv3-test-NextN",
|
||||||
"--speculative-num-steps",
|
"--speculative-num-steps",
|
||||||
"1",
|
"2",
|
||||||
"--speculative-eagle-topk",
|
"--speculative-eagle-topk",
|
||||||
"1",
|
"1",
|
||||||
"--speculative-num-draft-tokens",
|
"--speculative-num-draft-tokens",
|
||||||
"2",
|
"3",
|
||||||
"--attention-backend",
|
"--attention-backend",
|
||||||
"flashmla",
|
"flashmla",
|
||||||
]
|
]
|
||||||
@@ -146,7 +146,7 @@ class TestFlashMLAMTP(CustomTestCase):
|
|||||||
"avg_spec_accept_length"
|
"avg_spec_accept_length"
|
||||||
]
|
]
|
||||||
print(f"{avg_spec_accept_length=}")
|
print(f"{avg_spec_accept_length=}")
|
||||||
self.assertGreater(avg_spec_accept_length, 1.8)
|
self.assertGreater(avg_spec_accept_length, 2.4)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user