Refactor attention backend (#1381)
This commit is contained in:
@@ -55,8 +55,8 @@ class TestCreateKvIndices(unittest.TestCase):
|
||||
paged_kernel_lens,
|
||||
kv_indptr,
|
||||
None,
|
||||
req_to_token.size(1),
|
||||
kv_indices_triton,
|
||||
req_to_token.size(1),
|
||||
)
|
||||
|
||||
# Check
|
||||
|
||||
@@ -19,7 +19,8 @@ class TestServingThroughput(unittest.TestCase):
|
||||
other_args = []
|
||||
if disable_radix_cache:
|
||||
other_args.append("--disable-radix-cache")
|
||||
other_args.extend(["--attention-backend", attention_backend])
|
||||
if attention_backend:
|
||||
other_args.extend(["--attention-backend", attention_backend])
|
||||
other_args.extend(["--chunked-prefill-size", str(chunked_prefill_size)])
|
||||
other_args.extend(["--tensor-parallel-size", "2"])
|
||||
|
||||
|
||||
@@ -19,7 +19,8 @@ class TestServingThroughput(unittest.TestCase):
|
||||
other_args = []
|
||||
if disable_radix_cache:
|
||||
other_args.append("--disable-radix-cache")
|
||||
other_args.extend(["--attention-backend", attention_backend])
|
||||
if attention_backend:
|
||||
other_args.extend(["--attention-backend", attention_backend])
|
||||
other_args.extend(["--chunked-prefill-size", str(chunked_prefill_size)])
|
||||
|
||||
model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||
|
||||
@@ -96,23 +96,17 @@ class TestExtendAttention(unittest.TestCase):
|
||||
v_buffer,
|
||||
req_to_tokens,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
b_seq_len_prefix,
|
||||
b_start_loc_extend,
|
||||
b_seq_len_extend,
|
||||
max_len_in_batch,
|
||||
b_start_loc_extend,
|
||||
max_len_extend,
|
||||
)
|
||||
|
||||
redundant_attention(
|
||||
q_extend,
|
||||
k_extend,
|
||||
v_extend,
|
||||
o_redundant,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_tokens,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
|
||||
Reference in New Issue
Block a user