Cleanup attention backend: flashinfer and triton (#611)
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
"""Radix attention."""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from flashinfer.cascade import merge_state
|
||||
from torch import nn
|
||||
@@ -51,13 +50,13 @@ class RadixAttention(nn.Module):
|
||||
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
|
||||
input_metadata.req_to_token_pool.req_to_token,
|
||||
input_metadata.req_pool_indices,
|
||||
input_metadata.start_loc,
|
||||
input_metadata.triton_start_loc,
|
||||
input_metadata.seq_lens,
|
||||
input_metadata.prefix_lens,
|
||||
input_metadata.triton_prefix_lens,
|
||||
input_metadata.extend_start_loc,
|
||||
input_metadata.extend_seq_lens,
|
||||
input_metadata.max_seq_len,
|
||||
input_metadata.max_extend_len,
|
||||
input_metadata.triton_max_seq_len,
|
||||
input_metadata.triton_max_extend_len,
|
||||
sm_scale=self.scaling,
|
||||
logit_cap=self.logit_cap,
|
||||
)
|
||||
@@ -75,9 +74,9 @@ class RadixAttention(nn.Module):
|
||||
o.view(-1, self.tp_q_head_num, self.head_dim),
|
||||
input_metadata.req_to_token_pool.req_to_token,
|
||||
input_metadata.req_pool_indices,
|
||||
input_metadata.start_loc,
|
||||
input_metadata.triton_start_loc,
|
||||
input_metadata.seq_lens,
|
||||
input_metadata.max_seq_len,
|
||||
input_metadata.triton_max_seq_len,
|
||||
input_metadata.total_num_tokens,
|
||||
sm_scale=self.scaling,
|
||||
logit_cap=self.logit_cap,
|
||||
@@ -95,7 +94,7 @@ class RadixAttention(nn.Module):
|
||||
logits_soft_cap=self.logit_cap,
|
||||
)
|
||||
|
||||
if input_metadata.no_prefix:
|
||||
if input_metadata.extend_no_prefix:
|
||||
o = o1
|
||||
else:
|
||||
o2, s2 = input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
|
||||
|
||||
Reference in New Issue
Block a user