Update test_flashinfer (#560)
This commit is contained in:
@@ -1,18 +1,24 @@
|
|||||||
import flashinfer
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
from flashinfer import (
|
||||||
|
BatchDecodeWithPagedKVCacheWrapper,
|
||||||
|
BatchPrefillWithPagedKVCacheWrapper,
|
||||||
|
)
|
||||||
|
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
||||||
|
|
||||||
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
from sglang.srt.layers.extend_attention import extend_attention_fwd, redundant_attention
|
||||||
from sglang.srt.layers.token_attention import token_attention_fwd
|
from sglang.srt.layers.token_attention import token_attention_fwd
|
||||||
|
|
||||||
|
flashinfer_prefill_wrapper = None
|
||||||
|
flashinfer_decode_wrapper = None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("batch_size", [12, 37, 67])
|
@pytest.mark.parametrize("batch_size", [12, 37, 67])
|
||||||
@pytest.mark.parametrize("kv_len", [54, 97])
|
@pytest.mark.parametrize("kv_len", [54, 97])
|
||||||
@pytest.mark.parametrize("qo_len", [37, 17])
|
@pytest.mark.parametrize("qo_len", [37, 17])
|
||||||
@pytest.mark.parametrize("num_kv_heads", [4])
|
@pytest.mark.parametrize("num_kv_heads", [4])
|
||||||
@pytest.mark.parametrize("num_qo_heads", [4, 32])
|
@pytest.mark.parametrize("num_qo_heads", [32, 4])
|
||||||
@pytest.mark.parametrize("head_dim", [128])
|
@pytest.mark.parametrize("head_dim", [128])
|
||||||
@pytest.mark.parametrize("use_wrapper", [True, False])
|
|
||||||
def test_batch_prefill_with_paged_kv_cache(
|
def test_batch_prefill_with_paged_kv_cache(
|
||||||
batch_size,
|
batch_size,
|
||||||
kv_len,
|
kv_len,
|
||||||
@@ -20,12 +26,13 @@ def test_batch_prefill_with_paged_kv_cache(
|
|||||||
num_kv_heads,
|
num_kv_heads,
|
||||||
num_qo_heads,
|
num_qo_heads,
|
||||||
head_dim,
|
head_dim,
|
||||||
use_wrapper,
|
|
||||||
):
|
):
|
||||||
|
init_flashinfer(num_qo_heads, num_kv_heads)
|
||||||
|
|
||||||
q = torch.randn(batch_size * qo_len, num_qo_heads, head_dim).to(0).half()
|
q = torch.randn(batch_size * qo_len, num_qo_heads, head_dim).to(0).half()
|
||||||
q_indptr = torch.arange(0, batch_size + 1).to(0).int() * qo_len
|
qo_indptr = torch.arange(0, batch_size + 1).to(0).int() * qo_len
|
||||||
total_tokens = kv_len * batch_size
|
total_tokens = kv_len * batch_size
|
||||||
kv_data = torch.randn(total_tokens, 2, num_kv_heads, 1, head_dim).to(0).half()
|
kv_data = torch.randn(total_tokens, 2, num_kv_heads, head_dim).to(0).half()
|
||||||
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len
|
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len
|
||||||
kv_indices = torch.arange(0, total_tokens).to(0).int()
|
kv_indices = torch.arange(0, total_tokens).to(0).int()
|
||||||
kv_last_page_len = torch.full((batch_size,), 1, dtype=torch.int32).to(0)
|
kv_last_page_len = torch.full((batch_size,), 1, dtype=torch.int32).to(0)
|
||||||
@@ -70,21 +77,44 @@ def test_batch_prefill_with_paged_kv_cache(
|
|||||||
max_len_extend,
|
max_len_extend,
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_wrapper:
|
o_redundant = torch.empty_like(q)
|
||||||
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper()
|
b_start_loc = torch.zeros((batch_size,), dtype=torch.int32).to(0)
|
||||||
wrapper.begin_forward(q_indptr, batch_size, num_qo_heads, num_kv_heads)
|
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], dim=0)
|
||||||
o = wrapper.forward(
|
b_seq_len_prefix = b_seq_len - b_seq_len_extend
|
||||||
q, q_indptr, kv_data, kv_indptr, kv_indices, kv_last_page_len
|
|
||||||
)
|
redundant_attention(
|
||||||
else:
|
q,
|
||||||
o = flashinfer.batch_prefill_with_paged_kv_cache(
|
k_extend,
|
||||||
q,
|
v_extend,
|
||||||
q_indptr,
|
o_redundant,
|
||||||
kv_data,
|
k_buffer,
|
||||||
kv_indptr,
|
v_buffer,
|
||||||
kv_indices,
|
req_to_token,
|
||||||
kv_last_page_len,
|
b_req_idx,
|
||||||
)
|
b_start_loc,
|
||||||
|
b_seq_len,
|
||||||
|
b_seq_len_prefix,
|
||||||
|
max_len_in_batch,
|
||||||
|
)
|
||||||
|
print("Mean: ", torch.mean(torch.abs(o_redundant - o_triton)))
|
||||||
|
print("Max: ", torch.max(torch.abs(o_redundant - o_triton)))
|
||||||
|
assert torch.allclose(o_redundant, o_triton, rtol=1e-2, atol=1e-3)
|
||||||
|
|
||||||
|
flashinfer_prefill_wrapper.end_forward()
|
||||||
|
|
||||||
|
flashinfer_prefill_wrapper.begin_forward(
|
||||||
|
qo_indptr,
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
kv_last_page_len,
|
||||||
|
num_qo_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_dim,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
o = flashinfer_prefill_wrapper.forward(
|
||||||
|
q.contiguous().view(-1, num_qo_heads, head_dim), kv_data
|
||||||
|
)
|
||||||
|
|
||||||
print("Mean: ", torch.mean(torch.abs(o - o_triton)))
|
print("Mean: ", torch.mean(torch.abs(o - o_triton)))
|
||||||
print("Max: ", torch.max(torch.abs(o - o_triton)))
|
print("Max: ", torch.max(torch.abs(o - o_triton)))
|
||||||
@@ -105,10 +135,11 @@ def test_batch_decode_with_paged_kv_cache(
|
|||||||
):
|
):
|
||||||
# note(lsyin): when pytest, the number of heads cannot change, because triton kernel has a cache
|
# note(lsyin): when pytest, the number of heads cannot change, because triton kernel has a cache
|
||||||
# to test different shape of decode, change the parameters in the __main__, and run decode only once
|
# to test different shape of decode, change the parameters in the __main__, and run decode only once
|
||||||
|
init_flashinfer(num_qo_heads, num_kv_heads)
|
||||||
|
|
||||||
q = torch.randn(batch_size, num_qo_heads, head_dim).to(0).half()
|
q = torch.randn(batch_size, num_qo_heads, head_dim).to(0).half()
|
||||||
total_tokens = kv_len * batch_size
|
total_tokens = kv_len * batch_size
|
||||||
kv_data = torch.randn(total_tokens, 2, num_kv_heads, 1, head_dim).to(0).half()
|
kv_data = torch.randn(total_tokens, 2, num_kv_heads, head_dim).to(0).half()
|
||||||
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len
|
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len
|
||||||
kv_indices = torch.arange(0, total_tokens).to(0).int()
|
kv_indices = torch.arange(0, total_tokens).to(0).int()
|
||||||
kv_last_page_len = torch.full((batch_size,), 1, dtype=torch.int32).to(0)
|
kv_last_page_len = torch.full((batch_size,), 1, dtype=torch.int32).to(0)
|
||||||
@@ -139,26 +170,46 @@ def test_batch_decode_with_paged_kv_cache(
|
|||||||
total_tokens,
|
total_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper()
|
flashinfer_decode_wrapper.end_forward()
|
||||||
wrapper.begin_forward(
|
flashinfer_decode_wrapper.begin_forward(
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
kv_last_page_len,
|
kv_last_page_len,
|
||||||
batch_size,
|
|
||||||
num_qo_heads,
|
num_qo_heads,
|
||||||
num_kv_heads,
|
num_kv_heads,
|
||||||
head_dim,
|
head_dim,
|
||||||
1,
|
1,
|
||||||
"NONE",
|
pos_encoding_mode="NONE",
|
||||||
"float16",
|
data_type="float16",
|
||||||
|
)
|
||||||
|
o = flashinfer_decode_wrapper.forward(
|
||||||
|
q.contiguous().view(-1, num_qo_heads, head_dim), kv_data
|
||||||
)
|
)
|
||||||
o = wrapper.forward(q, kv_data, kv_indptr, kv_indices, kv_last_page_len)
|
|
||||||
|
|
||||||
print("Mean: ", torch.mean(torch.abs(o - o_triton)))
|
print("Mean: ", torch.mean(torch.abs(o - o_triton)))
|
||||||
print("Max: ", torch.max(torch.abs(o - o_triton)))
|
print("Max: ", torch.max(torch.abs(o - o_triton)))
|
||||||
assert torch.allclose(o, o_triton, rtol=1e-2, atol=2e-3)
|
assert torch.allclose(o, o_triton, rtol=1e-2, atol=2e-3)
|
||||||
|
|
||||||
|
|
||||||
|
def init_flashinfer(num_attention_heads, num_kv_heads):
|
||||||
|
if not _grouped_size_compiled_for_decode_kernels(num_attention_heads, num_kv_heads):
|
||||||
|
use_tensor_cores = True
|
||||||
|
else:
|
||||||
|
use_tensor_cores = False
|
||||||
|
|
||||||
|
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda")
|
||||||
|
|
||||||
|
global flashinfer_prefill_wrapper, flashinfer_decode_wrapper
|
||||||
|
|
||||||
|
flashinfer_prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
|
||||||
|
workspace_buffer, "NHD"
|
||||||
|
)
|
||||||
|
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||||
|
workspace_buffer, "NHD", use_tensor_cores=use_tensor_cores
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_batch_prefill_with_paged_kv_cache(12, 54, 37, 8, 8, 128, False)
|
test_batch_prefill_with_paged_kv_cache(12, 54, 37, 8, 8, 128)
|
||||||
test_batch_prefill_with_paged_kv_cache(37, 1111, 456, 32, 32, 128, True)
|
test_batch_prefill_with_paged_kv_cache(37, 1111, 456, 32, 32, 128)
|
||||||
test_batch_decode_with_paged_kv_cache(12, 54, 4, 32, 128)
|
test_batch_decode_with_paged_kv_cache(12, 54, 4, 32, 128)
|
||||||
|
|||||||
Reference in New Issue
Block a user