Fix flashinfer (#700)
This commit is contained in:
@@ -85,7 +85,19 @@ class RadixAttention(nn.Module):
|
|||||||
return o
|
return o
|
||||||
|
|
||||||
def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
||||||
o1, s1 = input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
|
self.store_kv_cache(k, v, input_metadata)
|
||||||
|
|
||||||
|
if input_metadata.total_num_tokens <= 4096:
|
||||||
|
o = input_metadata.flashinfer_prefill_wrapper_paged.forward(
|
||||||
|
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||||
|
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
||||||
|
causal=True,
|
||||||
|
sm_scale=self.scaling,
|
||||||
|
logits_soft_cap=self.logit_cap,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
o1, s1 = (
|
||||||
|
input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
|
||||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||||
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
|
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
|
||||||
v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
|
v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
|
||||||
@@ -93,22 +105,23 @@ class RadixAttention(nn.Module):
|
|||||||
sm_scale=self.scaling,
|
sm_scale=self.scaling,
|
||||||
logits_soft_cap=self.logit_cap,
|
logits_soft_cap=self.logit_cap,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if input_metadata.extend_no_prefix:
|
if input_metadata.extend_no_prefix:
|
||||||
o = o1
|
o = o1
|
||||||
else:
|
else:
|
||||||
o2, s2 = input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
|
o2, s2 = (
|
||||||
|
input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
|
||||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||||
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
||||||
causal=False,
|
causal=False,
|
||||||
sm_scale=self.scaling,
|
sm_scale=self.scaling,
|
||||||
logits_soft_cap=self.logit_cap,
|
logits_soft_cap=self.logit_cap,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
o, _ = merge_state(o1, s1, o2, s2)
|
o, _ = merge_state(o1, s1, o2, s2)
|
||||||
|
|
||||||
self.store_kv_cache(k, v, input_metadata)
|
|
||||||
|
|
||||||
if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
|
if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
|||||||
@@ -829,8 +829,9 @@ def init_flashinfer_args(
|
|||||||
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
|
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
|
||||||
head_dim = model_runner.model_config.head_dim
|
head_dim = model_runner.model_config.head_dim
|
||||||
batch_size = len(req_pool_indices)
|
batch_size = len(req_pool_indices)
|
||||||
|
total_num_tokens = int(torch.sum(seq_lens))
|
||||||
|
|
||||||
if forward_mode == ForwardMode.DECODE:
|
if forward_mode == ForwardMode.DECODE or total_num_tokens <= 4096:
|
||||||
paged_kernel_lens = seq_lens
|
paged_kernel_lens = seq_lens
|
||||||
else:
|
else:
|
||||||
paged_kernel_lens = prefix_lens
|
paged_kernel_lens = prefix_lens
|
||||||
|
|||||||
Reference in New Issue
Block a user