Fix cutlass MLA gets almost zero accuracy (#6998)
This commit is contained in:
@@ -157,7 +157,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
|||||||
):
|
):
|
||||||
if forward_mode.is_decode_or_idle():
|
if forward_mode.is_decode_or_idle():
|
||||||
if spec_info is None:
|
if spec_info is None:
|
||||||
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
|
max_seqlen_pad = self.cuda_graph_kv_indices.shape[1]
|
||||||
|
|
||||||
create_flashmla_kv_indices_triton[(bs,)](
|
create_flashmla_kv_indices_triton[(bs,)](
|
||||||
self.req_to_token,
|
self.req_to_token,
|
||||||
@@ -169,12 +169,6 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
|||||||
self.cuda_graph_kv_indices.stride(0),
|
self.cuda_graph_kv_indices.stride(0),
|
||||||
PAGE_SIZE,
|
PAGE_SIZE,
|
||||||
)
|
)
|
||||||
workspace_size = cutlass_mla_get_workspace_size(
|
|
||||||
max_seqlen_pad * PAGE_SIZE, bs
|
|
||||||
)
|
|
||||||
self.cuda_graph_mla_workspace = torch.empty(
|
|
||||||
workspace_size, device="cuda", dtype=torch.uint8
|
|
||||||
)
|
|
||||||
self.forward_metadata = CutlassMLADecodeMetadata(
|
self.forward_metadata = CutlassMLADecodeMetadata(
|
||||||
self.cuda_graph_mla_workspace,
|
self.cuda_graph_mla_workspace,
|
||||||
self.cuda_graph_kv_indices[:bs, :max_seqlen_pad],
|
self.cuda_graph_kv_indices[:bs, :max_seqlen_pad],
|
||||||
@@ -205,8 +199,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
|||||||
if forward_mode.is_decode_or_idle():
|
if forward_mode.is_decode_or_idle():
|
||||||
assert seq_lens_cpu is not None
|
assert seq_lens_cpu is not None
|
||||||
seq_lens = seq_lens[:bs]
|
seq_lens = seq_lens[:bs]
|
||||||
seq_lens_cpu = seq_lens_cpu[:bs]
|
|
||||||
max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)
|
|
||||||
create_flashmla_kv_indices_triton[(bs,)](
|
create_flashmla_kv_indices_triton[(bs,)](
|
||||||
self.req_to_token,
|
self.req_to_token,
|
||||||
req_pool_indices[:bs],
|
req_pool_indices[:bs],
|
||||||
@@ -217,16 +210,6 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
|||||||
self.cuda_graph_kv_indices.stride(0),
|
self.cuda_graph_kv_indices.stride(0),
|
||||||
PAGE_SIZE,
|
PAGE_SIZE,
|
||||||
)
|
)
|
||||||
workspace_size = cutlass_mla_get_workspace_size(
|
|
||||||
max_seqlen_pad * PAGE_SIZE, bs
|
|
||||||
)
|
|
||||||
self.cuda_graph_mla_workspace = torch.empty(
|
|
||||||
workspace_size, device="cuda", dtype=torch.uint8
|
|
||||||
)
|
|
||||||
self.forward_metadata.workspace = self.cuda_graph_mla_workspace
|
|
||||||
self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[
|
|
||||||
:bs, :max_seqlen_pad
|
|
||||||
]
|
|
||||||
else:
|
else:
|
||||||
super().init_forward_metadata_replay_cuda_graph(
|
super().init_forward_metadata_replay_cuda_graph(
|
||||||
bs,
|
bs,
|
||||||
|
|||||||
Reference in New Issue
Block a user