Update flashinfer to 0.0.5 (#554)

This commit is contained in:
Lianmin Zheng
2024-06-20 20:29:06 -07:00
committed by GitHub
parent 09593e9bc9
commit b7e2f800ac
3 changed files with 76 additions and 46 deletions

View File

@@ -12,7 +12,8 @@ from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetada
class RadixAttention(nn.Module):
def __init__(
self, num_heads, head_dim, scaling, num_kv_heads, layer_id, logit_cap=-1
self, num_heads: int, head_dim: int, scaling: float, num_kv_heads: int,
layer_id: int, logit_cap: int = -1
):
super().__init__()
self.tp_q_head_num = num_heads
@@ -20,7 +21,6 @@ class RadixAttention(nn.Module):
self.tp_v_head_num = num_kv_heads
self.head_dim = head_dim
self.layer_id = layer_id
self.logit_cap = logit_cap
assert np.allclose(scaling, 1.0 / (head_dim**0.5))
@@ -30,10 +30,17 @@ class RadixAttention(nn.Module):
self.prefill_forward = self.prefill_forward_flashinfer
self.extend_forward = self.prefill_forward_flashinfer
self.decode_forward = self.decode_forward_flashinfer
# flashinfer only accepts a boolean logit_cap argument
if logit_cap > 0:
assert logit_cap == 30
self.logit_cap = True
else:
self.logit_cap = False
else:
self.prefill_forward = self.prefill_forward_triton
self.extend_forward = self.extend_forward_triton
self.decode_forward = self.decode_forward_triton
self.logit_cap = logit_cap
def prefill_forward_triton(self, q, k, v, input_metadata: InputMetadata):
o = torch.empty_like(q)
@@ -100,9 +107,10 @@ class RadixAttention(nn.Module):
def prefill_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
self.store_kv_cache(k, v, input_metadata)
o = input_metadata.prefill_wrapper.forward(
o = input_metadata.flashinfer_prefill_wrapper.forward(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
logits_cap=self.logit_cap,
)
return o.view(-1, self.tp_q_head_num * self.head_dim)
@@ -110,9 +118,10 @@ class RadixAttention(nn.Module):
def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
self.store_kv_cache(k, v, input_metadata)
o = input_metadata.decode_wrapper.forward(
o = input_metadata.flashinfer_decode_wrapper.forward(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
logits_cap=self.logit_cap,
)
return o.view(-1, self.tp_q_head_num * self.head_dim)