[Submodule] Change FlashInfer to import (#156)
This commit is contained in:
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -1,3 +0,0 @@
|
|||||||
[submodule "3rdparty/flashinfer"]
|
|
||||||
path = 3rdparty/flashinfer
|
|
||||||
url = https://github.com/flashinfer-ai/flashinfer.git
|
|
||||||
|
|||||||
1
3rdparty/flashinfer
vendored
1
3rdparty/flashinfer
vendored
Submodule 3rdparty/flashinfer deleted from 88b9496e1a
@@ -5,13 +5,15 @@ It can be used in SGLang runtime to accelerate attention computation.
|
|||||||
|
|
||||||
### Install flashinfer
|
### Install flashinfer
|
||||||
|
|
||||||
Note: The compilation can take a very long time.
|
You can install flashinfer via pip as follows for CUDA 12.1.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git submodule update --init --recursive
|
pip install flashinfer -i https://flashinfer.ai/whl/cu121/
|
||||||
pip install 3rdparty/flashinfer/python
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
You can look for other CUDA versions in https://github.com/flashinfer-ai/flashinfer?tab=readme-ov-file#installation. If there is no desire version for your environment,
|
||||||
|
please build it from source (the compilation takes a long time).
|
||||||
|
|
||||||
### Run a Server With Flashinfer Mode
|
### Run a Server With Flashinfer Mode
|
||||||
|
|
||||||
Add `--model-mode flashinfer` argument to enable flashinfer when launching a server.
|
Add `--model-mode flashinfer` argument to enable flashinfer when launching a server.
|
||||||
|
|||||||
@@ -98,12 +98,7 @@ class RadixAttention(nn.Module):
|
|||||||
|
|
||||||
o = input_metadata.prefill_wrapper.forward(
|
o = input_metadata.prefill_wrapper.forward(
|
||||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||||
input_metadata.qo_indptr,
|
|
||||||
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
|
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
|
||||||
input_metadata.kv_indptr,
|
|
||||||
input_metadata.kv_indices,
|
|
||||||
input_metadata.kv_last_page_len,
|
|
||||||
allow_fp16_qk_reduction=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
||||||
@@ -114,9 +109,6 @@ class RadixAttention(nn.Module):
|
|||||||
o = input_metadata.decode_wrapper.forward(
|
o = input_metadata.decode_wrapper.forward(
|
||||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||||
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
|
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
|
||||||
input_metadata.kv_indptr,
|
|
||||||
input_metadata.kv_indices,
|
|
||||||
input_metadata.kv_last_page_len,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
||||||
|
|||||||
@@ -90,6 +90,11 @@ class InputMetadata:
|
|||||||
decode_wrapper = None
|
decode_wrapper = None
|
||||||
|
|
||||||
def init_flashinfer_args(self, tp_size):
|
def init_flashinfer_args(self, tp_size):
|
||||||
|
from flashinfer import (
|
||||||
|
BatchDecodeWithPagedKVCacheWrapper,
|
||||||
|
BatchPrefillWithPagedKVCacheWrapper,
|
||||||
|
)
|
||||||
|
|
||||||
self.kv_indptr = torch.zeros(
|
self.kv_indptr = torch.zeros(
|
||||||
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||||
)
|
)
|
||||||
@@ -107,11 +112,7 @@ class InputMetadata:
|
|||||||
(self.batch_size,), dtype=torch.int32, device="cuda"
|
(self.batch_size,), dtype=torch.int32, device="cuda"
|
||||||
)
|
)
|
||||||
|
|
||||||
from flashinfer.ops import (
|
workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8, device="cuda")
|
||||||
BatchDecodeWithPagedKVCacheWrapper,
|
|
||||||
BatchPrefillWithPagedKVCacheWrapper,
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.forward_mode == ForwardMode.PREFILL
|
self.forward_mode == ForwardMode.PREFILL
|
||||||
or self.forward_mode == ForwardMode.EXTEND
|
or self.forward_mode == ForwardMode.EXTEND
|
||||||
@@ -120,19 +121,21 @@ class InputMetadata:
|
|||||||
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||||
)
|
)
|
||||||
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
|
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
|
||||||
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper()
|
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD")
|
||||||
self.prefill_wrapper.begin_forward(
|
self.prefill_wrapper.begin_forward(
|
||||||
self.qo_indptr,
|
self.qo_indptr,
|
||||||
self.batch_size,
|
self.kv_indptr,
|
||||||
|
self.kv_indices,
|
||||||
|
self.kv_last_page_len,
|
||||||
self.model_runner.model_config.num_attention_heads // tp_size,
|
self.model_runner.model_config.num_attention_heads // tp_size,
|
||||||
self.model_runner.model_config.num_key_value_heads // tp_size,
|
self.model_runner.model_config.num_key_value_heads // tp_size,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper()
|
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD")
|
||||||
self.decode_wrapper.begin_forward(
|
self.decode_wrapper.begin_forward(
|
||||||
self.kv_indptr,
|
self.kv_indptr,
|
||||||
|
self.kv_indices,
|
||||||
self.kv_last_page_len,
|
self.kv_last_page_len,
|
||||||
self.batch_size,
|
|
||||||
self.model_runner.model_config.num_attention_heads // tp_size,
|
self.model_runner.model_config.num_attention_heads // tp_size,
|
||||||
self.model_runner.model_config.num_key_value_heads // tp_size,
|
self.model_runner.model_config.num_key_value_heads // tp_size,
|
||||||
self.model_runner.model_config.head_dim,
|
self.model_runner.model_config.head_dim,
|
||||||
|
|||||||
Reference in New Issue
Block a user