From 2dd7d0c533deecd9e4fea682f1d13fd8e7e9b8a2 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 4 Mar 2025 05:38:24 -0800 Subject: [PATCH] Revert "Fix nightly-test CI" (#4065) --- .../srt/layers/attention/flashinfer_backend.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 98fff770e..cf33ee257 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -427,10 +427,7 @@ class FlashInferAttnBackend(AttentionBackend): else: o2, s2 = prefill_wrapper_paged.forward_return_lse( q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), - self._to_dtype( - forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), - q.dtype, - ), + forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), causal=False, sm_scale=layer.scaling, logits_soft_cap=layer.logit_cap, @@ -472,9 +469,7 @@ class FlashInferAttnBackend(AttentionBackend): o = decode_wrapper.forward( q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), - self._to_dtype( - forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), q.dtype - ), + forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), sm_scale=layer.scaling, logits_soft_cap=layer.logit_cap, k_scale=layer.k_scale, @@ -483,12 +478,6 @@ class FlashInferAttnBackend(AttentionBackend): return o.view(-1, layer.tp_q_head_num * layer.head_dim) - def _to_dtype(self, kv_tuple, dtype): - if kv_tuple[0].dtype != dtype: - return tuple(t.to(dtype) for t in kv_tuple) - else: - return kv_tuple - def _get_wrapper_idx(self, layer: RadixAttention): if self.num_wrappers == 1: return 0