sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct
This commit is contained in:
25
sgl-kernel/tests/spatial/test_greenctx_stream.py
Normal file
25
sgl-kernel/tests/spatial/test_greenctx_stream.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from sgl_kernel import create_greenctx_stream_by_value, get_sm_available
|
||||
|
||||
|
||||
def test_green_ctx():
|
||||
A = torch.randn(5120, 5120).cuda()
|
||||
B = torch.randn(5120, 5120).cuda()
|
||||
C = torch.matmul(A, B)
|
||||
sm_counts = get_sm_available(0)
|
||||
stream_group = create_greenctx_stream_by_value(sm_counts // 2, sm_counts // 2, 0)
|
||||
with torch.cuda.stream(stream_group[0]):
|
||||
for _ in range(100):
|
||||
result_0 = torch.matmul(A, B)
|
||||
with torch.cuda.stream(stream_group[1]):
|
||||
for _ in range(100):
|
||||
result_1 = torch.matmul(A, B)
|
||||
torch.cuda.synchronize()
|
||||
assert torch.allclose(result_0, C)
|
||||
assert torch.allclose(result_1, C)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
87
sgl-kernel/tests/speculative/test_eagle_utils.py
Normal file
87
sgl-kernel/tests/speculative/test_eagle_utils.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from sgl_kernel import verify_tree_greedy
|
||||
|
||||
|
||||
def test_verify_tree_greedy():
|
||||
candidates = torch.tensor(
|
||||
[
|
||||
[0, 1, 2, 3, 4, 5],
|
||||
[7, 8, 9, 10, 11, 12],
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device="cuda",
|
||||
)
|
||||
retrive_index = torch.tensor(
|
||||
[
|
||||
[0, 1, 2, 3, 4, 5],
|
||||
[6, 7, 8, 9, 10, 11],
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device="cuda",
|
||||
)
|
||||
retrive_next_token = torch.tensor(
|
||||
[
|
||||
[1, 2, -1, 4, 5, -1],
|
||||
[4, 2, 3, -1, 5, -1],
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device="cuda",
|
||||
)
|
||||
retrive_next_sibling = torch.tensor(
|
||||
[
|
||||
[-1, 3, -1, -1, -1, -1],
|
||||
[-1, -1, -1, -1, 1, -1],
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
target_logits = torch.full((2, 6, 20), 1, dtype=torch.float32, device="cuda")
|
||||
target_logits[0, 0, 3] = 10
|
||||
target_logits[0, 3, 4] = 10
|
||||
target_logits[0, 4, 5] = 10
|
||||
target_logits[1, 0, 11] = 10
|
||||
target_logits[1, 4, 12] = 10
|
||||
for i in range(target_logits.shape[0]):
|
||||
for j in range(target_logits.shape[1]):
|
||||
if torch.max(target_logits[i][j]) < 10:
|
||||
target_logits[i][j][18] = 10
|
||||
|
||||
target_predict = torch.argmax(target_logits, dim=-1)
|
||||
predict_shape = (12,)
|
||||
|
||||
bs = candidates.shape[0]
|
||||
num_spec_step = 4
|
||||
|
||||
predicts = torch.full(
|
||||
predict_shape, -1, dtype=torch.int32, device="cuda"
|
||||
) # mutable
|
||||
accept_index = torch.full(
|
||||
(bs, num_spec_step), -1, dtype=torch.int32, device="cuda"
|
||||
) # mutable
|
||||
accept_token_num = torch.full((bs,), 0, dtype=torch.int32, device="cuda") # mutable
|
||||
|
||||
verify_tree_greedy(
|
||||
predicts=predicts,
|
||||
accept_index=accept_index,
|
||||
accept_token_num=accept_token_num,
|
||||
candidates=candidates,
|
||||
retrive_index=retrive_index,
|
||||
retrive_next_token=retrive_next_token,
|
||||
retrive_next_sibling=retrive_next_sibling,
|
||||
target_predict=target_predict,
|
||||
)
|
||||
|
||||
# Check the expected output.
|
||||
assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18]
|
||||
assert accept_index.tolist() == [
|
||||
[0, 3, 4, 5],
|
||||
[6, 10, 11, -1],
|
||||
]
|
||||
assert accept_token_num.tolist() == [3, 2]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
129
sgl-kernel/tests/speculative/test_speculative_sampling.py
Normal file
129
sgl-kernel/tests/speculative/test_speculative_sampling.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from sgl_kernel import tree_speculative_sampling_target_only
|
||||
|
||||
test_cases = [
|
||||
(
|
||||
1,
|
||||
1,
|
||||
[3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18],
|
||||
[[0, 3, 4, 5], [6, 10, 11, -1]],
|
||||
[3, 2],
|
||||
),
|
||||
(
|
||||
0, # threshold_single
|
||||
0, # threshold_acc
|
||||
[1, 2, 18, -1, -1, -1, 11, -1, -1, -1, 12, 18],
|
||||
[[0, 1, 2, -1], [6, 10, 11, -1]],
|
||||
[2, 2],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"threshold_single, threshold_acc, expected_predicts, expected_accept_index, expected_accept_token_num",
|
||||
test_cases,
|
||||
)
|
||||
def test_tree_speculative_sampling_target_only(
|
||||
threshold_single,
|
||||
threshold_acc,
|
||||
expected_predicts,
|
||||
expected_accept_index,
|
||||
expected_accept_token_num,
|
||||
):
|
||||
"""
|
||||
Tests the tree_speculative_sampling_target_only function using Pytest parameterization.
|
||||
"""
|
||||
device = "cuda"
|
||||
|
||||
candidates = torch.tensor(
|
||||
[
|
||||
[0, 1, 2, 3, 4, 5],
|
||||
[7, 8, 9, 10, 11, 12],
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
retrive_index = torch.tensor(
|
||||
[
|
||||
[0, 1, 2, 3, 4, 5],
|
||||
[6, 7, 8, 9, 10, 11],
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
retrive_next_token = torch.tensor(
|
||||
[
|
||||
[1, 2, -1, 4, 5, -1],
|
||||
[4, 2, 3, -1, 5, -1],
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
retrive_next_sibling = torch.tensor(
|
||||
[
|
||||
[-1, 3, -1, -1, -1, -1],
|
||||
[-1, -1, -1, -1, 1, -1],
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
|
||||
target_logits = torch.full((2, 6, 20), 1, dtype=torch.float32, device=device)
|
||||
target_logits[0, 0, 3] = 10
|
||||
target_logits[0, 3, 4] = 10
|
||||
target_logits[0, 4, 5] = 10
|
||||
target_logits[1, 0, 11] = 10
|
||||
target_logits[1, 4, 12] = 10
|
||||
|
||||
for i in range(target_logits.shape[0]):
|
||||
for j in range(target_logits.shape[1]):
|
||||
if torch.max(target_logits[i, j]) < 10:
|
||||
target_logits[i, j, 18] = 10
|
||||
|
||||
temperatures = torch.tensor([0.01, 0.01], dtype=torch.float32, device=device)
|
||||
bs, num_draft_tokens = candidates.shape
|
||||
num_spec_step = len(expected_accept_index[0])
|
||||
predict_shape = (len(expected_predicts),)
|
||||
|
||||
predicts = torch.full(predict_shape, -1, dtype=torch.int32, device=device)
|
||||
accept_index = torch.full((bs, num_spec_step), -1, dtype=torch.int32, device=device)
|
||||
accept_token_num = torch.full((bs,), 0, dtype=torch.int32, device=device)
|
||||
|
||||
expanded_temperature = temperatures.unsqueeze(1).unsqueeze(1)
|
||||
target_probs = F.softmax(target_logits / expanded_temperature, dim=-1)
|
||||
draft_probs = torch.full_like(target_probs, 0, dtype=torch.float32, device=device)
|
||||
coins = torch.rand(bs, num_draft_tokens, device=device, dtype=torch.float32)
|
||||
coins_for_final_sampling = torch.rand(bs, device=device).to(torch.float32)
|
||||
|
||||
tree_speculative_sampling_target_only(
|
||||
predicts=predicts,
|
||||
accept_index=accept_index,
|
||||
accept_token_num=accept_token_num,
|
||||
candidates=candidates,
|
||||
retrive_index=retrive_index,
|
||||
retrive_next_token=retrive_next_token,
|
||||
retrive_next_sibling=retrive_next_sibling,
|
||||
uniform_samples=coins,
|
||||
uniform_samples_for_final_sampling=coins_for_final_sampling,
|
||||
target_probs=target_probs,
|
||||
draft_probs=draft_probs,
|
||||
threshold_single=threshold_single,
|
||||
threshold_acc=threshold_acc,
|
||||
deterministic=True,
|
||||
)
|
||||
|
||||
assert (
|
||||
predicts.tolist() == expected_predicts
|
||||
), f"Predicts mismatch for thresholds ({threshold_single}, {threshold_acc})"
|
||||
assert (
|
||||
accept_index.tolist() == expected_accept_index
|
||||
), f"Accept index mismatch for thresholds ({threshold_single}, {threshold_acc})"
|
||||
assert (
|
||||
accept_token_num.tolist() == expected_accept_token_num
|
||||
), f"Accept token num mismatch for thresholds ({threshold_single}, {threshold_acc})"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
39
sgl-kernel/tests/test_activation.py
Normal file
39
sgl-kernel/tests/test_activation.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/tests/test_activation.py
|
||||
|
||||
import pytest
|
||||
import sgl_kernel
|
||||
import torch
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384])
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16])
|
||||
@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512])
|
||||
def test_fused_silu_mul(dim, batch_size, seq_len):
|
||||
x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16)
|
||||
y_ref = x[..., dim:] * torch.nn.functional.silu(x[..., :dim])
|
||||
y = sgl_kernel.silu_and_mul(x)
|
||||
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384])
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16])
|
||||
@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512])
|
||||
def test_fused_gelu_tanh_mul(dim, batch_size, seq_len):
|
||||
x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16)
|
||||
y_ref = x[..., dim:] * torch.nn.functional.gelu(x[..., :dim], approximate="tanh")
|
||||
y = sgl_kernel.gelu_tanh_and_mul(x)
|
||||
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384])
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16])
|
||||
@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512])
|
||||
def test_fused_gelu_mul(dim, batch_size, seq_len):
|
||||
x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16)
|
||||
y_ref = x[..., dim:] * torch.nn.functional.gelu(x[..., :dim], approximate="none")
|
||||
y = sgl_kernel.gelu_and_mul(x)
|
||||
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
23
sgl-kernel/tests/test_apply_token_bitmask_inplace.py
Normal file
23
sgl-kernel/tests/test_apply_token_bitmask_inplace.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import apply_token_bitmask_inplace_cuda
|
||||
|
||||
|
||||
def test_apply_token_bitmask_inplace_kernel():
|
||||
neginf = float("-inf")
|
||||
bool_mask = torch.tensor([0, 1, 0, 1, 0, 1, 0, 1, 0, 1], dtype=torch.bool)
|
||||
logits = torch.tensor(
|
||||
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], dtype=torch.float32
|
||||
)
|
||||
expected = torch.where(bool_mask, logits, neginf)
|
||||
|
||||
logits_gpu = logits.to("cuda")
|
||||
bitmask = torch.tensor([0b1010101010], dtype=torch.int32).to("cuda")
|
||||
apply_token_bitmask_inplace_cuda(logits_gpu, bitmask)
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(logits_gpu, expected.to("cuda"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_apply_token_bitmask_inplace_kernel()
|
||||
pytest.main([__file__])
|
||||
115
sgl-kernel/tests/test_awq_dequant.py
Normal file
115
sgl-kernel/tests/test_awq_dequant.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import itertools
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import awq_dequantize
|
||||
|
||||
|
||||
def reverse_awq_order(t: torch.Tensor):
|
||||
bits = 4
|
||||
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
|
||||
reverse_order_tensor = torch.arange(
|
||||
t.shape[-1],
|
||||
dtype=torch.int32,
|
||||
device=t.device,
|
||||
)
|
||||
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
|
||||
reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]
|
||||
reverse_order_tensor = reverse_order_tensor.view(-1)
|
||||
|
||||
t = t[:, reverse_order_tensor] & 0xF
|
||||
return t
|
||||
|
||||
|
||||
# qweights - [R , C // 8], int32
|
||||
# scales - [R // G, C ], float16
|
||||
# zeros - [R // G, C // 8], int32
|
||||
def awq_dequantize_torch(
|
||||
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor, group_size: int
|
||||
) -> torch.Tensor:
|
||||
|
||||
if group_size == -1:
|
||||
group_size = qweight.shape[0]
|
||||
|
||||
bits = 4
|
||||
shifts = torch.arange(0, 32, bits, device=qzeros.device)
|
||||
|
||||
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(
|
||||
torch.int8
|
||||
)
|
||||
|
||||
iweights = iweights.view(iweights.shape[0], -1)
|
||||
|
||||
zeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to(
|
||||
torch.int8
|
||||
)
|
||||
zeros = zeros.view(qzeros.shape[0], -1)
|
||||
zeros = reverse_awq_order(zeros)
|
||||
|
||||
iweights = reverse_awq_order(iweights)
|
||||
|
||||
iweights = torch.bitwise_and(iweights, (2**bits) - 1)
|
||||
zeros = torch.bitwise_and(zeros, (2**bits) - 1)
|
||||
|
||||
scales = scales.repeat_interleave(group_size, dim=0)
|
||||
zeros = zeros.repeat_interleave(group_size, dim=0)
|
||||
return (iweights - zeros) * scales
|
||||
|
||||
|
||||
def sglang_awq_dequantize(
|
||||
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
return awq_dequantize(qweight, scales, qzeros)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"qweight_row,qweight_col,is_bf16_act",
|
||||
list(
|
||||
itertools.product(
|
||||
[3584, 18944, 128, 256, 512, 1024, 1536],
|
||||
[448, 576, 4736, 16, 32, 64, 128, 72],
|
||||
[True, False],
|
||||
)
|
||||
),
|
||||
)
|
||||
def test_awq_dequant_compare_implementations(
|
||||
qweight_row: int, qweight_col: int, is_bf16_act: bool
|
||||
):
|
||||
device = torch.device("cuda")
|
||||
qweight = torch.randint(
|
||||
0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(qweight_row, qweight_col),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
group_size = qweight_row
|
||||
scales_row = qweight_row // group_size
|
||||
scales_col = qweight_col * 8
|
||||
|
||||
if is_bf16_act:
|
||||
scales = torch.rand(scales_row, scales_col, dtype=torch.bfloat16, device=device)
|
||||
else:
|
||||
scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device)
|
||||
|
||||
qzeros = torch.randint(
|
||||
0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(scales_row, qweight_col),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Run both implementations
|
||||
torch_out = awq_dequantize_torch(qweight, scales, qzeros, group_size)
|
||||
sglang_out = sglang_awq_dequantize(qweight, scales, qzeros)
|
||||
|
||||
# Compare results
|
||||
torch.testing.assert_close(
|
||||
torch_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
43
sgl-kernel/tests/test_bmm_fp8.py
Normal file
43
sgl-kernel/tests/test_bmm_fp8.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/tests/test_bmm_fp8.py
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from sgl_kernel import bmm_fp8
|
||||
|
||||
|
||||
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||
finfo = torch.finfo(dtype)
|
||||
min_val, max_val = x.aminmax()
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
||||
scale = finfo.max / amax
|
||||
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
return x_scl_sat.to(dtype), scale.float().reciprocal()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("input_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
|
||||
@pytest.mark.parametrize("mat2_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
|
||||
@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16])
|
||||
def test_bmm_fp8(input_dtype, mat2_dtype, res_dtype):
|
||||
if input_dtype == torch.float8_e5m2 and mat2_dtype == torch.float8_e5m2:
|
||||
pytest.skip("Invalid combination: both input and mat2 are e5m2")
|
||||
|
||||
input = torch.randn([16, 48, 64], device="cuda", dtype=torch.bfloat16)
|
||||
input_fp8, input_inv_s = to_float8(input, dtype=input_dtype)
|
||||
|
||||
# mat2 row major -> column major
|
||||
mat2 = torch.randn([16, 80, 64], device="cuda", dtype=torch.bfloat16).transpose(
|
||||
-2, -1
|
||||
)
|
||||
mat2_fp8, mat2_inv_s = to_float8(mat2, dtype=mat2_dtype)
|
||||
|
||||
res = torch.empty([16, 48, 80], device="cuda", dtype=res_dtype)
|
||||
bmm_fp8(input_fp8, mat2_fp8, input_inv_s, mat2_inv_s, res_dtype, res)
|
||||
|
||||
reference = torch.bmm(input, mat2)
|
||||
cos_sim = F.cosine_similarity(reference.reshape(-1), res.reshape(-1), dim=0)
|
||||
assert cos_sim > 0.99
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
489
sgl-kernel/tests/test_causal_conv1d.py
Normal file
489
sgl-kernel/tests/test_causal_conv1d.py
Normal file
@@ -0,0 +1,489 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/tests/kernels/mamba/test_causal_conv1d.py
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from sgl_kernel import causal_conv1d_fwd
|
||||
from sgl_kernel import causal_conv1d_update as causal_conv1d_update_kernel
|
||||
|
||||
PAD_SLOT_ID = -1
|
||||
|
||||
|
||||
def causal_conv1d_fn(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
query_start_loc: Optional[torch.Tensor] = None,
|
||||
cache_indices: Optional[torch.Tensor] = None,
|
||||
has_initial_state: Optional[torch.Tensor] = None,
|
||||
conv_states: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = "silu",
|
||||
pad_slot_id: int = PAD_SLOT_ID,
|
||||
):
|
||||
"""
|
||||
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
|
||||
sequences are concatenated from left to right for varlen
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
query_start_loc: (batch + 1) int32
|
||||
The cumulative sequence lengths of the sequences in
|
||||
the batch, used to index into sequence. prepended by 0.
|
||||
for example: query_start_loc = torch.Tensor([0,10,16,17]),
|
||||
x.shape=(dim,17)
|
||||
cache_indices: (batch) int32
|
||||
indicates the corresponding state index,
|
||||
like so: conv_state = conv_states[cache_indices[batch_id]]
|
||||
has_initial_state: (batch) bool
|
||||
indicates whether should the kernel take the current state as initial
|
||||
state for the calculations
|
||||
conv_states: (...,dim,width - 1) itype
|
||||
updated inplace if provided
|
||||
activation: either None or "silu" or "swish"
|
||||
pad_slot_id: int
|
||||
if cache_indices is passed, lets the kernel identify padded
|
||||
entries that will not be processed,
|
||||
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
|
||||
in this case, the kernel will not process entries at
|
||||
indices 0 and 3
|
||||
out: (batch, dim, seqlen)
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError("activation must be None, silu, or swish")
|
||||
if x.stride(-1) != 1:
|
||||
x = x.contiguous()
|
||||
bias = bias.contiguous() if bias is not None else None
|
||||
|
||||
causal_conv1d_fwd(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
conv_states,
|
||||
query_start_loc,
|
||||
cache_indices,
|
||||
has_initial_state,
|
||||
activation in ["silu", "swish"],
|
||||
pad_slot_id,
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
def causal_conv1d_update(
|
||||
x: torch.Tensor,
|
||||
conv_state: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = None,
|
||||
cache_seqlens: Optional[torch.Tensor] = None,
|
||||
conv_state_indices: Optional[torch.Tensor] = None,
|
||||
pad_slot_id: int = PAD_SLOT_ID,
|
||||
):
|
||||
"""
|
||||
x: (batch, dim) or (batch, dim, seqlen)
|
||||
conv_state: (batch, dim, state_len), where state_len >= width - 1
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
cache_seqlens: (batch,), dtype int32.
|
||||
If not None, the conv_state is treated as a circular buffer.
|
||||
The conv_state will be updated by copying x to the conv_state
|
||||
starting at the index
|
||||
@cache_seqlens % state_len.
|
||||
conv_state_indices: (batch,), dtype int32
|
||||
If not None, the conv_state is a larger tensor along the batch dim,
|
||||
and we are selecting the batch coords specified by conv_state_indices.
|
||||
Useful for a continuous batching scenario.
|
||||
pad_slot_id: int
|
||||
if cache_indices is passed, lets the kernel identify padded
|
||||
entries that will not be processed,
|
||||
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
|
||||
in this case, the kernel will not process entries at
|
||||
indices 0 and 3
|
||||
out: (batch, dim) or (batch, dim, seqlen)
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError(
|
||||
f"activation must be None, silu, or swish, actual: {activation}"
|
||||
)
|
||||
activation_val = activation in ["silu", "swish"]
|
||||
unsqueeze = x.dim() == 2
|
||||
if unsqueeze:
|
||||
x = x.unsqueeze(-1)
|
||||
causal_conv1d_update_kernel(
|
||||
x,
|
||||
conv_state,
|
||||
weight,
|
||||
bias,
|
||||
activation_val,
|
||||
cache_seqlens,
|
||||
conv_state_indices,
|
||||
pad_slot_id,
|
||||
)
|
||||
if unsqueeze:
|
||||
x = x.squeeze(-1)
|
||||
return x
|
||||
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def causal_conv1d_ref(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
initial_states: Optional[torch.Tensor] = None,
|
||||
return_final_states: bool = False,
|
||||
final_states_out: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = "silu",
|
||||
):
|
||||
"""
|
||||
x: (batch, dim, seqlen)
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
initial_states: (batch, dim, width - 1)
|
||||
final_states_out: (batch, dim, width - 1)
|
||||
|
||||
out: (batch, dim, seqlen)
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError("activation must be None, silu, or swish")
|
||||
dtype_in = x.dtype
|
||||
x = x.to(weight.dtype)
|
||||
seqlen = x.shape[-1]
|
||||
dim, width = weight.shape
|
||||
if initial_states is None:
|
||||
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
|
||||
else:
|
||||
x = torch.cat([initial_states, x], dim=-1)
|
||||
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
|
||||
out = out[..., :seqlen]
|
||||
if return_final_states:
|
||||
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
|
||||
dtype_in
|
||||
) # (batch, dim, width - 1)
|
||||
if final_states_out is not None:
|
||||
final_states_out.copy_(final_states)
|
||||
else:
|
||||
final_states_out = final_states
|
||||
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
||||
return (out, None) if not return_final_states else (out, final_states_out)
|
||||
|
||||
|
||||
def causal_conv1d_update_ref(
|
||||
x, conv_state, weight, bias=None, activation=None, cache_seqlens=None
|
||||
):
|
||||
"""
|
||||
x: (batch, dim) or (batch, dim, seqlen)
|
||||
conv_state: (batch, dim, state_len), where state_len >= width - 1
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
cache_seqlens: (batch,), dtype int32.
|
||||
If not None, the conv_state is treated as a circular buffer.
|
||||
The conv_state will be updated by copying x to the
|
||||
conv_state starting at the index
|
||||
@cache_seqlens % state_len before performing the convolution.
|
||||
|
||||
out: (batch, dim) or (batch, dim, seqlen)
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError("activation must be None, silu, or swish")
|
||||
dtype_in = x.dtype
|
||||
unsqueeze = x.dim() == 2
|
||||
if unsqueeze:
|
||||
x = x.unsqueeze(-1)
|
||||
batch, dim, seqlen = x.shape
|
||||
width = weight.shape[1]
|
||||
state_len = conv_state.shape[-1]
|
||||
assert conv_state.shape == (batch, dim, state_len)
|
||||
assert weight.shape == (dim, width)
|
||||
if cache_seqlens is None:
|
||||
x_new = torch.cat([conv_state, x], dim=-1).to(
|
||||
weight.dtype
|
||||
) # (batch, dim, state_len + seqlen)
|
||||
conv_state.copy_(x_new[:, :, -state_len:])
|
||||
else:
|
||||
width_idx = torch.arange(
|
||||
-(width - 1), 0, dtype=torch.long, device=x.device
|
||||
).unsqueeze(0) + cache_seqlens.unsqueeze(1)
|
||||
width_idx = (
|
||||
torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
|
||||
)
|
||||
x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
|
||||
copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(
|
||||
0
|
||||
) + cache_seqlens.unsqueeze(1)
|
||||
copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
|
||||
conv_state.scatter_(2, copy_idx, x)
|
||||
out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[
|
||||
:, :, -seqlen:
|
||||
]
|
||||
if unsqueeze:
|
||||
out = out.squeeze(-1)
|
||||
return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
|
||||
@pytest.mark.parametrize("silu_activation", [True])
|
||||
@pytest.mark.parametrize("has_bias", [True])
|
||||
@pytest.mark.parametrize("has_initial_state", [True, False])
|
||||
@pytest.mark.parametrize("width", [4])
|
||||
@pytest.mark.parametrize(
|
||||
"seqlen", [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 1025, 2048, 4096]
|
||||
)
|
||||
@pytest.mark.parametrize("dim", [64])
|
||||
@pytest.mark.parametrize("batch", [1])
|
||||
def test_causal_conv1d(
|
||||
batch, dim, seqlen, width, has_bias, silu_activation, has_initial_state, itype
|
||||
):
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
x = torch.randn(batch, dim, seqlen, device=device, dtype=itype).contiguous()
|
||||
|
||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||
if has_initial_state:
|
||||
initial_states = torch.randn(batch, dim, width - 1, device=device, dtype=itype)
|
||||
has_initial_state_tensor = torch.ones(batch, dtype=torch.bool, device=x.device)
|
||||
else:
|
||||
initial_states = None
|
||||
has_initial_state_tensor = None
|
||||
x_ref = x.clone()
|
||||
weight_ref = weight.clone()
|
||||
bias_ref = bias.clone() if bias is not None else None
|
||||
initial_states_ref = initial_states.clone() if initial_states is not None else None
|
||||
activation = None if not silu_activation else "silu"
|
||||
out = causal_conv1d_fn(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
activation=activation,
|
||||
conv_states=initial_states,
|
||||
has_initial_state=has_initial_state_tensor,
|
||||
)
|
||||
out_ref, final_states_ref = causal_conv1d_ref(
|
||||
x_ref,
|
||||
weight_ref,
|
||||
bias_ref,
|
||||
initial_states=initial_states_ref,
|
||||
return_final_states=True,
|
||||
activation=activation,
|
||||
)
|
||||
if has_initial_state:
|
||||
assert initial_states is not None and final_states_ref is not None
|
||||
assert torch.allclose(initial_states, final_states_ref, rtol=rtol, atol=atol)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("silu_activation", [False, True])
|
||||
@pytest.mark.parametrize("has_bias", [False, True])
|
||||
@pytest.mark.parametrize("seqlen", [1])
|
||||
@pytest.mark.parametrize("width", [4])
|
||||
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
||||
def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, itype):
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
|
||||
batch = 2
|
||||
x = torch.randn(batch, dim, seqlen, device=device, dtype=itype)
|
||||
x_ref = x.clone()
|
||||
conv_state = torch.randn(batch, dim, width - 1, device=device, dtype=itype)
|
||||
|
||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||
conv_state_ref = conv_state.detach().clone()
|
||||
activation = None if not silu_activation else "silu"
|
||||
out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation)
|
||||
out_ref = causal_conv1d_update_ref(
|
||||
x_ref, conv_state_ref, weight, bias, activation=activation
|
||||
)
|
||||
|
||||
assert torch.equal(conv_state, conv_state_ref)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("silu_activation", [False, True])
|
||||
@pytest.mark.parametrize("has_bias", [False, True])
|
||||
@pytest.mark.parametrize("seqlen", [1, 4, 5])
|
||||
@pytest.mark.parametrize("width", [2, 3, 4])
|
||||
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
||||
# tests correctness in case subset of the sequences are padded
|
||||
@pytest.mark.parametrize("with_padding", [True, False])
|
||||
def test_causal_conv1d_update_with_batch_gather(
|
||||
with_padding, dim, width, seqlen, has_bias, silu_activation, itype
|
||||
):
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
|
||||
batch_size = 3
|
||||
padding = 5 if with_padding else 0
|
||||
padded_batch_size = batch_size + padding
|
||||
total_entries = 10 * batch_size
|
||||
|
||||
x = torch.randn(padded_batch_size, dim, 1, device=device, dtype=itype)
|
||||
x_ref = x.clone()
|
||||
|
||||
conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
|
||||
dtype=torch.int32, device=device
|
||||
)
|
||||
unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
|
||||
unused_states_bool[conv_state_indices] = False
|
||||
padded_state_indices = torch.concat(
|
||||
[
|
||||
conv_state_indices,
|
||||
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
conv_state = torch.randn(total_entries, dim, width - 1, device=device, dtype=itype)
|
||||
conv_state_for_padding_test = conv_state.clone()
|
||||
|
||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||
conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
|
||||
activation = None if not silu_activation else "silu"
|
||||
out = causal_conv1d_update(
|
||||
x,
|
||||
conv_state,
|
||||
weight,
|
||||
bias,
|
||||
activation=activation,
|
||||
conv_state_indices=padded_state_indices,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
)
|
||||
out_ref = causal_conv1d_update_ref(
|
||||
x_ref[:batch_size], conv_state_ref, weight, bias, activation=activation
|
||||
)
|
||||
|
||||
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
|
||||
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
|
||||
assert torch.equal(
|
||||
conv_state[unused_states_bool], conv_state_for_padding_test[unused_states_bool]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("silu_activation", [True])
|
||||
@pytest.mark.parametrize("has_bias", [True])
|
||||
@pytest.mark.parametrize("width", [4])
|
||||
@pytest.mark.parametrize(
|
||||
"seqlen", [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 2049, 4096]
|
||||
)
|
||||
@pytest.mark.parametrize("dim", [64, 4096])
|
||||
# tests correctness in case subset of the sequences are padded
|
||||
@pytest.mark.parametrize("with_padding", [True, False])
|
||||
def test_causal_conv1d_varlen(
|
||||
with_padding, dim, seqlen, width, has_bias, silu_activation, itype
|
||||
):
|
||||
device = "cuda"
|
||||
torch.cuda.empty_cache()
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
|
||||
seqlens = []
|
||||
batch_size = 4
|
||||
if seqlen < 10:
|
||||
batch_size = 1
|
||||
padding = 3 if with_padding else 0
|
||||
padded_batch_size = batch_size + padding
|
||||
nsplits = padded_batch_size - 1
|
||||
|
||||
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
|
||||
seqlens.append(
|
||||
torch.diff(
|
||||
torch.cat([torch.tensor([-1]), eos_pos, torch.tensor([seqlen - 1])])
|
||||
).tolist()
|
||||
)
|
||||
assert sum(seqlens[-1]) == seqlen
|
||||
assert all(s > 0 for s in seqlens[-1])
|
||||
|
||||
total_entries = batch_size * 10
|
||||
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
|
||||
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], dim=0)
|
||||
x = torch.randn(1, 4096 + dim + 64, seqlen, device=device, dtype=itype)[
|
||||
:, 4096 : 4096 + dim, :
|
||||
]
|
||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||
x_ref = x.clone()
|
||||
weight_ref = weight.clone()
|
||||
bias_ref = bias.clone() if bias is not None else None
|
||||
activation = None if not silu_activation else "silu"
|
||||
final_states = torch.randn(
|
||||
total_entries, dim, width - 1, device=x.device, dtype=x.dtype
|
||||
)
|
||||
final_states_ref = final_states.clone()
|
||||
has_initial_states = torch.randint(
|
||||
0, 2, (cumsum.shape[0] - 1,), dtype=torch.bool, device=x.device
|
||||
)
|
||||
state_indices = torch.randperm(total_entries, dtype=torch.int32, device=x.device)[
|
||||
:batch_size
|
||||
]
|
||||
padded_state_indices = torch.concat(
|
||||
[
|
||||
state_indices,
|
||||
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
out = causal_conv1d_fn(
|
||||
x.squeeze(0),
|
||||
weight,
|
||||
bias,
|
||||
cumsum.cuda(),
|
||||
padded_state_indices,
|
||||
has_initial_states,
|
||||
final_states,
|
||||
activation,
|
||||
PAD_SLOT_ID,
|
||||
)
|
||||
out_ref = []
|
||||
out_ref_b = []
|
||||
|
||||
splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)]
|
||||
for i in range(len(seqlens[0])):
|
||||
x_s = [v[i].unsqueeze(0) for v in splits][0]
|
||||
if padded_state_indices[i] == PAD_SLOT_ID:
|
||||
continue
|
||||
out_ref_b.append(
|
||||
causal_conv1d_ref(
|
||||
x_s,
|
||||
weight_ref,
|
||||
bias_ref,
|
||||
activation=activation,
|
||||
return_final_states=True,
|
||||
final_states_out=final_states_ref[padded_state_indices[i]].unsqueeze(0),
|
||||
initial_states=(
|
||||
final_states_ref[padded_state_indices[i]].unsqueeze(0)
|
||||
if has_initial_states[i]
|
||||
else None
|
||||
),
|
||||
)
|
||||
)
|
||||
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2))
|
||||
out_ref_tensor = torch.cat(out_ref, dim=0)
|
||||
|
||||
unpadded_out = out[:, : out_ref_tensor.shape[-1]]
|
||||
assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
|
||||
assert torch.allclose(
|
||||
final_states[state_indices],
|
||||
final_states_ref[state_indices],
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
185
sgl-kernel/tests/test_custom_allreduce.py
Normal file
185
sgl-kernel/tests/test_custom_allreduce.py
Normal file
@@ -0,0 +1,185 @@
|
||||
import ctypes
|
||||
import multiprocessing as mp
|
||||
import random
|
||||
import socket
|
||||
import unittest
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import sgl_kernel.allreduce as custom_ops
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
||||
|
||||
|
||||
def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes):
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
init_method=distributed_init_method,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
group = dist.group.WORLD
|
||||
|
||||
try:
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
max_size = 8192 * 1024
|
||||
meta_ptrs = TestCustomAllReduce.create_shared_buffer(
|
||||
custom_ops.meta_size() + max_size, group=group
|
||||
)
|
||||
|
||||
rank_data = torch.empty(8 * 1024 * 1024, dtype=torch.uint8, device=device)
|
||||
buffer_ptrs = TestCustomAllReduce.create_shared_buffer(max_size, group=group)
|
||||
|
||||
custom_ptr = custom_ops.init_custom_ar(meta_ptrs, rank_data, rank, True)
|
||||
custom_ops.register_buffer(custom_ptr, buffer_ptrs)
|
||||
|
||||
test_loop = 10
|
||||
for sz in test_sizes:
|
||||
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
for _ in range(test_loop):
|
||||
inp1 = torch.randint(1, 16, (sz,), dtype=dtype, device=device)
|
||||
inp1_ref = inp1.clone()
|
||||
out1 = torch.empty_like(inp1)
|
||||
|
||||
custom_ops.all_reduce(
|
||||
custom_ptr, inp1, out1, buffer_ptrs[rank], max_size
|
||||
)
|
||||
|
||||
dist.all_reduce(inp1_ref, group=group)
|
||||
|
||||
torch.testing.assert_close(out1, inp1_ref)
|
||||
|
||||
finally:
|
||||
dist.barrier(group=group)
|
||||
if custom_ptr is not None:
|
||||
custom_ops.dispose(custom_ptr)
|
||||
if buffer_ptrs:
|
||||
TestCustomAllReduce.free_shared_buffer(buffer_ptrs, group)
|
||||
if meta_ptrs:
|
||||
TestCustomAllReduce.free_shared_buffer(meta_ptrs, group)
|
||||
|
||||
dist.destroy_process_group(group=group)
|
||||
|
||||
|
||||
def get_open_port() -> int:
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1]
|
||||
except OSError:
|
||||
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
|
||||
s.bind(("::1", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def multi_process_parallel(
|
||||
world_size: int, test_target: Any, target_args: tuple = ()
|
||||
) -> None:
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
procs = []
|
||||
distributed_init_port = get_open_port()
|
||||
for i in range(world_size):
|
||||
proc_args = (world_size, i, distributed_init_port) + target_args
|
||||
proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}")
|
||||
proc.start()
|
||||
procs.append(proc)
|
||||
|
||||
for i in range(world_size):
|
||||
procs[i].join()
|
||||
assert (
|
||||
procs[i].exitcode == 0
|
||||
), f"Process {i} failed with exit code {procs[i].exitcode}"
|
||||
|
||||
|
||||
class TestCustomAllReduce(unittest.TestCase):
|
||||
test_sizes = [
|
||||
512,
|
||||
2560,
|
||||
4096,
|
||||
5120,
|
||||
7680,
|
||||
32768,
|
||||
262144,
|
||||
524288,
|
||||
1048576,
|
||||
2097152,
|
||||
]
|
||||
world_sizes = [2, 4, 8]
|
||||
|
||||
@staticmethod
|
||||
def create_shared_buffer(
|
||||
size_in_bytes: int, group: Optional[ProcessGroup] = None
|
||||
) -> List[int]:
|
||||
lib = CudaRTLibrary()
|
||||
pointer = lib.cudaMalloc(size_in_bytes)
|
||||
handle = lib.cudaIpcGetMemHandle(pointer)
|
||||
if group is None:
|
||||
group = dist.group.WORLD
|
||||
world_size = dist.get_world_size(group=group)
|
||||
rank = dist.get_rank(group=group)
|
||||
|
||||
handle_bytes = ctypes.string_at(ctypes.addressof(handle), ctypes.sizeof(handle))
|
||||
input_tensor = torch.ByteTensor(list(handle_bytes)).to(f"cuda:{rank}")
|
||||
gathered_tensors = [torch.empty_like(input_tensor) for _ in range(world_size)]
|
||||
dist.all_gather(gathered_tensors, input_tensor, group=group)
|
||||
|
||||
handles = []
|
||||
handle_type = type(handle)
|
||||
for tensor in gathered_tensors:
|
||||
bytes_list = tensor.cpu().tolist()
|
||||
bytes_data = bytes(bytes_list)
|
||||
handle_obj = handle_type()
|
||||
ctypes.memmove(ctypes.addressof(handle_obj), bytes_data, len(bytes_data))
|
||||
handles.append(handle_obj)
|
||||
|
||||
pointers: List[int] = []
|
||||
for i, h in enumerate(handles):
|
||||
if i == rank:
|
||||
pointers.append(pointer.value)
|
||||
else:
|
||||
try:
|
||||
opened_ptr = lib.cudaIpcOpenMemHandle(h)
|
||||
pointers.append(opened_ptr.value)
|
||||
except Exception as e:
|
||||
print(f"Rank {rank}: Failed to open IPC handle from rank {i}: {e}")
|
||||
raise
|
||||
|
||||
dist.barrier(group=group)
|
||||
return pointers
|
||||
|
||||
@staticmethod
|
||||
def free_shared_buffer(
|
||||
pointers: List[int], group: Optional[ProcessGroup] = None
|
||||
) -> None:
|
||||
if group is None:
|
||||
group = dist.group.WORLD
|
||||
rank = dist.get_rank(group=group)
|
||||
lib = CudaRTLibrary()
|
||||
if pointers and len(pointers) > rank and pointers[rank] is not None:
|
||||
lib.cudaFree(ctypes.c_void_p(pointers[rank]))
|
||||
dist.barrier(group=group)
|
||||
|
||||
def test_correctness(self):
|
||||
for world_size in self.world_sizes:
|
||||
available_gpus = torch.cuda.device_count()
|
||||
if world_size > available_gpus:
|
||||
print(
|
||||
f"Skipping world_size={world_size}, requires {world_size} GPUs, found {available_gpus}"
|
||||
)
|
||||
continue
|
||||
|
||||
print(f"Running test for world_size={world_size}")
|
||||
multi_process_parallel(
|
||||
world_size, _run_correctness_worker, target_args=(self.test_sizes,)
|
||||
)
|
||||
print(f"custom allreduce tp = {world_size}: OK")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
104
sgl-kernel/tests/test_cutlass_mla.py
Normal file
104
sgl-kernel/tests/test_cutlass_mla.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size
|
||||
from torch import Tensor
|
||||
|
||||
# Disable tests on SM103 until the accuracy issues are fixed.
|
||||
if torch.cuda.get_device_capability() != (10, 0):
|
||||
pytest.skip(
|
||||
reason="Cutlass MLA Requires compute capability of 10.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
|
||||
def ref_mla(
|
||||
out: Tensor, # (bs, num_heads, v_head_dim)
|
||||
query: Tensor, # (bs, num_heads, head_dim)
|
||||
kv_cache: Tensor, # (num_blocks, block_size, head_dim)
|
||||
scale: float,
|
||||
block_tables: Tensor, # (bs, max_num_blocks)
|
||||
seq_lens: Tensor, # (bs,)
|
||||
):
|
||||
bs, num_heads, v_head_dim = out.shape
|
||||
head_dim = query.shape[2]
|
||||
|
||||
for i in range(bs):
|
||||
# gather and flatten KV-cache
|
||||
kv = kv_cache[block_tables[i]] # (max_num_blocks, block_size, head_dim)
|
||||
kv = kv.view(1, -1, head_dim)[:, : seq_lens[i]] # (1, seq_len, head_dim)
|
||||
v = kv[:, :, :v_head_dim]
|
||||
|
||||
q = query[i].view(num_heads, 1, head_dim)
|
||||
o = F.scaled_dot_product_attention(q, kv, v, scale=scale, enable_gqa=True)
|
||||
out[i] = o.view(num_heads, v_head_dim)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("mean_seq_len", [128, 1024, 4096])
|
||||
@pytest.mark.parametrize("bs", [1, 2, 4])
|
||||
@pytest.mark.parametrize("varlen", [False, True])
|
||||
@pytest.mark.parametrize("block_size", [1, 16, 64, 128])
|
||||
@pytest.mark.parametrize("num_heads", [16, 32, 64, 128])
|
||||
@pytest.mark.parametrize("num_kv_splits", [-1, 1])
|
||||
def test_cutlass_mla_decode(
|
||||
dtype: torch.dtype,
|
||||
mean_seq_len: int,
|
||||
bs: int,
|
||||
varlen: bool,
|
||||
block_size: int,
|
||||
num_heads: int,
|
||||
num_kv_splits: int,
|
||||
):
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.set_default_device("cuda")
|
||||
torch.manual_seed(42)
|
||||
|
||||
d = 576
|
||||
h_q = num_heads
|
||||
dv = 512
|
||||
|
||||
q_nope_dim = 128
|
||||
q_pe_dim = 64
|
||||
scale = (q_nope_dim + q_pe_dim) ** (-0.5)
|
||||
if varlen:
|
||||
seq_lens = torch.empty(bs).normal_(mean_seq_len, mean_seq_len / 2)
|
||||
seq_lens = seq_lens.clip(2).to(torch.int32)
|
||||
else:
|
||||
seq_lens = torch.full((bs,), mean_seq_len, dtype=torch.int32)
|
||||
max_seq_len = seq_lens.max().item()
|
||||
block_num = (max_seq_len + block_size - 1) // block_size
|
||||
|
||||
# Pad block_num so that small blocks can be packed into full 128-sized CUTLASS tiles.
|
||||
# One 128-wide tile can hold (128 // block_size) small blocks.
|
||||
pack_factor = 128 // block_size
|
||||
block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor
|
||||
|
||||
# Lager q values to detect split kv error
|
||||
q = torch.randn(bs, h_q, d) * 100.0
|
||||
block_table = torch.randint(0, bs * block_num, (bs, block_num), dtype=torch.int32)
|
||||
|
||||
kv_cache = torch.randn(block_table.numel(), block_size, d)
|
||||
|
||||
workspace_size = cutlass_mla_get_workspace_size(
|
||||
block_num * block_size, bs, num_kv_splits=num_kv_splits
|
||||
)
|
||||
workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8)
|
||||
|
||||
q_nope = torch.empty((h_q, bs, dv)).transpose(0, 1)
|
||||
q_nope.copy_(q[:, :, :dv])
|
||||
q_pe = q[:, :, dv:].clone()
|
||||
|
||||
out_ref = q.new_zeros(bs, h_q, dv)
|
||||
ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens)
|
||||
out = cutlass_mla_decode(
|
||||
q_nope, q_pe, kv_cache, seq_lens, block_table, workspace, scale, num_kv_splits
|
||||
)
|
||||
|
||||
torch.testing.assert_close(out, out_ref, atol=1e-2, rtol=1e-2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
283
sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py
Normal file
283
sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py
Normal file
@@ -0,0 +1,283 @@
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import cutlass_w4a8_moe_mm, sgl_per_tensor_quant_fp8
|
||||
from utils import is_hopper
|
||||
|
||||
|
||||
def pack_int4_values_to_int8(int4_values_interleaved: torch.Tensor) -> torch.Tensor:
|
||||
if int4_values_interleaved.shape[-1] % 2 != 0:
|
||||
raise ValueError(
|
||||
"the last dim size of int4_values_interleaved tensor must be even."
|
||||
)
|
||||
|
||||
input_tensor_int8 = int4_values_interleaved.to(torch.int8)
|
||||
|
||||
low_nibbles = input_tensor_int8[..., 0::2]
|
||||
high_nibbles = input_tensor_int8[..., 1::2]
|
||||
|
||||
packed_tensor = (high_nibbles << 4) | (low_nibbles & 0x0F)
|
||||
|
||||
return packed_tensor.to(torch.int8)
|
||||
|
||||
|
||||
def pack_interleave(num_experts, ref_weight, ref_scale):
|
||||
n, k = ref_weight.shape[1], ref_weight.shape[2]
|
||||
|
||||
weight = pack_int4_values_to_int8(ref_weight.cpu()).cuda()
|
||||
w_q = weight.view((num_experts, n, k // 2)).view(torch.int8)
|
||||
w_q = w_q.contiguous()
|
||||
|
||||
alignment = 4 if k % 512 == 0 else 1
|
||||
scale_interleaved = ref_scale.reshape(
|
||||
ref_scale.shape[0],
|
||||
ref_scale.shape[1],
|
||||
(ref_scale.shape[2] // alignment),
|
||||
alignment,
|
||||
) # [E, N, K/4, 4]
|
||||
scale_interleaved = scale_interleaved.permute(0, 2, 1, 3) # [E, K/4, N, 4]
|
||||
scale_interleaved = scale_interleaved.reshape(
|
||||
ref_scale.shape[0],
|
||||
ref_scale.shape[2] // alignment,
|
||||
ref_scale.shape[1] * alignment,
|
||||
) # [E, K/4, N*4]
|
||||
w_scale = scale_interleaved.contiguous()
|
||||
|
||||
return w_q, w_scale
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_hopper(),
|
||||
reason="cutlass_w4a8_moe_mm is only supported on sm90",
|
||||
)
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16])
|
||||
def test_int4_fp8_grouped_gemm_single_expert(batch_size):
|
||||
# Test parameters
|
||||
num_experts = 1
|
||||
m = batch_size # batch size
|
||||
k = 512 # input dimension
|
||||
n = 1024 # output dimension
|
||||
torch.manual_seed(0)
|
||||
dtype = torch.bfloat16
|
||||
device = "cuda"
|
||||
debug = False
|
||||
|
||||
print(f"\nTesting with batch_size={batch_size}")
|
||||
|
||||
# Create input tensors with ones
|
||||
if debug:
|
||||
a = torch.ones(m, k, dtype=torch.bfloat16, device=device)
|
||||
ref_w = torch.ones(num_experts, n, k, dtype=torch.int8, device=device)
|
||||
ref_w_scale = torch.ones(num_experts, n, k // 128, dtype=dtype, device=device)
|
||||
else:
|
||||
a = torch.randn(m, k, dtype=dtype, device=device)
|
||||
ref_w = torch.randint(
|
||||
-8, 8, (num_experts, n, k), dtype=torch.int8, device=device
|
||||
)
|
||||
affine_coeff = 0.005
|
||||
ref_w_scale = (
|
||||
torch.randn(num_experts, n, k // 128, dtype=dtype, device=device)
|
||||
* affine_coeff
|
||||
)
|
||||
|
||||
w, w_scale = pack_interleave(num_experts, ref_w, ref_w_scale)
|
||||
|
||||
# Create expert offsets and problem sizes
|
||||
expert_offsets = torch.tensor([0, m], dtype=torch.int32, device=device)
|
||||
problem_sizes = torch.tensor([[n, m, k]], dtype=torch.int32, device=device)
|
||||
|
||||
a_strides = torch.full((num_experts, 3), k, device=device, dtype=torch.int64)
|
||||
c_strides = torch.full((num_experts, 3), n, device=device, dtype=torch.int64)
|
||||
b_strides = a_strides
|
||||
s_strides = c_strides
|
||||
|
||||
# Quantize input
|
||||
a_q, a_scale = _per_tensor_quant_fp8(a)
|
||||
|
||||
# Create output tensor
|
||||
c = torch.empty((m, n), dtype=torch.bfloat16, device=device)
|
||||
cutlass_w4a8_moe_mm(
|
||||
c,
|
||||
a_q,
|
||||
w,
|
||||
a_scale,
|
||||
w_scale,
|
||||
expert_offsets[:-1],
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
c_strides,
|
||||
s_strides,
|
||||
128,
|
||||
8,
|
||||
)
|
||||
c = c.to(dtype)
|
||||
|
||||
# Reference implementation
|
||||
experts_selection_result = torch.full((m,), 0)
|
||||
c_ref = ref_grouped_gemm(
|
||||
c, a_q, a_scale, ref_w, ref_w_scale, num_experts, experts_selection_result
|
||||
)
|
||||
|
||||
# Compare results
|
||||
try:
|
||||
torch.testing.assert_close(c, c_ref, rtol=1e-2, atol=0.1)
|
||||
except AssertionError as e:
|
||||
# torch.set_printoptions(threshold=10_000)
|
||||
print(f" FAILURE: tensors are NOT close.")
|
||||
print(f" Ref tensor: {c_ref.flatten()}")
|
||||
print(f" Cutlass tensor: {c.flatten()}")
|
||||
print(
|
||||
f" Max absolute difference: {torch.max(torch.abs(c.to(c_ref.dtype) - c_ref))}"
|
||||
)
|
||||
print(
|
||||
f" Mean absolute difference: {torch.mean(torch.abs(c.to(c_ref.dtype) - c_ref))}"
|
||||
)
|
||||
print(f" AssertionError: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def _per_tensor_quant_fp8(
|
||||
x: torch.Tensor,
|
||||
dtype: torch.dtype = torch.float8_e4m3fn,
|
||||
):
|
||||
assert x.is_contiguous(), "`x` is not contiguous"
|
||||
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
||||
x_s = torch.empty(
|
||||
1,
|
||||
device=x.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
sgl_per_tensor_quant_fp8(x, x_q, x_s, is_static=False)
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_hopper(),
|
||||
reason="cutlass_w4a8_moe_mm is only supported on sm90",
|
||||
)
|
||||
@pytest.mark.parametrize("batch_size", [2, 4, 8, 16, 32])
|
||||
@pytest.mark.parametrize("k", [512, 1024, 2048, 4096, 7168])
|
||||
@pytest.mark.parametrize("n", [256, 512, 1024, 2048])
|
||||
@pytest.mark.parametrize("num_experts", [2, 4, 6, 8])
|
||||
def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
|
||||
torch.manual_seed(0)
|
||||
dtype = torch.bfloat16
|
||||
device = "cuda"
|
||||
debug = False
|
||||
|
||||
print(
|
||||
f"\nTesting with batch_size={batch_size}, k={k}, n={n}, num_experts={num_experts}"
|
||||
)
|
||||
|
||||
if debug:
|
||||
a = torch.ones(batch_size, k, dtype=torch.bfloat16, device=device)
|
||||
ref_w = torch.ones(num_experts, n, k, dtype=torch.int8, device=device)
|
||||
ref_w_scale = torch.ones(num_experts, n, k // 128, dtype=dtype, device=device)
|
||||
else:
|
||||
a = torch.randn(batch_size, k, dtype=dtype, device=device)
|
||||
ref_w = torch.randint(
|
||||
-8, 8, (num_experts, n, k), dtype=torch.int8, device=device
|
||||
)
|
||||
affine_coeff = 0.005
|
||||
ref_w_scale = (
|
||||
torch.randn(num_experts, n, k // 128, dtype=dtype, device=device)
|
||||
* affine_coeff
|
||||
)
|
||||
|
||||
w, w_scale = pack_interleave(num_experts, ref_w, ref_w_scale)
|
||||
|
||||
# random select experts
|
||||
experts_selection_result = torch.randint(
|
||||
0, num_experts, (batch_size,), device=device
|
||||
)
|
||||
permutation = torch.argsort(experts_selection_result)
|
||||
expert_token_counts = torch.bincount(
|
||||
experts_selection_result, minlength=num_experts
|
||||
)
|
||||
|
||||
# Create problem sizes and offsets for active experts
|
||||
problem_sizes = []
|
||||
for i in range(num_experts):
|
||||
problem_sizes.append([n, expert_token_counts[i].item(), k])
|
||||
problem_sizes = torch.tensor(problem_sizes, dtype=torch.int32, device=device)
|
||||
|
||||
expert_offsets = []
|
||||
offset = 0
|
||||
for i in range(num_experts):
|
||||
expert_offsets.append(offset)
|
||||
offset += problem_sizes[i][1].item()
|
||||
expert_offsets = torch.tensor(expert_offsets, dtype=torch.int32, device=device)
|
||||
|
||||
# Permute input and quantize
|
||||
a_q, a_scale = _per_tensor_quant_fp8(a)
|
||||
a_q_perm = a_q[permutation]
|
||||
|
||||
# Create stride tensors
|
||||
a_strides = torch.full((num_experts, 3), k, device=device, dtype=torch.int64)
|
||||
c_strides = torch.full((num_experts, 3), n, device=device, dtype=torch.int64)
|
||||
b_strides = a_strides
|
||||
s_strides = c_strides
|
||||
|
||||
c_perm = torch.empty((batch_size, n), dtype=torch.bfloat16, device=device)
|
||||
cutlass_w4a8_moe_mm(
|
||||
c_perm,
|
||||
a_q_perm,
|
||||
w,
|
||||
a_scale,
|
||||
w_scale,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
c_strides,
|
||||
s_strides,
|
||||
128,
|
||||
8,
|
||||
)
|
||||
|
||||
# Un-permute the result
|
||||
c = torch.empty_like(c_perm)
|
||||
c[permutation] = c_perm
|
||||
c = c.to(dtype)
|
||||
|
||||
c_ref = ref_grouped_gemm(
|
||||
c, a_q, a_scale, ref_w, ref_w_scale, num_experts, experts_selection_result
|
||||
)
|
||||
|
||||
# Compare results
|
||||
try:
|
||||
torch.testing.assert_close(c, c_ref, rtol=1e-2, atol=0.1)
|
||||
except AssertionError as e:
|
||||
print(f" FAILURE: tensors are NOT close.")
|
||||
print(
|
||||
f" Max absolute difference: {torch.max(torch.abs(c.to(c_ref.dtype) - c_ref))}"
|
||||
)
|
||||
print(
|
||||
f" Mean absolute difference: {torch.mean(torch.abs(c.to(c_ref.dtype) - c_ref))}"
|
||||
)
|
||||
print(f" AssertionError: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def ref_grouped_gemm(
|
||||
c, a_q, a_scale, w, w_scale, num_experts, experts_selection_result
|
||||
):
|
||||
dtype = torch.bfloat16
|
||||
c_ref = torch.zeros_like(c)
|
||||
for i in range(num_experts):
|
||||
token_idx = torch.where(experts_selection_result == i)[0]
|
||||
if len(token_idx) == 0:
|
||||
continue
|
||||
a = a_q[token_idx]
|
||||
|
||||
ref_w_scale_repeat = w_scale[i].repeat_interleave(128, dim=1).to(torch.float32)
|
||||
ref_w = w[i].to(torch.float32) * ref_w_scale_repeat
|
||||
c = torch.matmul(a.to(torch.float32), ref_w.t()) * a_scale
|
||||
c_ref[token_idx] = c.to(dtype)
|
||||
|
||||
return c_ref
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
32
sgl-kernel/tests/test_dsv3_fused_a_gemm.py
Normal file
32
sgl-kernel/tests/test_dsv3_fused_a_gemm.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from sgl_kernel import dsv3_fused_a_gemm
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [i + 1 for i in range(16)])
|
||||
def test_dsv3_fused_a_gemm(num_tokens):
|
||||
kHdIn = 7168
|
||||
kHdOut = 2112
|
||||
|
||||
mat_a = torch.randn(
|
||||
(num_tokens, kHdIn), dtype=torch.bfloat16, device="cuda"
|
||||
).contiguous()
|
||||
mat_b = torch.randn((kHdOut, kHdIn), dtype=torch.bfloat16, device="cuda").transpose(
|
||||
0, 1
|
||||
)
|
||||
output = torch.empty(
|
||||
(num_tokens, kHdOut), dtype=torch.bfloat16, device="cuda"
|
||||
).contiguous()
|
||||
|
||||
ref = F.linear(mat_a, mat_b.T)
|
||||
|
||||
output = dsv3_fused_a_gemm(mat_a, mat_b)
|
||||
|
||||
assert torch.allclose(
|
||||
output, ref, rtol=1e-2, atol=1e-3
|
||||
), "Fused GEMM output mismatch with torch.nn.functional.linear reference"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
35
sgl-kernel/tests/test_dsv3_router_gemm.py
Normal file
35
sgl-kernel/tests/test_dsv3_router_gemm.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from sgl_kernel import dsv3_router_gemm
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [i + 1 for i in range(16)])
|
||||
@pytest.mark.parametrize("num_experts", [256, 384])
|
||||
def test_dsv3_router_gemm(num_tokens, num_experts):
|
||||
hidden_dim = 7168
|
||||
|
||||
mat_a = torch.randn(
|
||||
(num_tokens, hidden_dim), dtype=torch.bfloat16, device="cuda"
|
||||
).contiguous()
|
||||
mat_b = torch.randn(
|
||||
(num_experts, hidden_dim), dtype=torch.bfloat16, device="cuda"
|
||||
).contiguous()
|
||||
|
||||
bf16_ref = F.linear(mat_a, mat_b)
|
||||
float_ref = bf16_ref.to(torch.float32)
|
||||
|
||||
bf16_output = dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.bfloat16)
|
||||
float_output = dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.float32)
|
||||
|
||||
assert torch.allclose(
|
||||
bf16_output, bf16_ref, rtol=1e-2, atol=1e-3
|
||||
), "Router GEMM output in bf16 dtype mismatch with torch.nn.functional.linear reference"
|
||||
|
||||
assert torch.allclose(
|
||||
float_output, float_ref, rtol=1e-2, atol=1e-3
|
||||
), "Router GEMM output in float32 dtype mismatch with torch.nn.functional.linear reference"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
1368
sgl-kernel/tests/test_flash_attention.py
Normal file
1368
sgl-kernel/tests/test_flash_attention.py
Normal file
File diff suppressed because it is too large
Load Diff
877
sgl-kernel/tests/test_flash_attention_4.py
Normal file
877
sgl-kernel/tests/test_flash_attention_4.py
Normal file
@@ -0,0 +1,877 @@
|
||||
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/b31ae1e4cd22cf5f820a2995b74b7cd3bd54355a/tests/cute/test_flash_attn.py
|
||||
|
||||
# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
||||
|
||||
import itertools
|
||||
import math
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from sgl_kernel.flash_attn import flash_attn_varlen_func
|
||||
from utils import is_hopper
|
||||
|
||||
flash_attn_varlen_func = partial(flash_attn_varlen_func, ver=4)
|
||||
|
||||
|
||||
def unpad_input(hidden_states, attention_mask, unused_mask=None):
|
||||
"""
|
||||
Arguments:
|
||||
hidden_states: (batch, seqlen, ...)
|
||||
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
|
||||
unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
|
||||
Return:
|
||||
hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
|
||||
indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
|
||||
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
|
||||
max_seqlen_in_batch: int
|
||||
seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
|
||||
"""
|
||||
all_masks = (
|
||||
(attention_mask + unused_mask) if unused_mask is not None else attention_mask
|
||||
)
|
||||
seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
|
||||
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||
indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
|
||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
||||
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
||||
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
|
||||
# index with integer indices.
|
||||
return (
|
||||
rearrange(hidden_states, "b s ... -> (b s) ...")[indices],
|
||||
indices,
|
||||
cu_seqlens,
|
||||
max_seqlen_in_batch,
|
||||
used_seqlens_in_batch,
|
||||
)
|
||||
|
||||
|
||||
def pad_input(hidden_states, indices, batch, seqlen):
|
||||
"""
|
||||
Arguments:
|
||||
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
||||
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
|
||||
batch: int, batch size for the padded sequence.
|
||||
seqlen: int, maximum sequence length for the padded sequence.
|
||||
Return:
|
||||
hidden_states: (batch, seqlen, ...)
|
||||
"""
|
||||
dim = hidden_states.shape[1:]
|
||||
output = torch.zeros(
|
||||
(batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype
|
||||
)
|
||||
output[indices] = hidden_states
|
||||
return rearrange(output, "(b s) ... -> b s ...", b=batch)
|
||||
|
||||
|
||||
def generate_random_padding_mask(
|
||||
max_seqlen, batch_size, device, mode="random", zero_lengths=False
|
||||
):
|
||||
assert mode in ["full", "random", "third"]
|
||||
if mode == "full":
|
||||
lengths = torch.full(
|
||||
(batch_size, 1), max_seqlen, device=device, dtype=torch.int32
|
||||
)
|
||||
elif mode == "random":
|
||||
lengths = torch.randint(
|
||||
max(0 if zero_lengths else 1, max_seqlen - 20),
|
||||
max_seqlen + 1,
|
||||
(batch_size, 1),
|
||||
device=device,
|
||||
)
|
||||
elif mode == "third":
|
||||
lengths = torch.randint(
|
||||
max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device
|
||||
)
|
||||
|
||||
if zero_lengths:
|
||||
# Generate zero-lengths every 5 batches and the last batch.
|
||||
for i in range(batch_size):
|
||||
if i % 5 == 0:
|
||||
lengths[i] = 0
|
||||
lengths[-1] = 0
|
||||
padding_mask = (
|
||||
repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size)
|
||||
< lengths
|
||||
)
|
||||
return padding_mask
|
||||
|
||||
|
||||
def generate_qkv(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
query_padding_mask=None,
|
||||
key_padding_mask=None,
|
||||
qv=None,
|
||||
kvpacked=False,
|
||||
qkvpacked=False,
|
||||
query_unused_mask=None,
|
||||
key_unused_mask=None,
|
||||
):
|
||||
"""
|
||||
Arguments:
|
||||
q: (batch_size, seqlen_q, nheads, d)
|
||||
k: (batch_size, seqlen_k, nheads_k, d)
|
||||
v: (batch_size, seqlen_k, nheads_k, d_v)
|
||||
query_padding_mask: (batch_size, seqlen), bool
|
||||
key_padding_mask: (batch_size, seqlen), bool
|
||||
"""
|
||||
assert not (kvpacked and qkvpacked)
|
||||
batch_size, seqlen_q, nheads, d = q.shape
|
||||
d_v = v.shape[-1]
|
||||
_, seqlen_k, nheads_k, _ = k.shape
|
||||
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||
assert v.shape == (batch_size, seqlen_k, nheads_k, d_v)
|
||||
if query_unused_mask is not None or key_unused_mask is not None:
|
||||
assert not kvpacked
|
||||
assert not qkvpacked
|
||||
|
||||
if query_padding_mask is not None:
|
||||
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input(
|
||||
q, query_padding_mask, query_unused_mask
|
||||
)
|
||||
output_pad_fn = lambda output_unpad: pad_input(
|
||||
output_unpad, indices_q, batch_size, seqlen_q
|
||||
)
|
||||
qv_unpad = (
|
||||
rearrange(qv, "b s ... -> (b s) ...")[indices_q] if qv is not None else None
|
||||
)
|
||||
else:
|
||||
q_unpad = rearrange(q, "b s h d -> (b s) h d")
|
||||
cu_seqlens_q = torch.arange(
|
||||
0,
|
||||
(batch_size + 1) * seqlen_q,
|
||||
step=seqlen_q,
|
||||
dtype=torch.int32,
|
||||
device=q_unpad.device,
|
||||
)
|
||||
seqused_q = None
|
||||
max_seqlen_q = seqlen_q
|
||||
output_pad_fn = lambda output_unpad: rearrange(
|
||||
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
||||
)
|
||||
qv_unpad = rearrange(qv, "b s ... -> (b s) ...") if qv is not None else None
|
||||
|
||||
if key_padding_mask is not None:
|
||||
k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input(
|
||||
k, key_padding_mask, key_unused_mask
|
||||
)
|
||||
v_unpad, *rest = unpad_input(v, key_padding_mask, key_unused_mask)
|
||||
else:
|
||||
k_unpad = rearrange(k, "b s h d -> (b s) h d")
|
||||
v_unpad = rearrange(v, "b s h d -> (b s) h d")
|
||||
cu_seqlens_k = torch.arange(
|
||||
0,
|
||||
(batch_size + 1) * seqlen_k,
|
||||
step=seqlen_k,
|
||||
dtype=torch.int32,
|
||||
device=k_unpad.device,
|
||||
)
|
||||
seqused_k = None
|
||||
max_seqlen_k = seqlen_k
|
||||
|
||||
if qkvpacked:
|
||||
assert (query_padding_mask == key_padding_mask).all()
|
||||
assert nheads == nheads_k
|
||||
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
|
||||
qkv = torch.stack([q, k, v], dim=2)
|
||||
if query_padding_mask is not None:
|
||||
dqkv_pad_fn = lambda dqkv_unpad: pad_input(
|
||||
dqkv_unpad, indices_q, batch_size, seqlen_q
|
||||
)
|
||||
else:
|
||||
dqkv_pad_fn = lambda dqkv_unpad: rearrange(
|
||||
dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
|
||||
)
|
||||
return (
|
||||
qkv_unpad.detach().requires_grad_(),
|
||||
cu_seqlens_q,
|
||||
max_seqlen_q,
|
||||
qkv.detach().requires_grad_(),
|
||||
output_pad_fn,
|
||||
dqkv_pad_fn,
|
||||
)
|
||||
elif kvpacked:
|
||||
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
|
||||
kv = torch.stack([k, v], dim=2)
|
||||
dq_pad_fn = output_pad_fn
|
||||
if key_padding_mask is not None:
|
||||
dkv_pad_fn = lambda dkv_unpad: pad_input(
|
||||
dkv_unpad, indices_k, batch_size, seqlen_k
|
||||
)
|
||||
else:
|
||||
dkv_pad_fn = lambda dkv_unpad: rearrange(
|
||||
dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
|
||||
)
|
||||
return (
|
||||
q_unpad.detach().requires_grad_(),
|
||||
kv_unpad.detach().requires_grad_(),
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
q.detach().requires_grad_(),
|
||||
kv.detach().requires_grad_(),
|
||||
output_pad_fn,
|
||||
dq_pad_fn,
|
||||
dkv_pad_fn,
|
||||
)
|
||||
else:
|
||||
dq_pad_fn = output_pad_fn
|
||||
if key_padding_mask is not None:
|
||||
dk_pad_fn = lambda dk_unpad: pad_input(
|
||||
dk_unpad, indices_k, batch_size, seqlen_k
|
||||
)
|
||||
else:
|
||||
dk_pad_fn = lambda dk_unpad: rearrange(
|
||||
dk_unpad, "(b s) h d -> b s h d", b=batch_size
|
||||
)
|
||||
return (
|
||||
q_unpad.detach().requires_grad_(),
|
||||
k_unpad.detach().requires_grad_(),
|
||||
v_unpad.detach().requires_grad_(),
|
||||
qv_unpad.detach() if qv is not None else None,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
seqused_q,
|
||||
seqused_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
q.detach().requires_grad_(),
|
||||
k.detach().requires_grad_(),
|
||||
v.detach().requires_grad_(),
|
||||
qv.detach() if qv is not None else None,
|
||||
output_pad_fn,
|
||||
dq_pad_fn,
|
||||
dk_pad_fn,
|
||||
)
|
||||
|
||||
|
||||
def construct_local_mask(
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
window_size=(None, None),
|
||||
sink_token_length=0,
|
||||
query_padding_mask=None,
|
||||
key_padding_mask=None,
|
||||
key_leftpad=None,
|
||||
device=None,
|
||||
):
|
||||
row_idx = rearrange(
|
||||
torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1"
|
||||
)
|
||||
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
|
||||
if key_leftpad is not None:
|
||||
key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1")
|
||||
col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0])
|
||||
col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)
|
||||
sk = (
|
||||
seqlen_k
|
||||
if key_padding_mask is None
|
||||
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
|
||||
)
|
||||
sq = (
|
||||
seqlen_q
|
||||
if query_padding_mask is None
|
||||
else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
|
||||
)
|
||||
if window_size[0] is None:
|
||||
return col_idx > row_idx + sk - sq + window_size[1]
|
||||
else:
|
||||
sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
|
||||
return torch.logical_or(
|
||||
col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
|
||||
torch.logical_and(
|
||||
col_idx < row_idx + sk - sq - window_size[0],
|
||||
col_idx >= sink_token_length,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def construct_chunk_mask(
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
attention_chunk,
|
||||
query_padding_mask=None,
|
||||
key_padding_mask=None,
|
||||
key_leftpad=None,
|
||||
device=None,
|
||||
):
|
||||
row_idx = rearrange(
|
||||
torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1"
|
||||
)
|
||||
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
|
||||
if key_leftpad is not None:
|
||||
key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1")
|
||||
col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0])
|
||||
col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)
|
||||
sk = (
|
||||
seqlen_k
|
||||
if key_padding_mask is None
|
||||
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
|
||||
)
|
||||
sq = (
|
||||
seqlen_q
|
||||
if query_padding_mask is None
|
||||
else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
|
||||
)
|
||||
sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
|
||||
# Subtract remainder instead of divide and then multiply to take care of negative values
|
||||
col_limit_left_chunk = row_idx + sk - sq - (row_idx + sk - sq) % attention_chunk
|
||||
return torch.logical_or(
|
||||
col_idx < col_limit_left_chunk,
|
||||
col_idx >= col_limit_left_chunk + attention_chunk,
|
||||
)
|
||||
|
||||
|
||||
def attention_ref(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
query_padding_mask=None,
|
||||
key_padding_mask=None,
|
||||
key_leftpad=None,
|
||||
attn_bias=None,
|
||||
dropout_p=0.0,
|
||||
dropout_mask=None,
|
||||
causal=False,
|
||||
qv=None,
|
||||
q_descale=None,
|
||||
k_descale=None,
|
||||
v_descale=None,
|
||||
window_size=(None, None),
|
||||
attention_chunk=0,
|
||||
sink_token_length=0,
|
||||
learnable_sink=None,
|
||||
softcap=0.0,
|
||||
upcast=True,
|
||||
reorder_ops=False,
|
||||
intermediate_dtype=None,
|
||||
):
|
||||
"""
|
||||
Arguments:
|
||||
q: (batch_size, seqlen_q, nheads, head_dim)
|
||||
k: (batch_size, seqlen_k, nheads, head_dim)
|
||||
v: (batch_size, seqlen_k, nheads, head_dim_v)
|
||||
qv: (batch_size, seqlen_q, nheads, head_dim_v)
|
||||
query_padding_mask: (batch_size, seqlen_q)
|
||||
key_padding_mask: (batch_size, seqlen_k)
|
||||
attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
|
||||
dropout_p: float
|
||||
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
|
||||
causal: whether to apply causal masking
|
||||
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
|
||||
output back to fp16/bf16.
|
||||
reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)
|
||||
without changing the math. This is to estimate the numerical error from operation
|
||||
reordering.
|
||||
Output:
|
||||
output: (batch_size, seqlen_q, nheads, head_dim_v)
|
||||
attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
|
||||
"""
|
||||
if causal:
|
||||
window_size = (window_size[0], 0)
|
||||
dtype_og = q.dtype
|
||||
if upcast:
|
||||
q, k, v = q.float(), k.float(), v.float()
|
||||
qv = qv.float() if qv is not None else None
|
||||
if q_descale is not None:
|
||||
q_descale = repeat(q_descale, "b h -> b 1 (h g) 1", g=q.shape[2] // k.shape[2])
|
||||
q = (q.float() * q_descale).to(q.dtype)
|
||||
qv = (qv.float() * q_descale).to(qv.dtype) if qv is not None else None
|
||||
if k_descale is not None:
|
||||
k = (k.float() * rearrange(k_descale, "b h -> b 1 h 1")).to(dtype=k.dtype)
|
||||
if v_descale is not None:
|
||||
v = (v.float() * rearrange(v_descale, "b h -> b 1 h 1")).to(dtype=v.dtype)
|
||||
seqlen_q, seqlen_k = q.shape[1], k.shape[1]
|
||||
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
|
||||
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
|
||||
d = q.shape[-1]
|
||||
dv = v.shape[-1]
|
||||
softmax_scale = 1.0 / math.sqrt(d if qv is None else d + dv)
|
||||
if not reorder_ops:
|
||||
scores = torch.einsum("bthd,bshd->bhts", q * softmax_scale, k)
|
||||
else:
|
||||
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
||||
if qv is not None:
|
||||
scores = scores + torch.einsum("bthd,bshd->bhts", qv * softmax_scale, v)
|
||||
if softcap > 0:
|
||||
scores = torch.tanh(scores / softcap) * softcap
|
||||
if key_padding_mask is not None:
|
||||
scores.masked_fill_(
|
||||
rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")
|
||||
)
|
||||
local_mask = None
|
||||
if window_size[0] is not None or window_size[1] is not None:
|
||||
local_mask = construct_local_mask(
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
window_size,
|
||||
sink_token_length,
|
||||
query_padding_mask,
|
||||
key_padding_mask,
|
||||
key_leftpad=key_leftpad,
|
||||
device=q.device,
|
||||
)
|
||||
if attention_chunk > 0:
|
||||
chunk_mask = construct_chunk_mask(
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
attention_chunk,
|
||||
query_padding_mask,
|
||||
key_padding_mask,
|
||||
key_leftpad=key_leftpad,
|
||||
device=q.device,
|
||||
)
|
||||
local_mask = (
|
||||
torch.logical_or(local_mask, chunk_mask)
|
||||
if local_mask is not None
|
||||
else chunk_mask
|
||||
)
|
||||
if local_mask is not None:
|
||||
scores.masked_fill_(local_mask, float("-inf"))
|
||||
if attn_bias is not None:
|
||||
scores = scores + attn_bias
|
||||
if learnable_sink is None:
|
||||
attention = torch.softmax(scores, dim=-1).to(v.dtype)
|
||||
else:
|
||||
scores_fp32 = scores.to(torch.float32)
|
||||
logits_max = torch.amax(scores_fp32, dim=-1, keepdim=True)
|
||||
learnable_sink = rearrange(learnable_sink, "h -> h 1 1")
|
||||
logits_or_sinks_max = torch.maximum(learnable_sink, logits_max)
|
||||
unnormalized_scores = torch.exp(scores_fp32 - logits_or_sinks_max)
|
||||
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + torch.exp(
|
||||
learnable_sink - logits_or_sinks_max
|
||||
)
|
||||
attention = (unnormalized_scores / normalizer).to(v.dtype)
|
||||
# We want to mask here so that the attention matrix doesn't have any NaNs
|
||||
# Otherwise we'll get NaN in dV
|
||||
if query_padding_mask is not None:
|
||||
attention = attention.masked_fill(
|
||||
rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0
|
||||
)
|
||||
# Without this we might get NaN in dv
|
||||
if key_padding_mask is not None:
|
||||
attention = attention.masked_fill(
|
||||
rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0
|
||||
)
|
||||
# Some rows might be completely masked out so we fill them with zero instead of NaN
|
||||
if local_mask is not None:
|
||||
attention = attention.masked_fill(
|
||||
torch.all(local_mask, dim=-1, keepdim=True), 0.0
|
||||
)
|
||||
dropout_scaling = 1.0 / (1 - dropout_p)
|
||||
# attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
|
||||
# output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
|
||||
if dropout_mask is not None:
|
||||
attention_drop = attention.masked_fill(~dropout_mask, 0.0)
|
||||
else:
|
||||
attention_drop = attention
|
||||
if intermediate_dtype is not None:
|
||||
attention_drop = attention_drop.to(intermediate_dtype).to(attention_drop.dtype)
|
||||
output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
|
||||
if query_padding_mask is not None:
|
||||
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
|
||||
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
is_hopper(),
|
||||
reason="skip on hopper",
|
||||
)
|
||||
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
|
||||
# @pytest.mark.parametrize("mha_type", ["mqa"])
|
||||
@pytest.mark.parametrize("has_learnable_sink", [False, True])
|
||||
# @pytest.mark.parametrize("has_learnable_sink", [False])
|
||||
# @pytest.mark.parametrize("has_qv", [False, True])
|
||||
@pytest.mark.parametrize("has_qv", [False])
|
||||
# @pytest.mark.parametrize("deterministic", [False, True])
|
||||
@pytest.mark.parametrize("deterministic", [False])
|
||||
# @pytest.mark.parametrize("softcap", [0.0, 15.0])
|
||||
@pytest.mark.parametrize("softcap", [0.0])
|
||||
@pytest.mark.parametrize("local", [False, True])
|
||||
# @pytest.mark.parametrize("local", [False])
|
||||
@pytest.mark.parametrize("causal", [False, True])
|
||||
# @pytest.mark.parametrize("causal", [False])
|
||||
# @pytest.mark.parametrize("add_unused_qkv", [False, True])
|
||||
@pytest.mark.parametrize("add_unused_qkv", [False])
|
||||
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
|
||||
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256])
|
||||
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
|
||||
# @pytest.mark.parametrize('d', [56, 80])
|
||||
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128])
|
||||
# @pytest.mark.parametrize("d", [64, 96, 128])
|
||||
@pytest.mark.parametrize("d", [128, 192])
|
||||
# @pytest.mark.parametrize("d", [192])
|
||||
@pytest.mark.parametrize(
|
||||
"seqlen_q,seqlen_k",
|
||||
[
|
||||
# (1, 1),
|
||||
# (1, 3),
|
||||
# (2, 1),
|
||||
(511, 1),
|
||||
(3, 513),
|
||||
(64, 128),
|
||||
(128, 128),
|
||||
(256, 256),
|
||||
(113, 203),
|
||||
(128, 217),
|
||||
(113, 211),
|
||||
(108, 256),
|
||||
(256, 512),
|
||||
(307, 256),
|
||||
(640, 128),
|
||||
(512, 256),
|
||||
(1024, 1024),
|
||||
(1023, 1024),
|
||||
(1024, 1023),
|
||||
(2048, 2048),
|
||||
],
|
||||
)
|
||||
def test_flash_attn_varlen_output(
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
d,
|
||||
add_unused_qkv,
|
||||
causal,
|
||||
local,
|
||||
softcap,
|
||||
deterministic,
|
||||
has_qv,
|
||||
has_learnable_sink,
|
||||
mha_type,
|
||||
dtype,
|
||||
):
|
||||
if (
|
||||
causal or local
|
||||
): # Right now we only support causal attention with seqlen_k == seqlen_q
|
||||
seqlen_k = seqlen_q
|
||||
device = "cuda"
|
||||
# set seed
|
||||
torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local))
|
||||
batch_size = 49 if seqlen_q <= 1024 else 7
|
||||
nheads = 6
|
||||
# batch_size = 1
|
||||
# nheads = 1
|
||||
nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1)
|
||||
dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
|
||||
# dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])
|
||||
dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d])
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
dv_vals = [d]
|
||||
# attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k else [0]
|
||||
attention_chunk_vals = [0]
|
||||
for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals):
|
||||
q_ref = torch.randn(
|
||||
batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref
|
||||
)
|
||||
if softcap > 0.0:
|
||||
# Ensure the values of qk are at least within softcap range.
|
||||
q_ref = (q_ref * softcap / 4).detach().requires_grad_()
|
||||
q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_()
|
||||
k_ref = (
|
||||
torch.randn(
|
||||
batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref
|
||||
)
|
||||
.to(dtype)
|
||||
.to(dtype_ref)
|
||||
.requires_grad_()
|
||||
)
|
||||
v_ref = (
|
||||
torch.randn(
|
||||
batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref
|
||||
)
|
||||
.to(dtype)
|
||||
.to(dtype_ref)
|
||||
.requires_grad_()
|
||||
)
|
||||
if has_qv:
|
||||
qv_ref = (
|
||||
torch.randn(
|
||||
batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref
|
||||
)
|
||||
.to(dtype)
|
||||
.to(dtype_ref)
|
||||
)
|
||||
else:
|
||||
qv_ref = None
|
||||
# Put window_size after QKV randn so that window_size changes from test to test
|
||||
window_size = (
|
||||
(None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist()
|
||||
)
|
||||
if has_learnable_sink:
|
||||
learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device)
|
||||
else:
|
||||
learnable_sink = None
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
q_descale, k_descale, v_descale = [
|
||||
torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32)
|
||||
* 2
|
||||
for _ in range(3)
|
||||
]
|
||||
else:
|
||||
q_descale, k_descale, v_descale = None, None, None
|
||||
q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)]
|
||||
qv = qv_ref.detach() if has_qv else None
|
||||
query_padding_mask = generate_random_padding_mask(
|
||||
seqlen_q, batch_size, device, mode="random", zero_lengths=False
|
||||
)
|
||||
# TODO: test zero_lengths
|
||||
key_padding_mask = generate_random_padding_mask(
|
||||
# seqlen_k, batch_size, device, mode="random", zero_lengths=True
|
||||
seqlen_k,
|
||||
batch_size,
|
||||
device,
|
||||
mode="random",
|
||||
zero_lengths=False,
|
||||
)
|
||||
|
||||
def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):
|
||||
if add_unused:
|
||||
another_mask = generate_random_padding_mask(max_seq_len, bs, device)
|
||||
attn_mask = torch.logical_and(padding_mask, another_mask)
|
||||
unused_mask = torch.logical_xor(
|
||||
torch.logical_or(padding_mask, another_mask), attn_mask
|
||||
)
|
||||
else:
|
||||
attn_mask = padding_mask
|
||||
unused_mask = None
|
||||
return attn_mask, unused_mask
|
||||
|
||||
query_padding_mask, query_unused_mask = _gen_unused_masks(
|
||||
query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device
|
||||
)
|
||||
# query_padding_mask[:] = True
|
||||
# query_unused_mask = None
|
||||
key_padding_mask, key_unused_mask = _gen_unused_masks(
|
||||
key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device
|
||||
)
|
||||
|
||||
if causal or local:
|
||||
key_padding_mask = query_padding_mask
|
||||
|
||||
(
|
||||
q_unpad,
|
||||
k_unpad,
|
||||
v_unpad,
|
||||
qv_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
seqused_q,
|
||||
seqused_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
qv,
|
||||
output_pad_fn,
|
||||
dq_pad_fn,
|
||||
dk_pad_fn,
|
||||
) = generate_qkv(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
query_padding_mask,
|
||||
key_padding_mask,
|
||||
qv=qv,
|
||||
kvpacked=False,
|
||||
query_unused_mask=query_unused_mask,
|
||||
key_unused_mask=key_unused_mask,
|
||||
)
|
||||
q_unpad, k_unpad, v_unpad = [
|
||||
x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)
|
||||
]
|
||||
out_ref, attn_ref = attention_ref(
|
||||
q_ref,
|
||||
k_ref,
|
||||
v_ref,
|
||||
query_padding_mask,
|
||||
key_padding_mask,
|
||||
causal=causal,
|
||||
qv=qv_ref,
|
||||
q_descale=q_descale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
window_size=window_size,
|
||||
attention_chunk=attention_chunk,
|
||||
learnable_sink=learnable_sink,
|
||||
softcap=softcap,
|
||||
)
|
||||
out_pt, attn_pt = attention_ref(
|
||||
q_ref,
|
||||
k_ref,
|
||||
v_ref,
|
||||
query_padding_mask,
|
||||
key_padding_mask,
|
||||
causal=causal,
|
||||
qv=qv_ref,
|
||||
q_descale=q_descale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
window_size=window_size,
|
||||
attention_chunk=attention_chunk,
|
||||
learnable_sink=learnable_sink,
|
||||
softcap=softcap,
|
||||
upcast=False,
|
||||
reorder_ops=True,
|
||||
intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None,
|
||||
)
|
||||
|
||||
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
|
||||
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
|
||||
|
||||
if query_unused_mask is not None:
|
||||
q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1")
|
||||
|
||||
# Numerical error if we just do any arithmetic on out_ref
|
||||
fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item()
|
||||
rtol = 2 if softcap == 0.0 else 3
|
||||
|
||||
pack_gqa_vals = [False, True, None]
|
||||
# num_splits_vals = [1, 3]
|
||||
num_splits_vals = [1]
|
||||
for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals):
|
||||
out_unpad, lse = flash_attn_varlen_func(
|
||||
q_unpad,
|
||||
k_unpad,
|
||||
v_unpad,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=None,
|
||||
max_seqlen_k=None,
|
||||
# seqused_q=seqused_q,
|
||||
# seqused_k=seqused_k,
|
||||
causal=causal,
|
||||
# qv=qv_unpad,
|
||||
# q_descale=q_descale,
|
||||
# k_descale=k_descale, v_descale=v_descale,
|
||||
window_size=window_size,
|
||||
# attention_chunk=attention_chunk,
|
||||
sinks=learnable_sink,
|
||||
softcap=softcap,
|
||||
pack_gqa=pack_gqa,
|
||||
return_softmax_lse=True,
|
||||
)
|
||||
out = output_pad_fn(out_unpad)
|
||||
if query_unused_mask is not None:
|
||||
out.masked_fill_(q_zero_masking, 0.0)
|
||||
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
||||
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
||||
# if not causal:
|
||||
# print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}")
|
||||
# breakpoint()
|
||||
|
||||
# Check that FlashAttention's numerical error is at most 3x the numerical error
|
||||
# of a Pytorch implementation.
|
||||
assert (out - out_ref).abs().max().item() <= rtol * (
|
||||
out_pt - out_ref
|
||||
).abs().max().item() + fwd_atol
|
||||
|
||||
if (
|
||||
dtype != torch.float8_e4m3fn
|
||||
and not has_qv
|
||||
and not dv > 256
|
||||
and not attention_chunk != 0
|
||||
and dv == d
|
||||
and not has_learnable_sink
|
||||
and False
|
||||
):
|
||||
g_unpad = torch.randn_like(out_unpad)
|
||||
do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2)
|
||||
# import flash_attn_3_cuda
|
||||
# dq_unpad, dk_unpad, dv_unpad, softmax_d, dq_accum, lse_log2 = flash_attn_3_cuda.bwd_varlen(
|
||||
# g_unpad,
|
||||
# q_unpad,
|
||||
# k_unpad,
|
||||
# v_unpad,
|
||||
# out_unpad,
|
||||
# lse,
|
||||
# None,
|
||||
# None,
|
||||
# None,
|
||||
# cu_seqlens_q,
|
||||
# cu_seqlens_k,
|
||||
# None, None,
|
||||
# max_seqlen_q,
|
||||
# max_seqlen_k,
|
||||
# d ** (-0.5),
|
||||
# causal,
|
||||
# window_size[0], window_size[1],
|
||||
# softcap,
|
||||
# deterministic,
|
||||
# 0, # sm_margin
|
||||
# )
|
||||
dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(
|
||||
out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad
|
||||
)
|
||||
dq = dq_pad_fn(dq_unpad)
|
||||
dk = dk_pad_fn(dk_unpad)
|
||||
dv = dk_pad_fn(dv_unpad)
|
||||
if key_unused_mask is not None:
|
||||
k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1")
|
||||
dk.masked_fill_(k_zero_masking, 0.0)
|
||||
dv.masked_fill_(k_zero_masking, 0.0)
|
||||
if query_unused_mask is not None:
|
||||
dq.masked_fill_(q_zero_masking, 0.0)
|
||||
# print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}")
|
||||
# assert (softmax_d - do_o).abs().max().item() <= 1e-5
|
||||
# assert dq_accum.abs().max().item() == 0.0
|
||||
g = output_pad_fn(g_unpad)
|
||||
|
||||
# qk = torch.einsum('bthd,bshd->bhts', q / (d ** 0.5), k).float()
|
||||
# qk = torch.masked_fill(qk, rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
|
||||
# dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float())
|
||||
# P = torch.softmax(qk, -1)
|
||||
# dP = P * (dS - (g.float() * out.float()).sum(-1).transpose(1, 2).unsqueeze(-1))
|
||||
# dQ = torch.einsum('bhts,bshd->bthd', dP, k.float())
|
||||
# dV = torch.einsum('bhts,bthd->bshd', P, g.float())
|
||||
# dK = torch.einsum('bhts,bthd->bshd', dP, q.float())
|
||||
|
||||
# dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)
|
||||
dq_ref, dk_ref, dv_ref = torch.autograd.grad(
|
||||
out_ref, (q_ref, k_ref, v_ref), g
|
||||
)
|
||||
dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g)
|
||||
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
|
||||
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
|
||||
print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
|
||||
print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
|
||||
print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
|
||||
print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
|
||||
print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
|
||||
print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
|
||||
print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
|
||||
print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
|
||||
print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
|
||||
print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
|
||||
# breakpoint()
|
||||
dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (
|
||||
0 if softcap == 0 else 3e-4
|
||||
)
|
||||
assert (dq - dq_ref).abs().max().item() <= rtol * (
|
||||
dq_pt - dq_ref
|
||||
).abs().max().item() + dq_atol
|
||||
dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (
|
||||
0 if softcap == 0 else 3e-4
|
||||
)
|
||||
assert (dk - dk_ref).abs().max().item() <= rtol * (
|
||||
dk_pt - dk_ref
|
||||
).abs().max().item() + dk_atol
|
||||
dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (
|
||||
0 if softcap == 0 else 3e-4
|
||||
)
|
||||
assert (dv - dv_ref).abs().max().item() <= rtol * (
|
||||
dv_pt - dv_ref
|
||||
).abs().max().item() + dv_atol
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
154
sgl-kernel/tests/test_fp4_gemm.py
Normal file
154
sgl-kernel/tests/test_fp4_gemm.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
|
||||
skip_condition = torch.cuda.get_device_capability() < (10, 0)
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
# m, n, k
|
||||
SHAPES = [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)]
|
||||
PAD_SHAPES = [(150, 128, 64), (128, 128, 96)]
|
||||
SHAPES.extend(PAD_SHAPES)
|
||||
|
||||
FLOAT4_E2M1_MAX = 6.0
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
kE2M1ToFloatArray = [
|
||||
0.0,
|
||||
0.5,
|
||||
1.0,
|
||||
1.5,
|
||||
2.0,
|
||||
3.0,
|
||||
4.0,
|
||||
6.0,
|
||||
]
|
||||
|
||||
|
||||
def e2m1_to_fp32(int4_value):
|
||||
signBit = int4_value & 0x8
|
||||
int4_absValue = int4_value & 0x7
|
||||
float_result = kE2M1ToFloatArray[int4_absValue]
|
||||
if signBit:
|
||||
float_result = -float_result
|
||||
return float_result
|
||||
|
||||
|
||||
def break_fp4_bytes(a, dtype):
|
||||
assert a.dtype == torch.uint8
|
||||
m, n = a.shape
|
||||
a = a.flatten()
|
||||
# Get upper 4 bits
|
||||
highHalfByte = (a & 0xF0) >> 4
|
||||
# Get lower 4 bits
|
||||
lowHalfByte = a & 0x0F
|
||||
fH = torch.tensor([e2m1_to_fp32(x) for x in highHalfByte]).to(a.device)
|
||||
fL = torch.tensor([e2m1_to_fp32(x) for x in lowHalfByte]).to(a.device)
|
||||
# [0xAB, 0xCD] -> [0xB, 0xA, 0xD, 0xC]
|
||||
out = torch.stack((fL, fH), dim=-1).reshape(m, n * 2)
|
||||
return out
|
||||
|
||||
|
||||
def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
|
||||
sf_m, sf_k = a_sf_swizzled.shape
|
||||
m_tiles = (m + 128 - 1) // 128
|
||||
f = block_size * 4
|
||||
k_tiles = (k + f - 1) // f
|
||||
tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4))
|
||||
tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))
|
||||
out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size)
|
||||
return out[0:m, 0:k]
|
||||
|
||||
|
||||
def dequantize_to_dtype(
|
||||
tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16
|
||||
):
|
||||
"""Dequantize the fp4 tensor back to high precision."""
|
||||
# Two fp4 values are packed into one uint8.
|
||||
assert tensor_fp4.dtype == torch.uint8
|
||||
m, packed_k = tensor_fp4.shape
|
||||
k = packed_k * 2
|
||||
tensor_f32 = break_fp4_bytes(tensor_fp4, dtype)
|
||||
tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
|
||||
tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
|
||||
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
|
||||
tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale
|
||||
|
||||
# scale the tensor
|
||||
out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
|
||||
return out
|
||||
|
||||
|
||||
def get_ref_results(
|
||||
a_fp4,
|
||||
b_fp4,
|
||||
a_sf,
|
||||
b_sf,
|
||||
a_global_scale,
|
||||
b_global_scale,
|
||||
m,
|
||||
n,
|
||||
dtype,
|
||||
block_size,
|
||||
device,
|
||||
):
|
||||
_, m_k = a_fp4.shape
|
||||
_, n_k = b_fp4.shape
|
||||
assert m_k == n_k
|
||||
a_in_dtype = dequantize_to_dtype(
|
||||
a_fp4, a_sf, a_global_scale, dtype=dtype, device=device, block_size=block_size
|
||||
)
|
||||
b_in_dtype = dequantize_to_dtype(
|
||||
b_fp4, b_sf, b_global_scale, dtype=dtype, device=device, block_size=block_size
|
||||
)
|
||||
return torch.matmul(a_in_dtype, b_in_dtype.t())
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
skip_condition, reason="Nvfp4 Requires compute capability of 10 or above."
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("shape", SHAPES)
|
||||
@torch.inference_mode()
|
||||
def test_nvfp4_gemm(
|
||||
dtype: torch.dtype,
|
||||
shape: tuple[int, int],
|
||||
) -> None:
|
||||
m, n, packed_k = shape
|
||||
k = packed_k * 2
|
||||
block_size = 16
|
||||
a_dtype = torch.randn((m, k), dtype=dtype, device="cuda")
|
||||
b_dtype = torch.randn((n, k), dtype=dtype, device="cuda")
|
||||
|
||||
a_global_scale = (
|
||||
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a_dtype.flatten(), dim=-1)
|
||||
).to(torch.float32)
|
||||
b_global_scale = (
|
||||
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1)
|
||||
).to(torch.float32)
|
||||
alpha = 1.0 / (a_global_scale * b_global_scale)
|
||||
a_fp4, a_scale_interleaved = scaled_fp4_quant(a_dtype, a_global_scale)
|
||||
b_fp4, b_scale_interleaved = scaled_fp4_quant(b_dtype, b_global_scale)
|
||||
|
||||
expected_out = get_ref_results(
|
||||
a_fp4,
|
||||
b_fp4,
|
||||
a_scale_interleaved,
|
||||
b_scale_interleaved,
|
||||
a_global_scale,
|
||||
b_global_scale,
|
||||
m,
|
||||
n,
|
||||
dtype,
|
||||
block_size,
|
||||
"cuda",
|
||||
)
|
||||
out = cutlass_scaled_fp4_mm(
|
||||
a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype
|
||||
)
|
||||
|
||||
torch.testing.assert_close(out, expected_out.to(dtype=dtype), atol=1e-1, rtol=1e-1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
261
sgl-kernel/tests/test_fp4_quantize.py
Normal file
261
sgl-kernel/tests/test_fp4_quantize.py
Normal file
@@ -0,0 +1,261 @@
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import (
|
||||
scaled_fp4_grouped_quant,
|
||||
scaled_fp4_quant,
|
||||
silu_and_mul,
|
||||
silu_and_mul_scaled_fp4_grouped_quant,
|
||||
)
|
||||
|
||||
skip_condition = torch.cuda.get_device_capability() < (10, 0)
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
SHAPES = [(128, 64), (128, 128), (256, 64), (256, 128)]
|
||||
PAD_SHAPES = [
|
||||
(90, 64),
|
||||
(150, 64),
|
||||
(128, 48),
|
||||
(128, 80),
|
||||
(150, 80),
|
||||
(90, 48),
|
||||
(90, 128),
|
||||
(150, 128),
|
||||
(150, 48),
|
||||
(90, 80),
|
||||
]
|
||||
|
||||
FLOAT4_E2M1_MAX = 6.0
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
# E2M1 to float
|
||||
# 0111 -> 6
|
||||
# 0110 -> 4
|
||||
# 0101 -> 3
|
||||
# 0100 -> 2
|
||||
# 0011 -> 1.5
|
||||
# 0010 -> 1
|
||||
# 0001 -> 0.5
|
||||
# 0000 -> 0
|
||||
E2M1_TO_FLOAT32 = [
|
||||
0.0,
|
||||
0.5,
|
||||
1.0,
|
||||
1.5,
|
||||
2.0,
|
||||
3.0,
|
||||
4.0,
|
||||
6.0,
|
||||
0.0,
|
||||
-0.5,
|
||||
-1.0,
|
||||
-1.5,
|
||||
-2.0,
|
||||
-3.0,
|
||||
-4.0,
|
||||
-6.0,
|
||||
]
|
||||
BLOCK_SIZE = 16
|
||||
|
||||
|
||||
def cast_from_fp4(x, m, n):
|
||||
# The fp4 values are packed in uint8 as [v_1st | v_2nd]
|
||||
v_2nd = x & 0xF
|
||||
v_1st = (x >> 4) & 0xF
|
||||
c = torch.stack((v_2nd, v_1st), dim=-1)
|
||||
out = torch.tensor([E2M1_TO_FLOAT32[x] for x in c.flatten()])
|
||||
out = out.reshape(m, n).to(torch.float32)
|
||||
return out
|
||||
|
||||
|
||||
def cast_to_fp4(x):
|
||||
sign = torch.sign(x)
|
||||
x = torch.abs(x)
|
||||
x[(x >= 0.0) & (x <= 0.25)] = 0.0
|
||||
x[(x > 0.25) & (x < 0.75)] = 0.5
|
||||
x[(x >= 0.75) & (x <= 1.25)] = 1.0
|
||||
x[(x > 1.25) & (x < 1.75)] = 1.5
|
||||
x[(x >= 1.75) & (x <= 2.5)] = 2.0
|
||||
x[(x > 2.5) & (x < 3.5)] = 3.0
|
||||
x[(x >= 3.5) & (x <= 5.0)] = 4.0
|
||||
x[x > 5.0] = 6.0
|
||||
return x * sign
|
||||
|
||||
|
||||
def get_reciprocal(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return torch.where(x == 0, torch.tensor(0.0, dtype=x.dtype), 1.0 / x)
|
||||
elif isinstance(x, (float, int)):
|
||||
return 0.0 if x == 0 else 1.0 / x
|
||||
else:
|
||||
raise TypeError("Input must be a float, int, or a torch.Tensor.")
|
||||
|
||||
|
||||
def ref_nvfp4_quant(x, global_scale):
|
||||
assert global_scale.dtype == torch.float32
|
||||
assert x.ndim == 2
|
||||
m, n = x.shape
|
||||
x = torch.reshape(x, (m, n // BLOCK_SIZE, BLOCK_SIZE))
|
||||
vec_max = torch.max(torch.abs(x), dim=-1, keepdim=True)[0].to(torch.float32)
|
||||
scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX))
|
||||
scale = scale.to(torch.float8_e4m3fn).to(torch.float32)
|
||||
output_scale = get_reciprocal(scale * get_reciprocal(global_scale))
|
||||
|
||||
scaled_x = x.to(torch.float32) * output_scale
|
||||
clipped_x = torch.clamp(scaled_x, -6.0, 6.0).reshape(m, n)
|
||||
return cast_to_fp4(clipped_x), scale.squeeze(-1)
|
||||
|
||||
|
||||
def recover_swizzled_scales(scale, m, n):
|
||||
rounded_m = ((m + 128 - 1) // 128) * 128
|
||||
scale_n = n // BLOCK_SIZE
|
||||
rounded_n = ((scale_n + 4 - 1) // 4) * 4
|
||||
# Recover the swizzled scaling factor to linear layout
|
||||
tmp = torch.reshape(scale, (1, rounded_m // 128, rounded_n // 4, 32, 4, 4))
|
||||
tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))
|
||||
result = torch.reshape(tmp, (rounded_m, rounded_n)).to(torch.float32)
|
||||
return result[:m, :scale_n]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
skip_condition, reason="Nvfp4 Requires compute capability of 10 or above."
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("shape", SHAPES)
|
||||
@torch.inference_mode()
|
||||
def test_quantize_to_fp4(
|
||||
dtype: torch.dtype,
|
||||
shape: tuple[int, int],
|
||||
) -> None:
|
||||
torch.manual_seed(42)
|
||||
torch.set_default_device("cuda:0")
|
||||
|
||||
m, n = shape
|
||||
|
||||
x = torch.randn((m, n), dtype=dtype)
|
||||
tensor_amax = torch.abs(x).max().to(torch.float32)
|
||||
global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
|
||||
out_ref, scale_ref = ref_nvfp4_quant(x, global_scale)
|
||||
|
||||
out, out_scale = scaled_fp4_quant(x, global_scale)
|
||||
scale_ans = recover_swizzled_scales(out_scale, m, n)
|
||||
out_ans = cast_from_fp4(out, m, n)
|
||||
|
||||
torch.testing.assert_close(out_ans, out_ref)
|
||||
torch.testing.assert_close(scale_ans, scale_ref)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
skip_condition, reason="Nvfp4 Requires compute capability of 10 or above."
|
||||
)
|
||||
@pytest.mark.parametrize("pad_shape", PAD_SHAPES)
|
||||
@torch.inference_mode()
|
||||
def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None:
|
||||
torch.manual_seed(42)
|
||||
dtype = torch.float16
|
||||
torch.set_default_device("cuda:0")
|
||||
|
||||
m, n = pad_shape
|
||||
|
||||
x = torch.randn((m, n), dtype=dtype)
|
||||
|
||||
tensor_amax = torch.abs(x).max().to(torch.float32)
|
||||
global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
|
||||
out_ref, scale_ref = ref_nvfp4_quant(x, global_scale)
|
||||
|
||||
out, out_scale = scaled_fp4_quant(x, global_scale)
|
||||
|
||||
scale_ans = recover_swizzled_scales(out_scale, m, n)
|
||||
out_ans = cast_from_fp4(out, m, n)
|
||||
|
||||
torch.testing.assert_close(out_ans, out_ref)
|
||||
torch.testing.assert_close(scale_ans, scale_ref)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
skip_condition, reason="Nvfp4 Requires compute capability of 10 or above."
|
||||
)
|
||||
@pytest.mark.parametrize("shape", [(2, 512, 2048), (2, 100, 128), (2, 128, 96)])
|
||||
def test_quantize_to_fp4_grouped(shape):
|
||||
torch.manual_seed(42)
|
||||
torch.set_default_device("cuda:0")
|
||||
|
||||
l, m, k = shape
|
||||
x = torch.randn((l, m, k), dtype=torch.bfloat16)
|
||||
max_m = m // 2
|
||||
assert max_m <= m
|
||||
mask = torch.randint(1, max_m, (l,), dtype=torch.int32)
|
||||
tensor_amax = x.abs().amax(dim=(1, 2)).to(torch.float32)
|
||||
x_sf_global = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
|
||||
output, output_scales = scaled_fp4_grouped_quant(
|
||||
x,
|
||||
x_sf_global,
|
||||
mask,
|
||||
)
|
||||
# output in logical (m, k, l), but its physical layout is (l, m, k).
|
||||
# So permute first to (l, m, k).
|
||||
output = output.permute(2, 0, 1)
|
||||
# output_scale in logical (32, 4, rm, 4, rk, l), but its physical layout is (l, rm, rk, 32, 4, 4).
|
||||
# So permute first to (l, rm, rk, 32, 4, 4).
|
||||
padded_m = ((m + 128 - 1) // 128) * 128
|
||||
output_scales = output_scales.permute(5, 2, 4, 0, 1, 3).view(l, padded_m, -1)
|
||||
for i in range(l):
|
||||
a_fp4, a_scale_interleaved = scaled_fp4_quant(x[i], x_sf_global[i])
|
||||
torch.testing.assert_close(a_fp4[: mask[i]], output[i][: mask[i]])
|
||||
# Recover swizzled scales to linear layout and drop padded values, so
|
||||
# no extra checks on padding are needed.
|
||||
scale_ref = recover_swizzled_scales(a_scale_interleaved, m, k)
|
||||
scale_ans = recover_swizzled_scales(output_scales[i], m, k)
|
||||
torch.testing.assert_close(scale_ref[: mask[i]], scale_ans[: mask[i]])
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
skip_condition, reason="Nvfp4 Requires compute capability of 10 or above."
|
||||
)
|
||||
@pytest.mark.parametrize("shape", [(32, 100, 2048), (32, 512, 2048), (6, 6144, 2048)])
|
||||
def test_silu_and_mul_quantize_to_fp4_grouped(shape):
|
||||
torch.manual_seed(42)
|
||||
torch.set_default_device("cuda:0")
|
||||
|
||||
l, m, k = shape
|
||||
x = torch.randn((l, m, k * 2), dtype=torch.bfloat16)
|
||||
max_m = m // 2
|
||||
assert max_m <= m
|
||||
mask = torch.randint(1, max_m, (l,), dtype=torch.int32)
|
||||
|
||||
ref_y = silu_and_mul(x)
|
||||
tensor_amax = ref_y.abs().amax(dim=(1, 2)).to(torch.float32)
|
||||
y_sf_global = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
|
||||
ref_output, ref_output_scales = scaled_fp4_grouped_quant(
|
||||
ref_y,
|
||||
y_sf_global,
|
||||
mask,
|
||||
)
|
||||
output, output_scales = silu_and_mul_scaled_fp4_grouped_quant(
|
||||
x,
|
||||
y_sf_global,
|
||||
mask,
|
||||
)
|
||||
|
||||
# output in logical (m, k, l), but its physical layout is (l, m, k).
|
||||
# So permute first to (l, m, k).
|
||||
output = output.permute(2, 0, 1)
|
||||
ref_output = ref_output.permute(2, 0, 1)
|
||||
|
||||
# output_scale in logical (32, 4, rm, 4, rk, l), but its physical layout is (l, rm, rk, 32, 4, 4).
|
||||
# So permute first to (l, rm, rk, 32, 4, 4).
|
||||
padded_m = ((m + 128 - 1) // 128) * 128
|
||||
output_scales = output_scales.permute(5, 2, 4, 0, 1, 3).view(l, padded_m, -1)
|
||||
ref_output_scales = ref_output_scales.permute(5, 2, 4, 0, 1, 3).view(
|
||||
l, padded_m, -1
|
||||
)
|
||||
|
||||
for i in range(l):
|
||||
torch.testing.assert_close(ref_output[i, : mask[i]], output[i, : mask[i]])
|
||||
# We need to recover the swizzled scales to linear layout before applying mask slice.
|
||||
scale_ref = recover_swizzled_scales(ref_output_scales[i], m, k)
|
||||
scale_ans = recover_swizzled_scales(output_scales[i], m, k)
|
||||
torch.testing.assert_close(scale_ref[: mask[i]], scale_ans[: mask[i]])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
93
sgl-kernel/tests/test_fp8_blockwise_gemm.py
Normal file
93
sgl-kernel/tests/test_fp8_blockwise_gemm.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import os
|
||||
import random
|
||||
from typing import Optional, Type
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import fp8_blockwise_scaled_mm
|
||||
|
||||
|
||||
def cdiv(a: int, b: int) -> int:
|
||||
return -(a // -b)
|
||||
|
||||
|
||||
def scale_shape(shape, group_shape):
|
||||
assert len(shape) == len(group_shape)
|
||||
return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))
|
||||
|
||||
|
||||
def baseline_scaled_mm(
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: Type[torch.dtype],
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# We treat N-dimensional group scaling as extended numpy-style broadcasting
|
||||
# in numpy simply stretches dimensions with an extent of 1 to match the
|
||||
# the target shape by repeating the data along that dimension (broadcasting)
|
||||
# , we extend these semantics to say if the extent of a dimension in the
|
||||
# source shape is not 1 and does not match the target shape we repeat each
|
||||
# element along that dimension src_shape[dim] // target_shape[dim] times
|
||||
# example if we have:
|
||||
# a = [[1, 2], and target_shape = (2, 4)
|
||||
# [3, 4]]
|
||||
# then we would expand a to:
|
||||
# a = [[1, 1, 2, 2],
|
||||
# [3, 3, 4, 4]]
|
||||
# NOTE this function this function does not explicitly broadcast dimensions
|
||||
# with an extent of 1, since this can be done implicitly by pytorch
|
||||
def group_broadcast(t, shape):
|
||||
for i, s in enumerate(shape):
|
||||
if t.shape[i] != s and t.shape[i] != 1:
|
||||
assert s % t.shape[i] == 0
|
||||
t = (
|
||||
t.unsqueeze(i + 1)
|
||||
.expand(*t.shape[: i + 1], s // t.shape[i], *t.shape[i + 1 :])
|
||||
.flatten(i, i + 1)
|
||||
)
|
||||
return t
|
||||
|
||||
scale_a = group_broadcast(scale_a, a.shape)
|
||||
scale_b = group_broadcast(scale_b, b.shape)
|
||||
output = torch.mm(
|
||||
(scale_a * a.to(dtype=torch.float32)), (scale_b * b.to(dtype=torch.float32))
|
||||
).to(out_dtype)
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output
|
||||
|
||||
|
||||
def _test_accuracy_once(M, N, K, out_dtype, device):
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
a_fp32 = (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
|
||||
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
b_fp32 = (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
|
||||
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn).t()
|
||||
scale_a_group_shape = (1, 128)
|
||||
scale_b_group_shape = (128, 128)
|
||||
scale_a_shape = scale_shape(a_fp8.shape, scale_a_group_shape)
|
||||
scale_b_shape = scale_shape(b_fp8.shape, scale_b_group_shape)
|
||||
scale_a = torch.randn(scale_a_shape, device=device, dtype=torch.float32) * 0.001
|
||||
scale_b = torch.randn(scale_b_shape, device=device, dtype=torch.float32) * 0.001
|
||||
scale_a = scale_a.t().contiguous().t()
|
||||
scale_b = scale_b.t().contiguous().t()
|
||||
o = baseline_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype)
|
||||
o1 = fp8_blockwise_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype)
|
||||
rtol = 0.02
|
||||
atol = 1
|
||||
torch.testing.assert_close(o, o1, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M", [1, 3, 5, 127, 128, 512, 1024, 4096])
|
||||
@pytest.mark.parametrize("N", [128, 512, 1024, 4096, 8192, 14080])
|
||||
@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 14080, 16384])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
def test_accuracy(M, N, K, out_dtype):
|
||||
_test_accuracy_once(M, N, K, out_dtype, "cuda")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
221
sgl-kernel/tests/test_fp8_blockwise_moe.py
Executable file
221
sgl-kernel/tests/test_fp8_blockwise_moe.py
Executable file
@@ -0,0 +1,221 @@
|
||||
import random
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import fp8_blockwise_scaled_grouped_mm
|
||||
|
||||
|
||||
def cdiv(a: int, b: int) -> int:
|
||||
return -(a // -b)
|
||||
|
||||
|
||||
def scale_shape(shape, group_shape):
|
||||
return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))
|
||||
|
||||
|
||||
def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to(
|
||||
dtype=torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
|
||||
# Copy from: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/utils.py
|
||||
def calc_diff(x, y):
|
||||
x, y = x.double(), y.double()
|
||||
denominator = (x * x + y * y).sum()
|
||||
sim = 2 * (x * y).sum() / denominator
|
||||
return 1 - sim
|
||||
|
||||
|
||||
def ceil_div(x: int, y: int) -> int:
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
pad_size = (128 - (n % 128)) % 128
|
||||
x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x
|
||||
x_view = x.view(m, -1, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||
fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
|
||||
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros(
|
||||
(ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device
|
||||
)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
||||
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(
|
||||
x_view.size(0), x_view.size(2)
|
||||
)
|
||||
|
||||
|
||||
def baseline_scaled_mm(
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: type[torch.dtype],
|
||||
) -> torch.Tensor:
|
||||
|
||||
def group_broadcast(t, shape):
|
||||
for i, s in enumerate(shape):
|
||||
if t.shape[i] != s and t.shape[i] != 1:
|
||||
assert s % t.shape[i] == 0
|
||||
t = (
|
||||
t.unsqueeze(i + 1)
|
||||
.expand(*t.shape[: i + 1], s // t.shape[i], *t.shape[i + 1 :])
|
||||
.flatten(i, i + 1)
|
||||
)
|
||||
return t
|
||||
|
||||
scale_a = group_broadcast(scale_a, a.shape)
|
||||
scale_b = group_broadcast(scale_b, b.shape)
|
||||
|
||||
return torch.mm(
|
||||
(scale_a * a.to(dtype=torch.float32)), (scale_b * b.to(dtype=torch.float32))
|
||||
).to(out_dtype)
|
||||
|
||||
|
||||
def is_sm100_supported(device=None) -> bool:
|
||||
return (torch.cuda.get_device_capability(device)[0] == 10) and (
|
||||
torch.version.cuda >= "12.8"
|
||||
)
|
||||
|
||||
|
||||
def is_sm90_supported(device=None) -> bool:
|
||||
return (torch.cuda.get_device_capability(device)[0] == 9) and (
|
||||
torch.version.cuda >= "12.3"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not (is_sm100_supported() or is_sm90_supported()),
|
||||
reason="fp8_blockwise_scaled_grouped_mm at sgl-kernel is only supported on sm100 or sm90",
|
||||
)
|
||||
@pytest.mark.parametrize("num_experts", [8, 16, 32, 64, 128])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.half, torch.bfloat16])
|
||||
def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
|
||||
device = "cuda"
|
||||
alignment = 128
|
||||
n_g = random.randint(1, 64) * 128
|
||||
k_g = random.randint(1, 64) * 128
|
||||
|
||||
expert_offsets = torch.zeros((num_experts + 1), device=device, dtype=torch.int32)
|
||||
problem_sizes = torch.zeros((num_experts, 3), device=device, dtype=torch.int32)
|
||||
layout_sfa = torch.zeros((num_experts, 5), device=device, dtype=torch.int32)
|
||||
layout_sfb = torch.zeros((num_experts, 5), device=device, dtype=torch.int32)
|
||||
|
||||
a_tensors = []
|
||||
b_tensors = []
|
||||
a_scales_tensors = []
|
||||
b_scales_tensors = []
|
||||
baseline_tensors = []
|
||||
|
||||
for g in range(num_experts):
|
||||
m_g = random.randint(1, 256)
|
||||
expert_offsets[g + 1] = expert_offsets[g] + m_g
|
||||
problem_sizes[g][:] = torch.tensor([m_g, n_g, k_g], device=device)
|
||||
|
||||
a = torch.randn((m_g, k_g), device=device, dtype=out_dtype) # (M, K):(K, 1)
|
||||
b = torch.randn((n_g, k_g), device=device, dtype=out_dtype).t() # (K, N):(1, K)
|
||||
|
||||
a_g, a_scale = per_token_cast_to_fp8(
|
||||
a
|
||||
) # ag -- (M, K):(K, 1), a_scale() -- (M, k):(k, 1)
|
||||
b_g, b_scale = per_block_cast_to_fp8(
|
||||
b
|
||||
) # bg -- (K, N):(N, 1), b_scale() -- (k, n):(n, 1)
|
||||
a_tensors.append(a_g)
|
||||
b_tensors.append(b_g)
|
||||
a_scales_tensors.append(a_scale)
|
||||
b_scales_tensors.append(b_scale)
|
||||
|
||||
baseline = torch.mm(a, b)
|
||||
baseline_tensors.append(baseline)
|
||||
a_stack = torch.empty(
|
||||
(expert_offsets[-1], k_g), device=device, dtype=torch.float8_e4m3fn
|
||||
)
|
||||
b_stack = torch.empty(
|
||||
(num_experts, n_g, k_g), device=device, dtype=torch.float8_e4m3fn
|
||||
)
|
||||
a_scale_stack = torch.empty(
|
||||
(expert_offsets[-1], (k_g // 128)), device=device, dtype=torch.float32
|
||||
)
|
||||
b_scale_stack = torch.empty(
|
||||
(num_experts, n_g // 128, k_g // 128), device=device, dtype=torch.float32
|
||||
)
|
||||
|
||||
for g in range(num_experts):
|
||||
# Matrix A is Row-Major.
|
||||
a_stack[expert_offsets[g] : expert_offsets[g + 1], :] = a_tensors[
|
||||
g
|
||||
] # a_stack[expert_offsets[g] : expert_offsets[g + 1], :] -- (M, K):(K, 1)
|
||||
b_stack[g] = b_tensors[g].t() # b_stack[g] -- (N, K):(K, 1)
|
||||
|
||||
# We need K-Major scale factor
|
||||
a_scale_stack[expert_offsets[g] : expert_offsets[g + 1], :] = a_scales_tensors[
|
||||
g
|
||||
]
|
||||
b_scale_stack[g] = b_scales_tensors[
|
||||
g
|
||||
].t() # b_scale_stack[g] -- (k, n):(n, 1), we need transpose & contiguous later
|
||||
b_stack = b_stack.transpose(1, 2) # Transpose Matrix B to Column-Major.
|
||||
b_scale_stack = b_scale_stack.transpose(1, 2)
|
||||
|
||||
c_out = torch.empty((expert_offsets[-1], n_g), device=device, dtype=out_dtype)
|
||||
a_strides = torch.full(
|
||||
(num_experts,), a_stack.stride(0), device=device, dtype=torch.int64
|
||||
)
|
||||
c_strides = torch.full(
|
||||
(num_experts,), c_out.stride(0), device=device, dtype=torch.int64
|
||||
)
|
||||
workspace = torch.empty((1024 * 1024 * 1024), device=device, dtype=torch.uint8)
|
||||
a_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
|
||||
b_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
|
||||
out_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
|
||||
a_scales_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
|
||||
b_scales_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
|
||||
|
||||
fp8_blockwise_scaled_grouped_mm(
|
||||
c_out,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
a_stack,
|
||||
b_stack,
|
||||
a_scale_stack,
|
||||
b_scale_stack,
|
||||
a_strides,
|
||||
a_strides,
|
||||
c_strides,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
expert_offsets[:-1],
|
||||
workspace,
|
||||
)
|
||||
|
||||
for g in range(num_experts):
|
||||
baseline = baseline_tensors[g]
|
||||
actual = c_out[expert_offsets[g] : expert_offsets[g + 1]]
|
||||
diff = calc_diff(actual, baseline)
|
||||
assert diff < 0.001
|
||||
print(
|
||||
f"m_g={baseline.shape[0]} n_g={n_g} k_g={k_g} num_experts={num_experts}, out_dtype={out_dtype}, diff={diff:.5f}: OK"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
49
sgl-kernel/tests/test_fp8_gemm.py
Normal file
49
sgl-kernel/tests/test_fp8_gemm.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import fp8_scaled_mm
|
||||
|
||||
|
||||
def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias):
|
||||
o = torch.matmul(a.to(torch.float32), b.to(torch.float32))
|
||||
o = o.to(torch.float32)
|
||||
temp1 = o * scale_a.view(-1, 1)
|
||||
temp2 = temp1 * scale_b.view(1, -1)
|
||||
final = temp2.to(out_dtype)
|
||||
if bias is not None:
|
||||
final = final + bias.view(1, -1)
|
||||
return final
|
||||
|
||||
|
||||
def _test_accuracy_once(M, N, K, with_bias, out_dtype, device):
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
a_fp32 = (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
|
||||
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
b_fp32 = (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
|
||||
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
scale_a = torch.randn((M,), device=device, dtype=torch.float32) * 0.001
|
||||
scale_b = torch.randn((N,), device=device, dtype=torch.float32) * 0.001
|
||||
if with_bias:
|
||||
bias = torch.randn((N,), device=device, dtype=out_dtype)
|
||||
else:
|
||||
bias = None
|
||||
b_fp8 = b_fp8.t()
|
||||
o = torch_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias)
|
||||
o1 = fp8_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias)
|
||||
rtol = 0.02
|
||||
atol = 1
|
||||
torch.testing.assert_close(o, o1, rtol=rtol, atol=atol)
|
||||
print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M", [1, 128, 512, 1024, 4096])
|
||||
@pytest.mark.parametrize("N", [16, 128, 512, 1024, 4096])
|
||||
@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384])
|
||||
@pytest.mark.parametrize("with_bias", [True, False])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
def test_accuracy(M, N, K, with_bias, out_dtype):
|
||||
_test_accuracy_once(M, N, K, with_bias, out_dtype, "cuda")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
131
sgl-kernel/tests/test_gptq_kernel.py
Normal file
131
sgl-kernel/tests/test_gptq_kernel.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import gptq_gemm
|
||||
|
||||
from sglang.srt.layers.quantization.utils import pack_cols, pack_rows
|
||||
|
||||
|
||||
def torch_dequantize(q_weight, q_zeros, scales, g_idx, use_shuffle, bit, K, N):
|
||||
assert bit == 4, "Reference dequantization only supports 4-bit"
|
||||
group_size = K // scales.shape[0]
|
||||
pack_factor = 32 // bit
|
||||
|
||||
# unpack q_weight: (K//pack_factor, N) -> (K, N)
|
||||
unpacked_q_weight = torch.empty(
|
||||
q_weight.shape[0] * pack_factor,
|
||||
q_weight.shape[1],
|
||||
dtype=torch.uint8,
|
||||
device=q_weight.device,
|
||||
)
|
||||
for i in range(pack_factor):
|
||||
unpacked_q_weight[i::pack_factor, :] = (q_weight >> (i * 4)) & 0x0F
|
||||
|
||||
# unpack q_zeros: (num_groups, N//pack_factor) -> (num_groups, N)
|
||||
unpacked_q_zeros = torch.empty(
|
||||
q_zeros.shape[0],
|
||||
q_zeros.shape[1] * pack_factor,
|
||||
dtype=torch.uint8,
|
||||
device=q_zeros.device,
|
||||
)
|
||||
for i in range(pack_factor):
|
||||
unpacked_q_zeros[:, i::pack_factor] = (q_zeros >> (i * 4)) & 0x0F
|
||||
|
||||
unpacked_q_zeros += 1
|
||||
unpacked_q_zeros = unpacked_q_zeros.to(scales.dtype)
|
||||
|
||||
scale_zeros = unpacked_q_zeros * scales # (num_groups, N)
|
||||
|
||||
current_g_idx = torch.tensor(
|
||||
[i // group_size for i in range(K)], dtype=torch.int32, device=q_weight.device
|
||||
)
|
||||
|
||||
scale_mat = scales[current_g_idx] # (K, N)
|
||||
scale_zeros_mat = scale_zeros[current_g_idx] # (K, N)
|
||||
|
||||
# dequant: weight * scale - scale_zeros
|
||||
dequantized_b = unpacked_q_weight.to(scales.dtype) * scale_mat - scale_zeros_mat
|
||||
|
||||
return dequantized_b.reshape(K, N)
|
||||
|
||||
|
||||
def torch_gptq_gemm(
|
||||
a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_shuffle, bit
|
||||
):
|
||||
K, N = a.shape[1], b_q_weight.shape[1]
|
||||
|
||||
b_dequant = torch_dequantize(
|
||||
b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_shuffle, bit, K, N
|
||||
)
|
||||
c = torch.matmul(a, b_dequant)
|
||||
return c
|
||||
|
||||
|
||||
def _test_gptq_gemm_once(M, N, K, bit, group_size, use_shuffle, dtype, device="cuda"):
|
||||
|
||||
b_fp = torch.randn(K, N, dtype=dtype, device=device)
|
||||
|
||||
assert K % group_size == 0, "K must be divisible by group_size"
|
||||
num_groups = K // group_size
|
||||
|
||||
if use_shuffle:
|
||||
return
|
||||
else:
|
||||
g_idx = torch.tensor(
|
||||
[i // group_size for i in range(K)], dtype=torch.int32, device=device
|
||||
)
|
||||
b_shuffled = b_fp[g_idx]
|
||||
|
||||
b_grouped = b_shuffled.reshape(num_groups, group_size, N)
|
||||
|
||||
b_max = torch.max(b_grouped, dim=1, keepdim=True)[0]
|
||||
b_min = torch.min(b_grouped, dim=1, keepdim=True)[0]
|
||||
|
||||
scales = (b_max - b_min) / (2**bit - 1)
|
||||
scales = scales.clamp(min=1e-6)
|
||||
|
||||
zeros_float = (-b_min / scales).round()
|
||||
|
||||
q_b = (
|
||||
(b_grouped / scales + zeros_float).round().clamp(0, 2**bit - 1).to(torch.uint8)
|
||||
)
|
||||
|
||||
q_zeros_unpacked = zeros_float.to(torch.uint8) - 1
|
||||
|
||||
b_q_weight = pack_rows(q_b.reshape(K, N), bit, K, N)
|
||||
|
||||
q_zeros_unpacked = q_zeros_unpacked.reshape(num_groups, N)
|
||||
b_gptq_qzeros = pack_cols(q_zeros_unpacked, bit, num_groups, N)
|
||||
b_gptq_scales = scales.squeeze(1)
|
||||
|
||||
a = torch.randn(M, K, dtype=dtype, device=device)
|
||||
|
||||
c_ref = torch_gptq_gemm(
|
||||
a, b_q_weight, b_gptq_qzeros, b_gptq_scales, g_idx, use_shuffle, bit
|
||||
)
|
||||
c_out = gptq_gemm(
|
||||
a, b_q_weight, b_gptq_qzeros, b_gptq_scales, g_idx, use_shuffle, bit
|
||||
)
|
||||
|
||||
rtol = 4e-2
|
||||
atol = 4e-2
|
||||
torch.testing.assert_close(c_ref, c_out, rtol=rtol, atol=atol)
|
||||
print(
|
||||
f"✅ Test passed: M={M}, N={N}, K={K}, bit={bit}, group_size={group_size}, use_shuffle={use_shuffle}, dtype={dtype}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M", [1, 8, 128])
|
||||
@pytest.mark.parametrize("N", [2048, 4096])
|
||||
@pytest.mark.parametrize("K", [2048, 4096])
|
||||
@pytest.mark.parametrize("bit", [4])
|
||||
@pytest.mark.parametrize("group_size", [128])
|
||||
@pytest.mark.parametrize("use_shuffle", [False])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16])
|
||||
def test_gptq_gemm(M, N, K, bit, group_size, use_shuffle, dtype):
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA not available")
|
||||
_test_gptq_gemm_once(M, N, K, bit, group_size, use_shuffle, dtype, "cuda")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
48
sgl-kernel/tests/test_int8_gemm.py
Normal file
48
sgl-kernel/tests/test_int8_gemm.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import int8_scaled_mm
|
||||
from utils import is_sm10x
|
||||
|
||||
|
||||
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
|
||||
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
||||
|
||||
|
||||
def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias):
|
||||
o = torch.matmul(a.to(torch.float32), b.to(torch.float32))
|
||||
if bias is not None:
|
||||
o = o.to(torch.float32) * scale_a.view(-1, 1) * scale_b.view(1, -1) + bias
|
||||
else:
|
||||
o = o.to(torch.float32) * scale_a.view(-1, 1) * scale_b.view(1, -1)
|
||||
return o.to(out_dtype)
|
||||
|
||||
|
||||
def _test_accuracy_once(M, N, K, with_bias, out_dtype, device):
|
||||
a = to_int8(torch.randn((M, K), device=device) * 5)
|
||||
b = to_int8(torch.randn((N, K), device=device).t() * 5)
|
||||
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
|
||||
if with_bias:
|
||||
bias = torch.randn((N,), device="cuda", dtype=out_dtype) * 10
|
||||
else:
|
||||
bias = None
|
||||
o = int8_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
o1 = torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
torch.testing.assert_close(o, o1)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
is_sm10x(),
|
||||
reason="int8_scaled_mm is only supported on sm90 and lower",
|
||||
)
|
||||
@pytest.mark.parametrize("M", [1, 16, 32, 64, 128, 512, 1024, 4096, 8192])
|
||||
@pytest.mark.parametrize("N", [16, 128, 512, 1024, 4096, 8192, 16384])
|
||||
@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384])
|
||||
@pytest.mark.parametrize("with_bias", [True, False])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16])
|
||||
def test_accuracy(M, N, K, with_bias, out_dtype):
|
||||
_test_accuracy_once(M, N, K, with_bias, out_dtype, "cuda")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
485
sgl-kernel/tests/test_kvcacheio.py
Normal file
485
sgl-kernel/tests/test_kvcacheio.py
Normal file
@@ -0,0 +1,485 @@
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel.kvcacheio import (
|
||||
transfer_kv_all_layer,
|
||||
transfer_kv_all_layer_direct_lf_pf,
|
||||
transfer_kv_all_layer_mla,
|
||||
transfer_kv_direct,
|
||||
transfer_kv_per_layer,
|
||||
transfer_kv_per_layer_direct_pf_lf,
|
||||
transfer_kv_per_layer_mla,
|
||||
)
|
||||
|
||||
|
||||
def ref_copy_with_indices(src_pool, dst_pool, src_indices, dst_indices):
|
||||
dst_pool[dst_indices] = src_pool[src_indices].to(dst_pool.device)
|
||||
|
||||
|
||||
def ref_copy_with_indices_pf_direct(
|
||||
src_pool, dst_pool, src_indices, dst_indices, page_size, layer_id, lf_to_pf=False
|
||||
):
|
||||
if lf_to_pf:
|
||||
for i in range(0, len(src_indices), page_size):
|
||||
dst_pool[dst_indices[i] // page_size][layer_id] = src_pool[layer_id][
|
||||
src_indices[i : i + page_size]
|
||||
].to(dst_pool.device)
|
||||
else:
|
||||
for i in range(0, len(src_indices), page_size):
|
||||
dst_pool[layer_id][dst_indices[i : i + page_size]] = src_pool[
|
||||
src_indices[i] // page_size
|
||||
][layer_id].to(dst_pool.device)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("num_items_to_transfer", [1, 128, 1024])
|
||||
@pytest.mark.parametrize("page_size", [1, 16, 64])
|
||||
@pytest.mark.parametrize("item_size", [256])
|
||||
@pytest.mark.parametrize("total_items_in_pool", [10240])
|
||||
@pytest.mark.parametrize("is_mla", [False, True])
|
||||
@pytest.mark.parametrize("all_layers", [False, True])
|
||||
def test_transfer_kv(
|
||||
dtype: torch.dtype,
|
||||
num_items_to_transfer: int,
|
||||
item_size: int,
|
||||
page_size: int,
|
||||
total_items_in_pool: int,
|
||||
is_mla: bool,
|
||||
all_layers: bool,
|
||||
):
|
||||
"""
|
||||
Tests the per-layer transfer functions, treating tensors as memory pools.
|
||||
"""
|
||||
|
||||
original_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(dtype)
|
||||
device = "cuda"
|
||||
torch.cuda.manual_seed(42)
|
||||
|
||||
num_layers = 4 # A small number of layers for pool creation
|
||||
|
||||
total_pages_in_pool = total_items_in_pool // page_size
|
||||
num_pages_to_transfer = num_items_to_transfer // page_size
|
||||
if num_pages_to_transfer == 0:
|
||||
torch.set_default_dtype(original_dtype)
|
||||
return
|
||||
page_indices = torch.randperm(total_pages_in_pool, dtype=torch.int64)
|
||||
src_indices_host = torch.cat(
|
||||
[
|
||||
torch.arange(p * page_size, (p + 1) * page_size)
|
||||
for p in page_indices[:num_pages_to_transfer]
|
||||
]
|
||||
)
|
||||
src_indices_device = src_indices_host.to(device)
|
||||
dst_indices_host = torch.cat(
|
||||
[
|
||||
torch.arange(p * page_size, (p + 1) * page_size)
|
||||
for p in page_indices[num_pages_to_transfer : 2 * num_pages_to_transfer]
|
||||
]
|
||||
)
|
||||
dst_indices_device = dst_indices_host.to(device)
|
||||
|
||||
# Prepare memory pools based on whether it's an MLA case.
|
||||
if is_mla:
|
||||
src_pool_host = torch.randn(
|
||||
num_layers, total_items_in_pool, item_size
|
||||
).pin_memory()
|
||||
dst_pool_ref = torch.zeros_like(src_pool_host).to(device)
|
||||
dst_pool_kernel = torch.zeros_like(dst_pool_ref)
|
||||
dst_pool_direct = torch.zeros_like(dst_pool_ref)
|
||||
else:
|
||||
src_k_pool = torch.randn(
|
||||
num_layers, total_items_in_pool, item_size
|
||||
).pin_memory()
|
||||
src_v_pool = torch.randn(
|
||||
num_layers, total_items_in_pool, item_size
|
||||
).pin_memory()
|
||||
dst_k_pool_ref = torch.zeros_like(src_k_pool).to(device)
|
||||
dst_v_pool_ref = torch.zeros_like(src_v_pool).to(device)
|
||||
dst_k_pool_kernel = torch.zeros_like(dst_k_pool_ref)
|
||||
dst_v_pool_kernel = torch.zeros_like(dst_v_pool_ref)
|
||||
dst_k_pool_direct = torch.zeros_like(dst_k_pool_ref)
|
||||
dst_v_pool_direct = torch.zeros_like(dst_v_pool_ref)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# We will test the per-layer function on the first layer (index 0) of the pool.
|
||||
layer_idx_to_test = 0
|
||||
|
||||
if is_mla:
|
||||
if not all_layers:
|
||||
ref_copy_with_indices(
|
||||
src_pool_host[layer_idx_to_test],
|
||||
dst_pool_ref[layer_idx_to_test],
|
||||
src_indices_host,
|
||||
dst_indices_device,
|
||||
)
|
||||
transfer_kv_per_layer_mla(
|
||||
src_pool_host[layer_idx_to_test],
|
||||
dst_pool_kernel[layer_idx_to_test],
|
||||
src_indices_device,
|
||||
dst_indices_device,
|
||||
item_size=item_size * dtype.itemsize,
|
||||
)
|
||||
transfer_kv_direct(
|
||||
[src_pool_host[layer_idx_to_test]],
|
||||
[dst_pool_direct[layer_idx_to_test]],
|
||||
src_indices_host,
|
||||
dst_indices_device,
|
||||
page_size=page_size,
|
||||
)
|
||||
else:
|
||||
for layer_id in range(num_layers):
|
||||
ref_copy_with_indices(
|
||||
src_pool_host[layer_id],
|
||||
dst_pool_ref[layer_id],
|
||||
src_indices_host,
|
||||
dst_indices_device,
|
||||
)
|
||||
src_layers_device = torch.tensor(
|
||||
[src_pool_host[layer_id].data_ptr() for layer_id in range(num_layers)],
|
||||
dtype=torch.uint64,
|
||||
device=device,
|
||||
)
|
||||
dst_layers_device = torch.tensor(
|
||||
[
|
||||
dst_pool_kernel[layer_id].data_ptr()
|
||||
for layer_id in range(num_layers)
|
||||
],
|
||||
dtype=torch.uint64,
|
||||
device=device,
|
||||
)
|
||||
transfer_kv_all_layer_mla(
|
||||
src_layers_device,
|
||||
dst_layers_device,
|
||||
src_indices_device,
|
||||
dst_indices_device,
|
||||
item_size=item_size * dtype.itemsize,
|
||||
num_layers=num_layers,
|
||||
)
|
||||
transfer_kv_direct(
|
||||
[src_pool_host[layer_id] for layer_id in range(num_layers)],
|
||||
[dst_pool_direct[layer_id] for layer_id in range(num_layers)],
|
||||
src_indices_host,
|
||||
dst_indices_device,
|
||||
page_size=page_size,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(dst_pool_kernel, dst_pool_ref)
|
||||
torch.testing.assert_close(dst_pool_direct, dst_pool_ref)
|
||||
else:
|
||||
if not all_layers:
|
||||
ref_copy_with_indices(
|
||||
src_k_pool[layer_idx_to_test],
|
||||
dst_k_pool_ref[layer_idx_to_test],
|
||||
src_indices_host,
|
||||
dst_indices_device,
|
||||
)
|
||||
ref_copy_with_indices(
|
||||
src_v_pool[layer_idx_to_test],
|
||||
dst_v_pool_ref[layer_idx_to_test],
|
||||
src_indices_host,
|
||||
dst_indices_device,
|
||||
)
|
||||
transfer_kv_per_layer(
|
||||
src_k_pool[layer_idx_to_test],
|
||||
dst_k_pool_kernel[layer_idx_to_test],
|
||||
src_v_pool[layer_idx_to_test],
|
||||
dst_v_pool_kernel[layer_idx_to_test],
|
||||
src_indices_device,
|
||||
dst_indices_device,
|
||||
item_size=item_size * dtype.itemsize,
|
||||
)
|
||||
transfer_kv_direct(
|
||||
[src_k_pool[layer_idx_to_test], src_v_pool[layer_idx_to_test]],
|
||||
[
|
||||
dst_k_pool_direct[layer_idx_to_test],
|
||||
dst_v_pool_direct[layer_idx_to_test],
|
||||
],
|
||||
src_indices_host,
|
||||
dst_indices_device,
|
||||
page_size=page_size,
|
||||
)
|
||||
else:
|
||||
for layer_id in range(num_layers):
|
||||
ref_copy_with_indices(
|
||||
src_k_pool[layer_id],
|
||||
dst_k_pool_ref[layer_id],
|
||||
src_indices_host,
|
||||
dst_indices_device,
|
||||
)
|
||||
ref_copy_with_indices(
|
||||
src_v_pool[layer_id],
|
||||
dst_v_pool_ref[layer_id],
|
||||
src_indices_host,
|
||||
dst_indices_device,
|
||||
)
|
||||
|
||||
src_k_layers_device = torch.tensor(
|
||||
[src_k_pool[layer_id].data_ptr() for layer_id in range(num_layers)],
|
||||
dtype=torch.uint64,
|
||||
device=device,
|
||||
)
|
||||
src_v_layers_device = torch.tensor(
|
||||
[src_v_pool[layer_id].data_ptr() for layer_id in range(num_layers)],
|
||||
dtype=torch.uint64,
|
||||
device=device,
|
||||
)
|
||||
dst_k_layers_device = torch.tensor(
|
||||
[
|
||||
dst_k_pool_kernel[layer_id].data_ptr()
|
||||
for layer_id in range(num_layers)
|
||||
],
|
||||
dtype=torch.uint64,
|
||||
device=device,
|
||||
)
|
||||
dst_v_layers_device = torch.tensor(
|
||||
[
|
||||
dst_v_pool_kernel[layer_id].data_ptr()
|
||||
for layer_id in range(num_layers)
|
||||
],
|
||||
dtype=torch.uint64,
|
||||
device=device,
|
||||
)
|
||||
transfer_kv_all_layer(
|
||||
src_k_layers_device,
|
||||
dst_k_layers_device,
|
||||
src_v_layers_device,
|
||||
dst_v_layers_device,
|
||||
src_indices_device,
|
||||
dst_indices_device,
|
||||
item_size=item_size * dtype.itemsize,
|
||||
num_layers=num_layers,
|
||||
)
|
||||
transfer_kv_direct(
|
||||
[src_k_pool[layer_id] for layer_id in range(num_layers)]
|
||||
+ [src_v_pool[layer_id] for layer_id in range(num_layers)],
|
||||
[dst_k_pool_direct[layer_id] for layer_id in range(num_layers)]
|
||||
+ [dst_v_pool_direct[layer_id] for layer_id in range(num_layers)],
|
||||
src_indices_host,
|
||||
dst_indices_device,
|
||||
page_size=page_size,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(dst_k_pool_kernel, dst_k_pool_ref)
|
||||
torch.testing.assert_close(dst_v_pool_kernel, dst_v_pool_ref)
|
||||
torch.testing.assert_close(dst_k_pool_direct, dst_k_pool_ref)
|
||||
torch.testing.assert_close(dst_v_pool_direct, dst_v_pool_ref)
|
||||
|
||||
torch.set_default_dtype(original_dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("num_items_to_transfer", [128, 1024, 8192])
|
||||
@pytest.mark.parametrize("page_size", [16, 64, 128])
|
||||
@pytest.mark.parametrize("item_size", [256])
|
||||
@pytest.mark.parametrize("total_items_in_pool", [20480])
|
||||
@pytest.mark.parametrize("is_mla", [False, True])
|
||||
@pytest.mark.parametrize("lf_to_pf", [False, True])
|
||||
def test_transfer_kv_pf_direct(
|
||||
dtype: torch.dtype,
|
||||
num_items_to_transfer: int,
|
||||
item_size: int,
|
||||
page_size: int,
|
||||
total_items_in_pool: int,
|
||||
is_mla: bool,
|
||||
lf_to_pf: bool,
|
||||
):
|
||||
original_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(dtype)
|
||||
device = "cuda"
|
||||
torch.cuda.manual_seed(42)
|
||||
|
||||
num_layers = 4
|
||||
|
||||
total_pages_in_pool = total_items_in_pool // page_size
|
||||
num_pages_to_transfer = num_items_to_transfer // page_size
|
||||
if num_pages_to_transfer == 0:
|
||||
torch.set_default_dtype(original_dtype)
|
||||
return
|
||||
page_indices = torch.randperm(total_pages_in_pool, dtype=torch.int64)
|
||||
src_indices_host = torch.cat(
|
||||
[
|
||||
torch.arange(p * page_size, (p + 1) * page_size)
|
||||
for p in page_indices[:num_pages_to_transfer]
|
||||
]
|
||||
)
|
||||
src_indices_device = src_indices_host.to(device)
|
||||
dst_indices_host = torch.cat(
|
||||
[
|
||||
torch.arange(p * page_size, (p + 1) * page_size)
|
||||
for p in page_indices[num_pages_to_transfer : 2 * num_pages_to_transfer]
|
||||
]
|
||||
)
|
||||
dst_indices_device = dst_indices_host.to(device)
|
||||
|
||||
# We will test the per-layer function on the first layer (index 0) of the pool.
|
||||
layer_idx_to_test = 0
|
||||
|
||||
if lf_to_pf:
|
||||
if is_mla:
|
||||
src_pool = torch.randn(num_layers, total_items_in_pool, item_size).to(
|
||||
device
|
||||
)
|
||||
src_pool_ptrs = [src_pool[i] for i in range(num_layers)]
|
||||
dst_pool_ref = torch.zeros(
|
||||
total_pages_in_pool, num_layers, page_size, item_size
|
||||
).pin_memory()
|
||||
dst_pool_direct = torch.zeros_like(dst_pool_ref)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
transfer_kv_all_layer_direct_lf_pf(
|
||||
src_pool_ptrs,
|
||||
[dst_pool_direct],
|
||||
src_indices_host,
|
||||
dst_indices_host,
|
||||
page_size,
|
||||
)
|
||||
for i in range(num_layers):
|
||||
ref_copy_with_indices_pf_direct(
|
||||
src_pool,
|
||||
dst_pool_ref,
|
||||
src_indices_device,
|
||||
dst_indices_host,
|
||||
page_size,
|
||||
i,
|
||||
lf_to_pf=True,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(dst_pool_direct, dst_pool_ref)
|
||||
|
||||
else:
|
||||
src_k_pool = torch.randn(num_layers, total_items_in_pool, item_size).to(
|
||||
device
|
||||
)
|
||||
src_k_pool_ptrs = [src_k_pool[i] for i in range(num_layers)]
|
||||
src_v_pool = torch.randn(num_layers, total_items_in_pool, item_size).to(
|
||||
device
|
||||
)
|
||||
src_v_pool_ptrs = [src_v_pool[i] for i in range(num_layers)]
|
||||
dst_k_pool_ref = torch.zeros(
|
||||
total_pages_in_pool, num_layers, page_size, item_size
|
||||
).pin_memory()
|
||||
dst_v_pool_ref = torch.zeros_like(dst_k_pool_ref)
|
||||
dst_k_pool_direct = torch.zeros_like(dst_k_pool_ref)
|
||||
dst_v_pool_direct = torch.zeros_like(dst_v_pool_ref)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
transfer_kv_all_layer_direct_lf_pf(
|
||||
src_k_pool_ptrs + src_v_pool_ptrs,
|
||||
[dst_k_pool_direct, dst_v_pool_direct],
|
||||
src_indices_host,
|
||||
dst_indices_host,
|
||||
page_size,
|
||||
)
|
||||
for i in range(num_layers):
|
||||
ref_copy_with_indices_pf_direct(
|
||||
src_k_pool,
|
||||
dst_k_pool_ref,
|
||||
src_indices_device,
|
||||
dst_indices_host,
|
||||
page_size,
|
||||
i,
|
||||
lf_to_pf=True,
|
||||
)
|
||||
ref_copy_with_indices_pf_direct(
|
||||
src_v_pool,
|
||||
dst_v_pool_ref,
|
||||
src_indices_device,
|
||||
dst_indices_host,
|
||||
page_size,
|
||||
i,
|
||||
lf_to_pf=True,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(dst_k_pool_direct, dst_k_pool_ref)
|
||||
torch.testing.assert_close(dst_v_pool_direct, dst_v_pool_ref)
|
||||
else:
|
||||
if is_mla:
|
||||
src_pool = torch.randn(
|
||||
total_pages_in_pool, num_layers, page_size, item_size
|
||||
).pin_memory()
|
||||
|
||||
dst_pool_ref = torch.zeros(num_layers, total_items_in_pool, item_size).to(
|
||||
device
|
||||
)
|
||||
dst_pool_direct = torch.zeros_like(dst_pool_ref)
|
||||
dst_pool_direct_ptrs = [dst_pool_direct[i] for i in range(num_layers)]
|
||||
torch.cuda.synchronize()
|
||||
|
||||
transfer_kv_per_layer_direct_pf_lf(
|
||||
[src_pool],
|
||||
[dst_pool_direct_ptrs[layer_idx_to_test]],
|
||||
src_indices_host,
|
||||
dst_indices_host,
|
||||
layer_idx_to_test,
|
||||
page_size,
|
||||
)
|
||||
ref_copy_with_indices_pf_direct(
|
||||
src_pool,
|
||||
dst_pool_ref,
|
||||
src_indices_host,
|
||||
dst_indices_device,
|
||||
page_size,
|
||||
layer_idx_to_test,
|
||||
lf_to_pf=False,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(dst_pool_direct, dst_pool_ref)
|
||||
else:
|
||||
src_k_pool = torch.randn(
|
||||
total_pages_in_pool, num_layers, page_size, item_size
|
||||
).pin_memory()
|
||||
src_v_pool = torch.randn(
|
||||
total_pages_in_pool, num_layers, page_size, item_size
|
||||
).pin_memory()
|
||||
|
||||
dst_k_pool_ref = torch.zeros(num_layers, total_items_in_pool, item_size).to(
|
||||
device
|
||||
)
|
||||
dst_k_pool_direct = torch.zeros_like(dst_k_pool_ref)
|
||||
dst_k_pool_direct_ptrs = [dst_k_pool_direct[i] for i in range(num_layers)]
|
||||
|
||||
dst_v_pool_ref = torch.zeros_like(dst_k_pool_ref)
|
||||
dst_v_pool_direct = torch.zeros_like(dst_v_pool_ref)
|
||||
dst_v_pool_direct_ptrs = [dst_v_pool_direct[i] for i in range(num_layers)]
|
||||
torch.cuda.synchronize()
|
||||
|
||||
transfer_kv_per_layer_direct_pf_lf(
|
||||
[src_k_pool, src_v_pool],
|
||||
[
|
||||
dst_k_pool_direct_ptrs[layer_idx_to_test],
|
||||
dst_v_pool_direct_ptrs[layer_idx_to_test],
|
||||
],
|
||||
src_indices_host,
|
||||
dst_indices_host,
|
||||
layer_idx_to_test,
|
||||
page_size,
|
||||
)
|
||||
|
||||
ref_copy_with_indices_pf_direct(
|
||||
src_k_pool,
|
||||
dst_k_pool_ref,
|
||||
src_indices_host,
|
||||
dst_indices_device,
|
||||
page_size,
|
||||
layer_idx_to_test,
|
||||
lf_to_pf=False,
|
||||
)
|
||||
ref_copy_with_indices_pf_direct(
|
||||
src_v_pool,
|
||||
dst_v_pool_ref,
|
||||
src_indices_host,
|
||||
dst_indices_device,
|
||||
page_size,
|
||||
layer_idx_to_test,
|
||||
lf_to_pf=False,
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(dst_k_pool_direct, dst_k_pool_ref)
|
||||
torch.testing.assert_close(dst_v_pool_direct, dst_v_pool_ref)
|
||||
torch.set_default_dtype(original_dtype)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
84
sgl-kernel/tests/test_lightning_attention_decode.py
Normal file
84
sgl-kernel/tests/test_lightning_attention_decode.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import lightning_attention_decode
|
||||
|
||||
|
||||
def naive_lightning_attention_decode(q, k, v, past_kv, slope):
|
||||
"""Naive implementation of lightning attention decode"""
|
||||
original_dtype = q.dtype
|
||||
ratio = torch.exp(-slope) # [h, 1, 1]
|
||||
|
||||
kv = past_kv
|
||||
b, h, n, d = q.shape
|
||||
|
||||
output = []
|
||||
for i in range(n):
|
||||
kv = ratio * kv.to(torch.float32) + torch.einsum(
|
||||
"... n d, ... n e -> ... d e",
|
||||
k[:, :, i : i + 1],
|
||||
v[:, :, i : i + 1],
|
||||
)
|
||||
qkv = torch.einsum(
|
||||
"... n e, ... e d -> ... n d",
|
||||
q[:, :, i : i + 1].to(torch.float32),
|
||||
kv.to(torch.float32),
|
||||
)
|
||||
output.append(qkv)
|
||||
output = torch.cat(output, dim=-2)
|
||||
|
||||
return output.to(original_dtype), kv
|
||||
|
||||
|
||||
configs = [
|
||||
# (batch_size, num_heads, dim, embed_dim)
|
||||
(1, 8, 64, 64),
|
||||
(2, 8, 64, 64),
|
||||
(1, 32, 32, 64),
|
||||
(2, 32, 32, 64),
|
||||
(4, 32, 64, 64),
|
||||
(4, 32, 64, 64),
|
||||
(16, 64, 96, 96),
|
||||
(64, 64, 96, 96),
|
||||
]
|
||||
|
||||
dtypes = [torch.float32, torch.float16, torch.bfloat16]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
@pytest.mark.parametrize("dtype", dtypes)
|
||||
@pytest.mark.parametrize("batch_size,num_heads,dim,embed_dim", configs)
|
||||
def test_lightning_attention_decode(dtype, batch_size, num_heads, dim, embed_dim):
|
||||
device = torch.device("cuda")
|
||||
|
||||
q = torch.randn(batch_size, num_heads, 1, dim, device=device, dtype=dtype)
|
||||
k = torch.randn(batch_size, num_heads, 1, dim, device=device, dtype=dtype)
|
||||
v = torch.randn(batch_size, num_heads, 1, embed_dim, device=device, dtype=dtype)
|
||||
past_kv = torch.randn(batch_size, num_heads, dim, embed_dim, device=device)
|
||||
slope = torch.randn(num_heads, 1, 1, device=device)
|
||||
|
||||
ref_output, ref_new_kv = naive_lightning_attention_decode(q, k, v, past_kv, slope)
|
||||
|
||||
output = torch.empty_like(ref_output)
|
||||
new_kv = torch.empty_like(ref_new_kv)
|
||||
lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv)
|
||||
|
||||
rtol = 1e-2
|
||||
atol = 1e-2
|
||||
|
||||
torch.testing.assert_close(
|
||||
output,
|
||||
ref_output,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
new_kv,
|
||||
ref_new_kv,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
121
sgl-kernel/tests/test_marlin_gemm.py
Normal file
121
sgl-kernel/tests/test_marlin_gemm.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import gptq_marlin_gemm
|
||||
from sgl_kernel.scalar_type import scalar_types
|
||||
|
||||
from sglang.srt.layers.quantization.marlin_utils import marlin_make_workspace
|
||||
from sglang.test.test_marlin_utils import awq_marlin_quantize, marlin_quantize
|
||||
|
||||
MNK_FACTORS = [
|
||||
(1, 1, 1),
|
||||
(1, 4, 8),
|
||||
(1, 7, 5),
|
||||
(13, 17, 67),
|
||||
(26, 37, 13),
|
||||
(67, 13, 11),
|
||||
(257, 13, 11),
|
||||
(658, 13, 11),
|
||||
]
|
||||
|
||||
|
||||
# uint4 for awq
|
||||
# uint4b8 for gptq
|
||||
@pytest.mark.parametrize("k_chunk", [128])
|
||||
@pytest.mark.parametrize("n_chunk", [64, 256])
|
||||
@pytest.mark.parametrize("quant_type", [scalar_types.uint4, scalar_types.uint4b8])
|
||||
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("act_order", [False, True])
|
||||
@pytest.mark.parametrize("is_k_full", [False, True])
|
||||
@pytest.mark.parametrize("use_atomic_add", [False, True])
|
||||
@pytest.mark.parametrize("use_fp32_reduce", [False, True])
|
||||
def test_gptq_marlin_gemm(
|
||||
k_chunk,
|
||||
n_chunk,
|
||||
quant_type,
|
||||
group_size,
|
||||
mnk_factors,
|
||||
act_order,
|
||||
is_k_full,
|
||||
use_atomic_add,
|
||||
use_fp32_reduce,
|
||||
):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
|
||||
|
||||
size_m = m_factor
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
if act_order:
|
||||
if group_size == -1:
|
||||
return
|
||||
if group_size == size_k:
|
||||
return
|
||||
if has_zp:
|
||||
return
|
||||
|
||||
if size_k % group_size != 0:
|
||||
return
|
||||
|
||||
a_input = torch.randn((size_m, size_k), dtype=torch.float16, device="cuda")
|
||||
b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device="cuda")
|
||||
|
||||
if has_zp:
|
||||
# AWQ style, unsigned + runtime zero-point
|
||||
if group_size == 16:
|
||||
return
|
||||
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
|
||||
b_weight, quant_type, group_size
|
||||
)
|
||||
g_idx = None
|
||||
sort_indices = None
|
||||
marlin_s2 = None
|
||||
else:
|
||||
# GPTQ style, unsigned + symmetric bias
|
||||
if group_size == 16:
|
||||
return
|
||||
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
||||
b_weight, quant_type, group_size, act_order
|
||||
)
|
||||
marlin_zp = None
|
||||
marlin_s2 = None
|
||||
|
||||
workspace = marlin_make_workspace(w_ref.device)
|
||||
|
||||
# marlin gemm
|
||||
output = gptq_marlin_gemm(
|
||||
a_input,
|
||||
None,
|
||||
marlin_q_w,
|
||||
marlin_s,
|
||||
marlin_s2,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
workspace,
|
||||
quant_type,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
is_k_full=is_k_full,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=False,
|
||||
)
|
||||
# ref gemm
|
||||
output_ref = torch.matmul(a_input, w_ref)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
max_diff = torch.mean(torch.abs(output - output_ref)) / torch.mean(
|
||||
torch.abs(output_ref)
|
||||
)
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import subprocess
|
||||
|
||||
subprocess.call(["pytest", "--tb=short", str(__file__)])
|
||||
148
sgl-kernel/tests/test_marlin_repack.py
Normal file
148
sgl-kernel/tests/test_marlin_repack.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import awq_marlin_repack, gptq_marlin_repack
|
||||
from sgl_kernel.scalar_type import scalar_types
|
||||
|
||||
from sglang.srt.layers.quantization.utils import (
|
||||
gptq_quantize_weights,
|
||||
pack_cols,
|
||||
pack_rows,
|
||||
quantize_weights,
|
||||
sort_weights,
|
||||
)
|
||||
from sglang.test.test_marlin_utils import get_weight_perm, marlin_weights
|
||||
|
||||
GPTQ_MARLIN_TILE = 16
|
||||
MARLIN_K_CHUNKS = [128]
|
||||
MARLIN_N_CHUNKS = [64, 256]
|
||||
|
||||
MNK_FACTORS = [
|
||||
(1, 1, 1),
|
||||
(1, 4, 8),
|
||||
(1, 7, 5),
|
||||
(13, 17, 67),
|
||||
(26, 37, 13),
|
||||
(67, 13, 11),
|
||||
(257, 13, 11),
|
||||
(658, 13, 11),
|
||||
]
|
||||
|
||||
|
||||
def awq_pack(
|
||||
q_w: torch.Tensor,
|
||||
num_bits: int,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
):
|
||||
assert q_w.shape == (size_k, size_n)
|
||||
|
||||
# Interleave column dim (for the dequantize code) and pack it to int32
|
||||
if num_bits == 4:
|
||||
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
|
||||
elif num_bits == 8:
|
||||
interleave = np.array([0, 2, 1, 3])
|
||||
else:
|
||||
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
||||
|
||||
q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel()
|
||||
q_w = q_w.reshape((-1, size_n)).contiguous()
|
||||
|
||||
return pack_cols(q_w, num_bits, size_k, size_n)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_bits", [4, 8])
|
||||
@pytest.mark.parametrize("k_tiles,n_tiles", [(1, 1), (2, 2)])
|
||||
@pytest.mark.parametrize("group_size", [16, 32])
|
||||
def test_awq_marlin_repack_correct(num_bits, k_tiles, n_tiles, group_size):
|
||||
tile_k, tile_n = 16, 64
|
||||
size_k = k_tiles * tile_k
|
||||
size_n = n_tiles * tile_n
|
||||
pack_factor = 32 // num_bits
|
||||
|
||||
b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device="cuda")
|
||||
|
||||
w_ref, q_w, s, zp = quantize_weights(
|
||||
b_weight, scalar_types.uint4, group_size, zero_points=True
|
||||
)
|
||||
|
||||
q_w_awq = awq_pack(q_w, num_bits, size_k, size_n)
|
||||
|
||||
weight_perm = get_weight_perm(num_bits)
|
||||
q_w_marlin = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
|
||||
|
||||
out_gpu = awq_marlin_repack(q_w_awq, size_k, size_n, num_bits)
|
||||
assert out_gpu.is_cuda and out_gpu.dtype == torch.int32
|
||||
|
||||
expected_cols = size_n * tile_k // pack_factor
|
||||
assert list(out_gpu.shape) == [size_k // tile_k, expected_cols]
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
torch.testing.assert_close(out_gpu, q_w_marlin)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("quant_type", [scalar_types.uint4b8])
|
||||
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
|
||||
@pytest.mark.parametrize("act_order", [False, True])
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
def test_gptq_marlin_repack(
|
||||
k_chunk, n_chunk, quant_type, group_size, act_order, mnk_factors
|
||||
):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
# Filter act_order
|
||||
if act_order:
|
||||
if group_size == -1:
|
||||
return
|
||||
if group_size == size_k:
|
||||
return
|
||||
|
||||
# Normalize group_size
|
||||
if group_size == -1:
|
||||
group_size = size_k
|
||||
assert group_size <= size_k
|
||||
|
||||
if size_k % group_size != 0:
|
||||
pytest.skip("size_k must be divisible by group_size")
|
||||
|
||||
# Create input
|
||||
b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device="cuda")
|
||||
|
||||
# Quantize (and apply act_order if provided)
|
||||
w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
|
||||
b_weight, quant_type, group_size, act_order
|
||||
)
|
||||
|
||||
q_w_gptq = pack_rows(q_w, quant_type.size_bits, size_k, size_n)
|
||||
|
||||
# For act_order, sort the "weights" and "g_idx" so that group ids are
|
||||
# increasing
|
||||
sort_indices = torch.empty(0, dtype=torch.int, device=b_weight.device)
|
||||
if act_order:
|
||||
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
|
||||
|
||||
marlin_layout_perm = get_weight_perm(quant_type.size_bits)
|
||||
q_w_marlin_ref = marlin_weights(
|
||||
q_w, size_k, size_n, quant_type.size_bits, marlin_layout_perm
|
||||
)
|
||||
|
||||
# Run Marlin repack GPU kernel
|
||||
q_w_marlin = gptq_marlin_repack(
|
||||
q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
torch.testing.assert_close(q_w_marlin, q_w_marlin_ref)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import subprocess
|
||||
|
||||
subprocess.call(["pytest", "--tb=short", str(__file__)])
|
||||
142
sgl-kernel/tests/test_merge_state.py
Normal file
142
sgl-kernel/tests/test_merge_state.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/55576c626421b5ee7e7ebe74afd26465c8ae863f/flashinfer/triton/kernels/cascade.py
|
||||
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from sgl_kernel import merge_state
|
||||
|
||||
|
||||
def check_input(x: torch.Tensor):
|
||||
assert x.is_cuda, f"{str(x)} must be a CUDA Tensor"
|
||||
assert x.is_contiguous(), f"{str(x)} must be contiguous"
|
||||
|
||||
|
||||
def check_dim(d, x: torch.Tensor):
|
||||
assert x.dim() == d, f"{str(x)} must be a {d}D tensor"
|
||||
|
||||
|
||||
def check_shape(a: torch.Tensor, b: torch.Tensor):
|
||||
assert a.dim() == b.dim(), "tensors should have same dim"
|
||||
for i in range(a.dim()):
|
||||
assert a.size(i) == b.size(
|
||||
i
|
||||
), f"tensors shape mismatch, {a.size()} and {b.size()}"
|
||||
|
||||
|
||||
def check_device(tensors: List[torch.Tensor]):
|
||||
device = tensors[0].device
|
||||
for t in tensors:
|
||||
assert (
|
||||
t.device == device
|
||||
), f"All tensors should be on the same device, but got {device} and {t.device}"
|
||||
|
||||
|
||||
@triton.jit
|
||||
def state_merge(o, m, d, other_o, other_m, other_d):
|
||||
m_max = tl.maximum(m, other_m)
|
||||
d = d * tl.exp2(m - m_max) + other_d * tl.exp2(other_m - m_max)
|
||||
o = o * tl.exp2(m - m_max) + other_o * tl.exp2(other_m - m_max)
|
||||
return o, m_max, d
|
||||
|
||||
|
||||
@triton.jit
|
||||
def state_normalize(o, m, d):
|
||||
o = o / d
|
||||
return o, m, d
|
||||
|
||||
|
||||
@triton.jit
|
||||
def state_get_lse(o, m, d):
|
||||
return m + tl.log2(d)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def merge_state_kernel(
|
||||
v_a_ptr,
|
||||
s_a_ptr,
|
||||
v_b_ptr,
|
||||
s_b_ptr,
|
||||
v_merged_ptr,
|
||||
s_merged_ptr,
|
||||
num_heads,
|
||||
head_dim,
|
||||
bdx: tl.constexpr,
|
||||
bdy: tl.constexpr,
|
||||
):
|
||||
pos = tl.program_id(axis=0)
|
||||
for tx in tl.range(bdx):
|
||||
for head_idx in tl.range(bdy):
|
||||
s_a_val = tl.load(s_a_ptr + pos * num_heads + head_idx)
|
||||
s_b_val = tl.load(s_b_ptr + pos * num_heads + head_idx)
|
||||
|
||||
offsets = (pos * num_heads + head_idx) * head_dim + tx
|
||||
v_a = tl.load(v_a_ptr + offsets)
|
||||
v_b = tl.load(v_b_ptr + offsets)
|
||||
|
||||
v_merged, s_max, d = state_merge(
|
||||
o=v_a, m=s_a_val, d=1, other_o=v_b, other_m=s_b_val, other_d=1
|
||||
)
|
||||
v_merged, s_max, d = state_normalize(v_merged, s_max, d)
|
||||
v_merged_offset = (pos * num_heads + head_idx) * head_dim + tx
|
||||
tl.store(v_merged_ptr + v_merged_offset, v_merged)
|
||||
|
||||
if s_merged_ptr:
|
||||
tl.store(
|
||||
s_merged_ptr + pos * num_heads + head_idx,
|
||||
tl.log2(d) + s_max,
|
||||
)
|
||||
|
||||
|
||||
def merge_state_triton(
|
||||
v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor
|
||||
):
|
||||
check_input(v_a)
|
||||
check_input(s_a)
|
||||
check_input(v_b)
|
||||
check_input(s_b)
|
||||
check_device([v_a, s_a, v_b, s_b])
|
||||
check_dim(3, v_a)
|
||||
check_dim(2, s_a)
|
||||
check_dim(3, v_b)
|
||||
check_dim(2, s_b)
|
||||
check_shape(v_a, v_b)
|
||||
check_shape(s_a, s_b)
|
||||
assert v_a.size(0) == s_a.size(0)
|
||||
assert v_a.size(1) == s_b.size(1)
|
||||
s_a = s_a.to(torch.float32)
|
||||
s_b = s_b.to(torch.float32)
|
||||
seq_len = v_a.size(0)
|
||||
num_heads = v_a.size(1)
|
||||
head_dim = v_a.size(2)
|
||||
v_merged = torch.empty_like(v_a).to(s_a.device)
|
||||
s_merged = torch.empty((seq_len, num_heads)).to(s_a.device)
|
||||
bdx = head_dim
|
||||
bdy = num_heads
|
||||
|
||||
merge_state_kernel[lambda meta: (seq_len,)](
|
||||
v_a, s_a, v_b, s_b, v_merged, s_merged, num_heads, head_dim, bdx=bdx, bdy=bdy
|
||||
)
|
||||
|
||||
return v_merged, s_merged
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_len", [2048])
|
||||
@pytest.mark.parametrize("num_heads", [32])
|
||||
@pytest.mark.parametrize("head_dim", [128])
|
||||
def test_merge_state(seq_len, num_heads, head_dim):
|
||||
va = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
|
||||
sa = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
|
||||
vb = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
|
||||
sb = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
|
||||
v_merged, s_merged = merge_state_triton(va, sa, vb, sb)
|
||||
v_merged_std, s_merged_std = merge_state(va, sa, vb, sb)
|
||||
|
||||
assert torch.allclose(v_merged, v_merged_std, atol=1e-2)
|
||||
assert torch.allclose(s_merged, s_merged_std, atol=1e-2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
400
sgl-kernel/tests/test_merge_state_v2.py
Normal file
400
sgl-kernel/tests/test_merge_state_v2.py
Normal file
@@ -0,0 +1,400 @@
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from sgl_kernel import merge_state, merge_state_v2
|
||||
|
||||
|
||||
@triton.jit
|
||||
def merge_state_kernel(
|
||||
output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_merged
|
||||
output_lse, # [NUM_TOKENS, NUM_HEADS] s_merged
|
||||
prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_a
|
||||
prefix_lse, # [NUM_TOKENS, NUM_HEADS] s_a
|
||||
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_b
|
||||
suffix_lse, # [NUM_TOKENS, NUM_HEADS] s_b
|
||||
HEAD_SIZE: tl.constexpr,
|
||||
PADDED_HEAD_SIZE: tl.constexpr,
|
||||
OUTPUT_LSE: tl.constexpr,
|
||||
):
|
||||
token_idx = tl.program_id(0)
|
||||
num_tokens = tl.num_programs(0)
|
||||
head_idx = tl.program_id(1)
|
||||
num_heads = tl.num_programs(1)
|
||||
|
||||
p_lse = tl.load(prefix_lse + token_idx * num_heads + head_idx)
|
||||
s_lse = tl.load(suffix_lse + token_idx * num_heads + head_idx)
|
||||
p_lse = float("-inf") if p_lse == float("inf") else p_lse
|
||||
s_lse = float("-inf") if s_lse == float("inf") else s_lse
|
||||
|
||||
max_lse = tl.maximum(p_lse, s_lse)
|
||||
p_lse = p_lse - max_lse
|
||||
s_lse = s_lse - max_lse
|
||||
out_se = tl.exp(p_lse) + tl.exp(s_lse)
|
||||
|
||||
if OUTPUT_LSE:
|
||||
out_lse = tl.log(out_se) + max_lse
|
||||
tl.store(output_lse + token_idx * num_heads + head_idx, out_lse)
|
||||
|
||||
head_arange = tl.arange(0, PADDED_HEAD_SIZE)
|
||||
head_mask = head_arange < HEAD_SIZE
|
||||
p_out = tl.load(
|
||||
prefix_output
|
||||
+ token_idx * num_heads * HEAD_SIZE
|
||||
+ head_idx * HEAD_SIZE
|
||||
+ head_arange,
|
||||
mask=head_mask,
|
||||
)
|
||||
s_out = tl.load(
|
||||
suffix_output
|
||||
+ token_idx * num_heads * HEAD_SIZE
|
||||
+ head_idx * HEAD_SIZE
|
||||
+ head_arange,
|
||||
mask=head_mask,
|
||||
)
|
||||
|
||||
p_scale = tl.exp(p_lse) / out_se
|
||||
s_scale = tl.exp(s_lse) / out_se
|
||||
out = p_out * p_scale + s_out * s_scale
|
||||
tl.store(
|
||||
output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange,
|
||||
out,
|
||||
mask=head_mask,
|
||||
)
|
||||
|
||||
|
||||
def merge_state_triton(
|
||||
prefix_output: torch.Tensor,
|
||||
prefix_lse: torch.Tensor,
|
||||
suffix_output: torch.Tensor,
|
||||
suffix_lse: torch.Tensor,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_lse: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
num_tokens = output.shape[0]
|
||||
num_query_heads = output.shape[1]
|
||||
head_size = output.shape[2]
|
||||
padded_head_size = triton.next_power_of_2(head_size)
|
||||
# Avoid creating new tensors if they are already provided
|
||||
if output is None:
|
||||
output = torch.empty_like(prefix_output)
|
||||
if output_lse is None:
|
||||
output_lse = torch.empty_like(prefix_lse)
|
||||
|
||||
merge_state_kernel[(num_tokens, num_query_heads)](
|
||||
output,
|
||||
output_lse,
|
||||
prefix_output,
|
||||
prefix_lse,
|
||||
suffix_output,
|
||||
suffix_lse,
|
||||
head_size,
|
||||
padded_head_size,
|
||||
output_lse is not None,
|
||||
)
|
||||
return output, output_lse
|
||||
|
||||
|
||||
# Naive PyTorch Implements of Merge Attn States
|
||||
def merge_state_torch(
|
||||
prefix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
prefix_lse: torch.Tensor, # [NUM_TOKENS, NUM_HEADS]
|
||||
suffix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
suffix_lse: torch.Tensor, # [NUM_TOKENS, NUM_HEADS]
|
||||
output: Optional[torch.Tensor] = None, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
output_lse: Optional[torch.Tensor] = None, # [NUM_TOKENS, NUM_HEADS]
|
||||
):
|
||||
# Avoid creating new tensors if they are already provided
|
||||
if output is None:
|
||||
output = torch.empty_like(prefix_output)
|
||||
if output_lse is None:
|
||||
output_lse = torch.empty_like(prefix_lse)
|
||||
p_lse = prefix_lse
|
||||
s_lse = suffix_lse
|
||||
# inf -> -inf
|
||||
p_lse[p_lse == torch.inf] = -torch.inf
|
||||
s_lse[s_lse == torch.inf] = -torch.inf
|
||||
# max_lse [NUM_HEADS, NUM_TOKENS]
|
||||
max_lse = torch.maximum(p_lse, s_lse)
|
||||
p_lse = p_lse - max_lse
|
||||
s_lse = s_lse - max_lse
|
||||
p_lse_exp = torch.exp(p_lse)
|
||||
s_lse_exp = torch.exp(s_lse)
|
||||
out_se = p_lse_exp + s_lse_exp
|
||||
if output_lse is not None:
|
||||
output_lse = torch.log(out_se) + max_lse
|
||||
p_scale = p_lse_exp / out_se
|
||||
s_scale = s_lse_exp / out_se
|
||||
p_scale = p_scale.unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
|
||||
s_scale = s_scale.unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
|
||||
output = prefix_output * p_scale + suffix_output * s_scale
|
||||
return output, output_lse
|
||||
|
||||
|
||||
NUM_BATCH_TOKENS = [256, 512, 613, 1024, 1536]
|
||||
NUM_QUERY_HEADS = [8, 16, 32]
|
||||
HEAD_SIZES = [32, 48, 64, 128, 256]
|
||||
DTYPES = [torch.half, torch.bfloat16]
|
||||
|
||||
all_case_info: list[tuple] = []
|
||||
|
||||
|
||||
def generate_markdown_table():
|
||||
global all_case_info
|
||||
table_header = (
|
||||
"| tokens | heads | headsize | dtype "
|
||||
"| device | torch | triton | v1 | v2 | speedup(vs triton) | speedup(vs v1)|"
|
||||
)
|
||||
table_separator = (
|
||||
"| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |"
|
||||
)
|
||||
|
||||
def shortly_dtype(dtype: torch.dtype) -> str:
|
||||
return str(dtype).removeprefix("torch.")
|
||||
|
||||
def shortly_device(device: str) -> str:
|
||||
return device.removeprefix("NVIDIA").strip()
|
||||
|
||||
print(table_header)
|
||||
print(table_separator)
|
||||
for info in all_case_info:
|
||||
(
|
||||
num_tokens,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype,
|
||||
device,
|
||||
time_torch,
|
||||
time_triton,
|
||||
time_v1,
|
||||
time_v2,
|
||||
) = info
|
||||
dtype = shortly_dtype(dtype)
|
||||
device = shortly_device(device)
|
||||
improved_triton = time_triton / time_v2
|
||||
improved_v1 = time_v1 / time_v2
|
||||
print(
|
||||
f"| {num_tokens} | {num_heads} | {head_size} "
|
||||
f"| {dtype} | {device} | {time_torch:.4f}ms "
|
||||
f"| {time_triton:.4f}ms "
|
||||
f"| {time_v1:.4f}ms "
|
||||
f"| {time_v2:.4f}ms "
|
||||
f"| {improved_triton:.4f}x "
|
||||
f"| {improved_v1:.4f}x |"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_BATCH_TOKENS)
|
||||
@pytest.mark.parametrize("num_query_heads", NUM_QUERY_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("output_dtype", DTYPES)
|
||||
@torch.inference_mode()
|
||||
def test_merge_attn_states(
|
||||
num_tokens: int, num_query_heads: int, head_size: int, output_dtype: torch.dtype
|
||||
):
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip(
|
||||
"Currently only support compare triton merge_attn_states "
|
||||
"with custom cuda merge_attn_states kernel"
|
||||
)
|
||||
|
||||
NUM_TOKENS = num_tokens
|
||||
NUM_HEADS = num_query_heads
|
||||
HEAD_SIZE = head_size
|
||||
|
||||
print(
|
||||
f"\nNUM_TOKENS:{NUM_TOKENS}, NUM_HEADS:{NUM_HEADS}, "
|
||||
f"HEAD_SIZE:{HEAD_SIZE}, DTYPE: {output_dtype}, "
|
||||
f"Device: {torch.cuda.get_device_name()}"
|
||||
)
|
||||
|
||||
# prefix_lse and suffix_lse contain inf and normal values
|
||||
prefix_lse = torch.randn(NUM_TOKENS, NUM_HEADS, dtype=torch.float32, device="cuda")
|
||||
suffix_lse = torch.randn(NUM_TOKENS, NUM_HEADS, dtype=torch.float32, device="cuda")
|
||||
|
||||
# Generate boolean masks
|
||||
mask_prefix = torch.rand(NUM_TOKENS, NUM_HEADS) < 0.1
|
||||
mask_suffix = torch.rand(NUM_TOKENS, NUM_HEADS) < 0.1
|
||||
# Ensure that the same position is not True at the same time
|
||||
combined_mask = torch.logical_and(mask_prefix, mask_suffix)
|
||||
mask_prefix = torch.logical_and(mask_prefix, ~combined_mask)
|
||||
mask_suffix = torch.logical_and(mask_suffix, ~combined_mask)
|
||||
|
||||
prefix_lse[mask_prefix] = float("inf")
|
||||
suffix_lse[mask_suffix] = float("inf")
|
||||
|
||||
# Other input tensors (need to be initialized but
|
||||
# no actual calculation needed)
|
||||
output = torch.zeros(
|
||||
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda"
|
||||
)
|
||||
output_lse = torch.zeros(
|
||||
(NUM_TOKENS, NUM_HEADS), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
prefix_output = torch.randn(
|
||||
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda"
|
||||
)
|
||||
suffix_output = torch.randn(
|
||||
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda"
|
||||
)
|
||||
|
||||
warmup_times = 2
|
||||
repeat_times = 20
|
||||
|
||||
def perf_kernel_fn(
|
||||
output_fn: torch.Tensor,
|
||||
output_lse_fn: torch.Tensor,
|
||||
kernel_fn: callable,
|
||||
fn_type: str = "torch",
|
||||
):
|
||||
# Avoid inplace inf -> -inf, we have to use prefix_lse
|
||||
# and suffix_lse for other kernel.
|
||||
if fn_type == "torch":
|
||||
prefix_lse_ = prefix_lse.clone()
|
||||
suffix_lse_ = suffix_lse.clone()
|
||||
else:
|
||||
prefix_lse_ = prefix_lse
|
||||
suffix_lse_ = suffix_lse
|
||||
|
||||
if fn_type == "cuda_v1":
|
||||
# merge_state v1 kernel not support float32
|
||||
if output_dtype not in (torch.half, torch.bfloat16):
|
||||
return 0, output_fn, output_lse_fn
|
||||
|
||||
total_time = 0
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
try:
|
||||
for _ in range(warmup_times):
|
||||
output_fn, output_lse_fn = kernel_fn(
|
||||
prefix_output,
|
||||
prefix_lse_,
|
||||
suffix_output,
|
||||
suffix_lse_,
|
||||
output_fn,
|
||||
output_lse_fn,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
for _ in range(repeat_times):
|
||||
start.record()
|
||||
output_fn, output_lse_fn = kernel_fn(
|
||||
prefix_output,
|
||||
prefix_lse_,
|
||||
suffix_output,
|
||||
suffix_lse_,
|
||||
output_fn,
|
||||
output_lse_fn,
|
||||
)
|
||||
end.record()
|
||||
torch.cuda.synchronize()
|
||||
total_time += start.elapsed_time(end)
|
||||
|
||||
avg_time = total_time / repeat_times
|
||||
return avg_time, output_fn, output_lse_fn
|
||||
except Exception as e:
|
||||
return 0, output_fn, output_lse_fn
|
||||
|
||||
# 0. Run the Torch kernel
|
||||
output_torch = output.clone()
|
||||
output_lse_torch = output_lse.clone()
|
||||
time_torch, output_torch, output_lse_torch = perf_kernel_fn(
|
||||
output_torch, output_lse_torch, merge_state_torch, fn_type="torch"
|
||||
)
|
||||
|
||||
# 1. Run the Triton kernel
|
||||
output_ref_triton = output.clone()
|
||||
output_lse_ref_triton = output_lse.clone()
|
||||
time_triton, output_ref_triton, output_lse_ref_triton = perf_kernel_fn(
|
||||
output_ref_triton,
|
||||
output_lse_ref_triton,
|
||||
merge_state_triton,
|
||||
fn_type="triton",
|
||||
)
|
||||
|
||||
# 2. Run the merge_state V1 kernel
|
||||
output_v1 = output.clone()
|
||||
output_lse_v1 = output_lse.clone()
|
||||
time_v1, output_v1, output_lse_v1 = perf_kernel_fn(
|
||||
output_v1, output_lse_v1, merge_state, fn_type="cuda_v1"
|
||||
)
|
||||
|
||||
# 3. Run the merge_state V2 kernel
|
||||
output_v2 = output.clone()
|
||||
output_lse_v2 = output_lse.clone()
|
||||
time_v2, output_v2, output_lse_v2 = perf_kernel_fn(
|
||||
output_v2, output_lse_v2, merge_state_v2, fn_type="cuda_v2"
|
||||
)
|
||||
|
||||
# 4. Performance compare
|
||||
improved = time_triton / time_v2
|
||||
print(f" Torch time: {time_torch:.6f}ms")
|
||||
print(f" Triton time: {time_triton:.6f}ms")
|
||||
print(f"CUDA v1 time: {time_v1:.6f}ms")
|
||||
print(f"CUDA v2 time: {time_v2:.6f}ms, Performance: {improved:.5f}x")
|
||||
print("-" * 100)
|
||||
|
||||
# 5. Correctness compare
|
||||
# Liger Kernel: Efficient Triton Kernels for LLM Training
|
||||
# https://arxiv.org/pdf/2410.10989, 3.3 Correctness
|
||||
# use rtol = 1e-2 for bfloat16.
|
||||
rtol = 1e-2 if output_dtype == torch.bfloat16 else 1e-3
|
||||
|
||||
def diff(a: torch.Tensor, b: torch.Tensor):
|
||||
max_diff = torch.max(torch.abs(a.float() - b.float()))
|
||||
return max_diff
|
||||
|
||||
# Use Triton output as reference because we want to replace
|
||||
# the Triton kernel with custom CUDA kernel for merge attn
|
||||
# states operation.
|
||||
output_ref = output_ref_triton
|
||||
output_lse_ref = output_lse_ref_triton
|
||||
torch.testing.assert_close(
|
||||
output_v2.float(), output_ref.float(), atol=1e-3, rtol=rtol
|
||||
)
|
||||
print("Output all match, max abs diff:")
|
||||
print(f"(Triton vs Torch) : {diff(output_torch, output_ref)}")
|
||||
print(f"(CUDA v2 vs Torch) : {diff(output_torch, output_v2)}")
|
||||
print(f"(CUDA v2 vs Triton): {diff(output_ref, output_v2)}")
|
||||
print("-" * 100)
|
||||
|
||||
torch.testing.assert_close(
|
||||
output_lse_v2.float(), output_lse_ref.float(), atol=1e-3, rtol=rtol
|
||||
)
|
||||
print("Output LSE all match, max abs diff:")
|
||||
print(f"(Triton vs Torch) : {diff(output_lse_torch, output_lse_ref)}")
|
||||
print(f"(CUDA v2 vs Torch) : {diff(output_lse_torch, output_lse_v2)}")
|
||||
print(f"(CUDA v2 vs Triton): {diff(output_lse_ref, output_lse_v2)}")
|
||||
print("-" * 100)
|
||||
|
||||
print(
|
||||
"All output values test passed! All inf values "
|
||||
"are correctly replaced with -inf."
|
||||
)
|
||||
print("-" * 100)
|
||||
|
||||
device = torch.cuda.get_device_name()
|
||||
all_case_info.append(
|
||||
(
|
||||
NUM_TOKENS,
|
||||
NUM_HEADS,
|
||||
HEAD_SIZE,
|
||||
output_dtype,
|
||||
device,
|
||||
time_torch,
|
||||
time_triton,
|
||||
time_v1,
|
||||
time_v2,
|
||||
)
|
||||
)
|
||||
if len(all_case_info) == (
|
||||
len(NUM_BATCH_TOKENS) * len(HEAD_SIZES) * len(NUM_QUERY_HEADS) * len(DTYPES)
|
||||
):
|
||||
generate_markdown_table()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
250
sgl-kernel/tests/test_moe_align.py
Normal file
250
sgl-kernel/tests/test_moe_align.py
Normal file
@@ -0,0 +1,250 @@
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from sgl_kernel import moe_align_block_size
|
||||
|
||||
|
||||
def ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage1(
|
||||
topk_ids_ptr,
|
||||
tokens_cnts_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
numel: tl.constexpr,
|
||||
tokens_per_thread: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
start_idx = pid * tokens_per_thread
|
||||
off_c = (pid + 1) * num_experts
|
||||
|
||||
for i in range(tokens_per_thread):
|
||||
if start_idx + i < numel:
|
||||
idx = tl.load(topk_ids_ptr + start_idx + i)
|
||||
token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
|
||||
tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage2(
|
||||
tokens_cnts_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
last_cnt = 0
|
||||
for i in range(1, num_experts + 1):
|
||||
token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
|
||||
last_cnt = last_cnt + token_cnt
|
||||
tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage3(
|
||||
total_tokens_post_pad_ptr,
|
||||
tokens_cnts_ptr,
|
||||
cumsum_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
):
|
||||
last_cumsum = 0
|
||||
off_cnt = num_experts * num_experts
|
||||
for i in range(1, num_experts + 1):
|
||||
token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
|
||||
last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
|
||||
tl.store(cumsum_ptr + i, last_cumsum)
|
||||
tl.store(total_tokens_post_pad_ptr, last_cumsum)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage4(
|
||||
topk_ids_ptr,
|
||||
sorted_token_ids_ptr,
|
||||
expert_ids_ptr,
|
||||
tokens_cnts_ptr,
|
||||
cumsum_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
numel: tl.constexpr,
|
||||
tokens_per_thread: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
start_idx = tl.load(cumsum_ptr + pid)
|
||||
end_idx = tl.load(cumsum_ptr + pid + 1)
|
||||
|
||||
for i in range(start_idx, end_idx, block_size):
|
||||
tl.store(expert_ids_ptr + i // block_size, pid)
|
||||
|
||||
start_idx = pid * tokens_per_thread
|
||||
off_t = pid * num_experts
|
||||
|
||||
for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)):
|
||||
expert_id = tl.load(topk_ids_ptr + i)
|
||||
token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
|
||||
rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
|
||||
tl.store(sorted_token_ids_ptr + rank_post_pad, i)
|
||||
tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
|
||||
|
||||
|
||||
def moe_align_block_size_triton(
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
block_size: int,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_pad: torch.Tensor,
|
||||
) -> None:
|
||||
numel = topk_ids.numel()
|
||||
grid = (num_experts,)
|
||||
tokens_cnts = torch.zeros(
|
||||
(num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
|
||||
tokens_per_thread = ceil_div(numel, num_experts)
|
||||
|
||||
moe_align_block_size_stage1[grid](
|
||||
topk_ids,
|
||||
tokens_cnts,
|
||||
num_experts,
|
||||
numel,
|
||||
tokens_per_thread,
|
||||
)
|
||||
moe_align_block_size_stage2[grid](
|
||||
tokens_cnts,
|
||||
num_experts,
|
||||
)
|
||||
moe_align_block_size_stage3[(1,)](
|
||||
num_tokens_post_pad,
|
||||
tokens_cnts,
|
||||
cumsum,
|
||||
num_experts,
|
||||
block_size,
|
||||
)
|
||||
moe_align_block_size_stage4[grid](
|
||||
topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
tokens_cnts,
|
||||
cumsum,
|
||||
num_experts,
|
||||
block_size,
|
||||
numel,
|
||||
tokens_per_thread,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"block_size,num_tokens,topk,num_experts,pad_sorted_token_ids",
|
||||
list(
|
||||
itertools.product(
|
||||
[32, 64, 128, 256], # block_size
|
||||
[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], # num_tokens
|
||||
[1, 2, 4, 8, 16, 32, 64], # topk
|
||||
[64, 160, 256, 257, 260, 264], # num_experts
|
||||
[True, False], # pad_sorted_token_ids
|
||||
)
|
||||
),
|
||||
)
|
||||
def test_moe_align_block_size_compare_implementations(
|
||||
block_size, num_tokens, topk, num_experts, pad_sorted_token_ids
|
||||
):
|
||||
|
||||
topk_ids = torch.argsort(torch.rand(num_tokens, num_experts, device="cuda"), dim=1)[
|
||||
:, :topk
|
||||
]
|
||||
|
||||
max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1)
|
||||
|
||||
sorted_ids_cuda = torch.empty(
|
||||
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
if not pad_sorted_token_ids:
|
||||
sorted_ids_cuda.fill_(topk_ids.numel())
|
||||
max_num_m_blocks = max_num_tokens_padded // block_size
|
||||
expert_ids_cuda = torch.zeros(
|
||||
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
num_tokens_post_pad_cuda = torch.empty(
|
||||
(1), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
cumsum_buffer = torch.empty(
|
||||
num_experts + 2, dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
|
||||
sorted_ids_triton = torch.empty_like(sorted_ids_cuda)
|
||||
sorted_ids_triton.fill_(topk_ids.numel())
|
||||
expert_ids_triton = torch.zeros_like(expert_ids_cuda)
|
||||
num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda)
|
||||
|
||||
moe_align_block_size(
|
||||
topk_ids,
|
||||
num_experts + 1,
|
||||
block_size,
|
||||
sorted_ids_cuda,
|
||||
expert_ids_cuda,
|
||||
num_tokens_post_pad_cuda,
|
||||
cumsum_buffer,
|
||||
pad_sorted_token_ids,
|
||||
)
|
||||
|
||||
moe_align_block_size_triton(
|
||||
topk_ids,
|
||||
num_experts + 1,
|
||||
block_size,
|
||||
sorted_ids_triton,
|
||||
expert_ids_triton,
|
||||
num_tokens_post_pad_triton,
|
||||
)
|
||||
|
||||
assert torch.allclose(expert_ids_cuda, expert_ids_triton, atol=0, rtol=0), (
|
||||
f"Expert IDs mismatch for block_size={block_size}, "
|
||||
f"num_tokens={num_tokens}, topk={topk}\n"
|
||||
f"CUDA expert_ids: {expert_ids_cuda}\n"
|
||||
f"Triton expert_ids: {expert_ids_triton}"
|
||||
)
|
||||
|
||||
assert torch.allclose(
|
||||
num_tokens_post_pad_cuda, num_tokens_post_pad_triton, atol=0, rtol=0
|
||||
), (
|
||||
f"Num tokens post pad mismatch for block_size={block_size}, "
|
||||
f"num_tokens={num_tokens}, topk={topk}\n"
|
||||
f"CUDA num_tokens_post_pad: {num_tokens_post_pad_cuda}\n"
|
||||
f"Triton num_tokens_post_pad: {num_tokens_post_pad_triton}"
|
||||
)
|
||||
|
||||
# Select an expert to check
|
||||
expert_idx = expert_ids_cuda.max().item()
|
||||
|
||||
# Get the first and last block id where expert_ids_cuda == expert_idx
|
||||
matching_indices = torch.where(expert_ids_cuda == expert_idx)[0]
|
||||
block_sorted_start = matching_indices[0].item() * block_size
|
||||
block_sorted_end = min(
|
||||
(matching_indices[-1].item() + 1) * block_size, num_tokens_post_pad_cuda.item()
|
||||
)
|
||||
|
||||
selected_sorted_ids_cuda = sorted_ids_cuda[
|
||||
block_sorted_start:block_sorted_end
|
||||
].sort()[0]
|
||||
selected_sorted_ids_triton = sorted_ids_triton[
|
||||
block_sorted_start:block_sorted_end
|
||||
].sort()[0]
|
||||
|
||||
assert torch.allclose(
|
||||
selected_sorted_ids_cuda,
|
||||
selected_sorted_ids_triton,
|
||||
atol=0,
|
||||
rtol=0,
|
||||
), (
|
||||
f"Sorted IDs mismatch for block_size={block_size}, "
|
||||
f"num_tokens={num_tokens}, topk={topk}\n"
|
||||
f"CUDA sorted_ids: {selected_sorted_ids_cuda}\n"
|
||||
f"Triton sorted_ids: {selected_sorted_ids_triton}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
103
sgl-kernel/tests/test_moe_fused_gate.py
Normal file
103
sgl-kernel/tests/test_moe_fused_gate.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import moe_fused_gate
|
||||
|
||||
from sglang.srt.layers.moe.topk import biased_grouped_topk
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"seq_length",
|
||||
list(range(1, 10))
|
||||
+ [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"params",
|
||||
[
|
||||
(128, 4, 2, 4),
|
||||
(256, 8, 4, 8), # deepseek v3
|
||||
(512, 16, 8, 16),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("num_fused_shared_experts", [0, 1, 2])
|
||||
@pytest.mark.parametrize("apply_routed_scaling_factor_on_output", [False, True])
|
||||
def test_moe_fused_gate_combined(
|
||||
seq_length, params, num_fused_shared_experts, apply_routed_scaling_factor_on_output
|
||||
):
|
||||
num_experts, num_expert_group, topk_group, topk = params
|
||||
dtype = torch.float32
|
||||
|
||||
torch.manual_seed(seq_length)
|
||||
tensor = torch.rand((seq_length, num_experts), dtype=dtype, device="cuda")
|
||||
scores = tensor.clone()
|
||||
bias = torch.rand(num_experts, dtype=dtype, device="cuda")
|
||||
topk = topk + num_fused_shared_experts
|
||||
|
||||
output, indices = moe_fused_gate(
|
||||
tensor,
|
||||
bias,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
topk=topk,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
routed_scaling_factor=2.5,
|
||||
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
|
||||
)
|
||||
ref_output, ref_indices = biased_grouped_topk(
|
||||
scores,
|
||||
scores,
|
||||
bias,
|
||||
topk=topk,
|
||||
renormalize=True,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
routed_scaling_factor=2.5,
|
||||
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
|
||||
)
|
||||
|
||||
# When num_fused_shared_experts > 0, ignore the comparison of the last topk dimension
|
||||
if num_fused_shared_experts > 0:
|
||||
original_indices = indices.clone()
|
||||
original_ref_indices = ref_indices.clone()
|
||||
|
||||
indices = indices[:, :-1]
|
||||
ref_indices = ref_indices[:, :-1]
|
||||
|
||||
valid_min = num_experts
|
||||
valid_max = num_experts + num_fused_shared_experts
|
||||
shared_indices = original_indices[:, -1]
|
||||
shared_ref_indices = original_ref_indices[:, -1]
|
||||
if shared_indices is not None:
|
||||
assert torch.all(
|
||||
(shared_indices >= valid_min) & (shared_indices < valid_max)
|
||||
), f"Shared expert indices out of range: found values outside [{valid_min}, {valid_max})"
|
||||
if shared_ref_indices is not None:
|
||||
assert torch.all(
|
||||
(shared_ref_indices >= valid_min) & (shared_ref_indices < valid_max)
|
||||
), f"Shared expert reference indices out of range: found values outside [{valid_min}, {valid_max})"
|
||||
|
||||
idx_check = torch.allclose(
|
||||
ref_indices.sort()[0].to(torch.int32),
|
||||
indices.sort()[0].to(torch.int32),
|
||||
rtol=1e-04,
|
||||
atol=1e-05,
|
||||
)
|
||||
output_check = torch.allclose(
|
||||
ref_output.sort()[0].to(torch.float32),
|
||||
output.sort()[0].to(torch.float32),
|
||||
rtol=1e-02,
|
||||
atol=1e-03,
|
||||
)
|
||||
|
||||
assert idx_check, (
|
||||
f"Indices mismatch at seq_length {seq_length}, dtype {dtype}, "
|
||||
f"params {params}, num_fused_shared_experts {num_fused_shared_experts}"
|
||||
)
|
||||
assert output_check, (
|
||||
f"Output mismatch at seq_length {seq_length}, dtype {dtype}, "
|
||||
f"params {params}, num_fused_shared_experts {num_fused_shared_experts}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
139
sgl-kernel/tests/test_moe_topk_softmax.py
Normal file
139
sgl-kernel/tests/test_moe_topk_softmax.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import topk_softmax
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_tokens, num_experts, topk",
|
||||
list(
|
||||
itertools.product(
|
||||
[1, 16, 128, 512, 1024, 2048], # num_tokens
|
||||
[4, 8, 16, 32, 64, 128, 256], # num_experts
|
||||
[1, 2, 4], # topk
|
||||
)
|
||||
),
|
||||
)
|
||||
def test_topk_softmax(num_tokens, num_experts, topk):
|
||||
gating_output = torch.randn(
|
||||
(num_tokens, num_experts), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
|
||||
topk_weights = torch.empty((num_tokens, topk), dtype=torch.float32, device="cuda")
|
||||
topk_indices = torch.empty((num_tokens, topk), dtype=torch.int32, device="cuda")
|
||||
|
||||
topk_softmax(
|
||||
topk_weights,
|
||||
topk_indices,
|
||||
gating_output,
|
||||
)
|
||||
|
||||
# Native torch implementation
|
||||
softmax_output = torch.softmax(gating_output, dim=-1)
|
||||
topk_weights_ref, topk_indices_ref = torch.topk(softmax_output, topk, dim=-1)
|
||||
|
||||
# Verify the top-k weights and indices match the torch native ones
|
||||
assert torch.allclose(
|
||||
topk_weights_ref, topk_weights, atol=1e-3, rtol=1e-3
|
||||
), f"Weights mismatch: torch={topk_indices_ref} vs SGLang={topk_weights}"
|
||||
|
||||
assert torch.allclose(
|
||||
topk_indices_ref.int(), topk_indices, atol=0, rtol=0
|
||||
), f"Indices mismatch: torch={topk_indices_ref}, SGLang={topk_indices}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_tokens, num_experts, topk, dtype",
|
||||
list(
|
||||
itertools.product(
|
||||
[1, 16, 128, 512, 1024, 2048], # num_tokens
|
||||
[4, 8, 16, 32, 64, 128, 256], # num_experts
|
||||
[1, 2, 4], # topk
|
||||
[torch.float16, torch.bfloat16, torch.float32], # dtype
|
||||
)
|
||||
),
|
||||
)
|
||||
def test_topk_softmax_dtype_regression(num_tokens, num_experts, topk, dtype):
|
||||
gating_output = torch.randn((num_tokens, num_experts), dtype=dtype, device="cuda")
|
||||
|
||||
topk_weights = torch.empty((num_tokens, topk), dtype=torch.float32, device="cuda")
|
||||
topk_indices = torch.empty((num_tokens, topk), dtype=torch.int32, device="cuda")
|
||||
|
||||
topk_softmax(
|
||||
topk_weights,
|
||||
topk_indices,
|
||||
gating_output,
|
||||
)
|
||||
|
||||
topk_weights_ref = torch.empty(
|
||||
(num_tokens, topk), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
topk_indices_ref = torch.empty((num_tokens, topk), dtype=torch.int32, device="cuda")
|
||||
|
||||
topk_softmax(
|
||||
topk_weights_ref,
|
||||
topk_indices_ref,
|
||||
gating_output.float(),
|
||||
)
|
||||
|
||||
assert torch.allclose(
|
||||
topk_weights_ref, topk_weights, atol=1e-3, rtol=1e-3
|
||||
), f"Weights mismatch: SGLang old interface={topk_indices_ref} vs SGLang new interface={topk_weights}"
|
||||
|
||||
assert torch.allclose(
|
||||
topk_indices_ref.int(), topk_indices, atol=0, rtol=0
|
||||
), f"Indices mismatch: SGLang old interface={topk_indices_ref}, SGLang new interface={topk_indices}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_tokens, num_experts, topk",
|
||||
list(
|
||||
itertools.product(
|
||||
[1, 16, 128, 512, 1024, 2048], # num_tokens
|
||||
[4, 8, 16, 32, 64, 128, 256], # num_experts
|
||||
[1, 2, 4], # topk
|
||||
)
|
||||
),
|
||||
)
|
||||
def test_topk_softmax_renormalize(num_tokens, num_experts, topk):
|
||||
gating_output = torch.randn(
|
||||
(num_tokens, num_experts), dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
|
||||
topk_weights = torch.empty((num_tokens, topk), dtype=torch.float32, device="cuda")
|
||||
topk_indices = torch.empty((num_tokens, topk), dtype=torch.int32, device="cuda")
|
||||
|
||||
topk_softmax(
|
||||
topk_weights,
|
||||
topk_indices,
|
||||
gating_output,
|
||||
renormalize=True,
|
||||
)
|
||||
|
||||
topk_weights_ref = torch.empty(
|
||||
(num_tokens, topk), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
topk_indices_ref = torch.empty((num_tokens, topk), dtype=torch.int32, device="cuda")
|
||||
token_expert_indices_ref = torch.empty(
|
||||
(num_tokens, topk), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
topk_softmax(
|
||||
topk_weights_ref,
|
||||
topk_indices_ref,
|
||||
gating_output,
|
||||
)
|
||||
topk_weights_ref = topk_weights_ref / topk_weights_ref.sum(dim=-1, keepdim=True)
|
||||
|
||||
assert torch.allclose(
|
||||
topk_weights_ref, topk_weights, atol=1e-3, rtol=1e-3
|
||||
), f"Weights mismatch: SGLang w/o fused renormalize={topk_indices_ref} vs SGLang w/ fused renormalize={topk_weights}"
|
||||
|
||||
assert torch.allclose(
|
||||
topk_indices_ref.int(), topk_indices, atol=0, rtol=0
|
||||
), f"Indices mismatch: SGLang w/o fused renormalize={topk_indices_ref}, SGLang w/ fused renormalize={topk_indices}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
146
sgl-kernel/tests/test_mscclpp.py
Normal file
146
sgl-kernel/tests/test_mscclpp.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import socket
|
||||
import unittest
|
||||
from enum import IntEnum
|
||||
from typing import Any
|
||||
|
||||
import sgl_kernel.allreduce as custom_ops
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
class MscclContextSelection(IntEnum):
|
||||
MSCCL1SHOT1NODELL = 1
|
||||
MSCCL1SHOT2NODELL = 2
|
||||
|
||||
|
||||
def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes):
|
||||
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
|
||||
torch.cuda.set_device(device)
|
||||
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
init_method=distributed_init_method,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
group = dist.group.WORLD
|
||||
cpu_group = torch.distributed.new_group(list(range(world_size)), backend="gloo")
|
||||
if rank == 0:
|
||||
unique_id = [custom_ops.mscclpp_generate_unique_id()]
|
||||
else:
|
||||
unique_id = [None]
|
||||
dist.broadcast_object_list(
|
||||
unique_id, src=0, device=torch.device("cpu"), group=cpu_group
|
||||
)
|
||||
unique_id = unique_id[0]
|
||||
rank_to_node, rank_to_ib = list(range(world_size)), list(range(world_size))
|
||||
for r in range(world_size):
|
||||
rank_to_node[r] = r // 8
|
||||
rank_to_ib[r] = rank % 8
|
||||
MAX_BYTES = 2**20
|
||||
scratch = torch.empty(
|
||||
MAX_BYTES * 8, dtype=torch.bfloat16, device=torch.cuda.current_device()
|
||||
)
|
||||
put_buffer = torch.empty(
|
||||
MAX_BYTES, dtype=torch.bfloat16, device=torch.cuda.current_device()
|
||||
)
|
||||
print(f"[{rank}] start mscclpp_context init")
|
||||
nranks_per_node = torch.cuda.device_count()
|
||||
selection = int(MscclContextSelection.MSCCL1SHOT1NODELL)
|
||||
mscclpp_context = custom_ops.mscclpp_init_context(
|
||||
unique_id,
|
||||
rank,
|
||||
world_size,
|
||||
scratch,
|
||||
put_buffer,
|
||||
nranks_per_node,
|
||||
rank_to_node,
|
||||
rank_to_ib,
|
||||
selection,
|
||||
)
|
||||
try:
|
||||
test_loop = 10
|
||||
for sz in test_sizes:
|
||||
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
if sz * dtype.itemsize > MAX_BYTES:
|
||||
continue
|
||||
if rank == 0:
|
||||
print(f"mscclpp allreduce test sz {sz}, dtype {dtype}")
|
||||
for _ in range(test_loop):
|
||||
inp1 = torch.randint(1, 16, (sz,), dtype=dtype, device=device)
|
||||
inp1_ref = inp1.clone()
|
||||
out1 = torch.empty_like(inp1)
|
||||
custom_ops.mscclpp_allreduce(
|
||||
mscclpp_context, inp1, out1, nthreads=512, nblocks=21
|
||||
)
|
||||
dist.all_reduce(inp1_ref, group=group)
|
||||
torch.testing.assert_close(out1, inp1_ref)
|
||||
finally:
|
||||
dist.barrier(group=group)
|
||||
dist.destroy_process_group(group=group)
|
||||
|
||||
|
||||
def get_open_port() -> int:
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1]
|
||||
except OSError:
|
||||
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
|
||||
s.bind(("::1", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def multi_process_parallel(
|
||||
world_size: int, test_target: Any, target_args: tuple = ()
|
||||
) -> None:
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
procs = []
|
||||
distributed_init_port = get_open_port()
|
||||
for i in range(world_size):
|
||||
proc_args = (world_size, i, distributed_init_port) + target_args
|
||||
proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}")
|
||||
proc.start()
|
||||
procs.append(proc)
|
||||
|
||||
for i in range(world_size):
|
||||
procs[i].join()
|
||||
assert (
|
||||
procs[i].exitcode == 0
|
||||
), f"Process {i} failed with exit code {procs[i].exitcode}"
|
||||
|
||||
|
||||
class TestMSCCLAllReduce(unittest.TestCase):
|
||||
test_sizes = [
|
||||
512,
|
||||
2560,
|
||||
4096,
|
||||
5120,
|
||||
7680,
|
||||
32768,
|
||||
262144,
|
||||
524288,
|
||||
]
|
||||
world_sizes = [8]
|
||||
|
||||
def test_correctness(self):
|
||||
for world_size in self.world_sizes:
|
||||
available_gpus = torch.cuda.device_count()
|
||||
if world_size > available_gpus:
|
||||
print(
|
||||
f"Skipping world_size={world_size}, found {available_gpus} and now ray is not supported here"
|
||||
)
|
||||
continue
|
||||
|
||||
print(f"Running test for world_size={world_size}")
|
||||
multi_process_parallel(
|
||||
world_size, _run_correctness_worker, target_args=(self.test_sizes,)
|
||||
)
|
||||
print(f"custom allreduce tp = {world_size}: OK")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
133
sgl-kernel/tests/test_norm.py
Normal file
133
sgl-kernel/tests/test_norm.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/tests/test_norm.py
|
||||
|
||||
import pytest
|
||||
import sgl_kernel
|
||||
import torch
|
||||
|
||||
|
||||
def llama_rms_norm(x, w, eps=1e-6):
|
||||
orig_dtype = x.dtype
|
||||
x = x.float()
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + eps)
|
||||
x = x * w.float()
|
||||
x = x.to(orig_dtype)
|
||||
return x
|
||||
|
||||
|
||||
def gemma_rms_norm(x, w, eps=1e-6):
|
||||
orig_dtype = x.dtype
|
||||
x = x.float()
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + eps)
|
||||
x = x * (1.0 + w.float())
|
||||
x = x.to(orig_dtype)
|
||||
return x
|
||||
|
||||
|
||||
def gemma_fused_add_rms_norm(x, residual, w, eps=1e-6):
|
||||
orig_dtype = x.dtype
|
||||
x = x + residual
|
||||
residual = x
|
||||
x = x.float()
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + eps)
|
||||
x = x * (1.0 + w.float())
|
||||
x = x.to(orig_dtype)
|
||||
return x, residual
|
||||
|
||||
|
||||
def fused_add_rms_norm(x, residual, weight, eps):
|
||||
orig_dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
x = x + residual.to(torch.float32)
|
||||
residual = x.to(orig_dtype)
|
||||
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + eps)
|
||||
x = (x * weight.float()).to(orig_dtype)
|
||||
return x, residual
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
|
||||
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16])
|
||||
@pytest.mark.parametrize("specify_out", [True, False])
|
||||
def test_norm(batch_size, hidden_size, dtype, specify_out):
|
||||
x = torch.randn(batch_size, hidden_size).to(0).to(dtype)
|
||||
w = torch.randn(hidden_size).to(0).to(dtype)
|
||||
|
||||
y_ref = llama_rms_norm(x, w)
|
||||
if specify_out:
|
||||
y = torch.empty_like(x)
|
||||
sgl_kernel.rmsnorm(x, w, out=y)
|
||||
else:
|
||||
y = sgl_kernel.rmsnorm(x, w)
|
||||
|
||||
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
|
||||
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
|
||||
def test_fused_add_rmsnorm(batch_size, hidden_size, dtype):
|
||||
eps = 1e-6
|
||||
|
||||
x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda")
|
||||
residual = torch.randn_like(x)
|
||||
weight = torch.randn(hidden_size, dtype=dtype, device="cuda")
|
||||
|
||||
x_native, residual_native = fused_add_rms_norm(
|
||||
x.clone(), residual.clone(), weight, eps
|
||||
)
|
||||
|
||||
x_fused = x.clone()
|
||||
residual_fused = residual.clone()
|
||||
sgl_kernel.fused_add_rmsnorm(x_fused, residual_fused, weight, eps)
|
||||
|
||||
torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
|
||||
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16])
|
||||
@pytest.mark.parametrize("specify_out", [True, False])
|
||||
def test_gemma_norm(batch_size, hidden_size, dtype, specify_out):
|
||||
x = torch.randn(batch_size, hidden_size).to(0).to(dtype)
|
||||
w = torch.randn(hidden_size).to(0).to(dtype)
|
||||
|
||||
y_ref = gemma_rms_norm(x, w)
|
||||
if specify_out:
|
||||
y = torch.empty_like(x)
|
||||
sgl_kernel.gemma_rmsnorm(x, w, out=y)
|
||||
else:
|
||||
y = sgl_kernel.gemma_rmsnorm(x, w)
|
||||
|
||||
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
|
||||
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16])
|
||||
def test_gemma_fused_add_rmsnorm(batch_size, hidden_size, dtype):
|
||||
eps = 1e-6
|
||||
|
||||
x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda")
|
||||
residual = torch.randn_like(x)
|
||||
weight = torch.randn(hidden_size, dtype=dtype, device="cuda")
|
||||
|
||||
x_native, residual_native = gemma_fused_add_rms_norm(
|
||||
x.clone(), residual.clone(), weight, eps
|
||||
)
|
||||
|
||||
x_fused = x.clone()
|
||||
residual_fused = residual.clone()
|
||||
sgl_kernel.gemma_fused_add_rmsnorm(x_fused, residual_fused, weight, eps)
|
||||
|
||||
torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
67
sgl-kernel/tests/test_per_tensor_quant_fp8.py
Normal file
67
sgl-kernel/tests/test_per_tensor_quant_fp8.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import itertools
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import sgl_per_tensor_quant_fp8
|
||||
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
|
||||
|
||||
def sglang_scaled_fp8_quant(
|
||||
input: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
fp8_type_: torch.dtype = torch.float8_e4m3fn
|
||||
output = torch.empty_like(input, device=input.device, dtype=fp8_type_)
|
||||
is_static = True
|
||||
if scale is None:
|
||||
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||
is_static = False
|
||||
sgl_per_tensor_quant_fp8(input, output, scale, is_static)
|
||||
|
||||
return output, scale
|
||||
|
||||
|
||||
def torch_scaled_fp8_quant(tensor, inv_scale):
|
||||
# The reference implementation that fully aligns to
|
||||
# the kernel being tested.
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
scale = inv_scale.reciprocal()
|
||||
qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
qweight = qweight.to(torch.float8_e4m3fn)
|
||||
return qweight
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_tokens,hidden_dim",
|
||||
list(itertools.product([128, 256, 512], [512, 2048, 4096])),
|
||||
)
|
||||
def test_per_tensor_quant_compare_implementations(
|
||||
num_tokens: int,
|
||||
hidden_dim: int,
|
||||
):
|
||||
device = torch.device("cuda")
|
||||
x = torch.rand((num_tokens, hidden_dim), dtype=torch.float16, device=device)
|
||||
|
||||
sglang_out, sglang_scale = sglang_scaled_fp8_quant(x)
|
||||
torch_out = torch_scaled_fp8_quant(x, sglang_scale)
|
||||
|
||||
torch.testing.assert_close(
|
||||
sglang_out.float(), torch_out.float(), rtol=1e-3, atol=1e-3
|
||||
)
|
||||
|
||||
scale = torch.rand(1, dtype=torch.float32, device=device)
|
||||
sglang_out, sglang_scale = sglang_scaled_fp8_quant(x, scale)
|
||||
torch_out = torch_scaled_fp8_quant(x, scale)
|
||||
|
||||
torch.testing.assert_close(
|
||||
sglang_out.float(), torch_out.float(), rtol=1e-3, atol=1e-3
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
97
sgl-kernel/tests/test_per_token_group_quant_8bit.py
Normal file
97
sgl-kernel/tests/test_per_token_group_quant_8bit.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
per_token_group_quant_8bit as triton_per_token_group_quant_8bit,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit
|
||||
from sglang.srt.layers.quantization.utils import assert_fp8_all_close
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_tokens, hidden_dim, group_size, dst_dtype, flags",
|
||||
list(
|
||||
itertools.product(
|
||||
[127, 128, 512, 1024, 4096, 8192], # num_tokens
|
||||
[256, 512, 1024, 2048, 4096], # hidden_dim
|
||||
[8, 16, 32, 64, 128], # group_size
|
||||
# TODO test int8
|
||||
[fp8_type_], # dtype
|
||||
[
|
||||
dict(
|
||||
column_major_scales=False,
|
||||
scale_tma_aligned=False,
|
||||
scale_ue8m0=False,
|
||||
),
|
||||
dict(
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=False,
|
||||
scale_ue8m0=False,
|
||||
),
|
||||
dict(
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=True,
|
||||
scale_ue8m0=False,
|
||||
),
|
||||
dict(
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=True,
|
||||
scale_ue8m0=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
),
|
||||
)
|
||||
def test_per_token_group_quant_with_column_major(
|
||||
num_tokens,
|
||||
hidden_dim,
|
||||
group_size,
|
||||
dst_dtype,
|
||||
flags,
|
||||
):
|
||||
if flags["scale_ue8m0"] and ((group_size != 128) or (hidden_dim % 512 != 0)):
|
||||
pytest.skip()
|
||||
return
|
||||
if flags["scale_ue8m0"] and not deep_gemm_wrapper.DEEPGEMM_BLACKWELL:
|
||||
pytest.skip("scale_ue8m0 only supported on Blackwell")
|
||||
return
|
||||
|
||||
x = torch.randn(num_tokens, hidden_dim, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
execute_kwargs = dict(
|
||||
x=x,
|
||||
group_size=group_size,
|
||||
eps=1e-10,
|
||||
dst_dtype=dst_dtype,
|
||||
**flags,
|
||||
)
|
||||
|
||||
x_q_triton, x_s_triton = triton_per_token_group_quant_8bit(**execute_kwargs)
|
||||
x_q_sglang, x_s_sglang = sglang_per_token_group_quant_8bit(**execute_kwargs)
|
||||
|
||||
# torch.set_printoptions(profile="full")
|
||||
# print(f"{x_q_triton=}")
|
||||
# print(f"{x_s_triton=}")
|
||||
# print(f"{x_q_sglang=}")
|
||||
# print(f"{x_s_sglang=}")
|
||||
# torch.set_printoptions(profile="default")
|
||||
|
||||
assert_fp8_all_close(x_q_triton, x_q_sglang)
|
||||
torch.testing.assert_close(
|
||||
x_s_triton.contiguous(),
|
||||
x_s_sglang.contiguous(),
|
||||
rtol=1e-3,
|
||||
atol=1e-5,
|
||||
msg=lambda message: message + f" {x_s_triton=} {x_s_sglang=}",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
57
sgl-kernel/tests/test_per_token_quant_fp8.py
Normal file
57
sgl-kernel/tests/test_per_token_quant_fp8.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import itertools
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import sgl_per_token_quant_fp8
|
||||
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
|
||||
|
||||
def torch_per_token_quant_fp8(tensor, inv_scale):
|
||||
# The reference implementation that fully aligns to
|
||||
# the kernel being tested.
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
inv_scale = inv_scale.view(-1, 1)
|
||||
scale = inv_scale.reciprocal()
|
||||
qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
qweight = qweight.to(torch.float8_e4m3fn)
|
||||
return qweight
|
||||
|
||||
|
||||
def sglang_per_token_quant_fp8(
|
||||
input: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
scale = torch.zeros(input.size(0), device=input.device, dtype=torch.float32)
|
||||
output = torch.empty_like(input, device=input.device, dtype=fp8_type_)
|
||||
|
||||
sgl_per_token_quant_fp8(input, output, scale)
|
||||
scale = scale.reshape(-1, 1)
|
||||
|
||||
return output, scale
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_tokens,hidden_dim",
|
||||
list(itertools.product([128, 256, 512], [512, 1368, 2048, 4096])),
|
||||
)
|
||||
def test_per_token_quant_compare_implementations(
|
||||
num_tokens: int,
|
||||
hidden_dim: int,
|
||||
):
|
||||
device = torch.device("cuda")
|
||||
x = torch.rand((num_tokens, hidden_dim), dtype=torch.float16, device=device)
|
||||
|
||||
sglang_out, sglang_scale = sglang_per_token_quant_fp8(x)
|
||||
torch_out = torch_per_token_quant_fp8(x, sglang_scale)
|
||||
|
||||
torch.testing.assert_close(
|
||||
sglang_out.float(), torch_out.float(), rtol=1e-3, atol=1e-3
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
118
sgl-kernel/tests/test_qserve_w4a8_per_chn_gemm.py
Normal file
118
sgl-kernel/tests/test_qserve_w4a8_per_chn_gemm.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import qserve_w4a8_per_chn_gemm
|
||||
|
||||
|
||||
# Adapted from https://github.com/mit-han-lab/omniserve/blob/main/omniserve/modeling/layers/quantized_linear/w4a8_linear.py
|
||||
def convert_to_qserve_format(qweight, scale, zero):
|
||||
assert qweight.min() >= 0 and qweight.max() <= 15, "Quantized weight out of range"
|
||||
in_features = qweight.shape[1]
|
||||
out_features = qweight.shape[0]
|
||||
assert in_features % 32 == 0, "Input features must be divisible by 32"
|
||||
assert out_features % 32 == 0, "Output features must be divisible by 32"
|
||||
|
||||
# ---- Repack the weight ---- #
|
||||
# pack to M // 32, K // 32, (8, 4), ([2], 2, 2, 4)
|
||||
qweight_unpack_reorder = (
|
||||
qweight.reshape(
|
||||
out_features // 32,
|
||||
2,
|
||||
2,
|
||||
8,
|
||||
in_features // 32,
|
||||
2,
|
||||
4,
|
||||
4,
|
||||
)
|
||||
.permute(0, 4, 3, 6, 1, 5, 2, 7)
|
||||
.contiguous()
|
||||
)
|
||||
qweight_unpack_reorder = (
|
||||
qweight_unpack_reorder.permute(0, 1, 2, 3, 5, 6, 7, 4)
|
||||
.contiguous()
|
||||
.to(torch.int8)
|
||||
)
|
||||
# B_fp16_reorder = B_fp16_reorder[:, :, :, :, :, :, [3, 2, 1, 0]].contiguous()
|
||||
# [16, 0, 17, 1, ...]
|
||||
qweight_unpack_repacked = (
|
||||
qweight_unpack_reorder[..., 1] << 4
|
||||
) + qweight_unpack_reorder[..., 0]
|
||||
qweight_unpack_repacked = qweight_unpack_repacked.reshape(
|
||||
out_features // 32, in_features // 32, 32, 16
|
||||
)
|
||||
qweight_unpack_repacked = qweight_unpack_repacked.reshape(
|
||||
out_features, in_features // 2
|
||||
).contiguous()
|
||||
|
||||
# ---- Pack the scales ---- #
|
||||
scale = scale.reshape(out_features).to(torch.float16).contiguous()
|
||||
szero = zero.reshape(out_features).to(torch.float16).contiguous() * scale
|
||||
|
||||
return qweight_unpack_repacked, scale, szero
|
||||
|
||||
|
||||
# INT4 Quantization
|
||||
def asym_quantize_tensor(tensor):
|
||||
tensor_min = tensor.min(dim=-1, keepdim=True)[0]
|
||||
tensor_max = tensor.max(dim=-1, keepdim=True)[0]
|
||||
q_min = 0
|
||||
q_max = 15
|
||||
tensor_scale = (tensor_max - tensor_min) / (q_max - q_min)
|
||||
tensor_zero = q_min - torch.round(tensor_min / tensor_scale)
|
||||
tensor_q = torch.clamp(
|
||||
torch.round(tensor / tensor_scale) + tensor_zero, q_min, q_max
|
||||
).to(torch.int8)
|
||||
return tensor_q, tensor_scale.to(torch.float16), tensor_zero.to(torch.int8)
|
||||
|
||||
|
||||
# INT8 Quantization
|
||||
def sym_quantize_tensor(tensor):
|
||||
tensor_scale = tensor.abs().max(dim=-1, keepdim=True)[0] / 127
|
||||
tensor_q = torch.clamp(torch.round(tensor / tensor_scale), -128, 127).to(torch.int8)
|
||||
return tensor_q, tensor_scale.to(torch.float16)
|
||||
|
||||
|
||||
def torch_w4a8_per_chn_gemm(a, b, a_scale, b_scale, b_zero, out_dtype):
|
||||
print(a.shape)
|
||||
print(b.shape)
|
||||
print(b_zero.shape)
|
||||
o = torch.matmul(
|
||||
a.to(torch.float16), (b.to(torch.float16) - b_zero.to(torch.float16)).t()
|
||||
)
|
||||
o = o * a_scale.view(-1, 1) * b_scale.view(1, -1)
|
||||
return o.to(out_dtype)
|
||||
|
||||
|
||||
def _test_accuracy_once(M, N, K, out_dtype, device):
|
||||
# to avoid overflow, multiply 0.01
|
||||
a = torch.randn((M, K), device=device, dtype=torch.float32) * 0.01
|
||||
b = torch.randn((N, K), device=device, dtype=torch.float32) * 0.01
|
||||
|
||||
# symmetric quantize a
|
||||
a_q, a_scale = sym_quantize_tensor(a)
|
||||
# asymmetric quantize b
|
||||
b_q, b_scale, b_zero = asym_quantize_tensor(b)
|
||||
# convert to qserve format
|
||||
b_q_format, b_scale_format, b_szero_format = convert_to_qserve_format(
|
||||
b_q, b_scale, b_zero
|
||||
)
|
||||
|
||||
# cal sum of every row of a
|
||||
a_sum = a.sum(dim=-1, keepdim=True).to(torch.float16)
|
||||
out = qserve_w4a8_per_chn_gemm(
|
||||
a_q, b_q_format, b_scale_format, a_scale, b_szero_format, a_sum
|
||||
)
|
||||
ref_out = torch_w4a8_per_chn_gemm(a_q, b_q, a_scale, b_scale, b_zero, out_dtype)
|
||||
torch.testing.assert_close(out, ref_out, rtol=1e-3, atol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M", [1, 16, 32, 64, 128, 512, 1024, 4096, 8192])
|
||||
@pytest.mark.parametrize("N", [128, 512, 1024, 4096, 8192, 16384])
|
||||
@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.float16])
|
||||
def test_accuracy(M, N, K, out_dtype):
|
||||
_test_accuracy_once(M, N, K, out_dtype, "cuda")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
183
sgl-kernel/tests/test_qserve_w4a8_per_group_gemm.py
Normal file
183
sgl-kernel/tests/test_qserve_w4a8_per_group_gemm.py
Normal file
@@ -0,0 +1,183 @@
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import qserve_w4a8_per_group_gemm
|
||||
|
||||
|
||||
# Adapted from https://github.com/mit-han-lab/omniserve/blob/main/omniserve/modeling/layers/quantized_linear/w4a8_linear.py
|
||||
def convert_to_qserve_format(qweight, chn_scale, scale_i8, zero_i8, group_size):
|
||||
assert qweight.min() >= 0 and qweight.max() <= 15, "Quantized weight out of range"
|
||||
in_features = qweight.shape[1]
|
||||
out_features = qweight.shape[0]
|
||||
assert in_features % 32 == 0, "Input features must be divisible by 32"
|
||||
assert out_features % 32 == 0, "Output features must be divisible by 32"
|
||||
assert group_size == 128, "Group size must be 128"
|
||||
assert (
|
||||
in_features % group_size == 0
|
||||
), "Input features must be divisible by group_size"
|
||||
|
||||
# ---- Repack the weight ---- #
|
||||
# pack to M // 32, K // 32, (8, 4), ([2], 2, 2, 4)
|
||||
qweight_unpack_reorder = (
|
||||
qweight.reshape(
|
||||
out_features // 32,
|
||||
2,
|
||||
2,
|
||||
8,
|
||||
in_features // 32,
|
||||
2,
|
||||
4,
|
||||
4,
|
||||
)
|
||||
.permute(0, 4, 3, 6, 1, 5, 2, 7)
|
||||
.contiguous()
|
||||
)
|
||||
qweight_unpack_reorder = (
|
||||
qweight_unpack_reorder.permute(0, 1, 2, 3, 5, 6, 7, 4)
|
||||
.contiguous()
|
||||
.to(torch.int8)
|
||||
)
|
||||
# B_fp16_reorder = B_fp16_reorder[:, :, :, :, :, :, [3, 2, 1, 0]].contiguous()
|
||||
# [16, 0, 17, 1, ...]
|
||||
qweigth_unpack_repacked = (
|
||||
qweight_unpack_reorder[..., 1] << 4
|
||||
) + qweight_unpack_reorder[..., 0]
|
||||
qweigth_unpack_repacked = qweigth_unpack_repacked.reshape(
|
||||
out_features // 32, in_features // 32, 32, 16
|
||||
)
|
||||
qweigth_unpack_repacked = qweigth_unpack_repacked.reshape(
|
||||
out_features, in_features // 2
|
||||
)
|
||||
|
||||
# ---- Pack the scales ---- #
|
||||
chn_scale = chn_scale.reshape(out_features)
|
||||
|
||||
scale_i8 = (
|
||||
scale_i8.reshape(out_features, in_features // group_size)
|
||||
.transpose(0, 1)
|
||||
.contiguous()
|
||||
)
|
||||
scale_i8 = scale_i8.reshape(in_features // group_size, out_features // 32, 32)
|
||||
scale_i8 = (
|
||||
scale_i8.reshape(in_features // group_size, out_features // 32, 4, 8)
|
||||
.transpose(-2, -1)
|
||||
.contiguous()
|
||||
)
|
||||
scale_i8 = scale_i8.reshape(in_features // group_size, out_features).contiguous()
|
||||
|
||||
# ---- Pack the zeros ---- #
|
||||
zero_i8 = -zero_i8
|
||||
# zero_i8 = zero_i8.int() # convert to 2-complement
|
||||
|
||||
zero_i8 = (
|
||||
zero_i8.reshape(out_features, in_features // group_size)
|
||||
.transpose(0, 1)
|
||||
.contiguous()
|
||||
)
|
||||
zero_i8 = zero_i8.reshape(in_features // group_size, out_features // 32, 32)
|
||||
# for the last dimension, organize as 0, 8, 16, 24, 1, 9, 17, 25, ... following the requirement of tensor core gemm
|
||||
zero_i8 = (
|
||||
zero_i8.reshape(in_features // group_size, out_features // 32, 4, 8)
|
||||
.transpose(-2, -1)
|
||||
.contiguous()
|
||||
)
|
||||
zero_i8 = (
|
||||
zero_i8.reshape(in_features // group_size, out_features).contiguous() * scale_i8
|
||||
)
|
||||
|
||||
return qweigth_unpack_repacked, chn_scale, scale_i8, zero_i8
|
||||
|
||||
|
||||
# Progressive Group INT4 Quantization
|
||||
def progressive_group_quantize_tensor(tensor, group_size):
|
||||
assert group_size == 128, "Group size must be 128"
|
||||
assert (
|
||||
tensor.shape[-1] % group_size == 0
|
||||
), "Input features must be divisible by group_size"
|
||||
# Channel scale
|
||||
# NOTE(HandH1998): use protective quantization range
|
||||
chn_scale = tensor.abs().max(dim=-1, keepdim=True)[0] / 119
|
||||
tensor_i8 = torch.clamp(torch.round(tensor / chn_scale), -119, 119)
|
||||
|
||||
# Group scale
|
||||
tensor_i8 = tensor_i8.reshape(-1, group_size)
|
||||
tensor_i8_min = tensor_i8.min(dim=-1, keepdim=True)[0]
|
||||
tensor_i8_max = tensor_i8.max(dim=-1, keepdim=True)[0]
|
||||
q_min = 0
|
||||
q_max = 15
|
||||
scale_i8 = torch.round((tensor_i8_max - tensor_i8_min) / (q_max - q_min))
|
||||
zero_i8 = q_min - torch.round(tensor_i8_min / scale_i8)
|
||||
tensor_q = (
|
||||
torch.clamp(torch.round(tensor_i8 / scale_i8) + zero_i8, q_min, q_max)
|
||||
.reshape(tensor.shape[0], -1)
|
||||
.to(torch.int8)
|
||||
)
|
||||
return (
|
||||
tensor_q,
|
||||
chn_scale.to(torch.float16),
|
||||
scale_i8.reshape(tensor.shape[0], -1).to(torch.int8),
|
||||
zero_i8.reshape(tensor.shape[0], -1).to(torch.int8),
|
||||
)
|
||||
|
||||
|
||||
# INT8 Quantization
|
||||
def sym_quantize_tensor(tensor):
|
||||
tensor_scale = tensor.abs().max(dim=-1, keepdim=True)[0] / 127
|
||||
tensor_q = torch.clamp(torch.round(tensor / tensor_scale), -128, 127).to(torch.int8)
|
||||
return tensor_q, tensor_scale.to(torch.float16)
|
||||
|
||||
|
||||
def torch_w4a8_per_group_gemm(
|
||||
a, b, a_scale, b_chn_scale, b_scale_i8, b_zero_i8, group_size, out_dtype
|
||||
):
|
||||
assert group_size == 128, "Group size must be 128"
|
||||
b_dq = (
|
||||
b.reshape(-1, group_size).to(torch.float32)
|
||||
- b_zero_i8.reshape(-1, 1).to(torch.float32)
|
||||
) * b_scale_i8.reshape(-1, 1).to(torch.float32)
|
||||
b_dq = b_dq.reshape(b.shape[0], b.shape[1])
|
||||
o = torch.matmul(a.to(torch.float32), b_dq.t())
|
||||
o = o * a_scale.view(-1, 1) * b_chn_scale.view(1, -1)
|
||||
return o.to(out_dtype)
|
||||
|
||||
|
||||
def _test_accuracy_once(M, N, K, group_size, out_dtype, device):
|
||||
# to avoid overflow, multiply 0.01
|
||||
a = torch.randn((M, K), device=device, dtype=torch.float32) * 0.01
|
||||
b = torch.randn((N, K), device=device, dtype=torch.float32) * 0.01
|
||||
|
||||
# symmetric quantize a
|
||||
a_q, a_scale = sym_quantize_tensor(a)
|
||||
# asymmetric quantize b
|
||||
b_q, b_chn_scale, b_scale_i8, b_zero_i8 = progressive_group_quantize_tensor(
|
||||
b, group_size
|
||||
)
|
||||
# convert to qserve format
|
||||
b_q_format, b_chn_scale_format, b_scale_i8_format, b_zero_i8_format = (
|
||||
convert_to_qserve_format(b_q, b_chn_scale, b_scale_i8, b_zero_i8, group_size)
|
||||
)
|
||||
|
||||
out = qserve_w4a8_per_group_gemm(
|
||||
a_q,
|
||||
b_q_format,
|
||||
b_zero_i8_format,
|
||||
b_scale_i8_format,
|
||||
b_chn_scale_format,
|
||||
a_scale,
|
||||
)
|
||||
ref_out = torch_w4a8_per_group_gemm(
|
||||
a_q, b_q, a_scale, b_chn_scale, b_scale_i8, b_zero_i8, group_size, out_dtype
|
||||
)
|
||||
torch.testing.assert_close(out, ref_out, rtol=1e-3, atol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M", [1, 16, 32, 64, 128, 512, 1024, 4096, 8192])
|
||||
@pytest.mark.parametrize("N", [128, 512, 1024, 4096, 8192, 16384])
|
||||
@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384])
|
||||
@pytest.mark.parametrize("group_size", [128])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.float16])
|
||||
def test_accuracy(M, N, K, group_size, out_dtype):
|
||||
_test_accuracy_once(M, N, K, group_size, out_dtype, "cuda")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
138
sgl-kernel/tests/test_rotary_embedding.py
Normal file
138
sgl-kernel/tests/test_rotary_embedding.py
Normal file
@@ -0,0 +1,138 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace
|
||||
from sgl_kernel.testing.rotary_embedding import (
|
||||
FlashInferRotaryEmbedding,
|
||||
MHATokenToKVPool,
|
||||
RotaryEmbedding,
|
||||
create_inputs,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device, batch_size, seq_len, num_q_heads, num_kv_heads, save_kv_cache",
|
||||
[
|
||||
# GPT-OSS cases
|
||||
*[
|
||||
(
|
||||
64,
|
||||
64,
|
||||
4096,
|
||||
8000,
|
||||
True,
|
||||
torch.bfloat16,
|
||||
"cuda",
|
||||
batch_size,
|
||||
seq_len,
|
||||
64,
|
||||
8,
|
||||
save_kv_cache,
|
||||
)
|
||||
for batch_size, seq_len in (
|
||||
(1, 1),
|
||||
(32, 1),
|
||||
(128, 1),
|
||||
(512, 1),
|
||||
(2, 512),
|
||||
(4, 4096),
|
||||
)
|
||||
for save_kv_cache in (False, True)
|
||||
],
|
||||
# Other cases
|
||||
(64, 64, 32, 8000, True, torch.bfloat16, "cuda", 32, 32, 1, 1, False),
|
||||
(256, 128, 4096, 10000, True, torch.bfloat16, "cuda", 2, 512, 4, 2, False),
|
||||
(512, 128, 311, 10000, True, torch.bfloat16, "cuda", 3, 39, 4, 2, False),
|
||||
(128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 32, 8, False),
|
||||
(128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 16, 4, False),
|
||||
(512, 128, 311, 10000, False, torch.bfloat16, "cuda", 3, 39, 4, 2, False),
|
||||
],
|
||||
)
|
||||
def test_correctness(
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int,
|
||||
save_kv_cache: bool,
|
||||
):
|
||||
config = dict(
|
||||
head_size=head_size,
|
||||
rotary_dim=rotary_dim,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
base=base,
|
||||
is_neox_style=is_neox_style,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
rope_ref = RotaryEmbedding(**config).to(device)
|
||||
rope_flashinfer = FlashInferRotaryEmbedding(**config).to(device)
|
||||
|
||||
inputs = create_inputs(
|
||||
head_size=head_size,
|
||||
batch_size=batch_size,
|
||||
seq_len=seq_len,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
num_q_heads=num_q_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
)
|
||||
|
||||
if save_kv_cache:
|
||||
pool_ref = MHATokenToKVPool(head_num=num_kv_heads, head_dim=head_size)
|
||||
pool_flashinfer = MHATokenToKVPool(head_num=num_kv_heads, head_dim=head_size)
|
||||
|
||||
query_ref, key_ref = inputs["query"].clone(), inputs["key"].clone()
|
||||
query_flashinfer, key_flashinfer = inputs["query"].clone(), inputs["key"].clone()
|
||||
|
||||
query_ref_out, key_ref_out = rope_ref.forward_native(
|
||||
inputs["pos_ids"], query_ref, key_ref
|
||||
)
|
||||
if save_kv_cache:
|
||||
pool_ref.set_kv_buffer(
|
||||
loc=inputs["out_cache_loc"],
|
||||
cache_k=key_ref_out.view(-1, num_kv_heads, head_size),
|
||||
cache_v=inputs["value"].view(-1, num_kv_heads, head_size),
|
||||
)
|
||||
|
||||
query_flashinfer_out, key_flashinfer_out = rope_flashinfer.forward_cuda(
|
||||
inputs["pos_ids"],
|
||||
query_flashinfer,
|
||||
key_flashinfer,
|
||||
fused_set_kv_buffer_arg=(
|
||||
FusedSetKVBufferArg(
|
||||
value=inputs["value"],
|
||||
k_buffer=pool_flashinfer.k_buffer[0].view(-1, num_kv_heads * head_size),
|
||||
v_buffer=pool_flashinfer.v_buffer[0].view(-1, num_kv_heads * head_size),
|
||||
k_scale=None,
|
||||
v_scale=None,
|
||||
cache_loc=inputs["out_cache_loc"],
|
||||
)
|
||||
if save_kv_cache
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
query_ref_out, query_flashinfer_out, atol=1e-2, rtol=1e-2
|
||||
)
|
||||
torch.testing.assert_close(key_ref_out, key_flashinfer_out, atol=1e-2, rtol=1e-2)
|
||||
if save_kv_cache:
|
||||
for field in ["k_buffer", "v_buffer"]:
|
||||
x_ref = getattr(pool_ref, field)[0]
|
||||
x_flashinfer = getattr(pool_flashinfer, field)[0]
|
||||
torch.testing.assert_close(x_ref, x_flashinfer, atol=1e-2, rtol=1e-2)
|
||||
nonzero_ref = x_ref != 0
|
||||
nonzero_flashinfer = x_ref != 0
|
||||
assert torch.all(nonzero_ref == nonzero_flashinfer)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
185
sgl-kernel/tests/test_sampling.py
Normal file
185
sgl-kernel/tests/test_sampling.py
Normal file
@@ -0,0 +1,185 @@
|
||||
# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/93e1a2634e22355b0856246b032b285ad1d1da6b/tests/test_sampling.py
|
||||
|
||||
import pytest
|
||||
import sgl_kernel
|
||||
import torch
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 99, 989])
|
||||
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
|
||||
@pytest.mark.parametrize("k", [100])
|
||||
@pytest.mark.parametrize("p", [0.1, 0.5])
|
||||
def test_top_k_top_p_sampling_from_probs_logits_top_k_first_alignment(
|
||||
batch_size, vocab_size, k, p
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
logits = torch.randn(batch_size, vocab_size, device="cuda:0") * 5
|
||||
generator_logits = torch.Generator("cuda:0")
|
||||
generator_probs = generator_logits.clone_state()
|
||||
samples = sgl_kernel.sampling.top_k_top_p_sampling_from_logits(
|
||||
logits, k, p, filter_apply_order="top_k_first", generator=generator_logits
|
||||
)
|
||||
samples_ref = sgl_kernel.sampling.top_k_top_p_sampling_from_probs(
|
||||
torch.softmax(logits, dim=-1),
|
||||
k,
|
||||
p,
|
||||
filter_apply_order="top_k_first",
|
||||
generator=generator_probs,
|
||||
)
|
||||
assert torch.all(samples == samples_ref)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 99, 989])
|
||||
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
|
||||
@pytest.mark.parametrize("k", [100])
|
||||
@pytest.mark.parametrize("p", [0.1, 0.5])
|
||||
def test_top_k_top_p_sampling_from_probs_logits_joint_alignment(
|
||||
batch_size, vocab_size, k, p
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
logits = torch.randn(batch_size, vocab_size, device="cuda:0") * 5
|
||||
generator_logits = torch.Generator("cuda:0")
|
||||
generator_probs = generator_logits.clone_state()
|
||||
samples = sgl_kernel.sampling.top_k_top_p_sampling_from_logits(
|
||||
logits, k, p, filter_apply_order="joint", generator=generator_logits
|
||||
)
|
||||
samples_ref = sgl_kernel.sampling.top_k_top_p_sampling_from_probs(
|
||||
torch.softmax(logits, dim=-1),
|
||||
k,
|
||||
p,
|
||||
filter_apply_order="joint",
|
||||
generator=generator_probs,
|
||||
)
|
||||
assert torch.all(samples == samples_ref)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 99, 989])
|
||||
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
|
||||
@pytest.mark.parametrize("p", [0.1, 0.5])
|
||||
def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p):
|
||||
torch.manual_seed(42)
|
||||
if p == 0.1:
|
||||
k = int(vocab_size * 0.5)
|
||||
elif p == 0.5:
|
||||
k = int(vocab_size * 0.1)
|
||||
else:
|
||||
raise ValueError("p not recognized")
|
||||
eps = 1e-4
|
||||
pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0")
|
||||
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
|
||||
# top-p mask
|
||||
sorted_prob, indices = torch.sort(normalized_prob, descending=False)
|
||||
cdf = torch.cumsum(sorted_prob, dim=-1)
|
||||
mask_top_p = torch.zeros(batch_size, vocab_size, dtype=torch.int32, device="cuda:0")
|
||||
mask_top_p.scatter_add_(1, indices, (cdf > (1 - p) - eps).int())
|
||||
# top-k mask
|
||||
sorted_prob, _ = torch.sort(normalized_prob, descending=True)
|
||||
pivot = sorted_prob[:, k - 1]
|
||||
mask_top_k = (normalized_prob >= pivot.unsqueeze(-1)).int()
|
||||
# overall mask
|
||||
mask = torch.minimum(mask_top_p, mask_top_k)
|
||||
top_p_tensor = torch.full((batch_size,), p, device="cuda:0")
|
||||
top_k_tensor = torch.full((batch_size,), k, device="cuda:0")
|
||||
|
||||
num_trails = 1000
|
||||
for _ in range(num_trails):
|
||||
samples = sgl_kernel.top_k_top_p_sampling_from_probs(
|
||||
normalized_prob,
|
||||
top_k_tensor,
|
||||
top_p_tensor,
|
||||
filter_apply_order="joint",
|
||||
)
|
||||
assert torch.all(samples < vocab_size) and torch.all(samples >= 0)
|
||||
assert torch.all(mask[torch.arange(batch_size), samples] == 1), normalized_prob[
|
||||
torch.arange(batch_size), samples
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 99, 989])
|
||||
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
|
||||
@pytest.mark.parametrize("p", [0.1, 0.5, 0.9])
|
||||
def test_top_p_renorm_probs(batch_size, vocab_size, p):
|
||||
torch.manual_seed(42)
|
||||
pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0")
|
||||
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
|
||||
sorted_prob, indices = torch.sort(normalized_prob, descending=False)
|
||||
cdf = torch.cumsum(sorted_prob, dim=-1)
|
||||
mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32, device="cuda:0")
|
||||
mask.scatter_add_(1, indices, (cdf >= (1 - p)).int())
|
||||
renorm_prob_ground_truth = normalized_prob.clone()
|
||||
renorm_prob_ground_truth[mask == 0] = 0
|
||||
renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum(
|
||||
dim=-1, keepdim=True
|
||||
)
|
||||
|
||||
renorm_prob = sgl_kernel.top_p_renorm_prob(normalized_prob, p)
|
||||
torch.testing.assert_close(
|
||||
renorm_prob_ground_truth,
|
||||
renorm_prob,
|
||||
rtol=1e-3,
|
||||
atol=1e-3,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 99, 989])
|
||||
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
|
||||
@pytest.mark.parametrize("k", [10, 100, 500])
|
||||
def test_top_k_renorm_probs(batch_size, vocab_size, k):
|
||||
if k > vocab_size:
|
||||
pytest.skip("k should be less than vocab_size")
|
||||
torch.manual_seed(42)
|
||||
pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0")
|
||||
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
|
||||
sorted_prob, _ = torch.sort(normalized_prob, descending=True)
|
||||
pivot = sorted_prob[:, k - 1]
|
||||
mask = (normalized_prob >= pivot.unsqueeze(-1)).int()
|
||||
renorm_prob_ground_truth = normalized_prob.clone()
|
||||
renorm_prob_ground_truth[mask == 0] = 0
|
||||
renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum(
|
||||
dim=-1, keepdim=True
|
||||
)
|
||||
|
||||
renorm_prob = sgl_kernel.top_k_renorm_prob(normalized_prob, k)
|
||||
for i in range(batch_size):
|
||||
torch.testing.assert_close(
|
||||
renorm_prob_ground_truth[i],
|
||||
renorm_prob[i],
|
||||
rtol=1e-3,
|
||||
atol=1e-3,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 99, 989])
|
||||
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
|
||||
@pytest.mark.parametrize("p", [0.05, 0.1, 0.2, 0.7, 1])
|
||||
def test_min_p_sampling(batch_size, vocab_size, p):
|
||||
torch.manual_seed(42)
|
||||
pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0")
|
||||
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
|
||||
sorted_prob, indices = torch.sort(normalized_prob, descending=False)
|
||||
# scale min-p
|
||||
top_probs = sorted_prob[:, -1].unsqueeze(-1)
|
||||
scaled_p = p * top_probs
|
||||
# min-p mask
|
||||
mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32, device="cuda:0")
|
||||
mask.scatter_add_(1, indices, (sorted_prob >= scaled_p).int())
|
||||
min_p_tensor = torch.full((batch_size,), p, device="cuda:0")
|
||||
|
||||
num_trails = 1000
|
||||
for _ in range(num_trails):
|
||||
samples = sgl_kernel.min_p_sampling_from_probs(
|
||||
normalized_prob,
|
||||
min_p_tensor,
|
||||
)
|
||||
|
||||
assert torch.all(mask[torch.arange(batch_size), samples] == 1), samples[
|
||||
torch.nonzero(mask[torch.arange(batch_size), samples] == 0)
|
||||
]
|
||||
|
||||
assert torch.all(mask[torch.arange(batch_size), samples] == 1), samples[
|
||||
torch.nonzero(mask[torch.arange(batch_size), samples] == 0)
|
||||
]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
492
sgl-kernel/tests/test_sparse_flash_attn.py
Normal file
492
sgl-kernel/tests/test_sparse_flash_attn.py
Normal file
@@ -0,0 +1,492 @@
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from sgl_kernel.sparse_flash_attn import (
|
||||
convert_vertical_slash_indexes,
|
||||
convert_vertical_slash_indexes_mergehead,
|
||||
sparse_attn_func,
|
||||
)
|
||||
from test_flash_attention import construct_local_mask, is_fa3_supported
|
||||
|
||||
|
||||
def ref_attn(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
query_padding_mask=None,
|
||||
key_padding_mask=None,
|
||||
attn_bias=None,
|
||||
dropout_p=0.0,
|
||||
dropout_mask=None,
|
||||
causal=False,
|
||||
window_size=(-1, -1), # -1 means infinite window size
|
||||
softcap=0.0,
|
||||
upcast=True,
|
||||
reorder_ops=False,
|
||||
key_leftpad=None,
|
||||
):
|
||||
"""
|
||||
Arguments:
|
||||
q: (batch_size, seqlen_q, nheads, head_dim)
|
||||
k: (batch_size, seqlen_k, nheads_k, head_dim)
|
||||
v: (batch_size, seqlen_k, nheads_k, head_dim)
|
||||
query_padding_mask: (batch_size, seqlen_q)
|
||||
key_padding_mask: (batch_size, seqlen_k)
|
||||
attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
|
||||
dropout_p: float
|
||||
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
|
||||
causal: whether to apply causal masking
|
||||
window_size: (int, int), left and right window size
|
||||
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
|
||||
output back to fp16/bf16.
|
||||
reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.)
|
||||
without changing the math. This is to estimate the numerical error from operation
|
||||
reordering.
|
||||
Output:
|
||||
output: (batch_size, seqlen_q, nheads, head_dim)
|
||||
lse: (batch_size, nheads, seqlen_q)
|
||||
"""
|
||||
if causal:
|
||||
window_size = (window_size[0], 0)
|
||||
dtype_og = q.dtype
|
||||
if upcast:
|
||||
q, k, v = q.float(), k.float(), v.float()
|
||||
seqlen_q, seqlen_k = q.shape[1], k.shape[1]
|
||||
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
|
||||
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
|
||||
d = q.shape[-1]
|
||||
if not reorder_ops:
|
||||
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
|
||||
else:
|
||||
scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
|
||||
|
||||
lse_ref = scores.logsumexp(dim=-1)
|
||||
|
||||
if softcap > 0:
|
||||
scores = scores / softcap
|
||||
scores = scores.tanh()
|
||||
scores = scores * softcap
|
||||
if key_padding_mask is not None:
|
||||
scores.masked_fill_(
|
||||
rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")
|
||||
)
|
||||
if window_size[0] >= 0 or window_size[1] >= 0:
|
||||
local_mask = construct_local_mask(
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
window_size,
|
||||
query_padding_mask,
|
||||
key_padding_mask,
|
||||
q.device,
|
||||
key_leftpad=key_leftpad,
|
||||
)
|
||||
scores.masked_fill_(local_mask, float("-inf"))
|
||||
if attn_bias is not None:
|
||||
scores = scores + attn_bias
|
||||
attention = torch.softmax(scores, dim=-1).to(v.dtype)
|
||||
# Some rows might be completely masked out so we fill them with zero instead of NaN
|
||||
if window_size[0] >= 0 or window_size[1] >= 0:
|
||||
attention = attention.masked_fill(
|
||||
torch.all(local_mask, dim=-1, keepdim=True), 0.0
|
||||
)
|
||||
# We want to mask here so that the attention matrix doesn't have any NaNs
|
||||
# Otherwise we'll get NaN in dV
|
||||
if query_padding_mask is not None:
|
||||
attention = attention.masked_fill(
|
||||
rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0
|
||||
)
|
||||
dropout_scaling = 1.0 / (1 - dropout_p)
|
||||
# attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
|
||||
# output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
|
||||
if dropout_mask is not None:
|
||||
attention_drop = attention.masked_fill(~dropout_mask, 0.0)
|
||||
else:
|
||||
attention_drop = attention
|
||||
output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
|
||||
if query_padding_mask is not None:
|
||||
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
|
||||
|
||||
return output.to(dtype=dtype_og), lse_ref
|
||||
|
||||
|
||||
def ref_paged_attn(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
query_lens: List[int],
|
||||
kv_lens: List[int],
|
||||
block_tables: torch.Tensor,
|
||||
scale: float,
|
||||
sliding_window: Optional[int] = None,
|
||||
soft_cap: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
num_seqs = len(query_lens)
|
||||
block_tables = block_tables.cpu().numpy()
|
||||
_, block_size, num_kv_heads, head_size = key_cache.shape
|
||||
|
||||
outputs: List[torch.Tensor] = []
|
||||
start_idx = 0
|
||||
for i in range(num_seqs):
|
||||
query_len = query_lens[i]
|
||||
kv_len = kv_lens[i]
|
||||
# clone to avoid clobbering the query tensor
|
||||
q = query[start_idx : start_idx + query_len].clone()
|
||||
q *= scale
|
||||
|
||||
num_kv_blocks = (kv_len + block_size - 1) // block_size
|
||||
block_indices = block_tables[i, :num_kv_blocks]
|
||||
|
||||
k = key_cache[block_indices].view(-1, num_kv_heads, head_size)
|
||||
k = k[:kv_len]
|
||||
v = value_cache[block_indices].view(-1, num_kv_heads, head_size)
|
||||
v = v[:kv_len]
|
||||
|
||||
if q.shape[1] != k.shape[1]:
|
||||
k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
|
||||
v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
|
||||
attn = torch.einsum("qhd,khd->hqk", q, k).float()
|
||||
empty_mask = torch.ones(query_len, kv_len)
|
||||
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
|
||||
if sliding_window is not None:
|
||||
sliding_window_mask = (
|
||||
torch.triu(
|
||||
empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1
|
||||
)
|
||||
.bool()
|
||||
.logical_not()
|
||||
)
|
||||
mask |= sliding_window_mask
|
||||
if soft_cap is not None:
|
||||
attn = soft_cap * torch.tanh(attn / soft_cap)
|
||||
attn.masked_fill_(mask, float("-inf"))
|
||||
attn = torch.softmax(attn, dim=-1).to(v.dtype)
|
||||
out = torch.einsum("hqk,khd->qhd", attn, v)
|
||||
|
||||
outputs.append(out)
|
||||
start_idx += query_len
|
||||
|
||||
return torch.cat(outputs, dim=0)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_fa3_supported(),
|
||||
reason="flash_attn at sgl-kernel is only supported on sm90 or sm80",
|
||||
)
|
||||
@pytest.mark.parametrize("batch_size", [1, 2])
|
||||
@pytest.mark.parametrize(
|
||||
"seq_lens",
|
||||
[
|
||||
(1, 1),
|
||||
(1, 1024),
|
||||
(1, 2048),
|
||||
(1023, 2049),
|
||||
(1023, 1023),
|
||||
(32, 32),
|
||||
(65, 65),
|
||||
(129, 129),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("num_heads", [1, 2, 4])
|
||||
@pytest.mark.parametrize("head_size", [128])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("NNZ_S", [0, 1, 2, 3, 7, 15, 32])
|
||||
@torch.inference_mode()
|
||||
def test_sparse_attention(
|
||||
batch_size,
|
||||
seq_lens,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype,
|
||||
NNZ_S,
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
block_size_M = 64
|
||||
block_size_N = 64
|
||||
seqlen_q, seqlen_k = seq_lens
|
||||
q = torch.randn(
|
||||
batch_size, seqlen_q, num_heads, head_size, dtype=dtype, requires_grad=False
|
||||
)
|
||||
k = torch.randn(
|
||||
batch_size, seqlen_k, num_heads, head_size, dtype=dtype, requires_grad=False
|
||||
)
|
||||
v = torch.randn(
|
||||
batch_size, seqlen_k, num_heads, head_size, dtype=dtype, requires_grad=False
|
||||
)
|
||||
NUM_ROWS = (seqlen_q + block_size_M - 1) // block_size_M
|
||||
if NNZ_S * block_size_N > seqlen_k:
|
||||
return
|
||||
NNZ_V = seqlen_k - NNZ_S * block_size_N
|
||||
block_count = torch.tensor(
|
||||
[NNZ_S] * batch_size * NUM_ROWS * num_heads, dtype=torch.int32
|
||||
).reshape(batch_size, num_heads, NUM_ROWS)
|
||||
column_count = torch.tensor(
|
||||
[NNZ_V] * batch_size * NUM_ROWS * num_heads, dtype=torch.int32
|
||||
).reshape(batch_size, num_heads, NUM_ROWS)
|
||||
block_offset = torch.tensor(
|
||||
[[i * block_size_N for i in range(NNZ_S)]] * batch_size * NUM_ROWS * num_heads,
|
||||
dtype=torch.int32,
|
||||
).reshape(batch_size, num_heads, NUM_ROWS, NNZ_S)
|
||||
column_index = torch.tensor(
|
||||
[[NNZ_S * block_size_N + i for i in range(NNZ_V)]]
|
||||
* batch_size
|
||||
* NUM_ROWS
|
||||
* num_heads,
|
||||
dtype=torch.int32,
|
||||
).reshape(batch_size, num_heads, NUM_ROWS, NNZ_V)
|
||||
out, lse = sparse_attn_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
block_count,
|
||||
block_offset,
|
||||
column_count,
|
||||
column_index,
|
||||
return_softmax_lse=True,
|
||||
)
|
||||
|
||||
ref_out, ref_lse = ref_attn(q, k, v)
|
||||
|
||||
torch.testing.assert_close(
|
||||
out, ref_out, atol=2e-2, rtol=1e-2
|
||||
), f"{torch.max(torch.abs(out - ref_out))}"
|
||||
torch.testing.assert_close(
|
||||
lse, ref_lse, atol=2e-2, rtol=1e-2
|
||||
), f"{torch.max(torch.abs(lse - ref_lse))}"
|
||||
|
||||
|
||||
# sparse attention utils
|
||||
# origin
|
||||
@pytest.mark.skipif(
|
||||
not is_fa3_supported(),
|
||||
reason="flash_attn at sgl-kernel is only supported on sm90 or sm80",
|
||||
)
|
||||
@pytest.mark.parametrize("causal", [True, False])
|
||||
def test_convert_vertical_slash_indexes(causal):
|
||||
# Prepare small, hand-checkable inputs
|
||||
q_seqlens = torch.tensor([4], dtype=torch.int32, device="cuda") # [BATCH]
|
||||
kv_seqlens = torch.tensor([4], dtype=torch.int32, device="cuda")
|
||||
vertical_indexes = torch.tensor(
|
||||
[[[1, 3]]], dtype=torch.int32, device="cuda"
|
||||
) # [BATCH, N_HEADS, NNZ_V]
|
||||
slash_indexes = torch.tensor(
|
||||
[[[2]]], dtype=torch.int32, device="cuda"
|
||||
) # [BATCH, N_HEADS, NNZ_S]
|
||||
context_size = 4
|
||||
block_size_M = 2
|
||||
block_size_N = 2
|
||||
|
||||
# Call your CUDA kernel wrapper
|
||||
block_count, block_offset, column_count, column_index = (
|
||||
convert_vertical_slash_indexes(
|
||||
q_seqlens,
|
||||
kv_seqlens,
|
||||
vertical_indexes,
|
||||
slash_indexes,
|
||||
context_size,
|
||||
block_size_M,
|
||||
block_size_N,
|
||||
causal=causal,
|
||||
)
|
||||
)
|
||||
|
||||
# Manually create expected outputs for this input
|
||||
# There are 2 rows (blocks): row0 (tokens 0-1), row1 (tokens 2-3)
|
||||
# Fill these expected tensors based on your CUDA kernel's logic
|
||||
# For demonstration, we assume:
|
||||
# - block_count: how many slash indices fall into each block
|
||||
# - block_offset: the value of those indices
|
||||
# - column_count: number of valid vertical indices per block
|
||||
# - column_index: the actual vertical indices
|
||||
|
||||
expected_column_index = torch.tensor(
|
||||
[[[[0, 0], [0, 0]]]], dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
# If causal=False, update these tensors according to expected behavior
|
||||
if not causal:
|
||||
# Update these values if your kernel produces different output in non-causal mode
|
||||
expected_column_index = torch.tensor(
|
||||
[[[[1, 0], [1, 3]]]], dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
# Assert that outputs match expectations
|
||||
assert torch.equal(column_index, expected_column_index)
|
||||
|
||||
|
||||
# mergehead
|
||||
@pytest.mark.skipif(
|
||||
not is_fa3_supported(),
|
||||
reason="flash_attn at sgl-kernel is only supported on sm90 or sm80",
|
||||
)
|
||||
@pytest.mark.parametrize("causal", [True, False])
|
||||
def test_convert_vertical_slash_indexes_mergehead(causal):
|
||||
# Prepare small, hand-checkable inputs for mergehead version
|
||||
q_seqlens = torch.tensor([4], dtype=torch.int32, device="cuda")
|
||||
kv_seqlens = torch.tensor([4], dtype=torch.int32, device="cuda")
|
||||
vertical_indexes = torch.tensor(
|
||||
[
|
||||
[
|
||||
[1, 3], # head 0
|
||||
[2, 0], # head 1
|
||||
]
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
) # [BATCH, N_HEADS, NNZ_V]
|
||||
slash_indexes = torch.tensor(
|
||||
[
|
||||
[
|
||||
[2, 0], # head 0
|
||||
[1, 3], # head 1
|
||||
]
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
) # [BATCH, N_HEADS, NNZ_S]
|
||||
vertical_indices_count = torch.tensor([2, 1], dtype=torch.int32, device="cuda")
|
||||
slash_indices_count = torch.tensor([1, 2], dtype=torch.int32, device="cuda")
|
||||
context_size = 4
|
||||
block_size_M = 2
|
||||
block_size_N = 2
|
||||
|
||||
# Call your CUDA kernel wrapper
|
||||
block_count, block_offset, column_count, column_index = (
|
||||
convert_vertical_slash_indexes_mergehead(
|
||||
q_seqlens,
|
||||
kv_seqlens,
|
||||
vertical_indexes,
|
||||
slash_indexes,
|
||||
vertical_indices_count,
|
||||
slash_indices_count,
|
||||
context_size,
|
||||
block_size_M,
|
||||
block_size_N,
|
||||
causal=causal,
|
||||
)
|
||||
)
|
||||
|
||||
# Manually create expected outputs for this input
|
||||
# For demonstration, assume:
|
||||
# - batch=1, head=2, num_rows=2, nnz_v=2, nnz_s=2
|
||||
# Fill these expected tensors according to your kernel's behavior
|
||||
|
||||
expected_column_index = torch.tensor(
|
||||
[[[[1, 0], [1, 3]], [[-1079459945, -1077788999], [-1080050043, -1104625879]]]],
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
if not causal:
|
||||
# If non-causal mode output is different, update these values
|
||||
expected_column_index = torch.tensor(
|
||||
[[[[1, 0], [1, 3]], [[2, -1077788999], [2, -1104625879]]]],
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
# Assert that outputs match expectations
|
||||
assert torch.equal(column_index, expected_column_index)
|
||||
|
||||
|
||||
# skip cause use fa2 for test
|
||||
# @pytest.mark.parametrize("seq_lens", [[(1024, 1328)],
|
||||
# [(1024, 1328), (1, 2048)],
|
||||
# [(1025, 1328), (2, 2048)],
|
||||
# [(1025, 2049), (2, 1281)],
|
||||
# ])
|
||||
# @pytest.mark.parametrize("head_size", [128])
|
||||
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
# @torch.inference_mode()
|
||||
# def test_sparse_attention_varlen(
|
||||
# seq_lens,
|
||||
# head_size,
|
||||
# dtype,
|
||||
# ) -> None:
|
||||
# torch.set_default_device("cuda")
|
||||
# torch.cuda.manual_seed_all(0)
|
||||
# block_size_M = 64
|
||||
# block_size_N = 64
|
||||
# num_seqs = len(seq_lens)
|
||||
# query_lens = [x[0] for x in seq_lens]
|
||||
# kv_lens = [x[1] for x in seq_lens]
|
||||
# num_heads = 1
|
||||
# query = torch.randn(sum(query_lens),
|
||||
# num_heads,
|
||||
# head_size,
|
||||
# dtype=dtype)
|
||||
# key = torch.randn(sum(kv_lens),
|
||||
# num_heads,
|
||||
# head_size,
|
||||
# dtype=dtype)
|
||||
# value = torch.randn_like(key)
|
||||
# cu_query_lens = torch.tensor([0] + query_lens,
|
||||
# dtype=torch.int32).cumsum(dim=0,
|
||||
# dtype=torch.int32)
|
||||
# cu_kv_lens = torch.tensor([0] + kv_lens,
|
||||
# dtype=torch.int32).cumsum(dim=0,
|
||||
# dtype=torch.int32)
|
||||
# max_query_len = max(query_lens)
|
||||
# max_kv_len = max(kv_lens)
|
||||
|
||||
# NUM_ROWS = (max_query_len + block_size_M - 1) // block_size_M
|
||||
# NNZ_S = 20
|
||||
# NNZ_V = 2048
|
||||
# batch_size = len(query_lens)
|
||||
|
||||
# block_counts = []
|
||||
# column_counts = []
|
||||
# block_offsets = []
|
||||
# column_indices = []
|
||||
# for b in range(batch_size):
|
||||
# block_counts.append(torch.tensor([NNZ_S] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS))
|
||||
# columns = kv_lens[b] - NNZ_S * block_size_N
|
||||
# column_counts.append(torch.tensor([columns] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS))
|
||||
# block_offsets.append(torch.tensor([[i * block_size_N for i in range(NNZ_S)]] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS, NNZ_S))
|
||||
# column_indices.append(torch.tensor([[NNZ_S * block_size_N + i for i in range(NNZ_V)]] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS, NNZ_V))
|
||||
# block_count = torch.concat(block_counts).reshape(batch_size, num_heads, NUM_ROWS)
|
||||
# column_count = torch.concat(column_counts).reshape(batch_size, num_heads, NUM_ROWS)
|
||||
# block_offset = torch.concat(block_offsets).reshape(batch_size, num_heads, NUM_ROWS, NNZ_S)
|
||||
# column_index = torch.concat(column_indices).reshape(batch_size, num_heads, NUM_ROWS, NNZ_V)
|
||||
# out, lse = sparse_attn_varlen_func(
|
||||
# query,
|
||||
# key,
|
||||
# value,
|
||||
# block_count,
|
||||
# block_offset,
|
||||
# column_count,
|
||||
# column_index,
|
||||
# cu_seqlens_q=cu_query_lens,
|
||||
# cu_seqlens_k=cu_kv_lens,
|
||||
# max_seqlen_q=max_query_len,
|
||||
# max_seqlen_k=max_kv_len,
|
||||
# return_softmax_lse=True,
|
||||
# )
|
||||
|
||||
# max_num_blocks_per_seq = (max_kv_len + 2048 - 1) // 2048
|
||||
# block_tables = torch.randint(0,
|
||||
# 2048,
|
||||
# (len(query_lens), max_num_blocks_per_seq),
|
||||
# dtype=torch.int32)
|
||||
# scale = head_size**-0.5
|
||||
|
||||
# ref_out, ref_lse, _ = ref_paged_attn(
|
||||
# query,
|
||||
# key,
|
||||
# value,
|
||||
# query_lens=query_lens,
|
||||
# kv_lens=kv_lens,
|
||||
# block_tables=block_tables,
|
||||
# scale=scale
|
||||
# )
|
||||
|
||||
# torch.testing.assert_close(out, ref_out, atol=2e-2, rtol=1e-2), \
|
||||
# f"{torch.max(torch.abs(out - ref_out))}"
|
||||
# torch.testing.assert_close(lse, ref_lse, atol=2e-2, rtol=1e-2), \
|
||||
# f"{torch.max(torch.abs(lse - ref_lse))}"
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
9
sgl-kernel/tests/utils.py
Normal file
9
sgl-kernel/tests/utils.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import torch
|
||||
|
||||
|
||||
def is_sm10x():
|
||||
return torch.cuda.get_device_capability() >= (10, 0)
|
||||
|
||||
|
||||
def is_hopper():
|
||||
return torch.cuda.get_device_capability() == (9, 0)
|
||||
Reference in New Issue
Block a user