[perf][sgl-kernel] extend cutlass_mla_decode to support num_head < 128 (#6929)
This commit is contained in:
@@ -40,15 +40,23 @@ def ref_mla(
|
||||
@pytest.mark.parametrize("bs", [1, 2, 4])
|
||||
@pytest.mark.parametrize("varlen", [False, True])
|
||||
@pytest.mark.parametrize("block_size", [1, 16, 64, 128])
|
||||
@pytest.mark.parametrize("num_heads", [16, 32, 64, 128])
|
||||
@pytest.mark.parametrize("num_kv_splits", [-1, 1])
|
||||
def test_cutlass_mla_decode(
|
||||
dtype: torch.dtype, mean_seq_len: int, bs: int, varlen: bool, block_size: int
|
||||
dtype: torch.dtype,
|
||||
mean_seq_len: int,
|
||||
bs: int,
|
||||
varlen: bool,
|
||||
block_size: int,
|
||||
num_heads: int,
|
||||
num_kv_splits: int,
|
||||
):
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.set_default_device("cuda")
|
||||
torch.manual_seed(42)
|
||||
|
||||
d = 576
|
||||
h_q = 128
|
||||
h_q = num_heads
|
||||
dv = 512
|
||||
|
||||
q_nope_dim = 128
|
||||
@@ -67,17 +75,22 @@ def test_cutlass_mla_decode(
|
||||
pack_factor = 128 // block_size
|
||||
block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor
|
||||
|
||||
# Lager q values to detect split kv error
|
||||
q = torch.randn(bs, h_q, d) * 100.0
|
||||
block_table = torch.randint(0, bs * block_num, (bs, block_num), dtype=torch.int32)
|
||||
|
||||
kv_cache = torch.randn(block_table.numel(), block_size, d)
|
||||
|
||||
workspace_size = cutlass_mla_get_workspace_size(block_num * block_size, bs)
|
||||
workspace_size = cutlass_mla_get_workspace_size(
|
||||
block_num * block_size, bs, num_kv_splits=num_kv_splits
|
||||
)
|
||||
workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8)
|
||||
|
||||
out_ref = q.new_zeros(bs, h_q, dv)
|
||||
ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens)
|
||||
out = cutlass_mla_decode(q, kv_cache, seq_lens, block_table, workspace)
|
||||
out = cutlass_mla_decode(
|
||||
q, kv_cache, seq_lens, block_table, workspace, num_kv_splits
|
||||
)
|
||||
|
||||
torch.testing.assert_close(out, out_ref, atol=1e-2, rtol=1e-2)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user