Cutlass MLA decode - fix dtype error (#5868)
This commit is contained in:
@@ -268,7 +268,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
|||||||
reshape_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
reshape_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||||
|
|
||||||
o = cutlass_mla_decode(
|
o = cutlass_mla_decode(
|
||||||
q_nope_and_q_pe=reshape_q,
|
q_nope_and_q_pe=reshape_q.to(self.q_data_type),
|
||||||
kv_c_and_k_pe_cache=k_cache.view(-1, PAGE_SIZE, self.kv_cache_dim),
|
kv_c_and_k_pe_cache=k_cache.view(-1, PAGE_SIZE, self.kv_cache_dim),
|
||||||
seq_lens=forward_batch.seq_lens.to(torch.int32),
|
seq_lens=forward_batch.seq_lens.to(torch.int32),
|
||||||
page_table=self.forward_metadata.block_kv_indices,
|
page_table=self.forward_metadata.block_kv_indices,
|
||||||
|
|||||||
Reference in New Issue
Block a user