Fix cutlass MLA gets almost zero accuracy (#6998)
This commit is contained in:
@@ -157,7 +157,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
||||
):
|
||||
if forward_mode.is_decode_or_idle():
|
||||
if spec_info is None:
|
||||
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
|
||||
max_seqlen_pad = self.cuda_graph_kv_indices.shape[1]
|
||||
|
||||
create_flashmla_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
@@ -169,12 +169,6 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
||||
self.cuda_graph_kv_indices.stride(0),
|
||||
PAGE_SIZE,
|
||||
)
|
||||
workspace_size = cutlass_mla_get_workspace_size(
|
||||
max_seqlen_pad * PAGE_SIZE, bs
|
||||
)
|
||||
self.cuda_graph_mla_workspace = torch.empty(
|
||||
workspace_size, device="cuda", dtype=torch.uint8
|
||||
)
|
||||
self.forward_metadata = CutlassMLADecodeMetadata(
|
||||
self.cuda_graph_mla_workspace,
|
||||
self.cuda_graph_kv_indices[:bs, :max_seqlen_pad],
|
||||
@@ -205,8 +199,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
||||
if forward_mode.is_decode_or_idle():
|
||||
assert seq_lens_cpu is not None
|
||||
seq_lens = seq_lens[:bs]
|
||||
seq_lens_cpu = seq_lens_cpu[:bs]
|
||||
max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)
|
||||
|
||||
create_flashmla_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices[:bs],
|
||||
@@ -217,16 +210,6 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
||||
self.cuda_graph_kv_indices.stride(0),
|
||||
PAGE_SIZE,
|
||||
)
|
||||
workspace_size = cutlass_mla_get_workspace_size(
|
||||
max_seqlen_pad * PAGE_SIZE, bs
|
||||
)
|
||||
self.cuda_graph_mla_workspace = torch.empty(
|
||||
workspace_size, device="cuda", dtype=torch.uint8
|
||||
)
|
||||
self.forward_metadata.workspace = self.cuda_graph_mla_workspace
|
||||
self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[
|
||||
:bs, :max_seqlen_pad
|
||||
]
|
||||
else:
|
||||
super().init_forward_metadata_replay_cuda_graph(
|
||||
bs,
|
||||
|
||||
Reference in New Issue
Block a user