Cutlass MLA decode - fix dtype error (#5868)
This commit is contained in:
@@ -268,7 +268,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
||||
reshape_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
|
||||
o = cutlass_mla_decode(
|
||||
q_nope_and_q_pe=reshape_q,
|
||||
q_nope_and_q_pe=reshape_q.to(self.q_data_type),
|
||||
kv_c_and_k_pe_cache=k_cache.view(-1, PAGE_SIZE, self.kv_cache_dim),
|
||||
seq_lens=forward_batch.seq_lens.to(torch.int32),
|
||||
page_table=self.forward_metadata.block_kv_indices,
|
||||
|
||||
Reference in New Issue
Block a user