Files
sglang/sgl-kernel/tests/test_flashmla.py
2025-10-21 21:17:50 -07:00

519 lines
18 KiB
Python

import math
import random
from typing import Optional, Tuple
import pytest
import torch
import triton
from sgl_kernel.flash_mla import (
flash_mla_sparse_fwd,
flash_mla_with_kvcache,
get_mla_metadata,
)
skip_condition = torch.cuda.get_device_capability() < (10, 0)
# ================ prefill usage ================ #
S_Q_PREFILL = [1, 62]
KV_TOPK_PREFILL = [
# Regular shapes
(128, 128),
(256, 256),
(512, 512),
# Irregular shapes
(592, 128),
(1840, 256),
(1592, 384),
(1521, 512),
# Irregular shapes with OOB TopK
(95, 128),
(153, 256),
(114, 384),
]
# ================= decode usage ================= #
B_DECODE = [1, 2, 6, 64]
S_Q_DECODE = [1, 2, 4]
S_K_DECODE = [20, 140, 4096]
IS_VARLEN = [False, True]
CAUSAL_TOPK = [(True, None), (False, None), (False, 128), (False, 2048)]
DTYPE = [torch.float16, torch.bfloat16]
def quantize_k_cache(
input_k_cache: torch.Tensor, # (num_blocks, block_size, h_k, d)
dv: int,
tile_size: int = 128,
) -> torch.Tensor:
"""
Quantize the k-cache
Return a tensor with shape (num_blocks, block_size, h_k, dv + 4(dv/tile_size) + t(d-dv)) of dtype uint8_t, where t = input_k_cache.element_size()
For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py or README.md
"""
assert dv % tile_size == 0
num_tiles = dv // tile_size
num_blocks, block_size, h_k, d = input_k_cache.shape
assert h_k == 1
input_k_cache = input_k_cache.squeeze(2) # [num_blocks, block_size, d]
input_elem_size = input_k_cache.element_size()
result = torch.empty(
(num_blocks, block_size, dv + num_tiles * 4 + input_elem_size * (d - dv)),
dtype=torch.float8_e4m3fn,
device=input_k_cache.device,
)
result_k_nope_part = result[..., :dv]
result_k_scale_factor = result[..., dv : dv + num_tiles * 4].view(torch.float32)
result_k_rope_part = result[..., dv + num_tiles * 4 :].view(input_k_cache.dtype)
result_k_rope_part[:] = input_k_cache[..., dv:]
for tile_idx in range(0, num_tiles):
cur_scale_factors_inv = (
torch.abs(
input_k_cache[..., tile_idx * tile_size : (tile_idx + 1) * tile_size]
)
.max(dim=-1)
.values
/ 448.0
) # [num_blocks, block_size]
result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv
cur_scale_factors_inv.unsqueeze_(-1) # [num_blocks, block_size, 1]
cur_quantized_nope = (
input_k_cache[
..., tile_idx * tile_size : (tile_idx + 1) * tile_size
].float()
/ cur_scale_factors_inv.float()
).to(torch.float8_e4m3fn)
result_k_nope_part[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = (
cur_quantized_nope
)
result = result.view(num_blocks, block_size, 1, -1)
return result
def dequantize_k_cache(
quant_k_cache: torch.Tensor, # (num_blocks, block_size, 1, bytes_per_token)
dv: int = 512,
tile_size: int = 128,
d: int = 576,
) -> torch.Tensor:
"""
De-quantize the k-cache
"""
assert dv % tile_size == 0
num_tiles = dv // tile_size
num_blocks, block_size, h_k, _ = quant_k_cache.shape
assert h_k == 1
result = torch.empty(
(num_blocks, block_size, d), dtype=torch.bfloat16, device=quant_k_cache.device
)
quant_k_cache = quant_k_cache.view(num_blocks, block_size, -1)
input_nope = quant_k_cache[..., :dv]
input_scale = quant_k_cache[..., dv : dv + num_tiles * 4].view(torch.float32)
input_rope = quant_k_cache[..., dv + num_tiles * 4 :].view(torch.bfloat16)
result[..., dv:] = input_rope
for tile_idx in range(0, num_tiles):
cur_nope = input_nope[
..., tile_idx * tile_size : (tile_idx + 1) * tile_size
].to(torch.float32)
cur_scales = input_scale[..., tile_idx].unsqueeze(-1)
result[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = (
cur_nope * cur_scales
)
result = result.view(num_blocks, block_size, 1, d)
return result
def cdiv(x: int, y: int):
return (x + y - 1) // y
def get_window_size(causal, window):
if window > 0:
window_size = (window - 1, 0) if causal else (window - 1, window - 1)
else:
window_size = (-1, -1)
return window_size
def get_attn_bias(s_q, s_k, causal, window):
attn_bias = torch.zeros(s_q, s_k, dtype=torch.float32, device="cuda")
if causal:
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device="cuda").tril(
diagonal=s_k - s_q
)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
if window > 0:
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device="cuda").tril(
diagonal=s_k - s_q - window
)
attn_bias.masked_fill_(temp_mask, float("-inf"))
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device="cuda").tril(
diagonal=s_k - s_q + window - 1
)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
return attn_bias
def sdpa(query, key, value, attn_bias, softmax_scale=None):
query = query.float().transpose(-3, -2)
key = key.float().transpose(-3, -2)
value = value.float().transpose(-3, -2)
key = key.repeat_interleave(h // h_k, dim=-3)
value = value.repeat_interleave(h // h_k, dim=-3)
if softmax_scale is None:
softmax_scale = query.shape[-1] ** (-0.5)
attn_weight = (query @ key.transpose(-2, -1)) * softmax_scale
attn_weight += attn_bias
lse = attn_weight.logsumexp(dim=-1)
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
return attn_weight.to(query.dtype) @ value, lse
def sdpa_checkpoint(*args, **kwargs):
return checkpoint(sdpa, *args, use_reentrant=False, **kwargs)
def reference_torch_prefill(
s_q, s_kv, topk, indices, q, kv, sm_scale: float
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def log2sumexp2(a: torch.Tensor, dim: int) -> torch.Tensor:
return torch.logsumexp(a * math.log(2), dim=dim) * math.log2(math.e)
indices = indices[0, :, 0, :] # [s_q, topk]
invalid_indices_mask = (indices < 0) | (indices >= s_kv)
qs = q[0, :, :, :].float() # [s_q, h_q, d_qk]
kvs = kv[0, :, 0, :].float() # [s_kv, d_qk]
kvs = torch.index_select(
kvs, 0, indices.masked_fill(invalid_indices_mask, 0).flatten()
).view(
s_q, topk, 576
) # [s_q, topk, d_qk]
attn_score = qs @ kvs.transpose(1, 2) # [s_q, h_q, topk]
attn_score.masked_fill_(invalid_indices_mask.unsqueeze(1), float("-inf"))
attn_score *= sm_scale * math.log2(math.e)
max_logits = torch.max(attn_score, dim=-1)[0] # [s_q, h_q]
lse = log2sumexp2(attn_score, dim=-1) # [s_q, h_q]
attn_score = torch.exp2(attn_score - lse.unsqueeze(-1)) # [s_q, h_q, topk]
result = attn_score @ kvs[:, :, :512]
return (max_logits, lse, result)
def reference_torch_decode(
cache_seqlens: torch.Tensor, # [batch_size]
block_table: torch.Tensor, # [batch_size, ?]
q: torch.Tensor, # [batch_size, s_q, h_q, d]
blocked_k: torch.Tensor, # [?, block_size, h_kv, d]
dv: int,
is_causal: bool,
indices: Optional[torch.Tensor] = None, # [batch_size, s_q, topk]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
A reference implementation in PyTorch
"""
def get_topk_attn_mask(s_q: int, s_k: int, indices: torch.Tensor):
mask = torch.zeros(s_q, s_k, dtype=torch.bool, device="cuda")
for i in range(s_q):
cur_indices = indices[i]
valid_indices = cur_indices[cur_indices != -1]
mask[i, valid_indices] = True
return mask
def scaled_dot_product_attention(
batch_idx: int,
query: torch.Tensor, # [h_q, s_q, d]
kv: torch.Tensor, # [h_kv, s_k, d]
dv: int,
is_causal,
indices: Optional[torch.Tensor], # [s_q, topk]
) -> Tuple[torch.Tensor, torch.Tensor]:
h_q = query.size(0)
h_kv = kv.size(0)
s_q = query.shape[-2]
s_k = kv.shape[-2]
query = query.float()
kv = kv.float()
if h_kv != 1:
kv = kv.repeat_interleave(h_q // h_kv, dim=0)
kv[kv != kv] = 0.0
attn_weight = query @ kv.transpose(-2, -1) # [h_q, s_q, s_k]
if (is_causal and query.size(1) > 1) or indices is not None:
mask = torch.ones(s_q, s_k, dtype=torch.bool, device="cuda")
if is_causal:
assert indices is None
mask = mask.tril(diagonal=s_k - s_q)
if indices is not None:
mask &= get_topk_attn_mask(s_q, s_k, indices)
attn_bias = torch.zeros(s_q, s_k, dtype=torch.float, device="cuda")
attn_bias.masked_fill_(mask.logical_not(), float("-inf"))
attn_weight += attn_bias.to(q.dtype)
attn_weight /= math.sqrt(query.size(-1))
lse = attn_weight.logsumexp(dim=-1) # [h_q, s_q]
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
output = attn_weight @ kv[..., :dv] # [h_q, s_q, dv]
# Correct for q tokens which has no attendable k
lonely_q_mask = lse == float("-inf")
output[lonely_q_mask.unsqueeze(-1).broadcast_to(h_q, s_q, dv)] = 0.0
lse[lonely_q_mask] = float("+inf")
return output, lse
b, s_q, h_q, d = q.size()
block_size = blocked_k.size(1)
h_kv = blocked_k.size(2)
cache_seqlens_cpu = cache_seqlens.cpu()
out_ref = torch.empty(b, s_q, h_q, dv, dtype=torch.float32, device="cuda")
lse_ref = torch.empty(b, h_q, s_q, dtype=torch.float32, device="cuda")
for i in range(b):
cur_len = cache_seqlens_cpu[i].item()
cur_num_blocks = cdiv(cur_len, block_size)
cur_block_indices = block_table[i][0:cur_num_blocks]
cur_kv = blocked_k[cur_block_indices].view(-1, h_kv, d)[:cur_len, ...]
cur_out, cur_lse = scaled_dot_product_attention(
i,
q[i].transpose(0, 1),
cur_kv.transpose(0, 1),
dv,
is_causal,
indices[i] if indices is not None else None,
)
out_ref[i] = cur_out.transpose(0, 1)
lse_ref[i] = cur_lse
out_ref = out_ref.to(torch.bfloat16)
return out_ref, lse_ref
@pytest.mark.parametrize("s_q", S_Q_PREFILL)
@pytest.mark.parametrize("kv_topk", KV_TOPK_PREFILL)
@torch.inference_mode()
def test_flashmla_prefill(
s_q: int,
kv_topk: Tuple[int, int],
):
torch.cuda.empty_cache()
q = torch.randn((1, s_q, 128, 576), dtype=torch.bfloat16, device="cuda") / 10
kv = torch.randn((1, kv_topk[0], 1, 576), dtype=torch.bfloat16, device="cuda") / 10
q.clamp_(-10, 10)
kv.clamp_(-10, 10)
indices = torch.full(
(1, s_q, 1, kv_topk[1]), kv_topk[0], dtype=torch.int32, device="cuda"
)
for s in range(s_q):
# NOTE We use the following method to generate indices so that most indices lies within [s_kv-20000, s_kv), which is more realistic for sparse attention
near_mask = (
torch.randint(0, 32, (min(kv_topk[1], kv_topk[0]),), device="cuda") < 31
)
cur_indices = torch.randperm(kv_topk[0], device="cuda")[: kv_topk[1]]
cur_indices[near_mask] = torch.randint(
max(0, kv_topk[0] - 20000),
kv_topk[0] - 1,
(near_mask.sum().item(),),
device="cuda",
)
if len(cur_indices) < kv_topk[1]:
cur_indices = torch.cat(
[
cur_indices,
torch.full(
(kv_topk[1] - len(cur_indices),), 2147480000, device="cuda"
),
]
)
cur_indices = cur_indices[torch.randperm(kv_topk[1], device="cuda")]
indices[0, s, 0] = cur_indices
indices = indices.to(q.device)
sm_scale = 1 / math.sqrt(576)
torch.cuda.synchronize()
ans_out, ans_max_logits, ans_lse = flash_mla_sparse_fwd(
q.squeeze(0), kv.squeeze(0), indices.squeeze(0), sm_scale=sm_scale
)
ans_out, ans_max_logits, ans_lse = (
ans_out.float(),
ans_max_logits.float(),
ans_lse.float(),
)
torch.cuda.synchronize()
ref_max_logits, ref_lse, ref_out = reference_torch_prefill(
s_q, kv_topk[0], kv_topk[1], indices, q, kv, sm_scale
)
torch.cuda.synchronize()
torch.testing.assert_close(ans_out, ref_out, atol=8e-4, rtol=2.01 / 128)
torch.testing.assert_close(
ans_max_logits,
ref_max_logits,
atol=1e-6,
rtol=2.01 / 65536,
)
torch.testing.assert_close(ans_lse, ref_lse, atol=1e-6, rtol=2.01 / 65536)
@pytest.mark.parametrize("b", B_DECODE)
@pytest.mark.parametrize("s_q", S_Q_DECODE)
@pytest.mark.parametrize("s_k", S_K_DECODE)
@pytest.mark.parametrize("is_varlen", IS_VARLEN)
@pytest.mark.parametrize("causal_topk", CAUSAL_TOPK)
@pytest.mark.parametrize("dtype", DTYPE)
@torch.inference_mode()
def test_flash_mla_decode(
b: int,
s_q: int,
s_k: int,
is_varlen: bool,
causal_topk: Tuple[bool, Optional[int]],
dtype: torch.dtype,
):
d = 576
dv = 512
block_size = 64
h_q = 128
h_kv = 1
is_causal = causal_topk[0]
topk = causal_topk[1]
# Generating test data
torch.cuda.synchronize()
cache_seqlens_cpu = torch.full((b,), s_k, dtype=torch.int32, device="cpu")
if is_varlen:
for i in range(b):
cache_seqlens_cpu[i] = max(random.normalvariate(s_k, s_k / 2), s_q)
max_seqlen = cache_seqlens_cpu.max().item()
max_seqlen_pad = cdiv(max_seqlen, 256) * 256
cache_seqlens = cache_seqlens_cpu.cuda()
q = torch.randn(b, s_q, 128, d, dtype=torch.bfloat16, device="cuda")
q.clamp_(min=-1.0, max=1.0)
block_table = torch.arange(
b * max_seqlen_pad // block_size, dtype=torch.int32, device="cuda"
).view(b, max_seqlen_pad // block_size)
block_table = block_table.view(-1)[torch.randperm(block_table.numel())].view(b, -1)
blocked_k = (
torch.randn(
block_table.numel(),
block_size,
h_kv,
d,
dtype=torch.bfloat16,
device="cuda",
)
/ 10
)
blocked_k.clamp_(min=-1.0, max=1.0)
if topk is None:
for i in range(b):
cur_len = cache_seqlens_cpu[i].item()
cur_num_blocks = cdiv(cur_len, block_size)
blocked_k[block_table[i][cur_num_blocks:]] = float("nan")
if cur_len % block_size != 0:
blocked_k[block_table[i][cur_num_blocks - 1]][
cur_len % block_size :
] = float("nan")
block_table[i][cur_num_blocks:] = 2147480000
abs_indices = None
indices_in_kvcache = None
else:
block_table_cpu = block_table.cpu()
abs_indices = torch.empty(b, s_q, topk, dtype=torch.int32, device="cpu")
indices_in_kvcache = torch.empty(b, s_q, topk, dtype=torch.int32, device="cpu")
for i in range(b):
# Generate indices
for j in range(s_q):
cur_abs_indices = torch.randperm(
int(cache_seqlens_cpu[i].item()), device="cpu"
)[:topk]
cur_blocked_indices = block_table_cpu[
i, cur_abs_indices // block_size
] * block_size + (cur_abs_indices % block_size)
if len(cur_abs_indices) < topk:
pad_len = topk - len(cur_abs_indices)
cur_abs_indices = torch.cat(
[cur_abs_indices, torch.full((pad_len,), -1, device="cpu")]
)
cur_blocked_indices = torch.cat(
[cur_blocked_indices, torch.full((pad_len,), -1, device="cpu")]
)
# Mask KV
perm = torch.randperm(topk, device="cpu")
cur_abs_indices = cur_abs_indices[perm]
cur_blocked_indices = cur_blocked_indices[perm]
abs_indices[i, j, :] = cur_abs_indices
indices_in_kvcache[i, j, :] = cur_blocked_indices
# Mask nonused KV as NaN
all_indices = indices_in_kvcache.flatten().tolist()
all_indices = list(set(all_indices))
if -1 in all_indices:
all_indices.remove(-1)
all_indices = torch.tensor(all_indices, dtype=torch.int32, device="cpu")
blocked_k = blocked_k.view(-1, h_kv, d)
nonused_indices_mask = torch.ones(
blocked_k.size(0) * blocked_k.size(1), dtype=torch.bool, device="cpu"
)
nonused_indices_mask[all_indices] = False
blocked_k[nonused_indices_mask, :, :] = float("nan")
blocked_k = blocked_k.view(-1, block_size, h_kv, d)
abs_indices = abs_indices.to(q.device)
indices_in_kvcache = indices_in_kvcache.to(q.device)
is_fp8 = topk is not None
if is_fp8:
# The quantization error may be too large to be distinguished from wrong kernels
# So we quantize and de-quantize kv-cache here to mitigate quantization error
blocked_k_quantized = quantize_k_cache(blocked_k, dv, 128)
blocked_k_dequantized = dequantize_k_cache(blocked_k_quantized)
blocked_k = blocked_k_dequantized
# Get schedule metadata
torch.cuda.synchronize()
tile_scheduler_metadata, num_splits = get_mla_metadata(
cache_seqlens, s_q * h_q // h_kv, h_kv, h_q, is_fp8, topk
)
torch.cuda.synchronize()
out_ans, lse_ans = flash_mla_with_kvcache(
q,
blocked_k if not is_fp8 else blocked_k_quantized, # type: ignore
block_table,
cache_seqlens,
dv,
tile_scheduler_metadata,
num_splits,
causal=is_causal,
is_fp8_kvcache=is_fp8,
indices=indices_in_kvcache,
)
out_ref, lse_ref = reference_torch_decode(
cache_seqlens, block_table, q, blocked_k, dv, is_causal, abs_indices
)
torch.testing.assert_close(out_ans, out_ref, atol=8e-4, rtol=2.01 / 128)
torch.testing.assert_close(lse_ans, lse_ref, atol=1e-6, rtol=8.01 / 65536)
if __name__ == "__main__":
pytest.main([__file__])