Refactor attention backend (#1381)

This commit is contained in:
Lianmin Zheng
2024-09-11 11:44:26 -07:00
committed by GitHub
parent c03cece42f
commit fec185ce0c
16 changed files with 568 additions and 564 deletions

View File

@@ -55,8 +55,8 @@ class TestCreateKvIndices(unittest.TestCase):
paged_kernel_lens,
kv_indptr,
None,
req_to_token.size(1),
kv_indices_triton,
req_to_token.size(1),
)
# Check

View File

@@ -19,7 +19,8 @@ class TestServingThroughput(unittest.TestCase):
other_args = []
if disable_radix_cache:
other_args.append("--disable-radix-cache")
other_args.extend(["--attention-backend", attention_backend])
if attention_backend:
other_args.extend(["--attention-backend", attention_backend])
other_args.extend(["--chunked-prefill-size", str(chunked_prefill_size)])
other_args.extend(["--tensor-parallel-size", "2"])

View File

@@ -19,7 +19,8 @@ class TestServingThroughput(unittest.TestCase):
other_args = []
if disable_radix_cache:
other_args.append("--disable-radix-cache")
other_args.extend(["--attention-backend", attention_backend])
if attention_backend:
other_args.extend(["--attention-backend", attention_backend])
other_args.extend(["--chunked-prefill-size", str(chunked_prefill_size)])
model = DEFAULT_MODEL_NAME_FOR_TEST

View File

@@ -96,23 +96,17 @@ class TestExtendAttention(unittest.TestCase):
v_buffer,
req_to_tokens,
b_req_idx,
b_start_loc,
b_seq_len,
b_seq_len_prefix,
b_start_loc_extend,
b_seq_len_extend,
max_len_in_batch,
b_start_loc_extend,
max_len_extend,
)
redundant_attention(
q_extend,
k_extend,
v_extend,
o_redundant,
k_buffer,
v_buffer,
req_to_tokens,
b_req_idx,
b_start_loc,
b_seq_len,