diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 7063b6f4b..a3b890219 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -422,7 +422,10 @@ class FlashInferAttnBackend(AttentionBackend): else: o2, s2 = prefill_wrapper_paged.forward_return_lse( q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), - forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), + self._to_dtype( + forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), + q.dtype, + ), causal=False, sm_scale=layer.scaling, logits_soft_cap=layer.logit_cap, @@ -464,7 +467,9 @@ class FlashInferAttnBackend(AttentionBackend): o = decode_wrapper.forward( q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), - forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), + self._to_dtype( + forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), q.dtype + ), sm_scale=layer.scaling, logits_soft_cap=layer.logit_cap, k_scale=layer.k_scale, @@ -473,6 +478,12 @@ class FlashInferAttnBackend(AttentionBackend): return o.view(-1, layer.tp_q_head_num * layer.head_dim) + def _to_dtype(self, kv_tuple, dtype): + if kv_tuple[0].dtype != dtype: + return tuple(t.to(dtype) for t in kv_tuple) + else: + return kv_tuple + def _get_wrapper_idx(self, layer: RadixAttention): if self.num_wrappers == 1: return 0