Simplify flashinfer dispatch (#1552)
This commit is contained in:
@@ -32,9 +32,10 @@ class RadixAttention(nn.Module):
|
||||
scaling: float,
|
||||
num_kv_heads: int,
|
||||
layer_id: int,
|
||||
sliding_window_size: int = -1,
|
||||
logit_cap: float = 0.0,
|
||||
v_head_dim: int = -1,
|
||||
sliding_window_size: int = -1,
|
||||
is_cross_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_q_head_num = num_heads
|
||||
@@ -47,6 +48,7 @@ class RadixAttention(nn.Module):
|
||||
self.layer_id = layer_id
|
||||
self.logit_cap = logit_cap
|
||||
self.sliding_window_size = sliding_window_size or -1
|
||||
self.is_cross_attention = is_cross_attention
|
||||
|
||||
def forward(self, q, k, v, forward_batch: ForwardBatch):
|
||||
if k is not None:
|
||||
|
||||
Reference in New Issue
Block a user