Fix flashinfer version (#576)
This commit is contained in:
@@ -30,12 +30,8 @@ class RadixAttention(nn.Module):
|
||||
self.prefill_forward = self.prefill_forward_flashinfer
|
||||
self.extend_forward = self.prefill_forward_flashinfer
|
||||
self.decode_forward = self.decode_forward_flashinfer
|
||||
# flashinfer only accepts a boolean logit_cap argument
|
||||
if logit_cap > 0:
|
||||
assert logit_cap == 30
|
||||
self.logit_cap = True
|
||||
else:
|
||||
self.logit_cap = False
|
||||
# flashinfer now accepts float logit_cap argument
|
||||
self.logit_cap = logit_cap if logit_cap > 0 else 0
|
||||
else:
|
||||
self.prefill_forward = self.prefill_forward_triton
|
||||
self.extend_forward = self.extend_forward_triton
|
||||
@@ -110,7 +106,7 @@ class RadixAttention(nn.Module):
|
||||
o = input_metadata.flashinfer_prefill_wrapper.forward(
|
||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
|
||||
logits_cap=self.logit_cap,
|
||||
logits_soft_cap=self.logit_cap,
|
||||
)
|
||||
|
||||
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
||||
@@ -121,7 +117,7 @@ class RadixAttention(nn.Module):
|
||||
o = input_metadata.flashinfer_decode_wrapper.forward(
|
||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
|
||||
logits_cap=self.logit_cap,
|
||||
logits_soft_cap=self.logit_cap,
|
||||
)
|
||||
|
||||
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
||||
|
||||
Reference in New Issue
Block a user