diff --git a/python/sglang/srt/layers/attention/cutlass_mla_backend.py b/python/sglang/srt/layers/attention/cutlass_mla_backend.py index 9ff5dfabf..8b0923ce3 100644 --- a/python/sglang/srt/layers/attention/cutlass_mla_backend.py +++ b/python/sglang/srt/layers/attention/cutlass_mla_backend.py @@ -268,7 +268,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend): reshape_q = q.view(-1, layer.tp_q_head_num, layer.head_dim) o = cutlass_mla_decode( - q_nope_and_q_pe=reshape_q, + q_nope_and_q_pe=reshape_q.to(self.q_data_type), kv_c_and_k_pe_cache=k_cache.view(-1, PAGE_SIZE, self.kv_cache_dim), seq_lens=forward_batch.seq_lens.to(torch.int32), page_table=self.forward_metadata.block_kv_indices,