diff --git a/.gitmodules b/.gitmodules index 649301455..e69de29bb 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +0,0 @@ -[submodule "3rdparty/flashinfer"] - path = 3rdparty/flashinfer - url = https://github.com/flashinfer-ai/flashinfer.git diff --git a/3rdparty/flashinfer b/3rdparty/flashinfer deleted file mode 160000 index 88b9496e1..000000000 --- a/3rdparty/flashinfer +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 88b9496e1a726ddb353eb42887cfc0ab32c99460 diff --git a/docs/flashinfer.md b/docs/flashinfer.md index c929ba088..7ea1e1efc 100644 --- a/docs/flashinfer.md +++ b/docs/flashinfer.md @@ -5,13 +5,15 @@ It can be used in SGLang runtime to accelerate attention computation. ### Install flashinfer -Note: The compilation can take a very long time. +You can install flashinfer via pip as follows for CUDA 12.1. ```bash -git submodule update --init --recursive -pip install 3rdparty/flashinfer/python +pip install flashinfer -i https://flashinfer.ai/whl/cu121/ ``` +You can look for other CUDA versions in https://github.com/flashinfer-ai/flashinfer?tab=readme-ov-file#installation. If there is no desire version for your environment, +please build it from source (the compilation takes a long time). + ### Run a Server With Flashinfer Mode Add `--model-mode flashinfer` argument to enable flashinfer when launching a server. diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 9857ff4d3..9b9525ac1 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -98,12 +98,7 @@ class RadixAttention(nn.Module): o = input_metadata.prefill_wrapper.forward( q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), - input_metadata.qo_indptr, input_metadata.token_to_kv_pool.kv_data[self.layer_id], - input_metadata.kv_indptr, - input_metadata.kv_indices, - input_metadata.kv_last_page_len, - allow_fp16_qk_reduction=True, ) return o.view(-1, self.tp_q_head_num * self.head_dim) @@ -114,9 +109,6 @@ class RadixAttention(nn.Module): o = input_metadata.decode_wrapper.forward( q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), input_metadata.token_to_kv_pool.kv_data[self.layer_id], - input_metadata.kv_indptr, - input_metadata.kv_indices, - input_metadata.kv_last_page_len, ) return o.view(-1, self.tp_q_head_num * self.head_dim) diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/router/model_runner.py index a68e99b1c..8fffbd958 100644 --- a/python/sglang/srt/managers/router/model_runner.py +++ b/python/sglang/srt/managers/router/model_runner.py @@ -90,6 +90,11 @@ class InputMetadata: decode_wrapper = None def init_flashinfer_args(self, tp_size): + from flashinfer import ( + BatchDecodeWithPagedKVCacheWrapper, + BatchPrefillWithPagedKVCacheWrapper, + ) + self.kv_indptr = torch.zeros( (self.batch_size + 1,), dtype=torch.int32, device="cuda" ) @@ -107,11 +112,7 @@ class InputMetadata: (self.batch_size,), dtype=torch.int32, device="cuda" ) - from flashinfer.ops import ( - BatchDecodeWithPagedKVCacheWrapper, - BatchPrefillWithPagedKVCacheWrapper, - ) - + workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8, device="cuda") if ( self.forward_mode == ForwardMode.PREFILL or self.forward_mode == ForwardMode.EXTEND @@ -120,19 +121,21 @@ class InputMetadata: (self.batch_size + 1,), dtype=torch.int32, device="cuda" ) self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0) - self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper() + self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD") self.prefill_wrapper.begin_forward( self.qo_indptr, - self.batch_size, + self.kv_indptr, + self.kv_indices, + self.kv_last_page_len, self.model_runner.model_config.num_attention_heads // tp_size, self.model_runner.model_config.num_key_value_heads // tp_size, ) else: - self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper() + self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD") self.decode_wrapper.begin_forward( self.kv_indptr, + self.kv_indices, self.kv_last_page_len, - self.batch_size, self.model_runner.model_config.num_attention_heads // tp_size, self.model_runner.model_config.num_key_value_heads // tp_size, self.model_runner.model_config.head_dim,