Cutlass MLA decode - fix dtype error (#5868)

This commit is contained in:
Trevor Morris
2025-04-28 21:12:58 -07:00
committed by GitHub
parent 26fc32d168
commit 8d463fe351

View File

@@ -268,7 +268,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
reshape_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
o = cutlass_mla_decode(
q_nope_and_q_pe=reshape_q,
q_nope_and_q_pe=reshape_q.to(self.q_data_type),
kv_c_and_k_pe_cache=k_cache.view(-1, PAGE_SIZE, self.kv_cache_dim),
seq_lens=forward_batch.seq_lens.to(torch.int32),
page_table=self.forward_metadata.block_kv_indices,