[fix]enable flashmla when using draft model P/D attention select (#11012)
This commit is contained in:
@@ -201,9 +201,10 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
||||
self.req_to_token.stride(0),
|
||||
self.cuda_graph_kv_indices.stride(0),
|
||||
)
|
||||
num_q_heads = self.num_q_heads * (self.num_draft_tokens or 1)
|
||||
mla_metadata, num_splits = get_mla_metadata(
|
||||
seq_lens.to(torch.int32),
|
||||
self.num_q_heads,
|
||||
num_q_heads,
|
||||
1,
|
||||
)
|
||||
self.cuda_graph_mla_metadata.copy_(mla_metadata)
|
||||
@@ -275,9 +276,10 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
||||
self.req_to_token.stride(0),
|
||||
self.cuda_graph_kv_indices.stride(0),
|
||||
)
|
||||
num_q_heads = self.num_q_heads * (self.num_draft_tokens or 1)
|
||||
mla_metadata, num_splits = get_mla_metadata(
|
||||
seq_lens.to(torch.int32),
|
||||
self.num_q_heads,
|
||||
num_q_heads,
|
||||
1,
|
||||
)
|
||||
self.cuda_graph_mla_metadata.copy_(mla_metadata)
|
||||
|
||||
Reference in New Issue
Block a user