[sgl-kernel] support flashmla libtorch (#11717)
This commit is contained in:
518
sgl-kernel/tests/test_flashmla.py
Normal file
518
sgl-kernel/tests/test_flashmla.py
Normal file
@@ -0,0 +1,518 @@
|
||||
import math
|
||||
import random
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import triton
|
||||
from sgl_kernel.flash_mla import (
|
||||
flash_mla_sparse_fwd,
|
||||
flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
)
|
||||
|
||||
skip_condition = torch.cuda.get_device_capability() < (10, 0)
|
||||
|
||||
# ================ prefill usage ================ #
|
||||
S_Q_PREFILL = [1, 62]
|
||||
KV_TOPK_PREFILL = [
|
||||
# Regular shapes
|
||||
(128, 128),
|
||||
(256, 256),
|
||||
(512, 512),
|
||||
# Irregular shapes
|
||||
(592, 128),
|
||||
(1840, 256),
|
||||
(1592, 384),
|
||||
(1521, 512),
|
||||
# Irregular shapes with OOB TopK
|
||||
(95, 128),
|
||||
(153, 256),
|
||||
(114, 384),
|
||||
]
|
||||
|
||||
# ================= decode usage ================= #
|
||||
B_DECODE = [1, 2, 6, 64]
|
||||
S_Q_DECODE = [1, 2, 4]
|
||||
S_K_DECODE = [20, 140, 4096]
|
||||
IS_VARLEN = [False, True]
|
||||
CAUSAL_TOPK = [(True, None), (False, None), (False, 128), (False, 2048)]
|
||||
DTYPE = [torch.float16, torch.bfloat16]
|
||||
|
||||
|
||||
def quantize_k_cache(
|
||||
input_k_cache: torch.Tensor, # (num_blocks, block_size, h_k, d)
|
||||
dv: int,
|
||||
tile_size: int = 128,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Quantize the k-cache
|
||||
Return a tensor with shape (num_blocks, block_size, h_k, dv + 4(dv/tile_size) + t(d-dv)) of dtype uint8_t, where t = input_k_cache.element_size()
|
||||
For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py or README.md
|
||||
"""
|
||||
assert dv % tile_size == 0
|
||||
num_tiles = dv // tile_size
|
||||
num_blocks, block_size, h_k, d = input_k_cache.shape
|
||||
assert h_k == 1
|
||||
input_k_cache = input_k_cache.squeeze(2) # [num_blocks, block_size, d]
|
||||
input_elem_size = input_k_cache.element_size()
|
||||
|
||||
result = torch.empty(
|
||||
(num_blocks, block_size, dv + num_tiles * 4 + input_elem_size * (d - dv)),
|
||||
dtype=torch.float8_e4m3fn,
|
||||
device=input_k_cache.device,
|
||||
)
|
||||
result_k_nope_part = result[..., :dv]
|
||||
result_k_scale_factor = result[..., dv : dv + num_tiles * 4].view(torch.float32)
|
||||
result_k_rope_part = result[..., dv + num_tiles * 4 :].view(input_k_cache.dtype)
|
||||
result_k_rope_part[:] = input_k_cache[..., dv:]
|
||||
|
||||
for tile_idx in range(0, num_tiles):
|
||||
cur_scale_factors_inv = (
|
||||
torch.abs(
|
||||
input_k_cache[..., tile_idx * tile_size : (tile_idx + 1) * tile_size]
|
||||
)
|
||||
.max(dim=-1)
|
||||
.values
|
||||
/ 448.0
|
||||
) # [num_blocks, block_size]
|
||||
result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv
|
||||
|
||||
cur_scale_factors_inv.unsqueeze_(-1) # [num_blocks, block_size, 1]
|
||||
cur_quantized_nope = (
|
||||
input_k_cache[
|
||||
..., tile_idx * tile_size : (tile_idx + 1) * tile_size
|
||||
].float()
|
||||
/ cur_scale_factors_inv.float()
|
||||
).to(torch.float8_e4m3fn)
|
||||
result_k_nope_part[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = (
|
||||
cur_quantized_nope
|
||||
)
|
||||
|
||||
result = result.view(num_blocks, block_size, 1, -1)
|
||||
return result
|
||||
|
||||
|
||||
def dequantize_k_cache(
|
||||
quant_k_cache: torch.Tensor, # (num_blocks, block_size, 1, bytes_per_token)
|
||||
dv: int = 512,
|
||||
tile_size: int = 128,
|
||||
d: int = 576,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
De-quantize the k-cache
|
||||
"""
|
||||
assert dv % tile_size == 0
|
||||
num_tiles = dv // tile_size
|
||||
num_blocks, block_size, h_k, _ = quant_k_cache.shape
|
||||
assert h_k == 1
|
||||
result = torch.empty(
|
||||
(num_blocks, block_size, d), dtype=torch.bfloat16, device=quant_k_cache.device
|
||||
)
|
||||
|
||||
quant_k_cache = quant_k_cache.view(num_blocks, block_size, -1)
|
||||
|
||||
input_nope = quant_k_cache[..., :dv]
|
||||
input_scale = quant_k_cache[..., dv : dv + num_tiles * 4].view(torch.float32)
|
||||
input_rope = quant_k_cache[..., dv + num_tiles * 4 :].view(torch.bfloat16)
|
||||
result[..., dv:] = input_rope
|
||||
|
||||
for tile_idx in range(0, num_tiles):
|
||||
cur_nope = input_nope[
|
||||
..., tile_idx * tile_size : (tile_idx + 1) * tile_size
|
||||
].to(torch.float32)
|
||||
cur_scales = input_scale[..., tile_idx].unsqueeze(-1)
|
||||
result[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = (
|
||||
cur_nope * cur_scales
|
||||
)
|
||||
|
||||
result = result.view(num_blocks, block_size, 1, d)
|
||||
return result
|
||||
|
||||
|
||||
def cdiv(x: int, y: int):
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
def get_window_size(causal, window):
|
||||
if window > 0:
|
||||
window_size = (window - 1, 0) if causal else (window - 1, window - 1)
|
||||
else:
|
||||
window_size = (-1, -1)
|
||||
return window_size
|
||||
|
||||
|
||||
def get_attn_bias(s_q, s_k, causal, window):
|
||||
attn_bias = torch.zeros(s_q, s_k, dtype=torch.float32, device="cuda")
|
||||
if causal:
|
||||
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device="cuda").tril(
|
||||
diagonal=s_k - s_q
|
||||
)
|
||||
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
||||
if window > 0:
|
||||
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device="cuda").tril(
|
||||
diagonal=s_k - s_q - window
|
||||
)
|
||||
attn_bias.masked_fill_(temp_mask, float("-inf"))
|
||||
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device="cuda").tril(
|
||||
diagonal=s_k - s_q + window - 1
|
||||
)
|
||||
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
||||
return attn_bias
|
||||
|
||||
|
||||
def sdpa(query, key, value, attn_bias, softmax_scale=None):
|
||||
query = query.float().transpose(-3, -2)
|
||||
key = key.float().transpose(-3, -2)
|
||||
value = value.float().transpose(-3, -2)
|
||||
key = key.repeat_interleave(h // h_k, dim=-3)
|
||||
value = value.repeat_interleave(h // h_k, dim=-3)
|
||||
if softmax_scale is None:
|
||||
softmax_scale = query.shape[-1] ** (-0.5)
|
||||
attn_weight = (query @ key.transpose(-2, -1)) * softmax_scale
|
||||
attn_weight += attn_bias
|
||||
lse = attn_weight.logsumexp(dim=-1)
|
||||
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
|
||||
return attn_weight.to(query.dtype) @ value, lse
|
||||
|
||||
|
||||
def sdpa_checkpoint(*args, **kwargs):
|
||||
return checkpoint(sdpa, *args, use_reentrant=False, **kwargs)
|
||||
|
||||
|
||||
def reference_torch_prefill(
|
||||
s_q, s_kv, topk, indices, q, kv, sm_scale: float
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
def log2sumexp2(a: torch.Tensor, dim: int) -> torch.Tensor:
|
||||
return torch.logsumexp(a * math.log(2), dim=dim) * math.log2(math.e)
|
||||
|
||||
indices = indices[0, :, 0, :] # [s_q, topk]
|
||||
invalid_indices_mask = (indices < 0) | (indices >= s_kv)
|
||||
qs = q[0, :, :, :].float() # [s_q, h_q, d_qk]
|
||||
kvs = kv[0, :, 0, :].float() # [s_kv, d_qk]
|
||||
|
||||
kvs = torch.index_select(
|
||||
kvs, 0, indices.masked_fill(invalid_indices_mask, 0).flatten()
|
||||
).view(
|
||||
s_q, topk, 576
|
||||
) # [s_q, topk, d_qk]
|
||||
attn_score = qs @ kvs.transpose(1, 2) # [s_q, h_q, topk]
|
||||
attn_score.masked_fill_(invalid_indices_mask.unsqueeze(1), float("-inf"))
|
||||
attn_score *= sm_scale * math.log2(math.e)
|
||||
max_logits = torch.max(attn_score, dim=-1)[0] # [s_q, h_q]
|
||||
lse = log2sumexp2(attn_score, dim=-1) # [s_q, h_q]
|
||||
attn_score = torch.exp2(attn_score - lse.unsqueeze(-1)) # [s_q, h_q, topk]
|
||||
result = attn_score @ kvs[:, :, :512]
|
||||
return (max_logits, lse, result)
|
||||
|
||||
|
||||
def reference_torch_decode(
|
||||
cache_seqlens: torch.Tensor, # [batch_size]
|
||||
block_table: torch.Tensor, # [batch_size, ?]
|
||||
q: torch.Tensor, # [batch_size, s_q, h_q, d]
|
||||
blocked_k: torch.Tensor, # [?, block_size, h_kv, d]
|
||||
dv: int,
|
||||
is_causal: bool,
|
||||
indices: Optional[torch.Tensor] = None, # [batch_size, s_q, topk]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
A reference implementation in PyTorch
|
||||
"""
|
||||
|
||||
def get_topk_attn_mask(s_q: int, s_k: int, indices: torch.Tensor):
|
||||
mask = torch.zeros(s_q, s_k, dtype=torch.bool, device="cuda")
|
||||
for i in range(s_q):
|
||||
cur_indices = indices[i]
|
||||
valid_indices = cur_indices[cur_indices != -1]
|
||||
mask[i, valid_indices] = True
|
||||
return mask
|
||||
|
||||
def scaled_dot_product_attention(
|
||||
batch_idx: int,
|
||||
query: torch.Tensor, # [h_q, s_q, d]
|
||||
kv: torch.Tensor, # [h_kv, s_k, d]
|
||||
dv: int,
|
||||
is_causal,
|
||||
indices: Optional[torch.Tensor], # [s_q, topk]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
h_q = query.size(0)
|
||||
h_kv = kv.size(0)
|
||||
s_q = query.shape[-2]
|
||||
s_k = kv.shape[-2]
|
||||
query = query.float()
|
||||
kv = kv.float()
|
||||
if h_kv != 1:
|
||||
kv = kv.repeat_interleave(h_q // h_kv, dim=0)
|
||||
kv[kv != kv] = 0.0
|
||||
attn_weight = query @ kv.transpose(-2, -1) # [h_q, s_q, s_k]
|
||||
if (is_causal and query.size(1) > 1) or indices is not None:
|
||||
mask = torch.ones(s_q, s_k, dtype=torch.bool, device="cuda")
|
||||
if is_causal:
|
||||
assert indices is None
|
||||
mask = mask.tril(diagonal=s_k - s_q)
|
||||
if indices is not None:
|
||||
mask &= get_topk_attn_mask(s_q, s_k, indices)
|
||||
attn_bias = torch.zeros(s_q, s_k, dtype=torch.float, device="cuda")
|
||||
attn_bias.masked_fill_(mask.logical_not(), float("-inf"))
|
||||
attn_weight += attn_bias.to(q.dtype)
|
||||
attn_weight /= math.sqrt(query.size(-1))
|
||||
lse = attn_weight.logsumexp(dim=-1) # [h_q, s_q]
|
||||
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
|
||||
output = attn_weight @ kv[..., :dv] # [h_q, s_q, dv]
|
||||
# Correct for q tokens which has no attendable k
|
||||
lonely_q_mask = lse == float("-inf")
|
||||
output[lonely_q_mask.unsqueeze(-1).broadcast_to(h_q, s_q, dv)] = 0.0
|
||||
lse[lonely_q_mask] = float("+inf")
|
||||
|
||||
return output, lse
|
||||
|
||||
b, s_q, h_q, d = q.size()
|
||||
block_size = blocked_k.size(1)
|
||||
h_kv = blocked_k.size(2)
|
||||
cache_seqlens_cpu = cache_seqlens.cpu()
|
||||
out_ref = torch.empty(b, s_q, h_q, dv, dtype=torch.float32, device="cuda")
|
||||
lse_ref = torch.empty(b, h_q, s_q, dtype=torch.float32, device="cuda")
|
||||
for i in range(b):
|
||||
cur_len = cache_seqlens_cpu[i].item()
|
||||
cur_num_blocks = cdiv(cur_len, block_size)
|
||||
cur_block_indices = block_table[i][0:cur_num_blocks]
|
||||
cur_kv = blocked_k[cur_block_indices].view(-1, h_kv, d)[:cur_len, ...]
|
||||
cur_out, cur_lse = scaled_dot_product_attention(
|
||||
i,
|
||||
q[i].transpose(0, 1),
|
||||
cur_kv.transpose(0, 1),
|
||||
dv,
|
||||
is_causal,
|
||||
indices[i] if indices is not None else None,
|
||||
)
|
||||
out_ref[i] = cur_out.transpose(0, 1)
|
||||
lse_ref[i] = cur_lse
|
||||
out_ref = out_ref.to(torch.bfloat16)
|
||||
return out_ref, lse_ref
|
||||
|
||||
|
||||
@pytest.mark.parametrize("s_q", S_Q_PREFILL)
|
||||
@pytest.mark.parametrize("kv_topk", KV_TOPK_PREFILL)
|
||||
@torch.inference_mode()
|
||||
def test_flashmla_prefill(
|
||||
s_q: int,
|
||||
kv_topk: Tuple[int, int],
|
||||
):
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
q = torch.randn((1, s_q, 128, 576), dtype=torch.bfloat16, device="cuda") / 10
|
||||
kv = torch.randn((1, kv_topk[0], 1, 576), dtype=torch.bfloat16, device="cuda") / 10
|
||||
|
||||
q.clamp_(-10, 10)
|
||||
kv.clamp_(-10, 10)
|
||||
|
||||
indices = torch.full(
|
||||
(1, s_q, 1, kv_topk[1]), kv_topk[0], dtype=torch.int32, device="cuda"
|
||||
)
|
||||
for s in range(s_q):
|
||||
# NOTE We use the following method to generate indices so that most indices lies within [s_kv-20000, s_kv), which is more realistic for sparse attention
|
||||
near_mask = (
|
||||
torch.randint(0, 32, (min(kv_topk[1], kv_topk[0]),), device="cuda") < 31
|
||||
)
|
||||
cur_indices = torch.randperm(kv_topk[0], device="cuda")[: kv_topk[1]]
|
||||
cur_indices[near_mask] = torch.randint(
|
||||
max(0, kv_topk[0] - 20000),
|
||||
kv_topk[0] - 1,
|
||||
(near_mask.sum().item(),),
|
||||
device="cuda",
|
||||
)
|
||||
if len(cur_indices) < kv_topk[1]:
|
||||
cur_indices = torch.cat(
|
||||
[
|
||||
cur_indices,
|
||||
torch.full(
|
||||
(kv_topk[1] - len(cur_indices),), 2147480000, device="cuda"
|
||||
),
|
||||
]
|
||||
)
|
||||
cur_indices = cur_indices[torch.randperm(kv_topk[1], device="cuda")]
|
||||
indices[0, s, 0] = cur_indices
|
||||
indices = indices.to(q.device)
|
||||
|
||||
sm_scale = 1 / math.sqrt(576)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
ans_out, ans_max_logits, ans_lse = flash_mla_sparse_fwd(
|
||||
q.squeeze(0), kv.squeeze(0), indices.squeeze(0), sm_scale=sm_scale
|
||||
)
|
||||
|
||||
ans_out, ans_max_logits, ans_lse = (
|
||||
ans_out.float(),
|
||||
ans_max_logits.float(),
|
||||
ans_lse.float(),
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
ref_max_logits, ref_lse, ref_out = reference_torch_prefill(
|
||||
s_q, kv_topk[0], kv_topk[1], indices, q, kv, sm_scale
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
torch.testing.assert_close(ans_out, ref_out, atol=8e-4, rtol=2.01 / 128)
|
||||
torch.testing.assert_close(
|
||||
ans_max_logits,
|
||||
ref_max_logits,
|
||||
atol=1e-6,
|
||||
rtol=2.01 / 65536,
|
||||
)
|
||||
torch.testing.assert_close(ans_lse, ref_lse, atol=1e-6, rtol=2.01 / 65536)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("b", B_DECODE)
|
||||
@pytest.mark.parametrize("s_q", S_Q_DECODE)
|
||||
@pytest.mark.parametrize("s_k", S_K_DECODE)
|
||||
@pytest.mark.parametrize("is_varlen", IS_VARLEN)
|
||||
@pytest.mark.parametrize("causal_topk", CAUSAL_TOPK)
|
||||
@pytest.mark.parametrize("dtype", DTYPE)
|
||||
@torch.inference_mode()
|
||||
def test_flash_mla_decode(
|
||||
b: int,
|
||||
s_q: int,
|
||||
s_k: int,
|
||||
is_varlen: bool,
|
||||
causal_topk: Tuple[bool, Optional[int]],
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
d = 576
|
||||
dv = 512
|
||||
block_size = 64
|
||||
h_q = 128
|
||||
h_kv = 1
|
||||
is_causal = causal_topk[0]
|
||||
topk = causal_topk[1]
|
||||
|
||||
# Generating test data
|
||||
torch.cuda.synchronize()
|
||||
|
||||
cache_seqlens_cpu = torch.full((b,), s_k, dtype=torch.int32, device="cpu")
|
||||
if is_varlen:
|
||||
for i in range(b):
|
||||
cache_seqlens_cpu[i] = max(random.normalvariate(s_k, s_k / 2), s_q)
|
||||
|
||||
max_seqlen = cache_seqlens_cpu.max().item()
|
||||
max_seqlen_pad = cdiv(max_seqlen, 256) * 256
|
||||
cache_seqlens = cache_seqlens_cpu.cuda()
|
||||
|
||||
q = torch.randn(b, s_q, 128, d, dtype=torch.bfloat16, device="cuda")
|
||||
q.clamp_(min=-1.0, max=1.0)
|
||||
|
||||
block_table = torch.arange(
|
||||
b * max_seqlen_pad // block_size, dtype=torch.int32, device="cuda"
|
||||
).view(b, max_seqlen_pad // block_size)
|
||||
block_table = block_table.view(-1)[torch.randperm(block_table.numel())].view(b, -1)
|
||||
blocked_k = (
|
||||
torch.randn(
|
||||
block_table.numel(),
|
||||
block_size,
|
||||
h_kv,
|
||||
d,
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
)
|
||||
/ 10
|
||||
)
|
||||
blocked_k.clamp_(min=-1.0, max=1.0)
|
||||
|
||||
if topk is None:
|
||||
for i in range(b):
|
||||
cur_len = cache_seqlens_cpu[i].item()
|
||||
cur_num_blocks = cdiv(cur_len, block_size)
|
||||
blocked_k[block_table[i][cur_num_blocks:]] = float("nan")
|
||||
if cur_len % block_size != 0:
|
||||
blocked_k[block_table[i][cur_num_blocks - 1]][
|
||||
cur_len % block_size :
|
||||
] = float("nan")
|
||||
block_table[i][cur_num_blocks:] = 2147480000
|
||||
abs_indices = None
|
||||
indices_in_kvcache = None
|
||||
else:
|
||||
block_table_cpu = block_table.cpu()
|
||||
abs_indices = torch.empty(b, s_q, topk, dtype=torch.int32, device="cpu")
|
||||
indices_in_kvcache = torch.empty(b, s_q, topk, dtype=torch.int32, device="cpu")
|
||||
for i in range(b):
|
||||
# Generate indices
|
||||
for j in range(s_q):
|
||||
cur_abs_indices = torch.randperm(
|
||||
int(cache_seqlens_cpu[i].item()), device="cpu"
|
||||
)[:topk]
|
||||
cur_blocked_indices = block_table_cpu[
|
||||
i, cur_abs_indices // block_size
|
||||
] * block_size + (cur_abs_indices % block_size)
|
||||
if len(cur_abs_indices) < topk:
|
||||
pad_len = topk - len(cur_abs_indices)
|
||||
cur_abs_indices = torch.cat(
|
||||
[cur_abs_indices, torch.full((pad_len,), -1, device="cpu")]
|
||||
)
|
||||
cur_blocked_indices = torch.cat(
|
||||
[cur_blocked_indices, torch.full((pad_len,), -1, device="cpu")]
|
||||
)
|
||||
|
||||
# Mask KV
|
||||
perm = torch.randperm(topk, device="cpu")
|
||||
cur_abs_indices = cur_abs_indices[perm]
|
||||
cur_blocked_indices = cur_blocked_indices[perm]
|
||||
|
||||
abs_indices[i, j, :] = cur_abs_indices
|
||||
indices_in_kvcache[i, j, :] = cur_blocked_indices
|
||||
|
||||
# Mask nonused KV as NaN
|
||||
all_indices = indices_in_kvcache.flatten().tolist()
|
||||
all_indices = list(set(all_indices))
|
||||
if -1 in all_indices:
|
||||
all_indices.remove(-1)
|
||||
all_indices = torch.tensor(all_indices, dtype=torch.int32, device="cpu")
|
||||
|
||||
blocked_k = blocked_k.view(-1, h_kv, d)
|
||||
nonused_indices_mask = torch.ones(
|
||||
blocked_k.size(0) * blocked_k.size(1), dtype=torch.bool, device="cpu"
|
||||
)
|
||||
nonused_indices_mask[all_indices] = False
|
||||
blocked_k[nonused_indices_mask, :, :] = float("nan")
|
||||
blocked_k = blocked_k.view(-1, block_size, h_kv, d)
|
||||
|
||||
abs_indices = abs_indices.to(q.device)
|
||||
indices_in_kvcache = indices_in_kvcache.to(q.device)
|
||||
|
||||
is_fp8 = topk is not None
|
||||
if is_fp8:
|
||||
# The quantization error may be too large to be distinguished from wrong kernels
|
||||
# So we quantize and de-quantize kv-cache here to mitigate quantization error
|
||||
blocked_k_quantized = quantize_k_cache(blocked_k, dv, 128)
|
||||
blocked_k_dequantized = dequantize_k_cache(blocked_k_quantized)
|
||||
blocked_k = blocked_k_dequantized
|
||||
|
||||
# Get schedule metadata
|
||||
torch.cuda.synchronize()
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
||||
cache_seqlens, s_q * h_q // h_kv, h_kv, h_q, is_fp8, topk
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
out_ans, lse_ans = flash_mla_with_kvcache(
|
||||
q,
|
||||
blocked_k if not is_fp8 else blocked_k_quantized, # type: ignore
|
||||
block_table,
|
||||
cache_seqlens,
|
||||
dv,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
causal=is_causal,
|
||||
is_fp8_kvcache=is_fp8,
|
||||
indices=indices_in_kvcache,
|
||||
)
|
||||
|
||||
out_ref, lse_ref = reference_torch_decode(
|
||||
cache_seqlens, block_table, q, blocked_k, dv, is_causal, abs_indices
|
||||
)
|
||||
torch.testing.assert_close(out_ans, out_ref, atol=8e-4, rtol=2.01 / 128)
|
||||
torch.testing.assert_close(lse_ans, lse_ref, atol=1e-6, rtol=8.01 / 65536)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
Reference in New Issue
Block a user