Cleanup attention backend: flashinfer and triton (#611)

This commit is contained in:
Lianmin Zheng
2024-07-12 18:21:11 -07:00
committed by GitHub
parent af4e7910e7
commit 396a69240f
4 changed files with 180 additions and 161 deletions

View File

@@ -1,6 +1,5 @@
"""Radix attention."""
import numpy as np
import torch
from flashinfer.cascade import merge_state
from torch import nn
@@ -51,13 +50,13 @@ class RadixAttention(nn.Module):
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
input_metadata.req_to_token_pool.req_to_token,
input_metadata.req_pool_indices,
input_metadata.start_loc,
input_metadata.triton_start_loc,
input_metadata.seq_lens,
input_metadata.prefix_lens,
input_metadata.triton_prefix_lens,
input_metadata.extend_start_loc,
input_metadata.extend_seq_lens,
input_metadata.max_seq_len,
input_metadata.max_extend_len,
input_metadata.triton_max_seq_len,
input_metadata.triton_max_extend_len,
sm_scale=self.scaling,
logit_cap=self.logit_cap,
)
@@ -75,9 +74,9 @@ class RadixAttention(nn.Module):
o.view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.req_to_token_pool.req_to_token,
input_metadata.req_pool_indices,
input_metadata.start_loc,
input_metadata.triton_start_loc,
input_metadata.seq_lens,
input_metadata.max_seq_len,
input_metadata.triton_max_seq_len,
input_metadata.total_num_tokens,
sm_scale=self.scaling,
logit_cap=self.logit_cap,
@@ -95,7 +94,7 @@ class RadixAttention(nn.Module):
logits_soft_cap=self.logit_cap,
)
if input_metadata.no_prefix:
if input_metadata.extend_no_prefix:
o = o1
else:
o2, s2 = input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(