Clean up the usage of flashinfer (#610)
This commit is contained in:
@@ -31,21 +31,13 @@ class RadixAttention(nn.Module):
|
||||
self.layer_id = layer_id
|
||||
|
||||
if not global_server_args_dict.get("disable_flashinfer", False):
|
||||
self.prefill_forward = self.prefill_forward_flashinfer
|
||||
self.extend_forward = self.prefill_forward_flashinfer
|
||||
self.extend_forward = self.extend_forward_flashinfer
|
||||
self.decode_forward = self.decode_forward_flashinfer
|
||||
# flashinfer now accepts float logit_cap argument
|
||||
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
|
||||
else:
|
||||
self.prefill_forward = self.prefill_forward_triton
|
||||
self.extend_forward = self.extend_forward_triton
|
||||
self.decode_forward = self.decode_forward_triton
|
||||
self.logit_cap = logit_cap if logit_cap is not None else 0
|
||||
|
||||
def prefill_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
||||
# In SGLang, we call both the typical "prefill" and "prefill with cache" as "extend".
|
||||
# See the extend_forward_xxx functions.
|
||||
raise NotImplementedError()
|
||||
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
|
||||
|
||||
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
||||
o = torch.empty_like(q)
|
||||
@@ -86,7 +78,6 @@ class RadixAttention(nn.Module):
|
||||
input_metadata.start_loc,
|
||||
input_metadata.seq_lens,
|
||||
input_metadata.max_seq_len,
|
||||
input_metadata.other_kv_index,
|
||||
input_metadata.total_num_tokens,
|
||||
sm_scale=self.scaling,
|
||||
logit_cap=self.logit_cap,
|
||||
@@ -94,7 +85,7 @@ class RadixAttention(nn.Module):
|
||||
|
||||
return o
|
||||
|
||||
def prefill_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
||||
def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
||||
o1, s1 = input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
|
||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
|
||||
|
||||
@@ -107,7 +107,6 @@ def _fwd_kernel_stage2(
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
stride_req_to_token_b,
|
||||
other_kv_index, # To fix a NAN issue
|
||||
kv_group_num: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
@@ -138,7 +137,7 @@ def _fwd_kernel_stage2(
|
||||
+ cur_batch_req_idx * stride_req_to_token_b
|
||||
+ (start_n + offs_n),
|
||||
mask=(start_n + offs_n) < cur_batch_seq_len,
|
||||
other=other_kv_index,
|
||||
other=0,
|
||||
)
|
||||
|
||||
qk = tl.load(
|
||||
@@ -250,7 +249,6 @@ def _token_softmax_reducev_fwd(
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
other_kv_index,
|
||||
):
|
||||
BLOCK = 64
|
||||
batch, head = b_seq_len.shape[0], logics.shape[0]
|
||||
@@ -277,7 +275,6 @@ def _token_softmax_reducev_fwd(
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
req_to_tokens.stride(0),
|
||||
other_kv_index,
|
||||
)
|
||||
return
|
||||
|
||||
@@ -295,7 +292,6 @@ def _token_softmax_reducev_fwd(
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
req_to_tokens.stride(0),
|
||||
other_kv_index,
|
||||
kv_group_num=kv_group_num,
|
||||
BLOCK_DMODEL=v_buffer.shape[-1],
|
||||
BLOCK_N=BLOCK,
|
||||
@@ -315,7 +311,6 @@ def token_attention_fwd(
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
max_len_in_batch,
|
||||
other_kv_index,
|
||||
total_num_tokens,
|
||||
sm_scale=None,
|
||||
logit_cap=-1,
|
||||
@@ -347,5 +342,4 @@ def token_attention_fwd(
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
other_kv_index,
|
||||
)
|
||||
|
||||
@@ -729,7 +729,6 @@ class InputMetadata:
|
||||
out_cache_cont_start: torch.Tensor = None
|
||||
out_cache_cont_end: torch.Tensor = None
|
||||
|
||||
other_kv_index: torch.Tensor = None
|
||||
return_logprob: bool = False
|
||||
top_logprobs_nums: List[int] = None
|
||||
|
||||
@@ -743,24 +742,19 @@ class InputMetadata:
|
||||
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
||||
|
||||
def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim):
|
||||
if (
|
||||
self.forward_mode == ForwardMode.EXTEND
|
||||
):
|
||||
if self.forward_mode == ForwardMode.DECODE:
|
||||
paged_kernel_lens = self.seq_lens
|
||||
else:
|
||||
paged_kernel_lens = self.prefix_lens
|
||||
self.no_prefix = torch.all(self.prefix_lens == 0)
|
||||
else:
|
||||
paged_kernel_lens = self.seq_lens
|
||||
|
||||
self.kv_indptr = torch.zeros(
|
||||
kv_indptr = torch.zeros(
|
||||
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
self.kv_last_page_len = torch.ones(
|
||||
(self.batch_size,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
||||
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
||||
self.kv_indices = torch.cat(
|
||||
kv_indices = torch.cat(
|
||||
[
|
||||
self.req_to_token_pool.req_to_token[
|
||||
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
||||
@@ -769,18 +763,34 @@ class InputMetadata:
|
||||
],
|
||||
dim=0,
|
||||
).contiguous()
|
||||
kv_last_page_len = torch.ones(
|
||||
(self.batch_size,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
if self.forward_mode == ForwardMode.EXTEND:
|
||||
if self.forward_mode == ForwardMode.DECODE:
|
||||
self.flashinfer_decode_wrapper.end_forward()
|
||||
self.flashinfer_decode_wrapper.begin_forward(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_len,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
1,
|
||||
pos_encoding_mode="NONE",
|
||||
data_type=self.token_to_kv_pool.kv_data[0].dtype,
|
||||
)
|
||||
else:
|
||||
# extend part
|
||||
self.qo_indptr = torch.zeros(
|
||||
qo_indptr = torch.zeros(
|
||||
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
|
||||
qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
|
||||
|
||||
self.flashinfer_prefill_wrapper_ragged.end_forward()
|
||||
self.flashinfer_prefill_wrapper_ragged.begin_forward(
|
||||
self.qo_indptr,
|
||||
self.qo_indptr.clone(),
|
||||
qo_indptr,
|
||||
qo_indptr,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
@@ -789,28 +799,15 @@ class InputMetadata:
|
||||
# cached part
|
||||
self.flashinfer_prefill_wrapper_paged.end_forward()
|
||||
self.flashinfer_prefill_wrapper_paged.begin_forward(
|
||||
self.qo_indptr,
|
||||
self.kv_indptr,
|
||||
self.kv_indices,
|
||||
self.kv_last_page_len,
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_len,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
1,
|
||||
)
|
||||
else:
|
||||
self.flashinfer_decode_wrapper.end_forward()
|
||||
self.flashinfer_decode_wrapper.begin_forward(
|
||||
self.kv_indptr,
|
||||
self.kv_indices,
|
||||
self.kv_last_page_len,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
1,
|
||||
pos_encoding_mode="NONE",
|
||||
data_type=self.token_to_kv_pool.kv_data[0].dtype,
|
||||
)
|
||||
|
||||
def init_extend_args(self):
|
||||
self.extend_seq_lens = self.seq_lens - self.prefix_lens
|
||||
@@ -822,7 +819,6 @@ class InputMetadata:
|
||||
def create(
|
||||
cls,
|
||||
model_runner,
|
||||
tp_size,
|
||||
forward_mode,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
@@ -833,9 +829,6 @@ class InputMetadata:
|
||||
out_cache_cont_end=None,
|
||||
top_logprobs_nums=None,
|
||||
return_logprob=False,
|
||||
flashinfer_prefill_wrapper_ragged=None,
|
||||
flashinfer_prefill_wrapper_paged=None,
|
||||
flashinfer_decode_wrapper=None,
|
||||
):
|
||||
batch_size = len(req_pool_indices)
|
||||
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
||||
@@ -845,9 +838,6 @@ class InputMetadata:
|
||||
|
||||
if forward_mode == ForwardMode.DECODE:
|
||||
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
|
||||
other_kv_index = model_runner.req_to_token_pool.req_to_token[
|
||||
req_pool_indices[0], seq_lens[0] - 1
|
||||
].item()
|
||||
else:
|
||||
seq_lens_cpu = seq_lens.cpu().numpy()
|
||||
prefix_lens_cpu = prefix_lens.cpu().numpy()
|
||||
@@ -865,7 +855,6 @@ class InputMetadata:
|
||||
),
|
||||
device="cuda",
|
||||
)
|
||||
other_kv_index = None
|
||||
|
||||
ret = cls(
|
||||
forward_mode=forward_mode,
|
||||
@@ -882,12 +871,11 @@ class InputMetadata:
|
||||
out_cache_loc=out_cache_loc,
|
||||
out_cache_cont_start=out_cache_cont_start,
|
||||
out_cache_cont_end=out_cache_cont_end,
|
||||
other_kv_index=other_kv_index,
|
||||
return_logprob=return_logprob,
|
||||
top_logprobs_nums=top_logprobs_nums,
|
||||
flashinfer_prefill_wrapper_ragged=flashinfer_prefill_wrapper_ragged,
|
||||
flashinfer_prefill_wrapper_paged=flashinfer_prefill_wrapper_paged,
|
||||
flashinfer_decode_wrapper=flashinfer_decode_wrapper,
|
||||
flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
|
||||
flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
|
||||
flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
|
||||
)
|
||||
|
||||
if forward_mode == ForwardMode.EXTEND:
|
||||
@@ -895,8 +883,8 @@ class InputMetadata:
|
||||
|
||||
if not global_server_args_dict.get("disable_flashinfer", False):
|
||||
ret.init_flashinfer_args(
|
||||
model_runner.model_config.num_attention_heads // tp_size,
|
||||
model_runner.model_config.get_num_kv_heads(tp_size),
|
||||
model_runner.model_config.num_attention_heads // model_runner.tp_size,
|
||||
model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
|
||||
model_runner.model_config.head_dim,
|
||||
)
|
||||
|
||||
|
||||
@@ -221,7 +221,6 @@ class ModelRunner:
|
||||
input_metadata = InputMetadata.create(
|
||||
self,
|
||||
forward_mode=ForwardMode.EXTEND,
|
||||
tp_size=self.tp_size,
|
||||
req_pool_indices=batch.req_pool_indices,
|
||||
seq_lens=batch.seq_lens,
|
||||
prefix_lens=batch.prefix_lens,
|
||||
@@ -229,9 +228,6 @@ class ModelRunner:
|
||||
out_cache_loc=batch.out_cache_loc,
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
return_logprob=batch.return_logprob,
|
||||
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
|
||||
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
|
||||
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
||||
)
|
||||
return self.model.forward(
|
||||
batch.input_ids, input_metadata.positions, input_metadata
|
||||
@@ -242,7 +238,6 @@ class ModelRunner:
|
||||
input_metadata = InputMetadata.create(
|
||||
self,
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
tp_size=self.tp_size,
|
||||
req_pool_indices=batch.req_pool_indices,
|
||||
seq_lens=batch.seq_lens,
|
||||
prefix_lens=batch.prefix_lens,
|
||||
@@ -252,9 +247,6 @@ class ModelRunner:
|
||||
out_cache_cont_end=batch.out_cache_cont_end,
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
return_logprob=batch.return_logprob,
|
||||
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
|
||||
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
|
||||
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
||||
)
|
||||
return self.model.forward(
|
||||
batch.input_ids, input_metadata.positions, input_metadata
|
||||
|
||||
@@ -53,6 +53,7 @@ class ServerArgs:
|
||||
disable_flashinfer: bool = False
|
||||
disable_radix_cache: bool = False
|
||||
disable_regex_jump_forward: bool = False
|
||||
disable_cuda_graph: bool = False
|
||||
disable_disk_cache: bool = False
|
||||
attention_reduce_in_fp32: bool = False
|
||||
enable_p2p_check: bool = False
|
||||
@@ -294,6 +295,11 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Disable regex jump-forward",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-cuda-graph",
|
||||
action="store_true",
|
||||
help="Disable cuda graph.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-disk-cache",
|
||||
action="store_true",
|
||||
|
||||
Reference in New Issue
Block a user