From d9ac639202fdc97f42fe41ff75a604089a7cac37 Mon Sep 17 00:00:00 2001 From: Yueyang Pan Date: Tue, 2 Jul 2024 07:08:39 +0200 Subject: [PATCH] Fix flashinfer version (#576) --- python/sglang/srt/layers/radix_attention.py | 12 ++++-------- python/sglang/srt/server.py | 2 +- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 983eff0e3..824fee8bc 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -30,12 +30,8 @@ class RadixAttention(nn.Module): self.prefill_forward = self.prefill_forward_flashinfer self.extend_forward = self.prefill_forward_flashinfer self.decode_forward = self.decode_forward_flashinfer - # flashinfer only accepts a boolean logit_cap argument - if logit_cap > 0: - assert logit_cap == 30 - self.logit_cap = True - else: - self.logit_cap = False + # flashinfer now accepts float logit_cap argument + self.logit_cap = logit_cap if logit_cap > 0 else 0 else: self.prefill_forward = self.prefill_forward_triton self.extend_forward = self.extend_forward_triton @@ -110,7 +106,7 @@ class RadixAttention(nn.Module): o = input_metadata.flashinfer_prefill_wrapper.forward( q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), input_metadata.token_to_kv_pool.kv_data[self.layer_id], - logits_cap=self.logit_cap, + logits_soft_cap=self.logit_cap, ) return o.view(-1, self.tp_q_head_num * self.head_dim) @@ -121,7 +117,7 @@ class RadixAttention(nn.Module): o = input_metadata.flashinfer_decode_wrapper.forward( q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), input_metadata.token_to_kv_pool.kv_data[self.layer_id], - logits_cap=self.logit_cap, + logits_soft_cap=self.logit_cap, ) return o.view(-1, self.tp_q_head_num * self.head_dim) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index fb19a0348..3f410887c 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -152,7 +152,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg if server_args.disable_disk_cache: disable_cache() if server_args.enable_flashinfer: - assert_pkg_version("flashinfer", "0.0.5") + assert_pkg_version("flashinfer", "0.0.7") if server_args.chat_template: # TODO: replace this with huggingface transformers template load_chat_template_for_openai_api(server_args.chat_template)