sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct

This commit is contained in:
maxiao1
2025-09-13 17:00:20 +08:00
commit 118f1fc726
2037 changed files with 515371 additions and 0 deletions

View File

@@ -0,0 +1,25 @@
import pytest
import torch
import torch.nn.functional as F
from sgl_kernel import create_greenctx_stream_by_value, get_sm_available
def test_green_ctx():
A = torch.randn(5120, 5120).cuda()
B = torch.randn(5120, 5120).cuda()
C = torch.matmul(A, B)
sm_counts = get_sm_available(0)
stream_group = create_greenctx_stream_by_value(sm_counts // 2, sm_counts // 2, 0)
with torch.cuda.stream(stream_group[0]):
for _ in range(100):
result_0 = torch.matmul(A, B)
with torch.cuda.stream(stream_group[1]):
for _ in range(100):
result_1 = torch.matmul(A, B)
torch.cuda.synchronize()
assert torch.allclose(result_0, C)
assert torch.allclose(result_1, C)
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,87 @@
import pytest
import torch
import torch.nn.functional as F
from sgl_kernel import verify_tree_greedy
def test_verify_tree_greedy():
candidates = torch.tensor(
[
[0, 1, 2, 3, 4, 5],
[7, 8, 9, 10, 11, 12],
],
dtype=torch.int64,
device="cuda",
)
retrive_index = torch.tensor(
[
[0, 1, 2, 3, 4, 5],
[6, 7, 8, 9, 10, 11],
],
dtype=torch.int64,
device="cuda",
)
retrive_next_token = torch.tensor(
[
[1, 2, -1, 4, 5, -1],
[4, 2, 3, -1, 5, -1],
],
dtype=torch.int64,
device="cuda",
)
retrive_next_sibling = torch.tensor(
[
[-1, 3, -1, -1, -1, -1],
[-1, -1, -1, -1, 1, -1],
],
dtype=torch.int64,
device="cuda",
)
target_logits = torch.full((2, 6, 20), 1, dtype=torch.float32, device="cuda")
target_logits[0, 0, 3] = 10
target_logits[0, 3, 4] = 10
target_logits[0, 4, 5] = 10
target_logits[1, 0, 11] = 10
target_logits[1, 4, 12] = 10
for i in range(target_logits.shape[0]):
for j in range(target_logits.shape[1]):
if torch.max(target_logits[i][j]) < 10:
target_logits[i][j][18] = 10
target_predict = torch.argmax(target_logits, dim=-1)
predict_shape = (12,)
bs = candidates.shape[0]
num_spec_step = 4
predicts = torch.full(
predict_shape, -1, dtype=torch.int32, device="cuda"
) # mutable
accept_index = torch.full(
(bs, num_spec_step), -1, dtype=torch.int32, device="cuda"
) # mutable
accept_token_num = torch.full((bs,), 0, dtype=torch.int32, device="cuda") # mutable
verify_tree_greedy(
predicts=predicts,
accept_index=accept_index,
accept_token_num=accept_token_num,
candidates=candidates,
retrive_index=retrive_index,
retrive_next_token=retrive_next_token,
retrive_next_sibling=retrive_next_sibling,
target_predict=target_predict,
)
# Check the expected output.
assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18]
assert accept_index.tolist() == [
[0, 3, 4, 5],
[6, 10, 11, -1],
]
assert accept_token_num.tolist() == [3, 2]
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,129 @@
import pytest
import torch
import torch.nn.functional as F
from sgl_kernel import tree_speculative_sampling_target_only
test_cases = [
(
1,
1,
[3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18],
[[0, 3, 4, 5], [6, 10, 11, -1]],
[3, 2],
),
(
0, # threshold_single
0, # threshold_acc
[1, 2, 18, -1, -1, -1, 11, -1, -1, -1, 12, 18],
[[0, 1, 2, -1], [6, 10, 11, -1]],
[2, 2],
),
]
@pytest.mark.parametrize(
"threshold_single, threshold_acc, expected_predicts, expected_accept_index, expected_accept_token_num",
test_cases,
)
def test_tree_speculative_sampling_target_only(
threshold_single,
threshold_acc,
expected_predicts,
expected_accept_index,
expected_accept_token_num,
):
"""
Tests the tree_speculative_sampling_target_only function using Pytest parameterization.
"""
device = "cuda"
candidates = torch.tensor(
[
[0, 1, 2, 3, 4, 5],
[7, 8, 9, 10, 11, 12],
],
dtype=torch.int64,
device=device,
)
retrive_index = torch.tensor(
[
[0, 1, 2, 3, 4, 5],
[6, 7, 8, 9, 10, 11],
],
dtype=torch.int64,
device=device,
)
retrive_next_token = torch.tensor(
[
[1, 2, -1, 4, 5, -1],
[4, 2, 3, -1, 5, -1],
],
dtype=torch.int64,
device=device,
)
retrive_next_sibling = torch.tensor(
[
[-1, 3, -1, -1, -1, -1],
[-1, -1, -1, -1, 1, -1],
],
dtype=torch.int64,
device=device,
)
target_logits = torch.full((2, 6, 20), 1, dtype=torch.float32, device=device)
target_logits[0, 0, 3] = 10
target_logits[0, 3, 4] = 10
target_logits[0, 4, 5] = 10
target_logits[1, 0, 11] = 10
target_logits[1, 4, 12] = 10
for i in range(target_logits.shape[0]):
for j in range(target_logits.shape[1]):
if torch.max(target_logits[i, j]) < 10:
target_logits[i, j, 18] = 10
temperatures = torch.tensor([0.01, 0.01], dtype=torch.float32, device=device)
bs, num_draft_tokens = candidates.shape
num_spec_step = len(expected_accept_index[0])
predict_shape = (len(expected_predicts),)
predicts = torch.full(predict_shape, -1, dtype=torch.int32, device=device)
accept_index = torch.full((bs, num_spec_step), -1, dtype=torch.int32, device=device)
accept_token_num = torch.full((bs,), 0, dtype=torch.int32, device=device)
expanded_temperature = temperatures.unsqueeze(1).unsqueeze(1)
target_probs = F.softmax(target_logits / expanded_temperature, dim=-1)
draft_probs = torch.full_like(target_probs, 0, dtype=torch.float32, device=device)
coins = torch.rand(bs, num_draft_tokens, device=device, dtype=torch.float32)
coins_for_final_sampling = torch.rand(bs, device=device).to(torch.float32)
tree_speculative_sampling_target_only(
predicts=predicts,
accept_index=accept_index,
accept_token_num=accept_token_num,
candidates=candidates,
retrive_index=retrive_index,
retrive_next_token=retrive_next_token,
retrive_next_sibling=retrive_next_sibling,
uniform_samples=coins,
uniform_samples_for_final_sampling=coins_for_final_sampling,
target_probs=target_probs,
draft_probs=draft_probs,
threshold_single=threshold_single,
threshold_acc=threshold_acc,
deterministic=True,
)
assert (
predicts.tolist() == expected_predicts
), f"Predicts mismatch for thresholds ({threshold_single}, {threshold_acc})"
assert (
accept_index.tolist() == expected_accept_index
), f"Accept index mismatch for thresholds ({threshold_single}, {threshold_acc})"
assert (
accept_token_num.tolist() == expected_accept_token_num
), f"Accept token num mismatch for thresholds ({threshold_single}, {threshold_acc})"
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,39 @@
# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/tests/test_activation.py
import pytest
import sgl_kernel
import torch
@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384])
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16])
@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512])
def test_fused_silu_mul(dim, batch_size, seq_len):
x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16)
y_ref = x[..., dim:] * torch.nn.functional.silu(x[..., :dim])
y = sgl_kernel.silu_and_mul(x)
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384])
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16])
@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512])
def test_fused_gelu_tanh_mul(dim, batch_size, seq_len):
x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16)
y_ref = x[..., dim:] * torch.nn.functional.gelu(x[..., :dim], approximate="tanh")
y = sgl_kernel.gelu_tanh_and_mul(x)
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384])
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16])
@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512])
def test_fused_gelu_mul(dim, batch_size, seq_len):
x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16)
y_ref = x[..., dim:] * torch.nn.functional.gelu(x[..., :dim], approximate="none")
y = sgl_kernel.gelu_and_mul(x)
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,23 @@
import pytest
import torch
from sgl_kernel import apply_token_bitmask_inplace_cuda
def test_apply_token_bitmask_inplace_kernel():
neginf = float("-inf")
bool_mask = torch.tensor([0, 1, 0, 1, 0, 1, 0, 1, 0, 1], dtype=torch.bool)
logits = torch.tensor(
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], dtype=torch.float32
)
expected = torch.where(bool_mask, logits, neginf)
logits_gpu = logits.to("cuda")
bitmask = torch.tensor([0b1010101010], dtype=torch.int32).to("cuda")
apply_token_bitmask_inplace_cuda(logits_gpu, bitmask)
torch.cuda.synchronize()
torch.testing.assert_close(logits_gpu, expected.to("cuda"))
if __name__ == "__main__":
test_apply_token_bitmask_inplace_kernel()
pytest.main([__file__])

View File

@@ -0,0 +1,115 @@
import itertools
from typing import Optional, Tuple
import pytest
import torch
from sgl_kernel import awq_dequantize
def reverse_awq_order(t: torch.Tensor):
bits = 4
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
reverse_order_tensor = torch.arange(
t.shape[-1],
dtype=torch.int32,
device=t.device,
)
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]
reverse_order_tensor = reverse_order_tensor.view(-1)
t = t[:, reverse_order_tensor] & 0xF
return t
# qweights - [R , C // 8], int32
# scales - [R // G, C ], float16
# zeros - [R // G, C // 8], int32
def awq_dequantize_torch(
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor, group_size: int
) -> torch.Tensor:
if group_size == -1:
group_size = qweight.shape[0]
bits = 4
shifts = torch.arange(0, 32, bits, device=qzeros.device)
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(
torch.int8
)
iweights = iweights.view(iweights.shape[0], -1)
zeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to(
torch.int8
)
zeros = zeros.view(qzeros.shape[0], -1)
zeros = reverse_awq_order(zeros)
iweights = reverse_awq_order(iweights)
iweights = torch.bitwise_and(iweights, (2**bits) - 1)
zeros = torch.bitwise_and(zeros, (2**bits) - 1)
scales = scales.repeat_interleave(group_size, dim=0)
zeros = zeros.repeat_interleave(group_size, dim=0)
return (iweights - zeros) * scales
def sglang_awq_dequantize(
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
) -> torch.Tensor:
return awq_dequantize(qweight, scales, qzeros)
@pytest.mark.parametrize(
"qweight_row,qweight_col,is_bf16_act",
list(
itertools.product(
[3584, 18944, 128, 256, 512, 1024, 1536],
[448, 576, 4736, 16, 32, 64, 128, 72],
[True, False],
)
),
)
def test_awq_dequant_compare_implementations(
qweight_row: int, qweight_col: int, is_bf16_act: bool
):
device = torch.device("cuda")
qweight = torch.randint(
0,
torch.iinfo(torch.int32).max,
(qweight_row, qweight_col),
dtype=torch.int32,
device=device,
)
group_size = qweight_row
scales_row = qweight_row // group_size
scales_col = qweight_col * 8
if is_bf16_act:
scales = torch.rand(scales_row, scales_col, dtype=torch.bfloat16, device=device)
else:
scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device)
qzeros = torch.randint(
0,
torch.iinfo(torch.int32).max,
(scales_row, qweight_col),
dtype=torch.int32,
device=device,
)
# Run both implementations
torch_out = awq_dequantize_torch(qweight, scales, qzeros, group_size)
sglang_out = sglang_awq_dequantize(qweight, scales, qzeros)
# Compare results
torch.testing.assert_close(
torch_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5
)
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,43 @@
# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/tests/test_bmm_fp8.py
import pytest
import torch
import torch.nn.functional as F
from sgl_kernel import bmm_fp8
def to_float8(x, dtype=torch.float8_e4m3fn):
finfo = torch.finfo(dtype)
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
scale = finfo.max / amax
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
return x_scl_sat.to(dtype), scale.float().reciprocal()
@pytest.mark.parametrize("input_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
@pytest.mark.parametrize("mat2_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16])
def test_bmm_fp8(input_dtype, mat2_dtype, res_dtype):
if input_dtype == torch.float8_e5m2 and mat2_dtype == torch.float8_e5m2:
pytest.skip("Invalid combination: both input and mat2 are e5m2")
input = torch.randn([16, 48, 64], device="cuda", dtype=torch.bfloat16)
input_fp8, input_inv_s = to_float8(input, dtype=input_dtype)
# mat2 row major -> column major
mat2 = torch.randn([16, 80, 64], device="cuda", dtype=torch.bfloat16).transpose(
-2, -1
)
mat2_fp8, mat2_inv_s = to_float8(mat2, dtype=mat2_dtype)
res = torch.empty([16, 48, 80], device="cuda", dtype=res_dtype)
bmm_fp8(input_fp8, mat2_fp8, input_inv_s, mat2_inv_s, res_dtype, res)
reference = torch.bmm(input, mat2)
cos_sim = F.cosine_similarity(reference.reshape(-1), res.reshape(-1), dim=0)
assert cos_sim > 0.99
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,489 @@
# Adapted from https://github.com/vllm-project/vllm/blob/main/tests/kernels/mamba/test_causal_conv1d.py
from typing import Optional
import torch
from sgl_kernel import causal_conv1d_fwd
from sgl_kernel import causal_conv1d_update as causal_conv1d_update_kernel
PAD_SLOT_ID = -1
def causal_conv1d_fn(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
query_start_loc: Optional[torch.Tensor] = None,
cache_indices: Optional[torch.Tensor] = None,
has_initial_state: Optional[torch.Tensor] = None,
conv_states: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
pad_slot_id: int = PAD_SLOT_ID,
):
"""
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
sequences are concatenated from left to right for varlen
weight: (dim, width)
bias: (dim,)
query_start_loc: (batch + 1) int32
The cumulative sequence lengths of the sequences in
the batch, used to index into sequence. prepended by 0.
for example: query_start_loc = torch.Tensor([0,10,16,17]),
x.shape=(dim,17)
cache_indices: (batch) int32
indicates the corresponding state index,
like so: conv_state = conv_states[cache_indices[batch_id]]
has_initial_state: (batch) bool
indicates whether should the kernel take the current state as initial
state for the calculations
conv_states: (...,dim,width - 1) itype
updated inplace if provided
activation: either None or "silu" or "swish"
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
if x.stride(-1) != 1:
x = x.contiguous()
bias = bias.contiguous() if bias is not None else None
causal_conv1d_fwd(
x,
weight,
bias,
conv_states,
query_start_loc,
cache_indices,
has_initial_state,
activation in ["silu", "swish"],
pad_slot_id,
)
return x
def causal_conv1d_update(
x: torch.Tensor,
conv_state: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
activation: Optional[str] = None,
cache_seqlens: Optional[torch.Tensor] = None,
conv_state_indices: Optional[torch.Tensor] = None,
pad_slot_id: int = PAD_SLOT_ID,
):
"""
x: (batch, dim) or (batch, dim, seqlen)
conv_state: (batch, dim, state_len), where state_len >= width - 1
weight: (dim, width)
bias: (dim,)
cache_seqlens: (batch,), dtype int32.
If not None, the conv_state is treated as a circular buffer.
The conv_state will be updated by copying x to the conv_state
starting at the index
@cache_seqlens % state_len.
conv_state_indices: (batch,), dtype int32
If not None, the conv_state is a larger tensor along the batch dim,
and we are selecting the batch coords specified by conv_state_indices.
Useful for a continuous batching scenario.
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim) or (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError(
f"activation must be None, silu, or swish, actual: {activation}"
)
activation_val = activation in ["silu", "swish"]
unsqueeze = x.dim() == 2
if unsqueeze:
x = x.unsqueeze(-1)
causal_conv1d_update_kernel(
x,
conv_state,
weight,
bias,
activation_val,
cache_seqlens,
conv_state_indices,
pad_slot_id,
)
if unsqueeze:
x = x.squeeze(-1)
return x
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import pytest
import torch
import torch.nn.functional as F
def causal_conv1d_ref(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
initial_states: Optional[torch.Tensor] = None,
return_final_states: bool = False,
final_states_out: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
bias: (dim,)
initial_states: (batch, dim, width - 1)
final_states_out: (batch, dim, width - 1)
out: (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
dtype_in = x.dtype
x = x.to(weight.dtype)
seqlen = x.shape[-1]
dim, width = weight.shape
if initial_states is None:
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
else:
x = torch.cat([initial_states, x], dim=-1)
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
out = out[..., :seqlen]
if return_final_states:
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
dtype_in
) # (batch, dim, width - 1)
if final_states_out is not None:
final_states_out.copy_(final_states)
else:
final_states_out = final_states
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
return (out, None) if not return_final_states else (out, final_states_out)
def causal_conv1d_update_ref(
x, conv_state, weight, bias=None, activation=None, cache_seqlens=None
):
"""
x: (batch, dim) or (batch, dim, seqlen)
conv_state: (batch, dim, state_len), where state_len >= width - 1
weight: (dim, width)
bias: (dim,)
cache_seqlens: (batch,), dtype int32.
If not None, the conv_state is treated as a circular buffer.
The conv_state will be updated by copying x to the
conv_state starting at the index
@cache_seqlens % state_len before performing the convolution.
out: (batch, dim) or (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
dtype_in = x.dtype
unsqueeze = x.dim() == 2
if unsqueeze:
x = x.unsqueeze(-1)
batch, dim, seqlen = x.shape
width = weight.shape[1]
state_len = conv_state.shape[-1]
assert conv_state.shape == (batch, dim, state_len)
assert weight.shape == (dim, width)
if cache_seqlens is None:
x_new = torch.cat([conv_state, x], dim=-1).to(
weight.dtype
) # (batch, dim, state_len + seqlen)
conv_state.copy_(x_new[:, :, -state_len:])
else:
width_idx = torch.arange(
-(width - 1), 0, dtype=torch.long, device=x.device
).unsqueeze(0) + cache_seqlens.unsqueeze(1)
width_idx = (
torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
)
x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(
0
) + cache_seqlens.unsqueeze(1)
copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
conv_state.scatter_(2, copy_idx, x)
out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[
:, :, -seqlen:
]
if unsqueeze:
out = out.squeeze(-1)
return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
@pytest.mark.parametrize("has_initial_state", [True, False])
@pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize(
"seqlen", [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 1025, 2048, 4096]
)
@pytest.mark.parametrize("dim", [64])
@pytest.mark.parametrize("batch", [1])
def test_causal_conv1d(
batch, dim, seqlen, width, has_bias, silu_activation, has_initial_state, itype
):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
x = torch.randn(batch, dim, seqlen, device=device, dtype=itype).contiguous()
weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
if has_initial_state:
initial_states = torch.randn(batch, dim, width - 1, device=device, dtype=itype)
has_initial_state_tensor = torch.ones(batch, dtype=torch.bool, device=x.device)
else:
initial_states = None
has_initial_state_tensor = None
x_ref = x.clone()
weight_ref = weight.clone()
bias_ref = bias.clone() if bias is not None else None
initial_states_ref = initial_states.clone() if initial_states is not None else None
activation = None if not silu_activation else "silu"
out = causal_conv1d_fn(
x,
weight,
bias,
activation=activation,
conv_states=initial_states,
has_initial_state=has_initial_state_tensor,
)
out_ref, final_states_ref = causal_conv1d_ref(
x_ref,
weight_ref,
bias_ref,
initial_states=initial_states_ref,
return_final_states=True,
activation=activation,
)
if has_initial_state:
assert initial_states is not None and final_states_ref is not None
assert torch.allclose(initial_states, final_states_ref, rtol=rtol, atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True])
@pytest.mark.parametrize("seqlen", [1])
@pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
batch = 2
x = torch.randn(batch, dim, seqlen, device=device, dtype=itype)
x_ref = x.clone()
conv_state = torch.randn(batch, dim, width - 1, device=device, dtype=itype)
weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
conv_state_ref = conv_state.detach().clone()
activation = None if not silu_activation else "silu"
out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation)
out_ref = causal_conv1d_update_ref(
x_ref, conv_state_ref, weight, bias, activation=activation
)
assert torch.equal(conv_state, conv_state_ref)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True])
@pytest.mark.parametrize("seqlen", [1, 4, 5])
@pytest.mark.parametrize("width", [2, 3, 4])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [True, False])
def test_causal_conv1d_update_with_batch_gather(
with_padding, dim, width, seqlen, has_bias, silu_activation, itype
):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
batch_size = 3
padding = 5 if with_padding else 0
padded_batch_size = batch_size + padding
total_entries = 10 * batch_size
x = torch.randn(padded_batch_size, dim, 1, device=device, dtype=itype)
x_ref = x.clone()
conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
dtype=torch.int32, device=device
)
unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
unused_states_bool[conv_state_indices] = False
padded_state_indices = torch.concat(
[
conv_state_indices,
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
],
dim=0,
)
conv_state = torch.randn(total_entries, dim, width - 1, device=device, dtype=itype)
conv_state_for_padding_test = conv_state.clone()
weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
activation = None if not silu_activation else "silu"
out = causal_conv1d_update(
x,
conv_state,
weight,
bias,
activation=activation,
conv_state_indices=padded_state_indices,
pad_slot_id=PAD_SLOT_ID,
)
out_ref = causal_conv1d_update_ref(
x_ref[:batch_size], conv_state_ref, weight, bias, activation=activation
)
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
assert torch.equal(
conv_state[unused_states_bool], conv_state_for_padding_test[unused_states_bool]
)
@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
@pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize(
"seqlen", [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 2049, 4096]
)
@pytest.mark.parametrize("dim", [64, 4096])
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [True, False])
def test_causal_conv1d_varlen(
with_padding, dim, seqlen, width, has_bias, silu_activation, itype
):
device = "cuda"
torch.cuda.empty_cache()
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
seqlens = []
batch_size = 4
if seqlen < 10:
batch_size = 1
padding = 3 if with_padding else 0
padded_batch_size = batch_size + padding
nsplits = padded_batch_size - 1
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
seqlens.append(
torch.diff(
torch.cat([torch.tensor([-1]), eos_pos, torch.tensor([seqlen - 1])])
).tolist()
)
assert sum(seqlens[-1]) == seqlen
assert all(s > 0 for s in seqlens[-1])
total_entries = batch_size * 10
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], dim=0)
x = torch.randn(1, 4096 + dim + 64, seqlen, device=device, dtype=itype)[
:, 4096 : 4096 + dim, :
]
weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
x_ref = x.clone()
weight_ref = weight.clone()
bias_ref = bias.clone() if bias is not None else None
activation = None if not silu_activation else "silu"
final_states = torch.randn(
total_entries, dim, width - 1, device=x.device, dtype=x.dtype
)
final_states_ref = final_states.clone()
has_initial_states = torch.randint(
0, 2, (cumsum.shape[0] - 1,), dtype=torch.bool, device=x.device
)
state_indices = torch.randperm(total_entries, dtype=torch.int32, device=x.device)[
:batch_size
]
padded_state_indices = torch.concat(
[
state_indices,
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
],
dim=-1,
)
out = causal_conv1d_fn(
x.squeeze(0),
weight,
bias,
cumsum.cuda(),
padded_state_indices,
has_initial_states,
final_states,
activation,
PAD_SLOT_ID,
)
out_ref = []
out_ref_b = []
splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)]
for i in range(len(seqlens[0])):
x_s = [v[i].unsqueeze(0) for v in splits][0]
if padded_state_indices[i] == PAD_SLOT_ID:
continue
out_ref_b.append(
causal_conv1d_ref(
x_s,
weight_ref,
bias_ref,
activation=activation,
return_final_states=True,
final_states_out=final_states_ref[padded_state_indices[i]].unsqueeze(0),
initial_states=(
final_states_ref[padded_state_indices[i]].unsqueeze(0)
if has_initial_states[i]
else None
),
)
)
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2))
out_ref_tensor = torch.cat(out_ref, dim=0)
unpadded_out = out[:, : out_ref_tensor.shape[-1]]
assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
assert torch.allclose(
final_states[state_indices],
final_states_ref[state_indices],
rtol=rtol,
atol=atol,
)
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,185 @@
import ctypes
import multiprocessing as mp
import random
import socket
import unittest
from typing import Any, List, Optional
import sgl_kernel.allreduce as custom_ops
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes):
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
dist.init_process_group(
backend="nccl",
init_method=distributed_init_method,
rank=rank,
world_size=world_size,
)
group = dist.group.WORLD
try:
device = torch.device(f"cuda:{rank}")
max_size = 8192 * 1024
meta_ptrs = TestCustomAllReduce.create_shared_buffer(
custom_ops.meta_size() + max_size, group=group
)
rank_data = torch.empty(8 * 1024 * 1024, dtype=torch.uint8, device=device)
buffer_ptrs = TestCustomAllReduce.create_shared_buffer(max_size, group=group)
custom_ptr = custom_ops.init_custom_ar(meta_ptrs, rank_data, rank, True)
custom_ops.register_buffer(custom_ptr, buffer_ptrs)
test_loop = 10
for sz in test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
for _ in range(test_loop):
inp1 = torch.randint(1, 16, (sz,), dtype=dtype, device=device)
inp1_ref = inp1.clone()
out1 = torch.empty_like(inp1)
custom_ops.all_reduce(
custom_ptr, inp1, out1, buffer_ptrs[rank], max_size
)
dist.all_reduce(inp1_ref, group=group)
torch.testing.assert_close(out1, inp1_ref)
finally:
dist.barrier(group=group)
if custom_ptr is not None:
custom_ops.dispose(custom_ptr)
if buffer_ptrs:
TestCustomAllReduce.free_shared_buffer(buffer_ptrs, group)
if meta_ptrs:
TestCustomAllReduce.free_shared_buffer(meta_ptrs, group)
dist.destroy_process_group(group=group)
def get_open_port() -> int:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
except OSError:
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
s.bind(("::1", 0))
return s.getsockname()[1]
def multi_process_parallel(
world_size: int, test_target: Any, target_args: tuple = ()
) -> None:
mp.set_start_method("spawn", force=True)
procs = []
distributed_init_port = get_open_port()
for i in range(world_size):
proc_args = (world_size, i, distributed_init_port) + target_args
proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}")
proc.start()
procs.append(proc)
for i in range(world_size):
procs[i].join()
assert (
procs[i].exitcode == 0
), f"Process {i} failed with exit code {procs[i].exitcode}"
class TestCustomAllReduce(unittest.TestCase):
test_sizes = [
512,
2560,
4096,
5120,
7680,
32768,
262144,
524288,
1048576,
2097152,
]
world_sizes = [2, 4, 8]
@staticmethod
def create_shared_buffer(
size_in_bytes: int, group: Optional[ProcessGroup] = None
) -> List[int]:
lib = CudaRTLibrary()
pointer = lib.cudaMalloc(size_in_bytes)
handle = lib.cudaIpcGetMemHandle(pointer)
if group is None:
group = dist.group.WORLD
world_size = dist.get_world_size(group=group)
rank = dist.get_rank(group=group)
handle_bytes = ctypes.string_at(ctypes.addressof(handle), ctypes.sizeof(handle))
input_tensor = torch.ByteTensor(list(handle_bytes)).to(f"cuda:{rank}")
gathered_tensors = [torch.empty_like(input_tensor) for _ in range(world_size)]
dist.all_gather(gathered_tensors, input_tensor, group=group)
handles = []
handle_type = type(handle)
for tensor in gathered_tensors:
bytes_list = tensor.cpu().tolist()
bytes_data = bytes(bytes_list)
handle_obj = handle_type()
ctypes.memmove(ctypes.addressof(handle_obj), bytes_data, len(bytes_data))
handles.append(handle_obj)
pointers: List[int] = []
for i, h in enumerate(handles):
if i == rank:
pointers.append(pointer.value)
else:
try:
opened_ptr = lib.cudaIpcOpenMemHandle(h)
pointers.append(opened_ptr.value)
except Exception as e:
print(f"Rank {rank}: Failed to open IPC handle from rank {i}: {e}")
raise
dist.barrier(group=group)
return pointers
@staticmethod
def free_shared_buffer(
pointers: List[int], group: Optional[ProcessGroup] = None
) -> None:
if group is None:
group = dist.group.WORLD
rank = dist.get_rank(group=group)
lib = CudaRTLibrary()
if pointers and len(pointers) > rank and pointers[rank] is not None:
lib.cudaFree(ctypes.c_void_p(pointers[rank]))
dist.barrier(group=group)
def test_correctness(self):
for world_size in self.world_sizes:
available_gpus = torch.cuda.device_count()
if world_size > available_gpus:
print(
f"Skipping world_size={world_size}, requires {world_size} GPUs, found {available_gpus}"
)
continue
print(f"Running test for world_size={world_size}")
multi_process_parallel(
world_size, _run_correctness_worker, target_args=(self.test_sizes,)
)
print(f"custom allreduce tp = {world_size}: OK")
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,104 @@
import pytest
import torch
import torch.nn.functional as F
from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size
from torch import Tensor
# Disable tests on SM103 until the accuracy issues are fixed.
if torch.cuda.get_device_capability() != (10, 0):
pytest.skip(
reason="Cutlass MLA Requires compute capability of 10.",
allow_module_level=True,
)
def ref_mla(
out: Tensor, # (bs, num_heads, v_head_dim)
query: Tensor, # (bs, num_heads, head_dim)
kv_cache: Tensor, # (num_blocks, block_size, head_dim)
scale: float,
block_tables: Tensor, # (bs, max_num_blocks)
seq_lens: Tensor, # (bs,)
):
bs, num_heads, v_head_dim = out.shape
head_dim = query.shape[2]
for i in range(bs):
# gather and flatten KV-cache
kv = kv_cache[block_tables[i]] # (max_num_blocks, block_size, head_dim)
kv = kv.view(1, -1, head_dim)[:, : seq_lens[i]] # (1, seq_len, head_dim)
v = kv[:, :, :v_head_dim]
q = query[i].view(num_heads, 1, head_dim)
o = F.scaled_dot_product_attention(q, kv, v, scale=scale, enable_gqa=True)
out[i] = o.view(num_heads, v_head_dim)
return out
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("mean_seq_len", [128, 1024, 4096])
@pytest.mark.parametrize("bs", [1, 2, 4])
@pytest.mark.parametrize("varlen", [False, True])
@pytest.mark.parametrize("block_size", [1, 16, 64, 128])
@pytest.mark.parametrize("num_heads", [16, 32, 64, 128])
@pytest.mark.parametrize("num_kv_splits", [-1, 1])
def test_cutlass_mla_decode(
dtype: torch.dtype,
mean_seq_len: int,
bs: int,
varlen: bool,
block_size: int,
num_heads: int,
num_kv_splits: int,
):
torch.set_default_dtype(dtype)
torch.set_default_device("cuda")
torch.manual_seed(42)
d = 576
h_q = num_heads
dv = 512
q_nope_dim = 128
q_pe_dim = 64
scale = (q_nope_dim + q_pe_dim) ** (-0.5)
if varlen:
seq_lens = torch.empty(bs).normal_(mean_seq_len, mean_seq_len / 2)
seq_lens = seq_lens.clip(2).to(torch.int32)
else:
seq_lens = torch.full((bs,), mean_seq_len, dtype=torch.int32)
max_seq_len = seq_lens.max().item()
block_num = (max_seq_len + block_size - 1) // block_size
# Pad block_num so that small blocks can be packed into full 128-sized CUTLASS tiles.
# One 128-wide tile can hold (128 // block_size) small blocks.
pack_factor = 128 // block_size
block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor
# Lager q values to detect split kv error
q = torch.randn(bs, h_q, d) * 100.0
block_table = torch.randint(0, bs * block_num, (bs, block_num), dtype=torch.int32)
kv_cache = torch.randn(block_table.numel(), block_size, d)
workspace_size = cutlass_mla_get_workspace_size(
block_num * block_size, bs, num_kv_splits=num_kv_splits
)
workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8)
q_nope = torch.empty((h_q, bs, dv)).transpose(0, 1)
q_nope.copy_(q[:, :, :dv])
q_pe = q[:, :, dv:].clone()
out_ref = q.new_zeros(bs, h_q, dv)
ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens)
out = cutlass_mla_decode(
q_nope, q_pe, kv_cache, seq_lens, block_table, workspace, scale, num_kv_splits
)
torch.testing.assert_close(out, out_ref, atol=1e-2, rtol=1e-2)
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,283 @@
import pytest
import torch
from sgl_kernel import cutlass_w4a8_moe_mm, sgl_per_tensor_quant_fp8
from utils import is_hopper
def pack_int4_values_to_int8(int4_values_interleaved: torch.Tensor) -> torch.Tensor:
if int4_values_interleaved.shape[-1] % 2 != 0:
raise ValueError(
"the last dim size of int4_values_interleaved tensor must be even."
)
input_tensor_int8 = int4_values_interleaved.to(torch.int8)
low_nibbles = input_tensor_int8[..., 0::2]
high_nibbles = input_tensor_int8[..., 1::2]
packed_tensor = (high_nibbles << 4) | (low_nibbles & 0x0F)
return packed_tensor.to(torch.int8)
def pack_interleave(num_experts, ref_weight, ref_scale):
n, k = ref_weight.shape[1], ref_weight.shape[2]
weight = pack_int4_values_to_int8(ref_weight.cpu()).cuda()
w_q = weight.view((num_experts, n, k // 2)).view(torch.int8)
w_q = w_q.contiguous()
alignment = 4 if k % 512 == 0 else 1
scale_interleaved = ref_scale.reshape(
ref_scale.shape[0],
ref_scale.shape[1],
(ref_scale.shape[2] // alignment),
alignment,
) # [E, N, K/4, 4]
scale_interleaved = scale_interleaved.permute(0, 2, 1, 3) # [E, K/4, N, 4]
scale_interleaved = scale_interleaved.reshape(
ref_scale.shape[0],
ref_scale.shape[2] // alignment,
ref_scale.shape[1] * alignment,
) # [E, K/4, N*4]
w_scale = scale_interleaved.contiguous()
return w_q, w_scale
@pytest.mark.skipif(
not is_hopper(),
reason="cutlass_w4a8_moe_mm is only supported on sm90",
)
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16])
def test_int4_fp8_grouped_gemm_single_expert(batch_size):
# Test parameters
num_experts = 1
m = batch_size # batch size
k = 512 # input dimension
n = 1024 # output dimension
torch.manual_seed(0)
dtype = torch.bfloat16
device = "cuda"
debug = False
print(f"\nTesting with batch_size={batch_size}")
# Create input tensors with ones
if debug:
a = torch.ones(m, k, dtype=torch.bfloat16, device=device)
ref_w = torch.ones(num_experts, n, k, dtype=torch.int8, device=device)
ref_w_scale = torch.ones(num_experts, n, k // 128, dtype=dtype, device=device)
else:
a = torch.randn(m, k, dtype=dtype, device=device)
ref_w = torch.randint(
-8, 8, (num_experts, n, k), dtype=torch.int8, device=device
)
affine_coeff = 0.005
ref_w_scale = (
torch.randn(num_experts, n, k // 128, dtype=dtype, device=device)
* affine_coeff
)
w, w_scale = pack_interleave(num_experts, ref_w, ref_w_scale)
# Create expert offsets and problem sizes
expert_offsets = torch.tensor([0, m], dtype=torch.int32, device=device)
problem_sizes = torch.tensor([[n, m, k]], dtype=torch.int32, device=device)
a_strides = torch.full((num_experts, 3), k, device=device, dtype=torch.int64)
c_strides = torch.full((num_experts, 3), n, device=device, dtype=torch.int64)
b_strides = a_strides
s_strides = c_strides
# Quantize input
a_q, a_scale = _per_tensor_quant_fp8(a)
# Create output tensor
c = torch.empty((m, n), dtype=torch.bfloat16, device=device)
cutlass_w4a8_moe_mm(
c,
a_q,
w,
a_scale,
w_scale,
expert_offsets[:-1],
problem_sizes,
a_strides,
b_strides,
c_strides,
s_strides,
128,
8,
)
c = c.to(dtype)
# Reference implementation
experts_selection_result = torch.full((m,), 0)
c_ref = ref_grouped_gemm(
c, a_q, a_scale, ref_w, ref_w_scale, num_experts, experts_selection_result
)
# Compare results
try:
torch.testing.assert_close(c, c_ref, rtol=1e-2, atol=0.1)
except AssertionError as e:
# torch.set_printoptions(threshold=10_000)
print(f" FAILURE: tensors are NOT close.")
print(f" Ref tensor: {c_ref.flatten()}")
print(f" Cutlass tensor: {c.flatten()}")
print(
f" Max absolute difference: {torch.max(torch.abs(c.to(c_ref.dtype) - c_ref))}"
)
print(
f" Mean absolute difference: {torch.mean(torch.abs(c.to(c_ref.dtype) - c_ref))}"
)
print(f" AssertionError: {e}")
raise
def _per_tensor_quant_fp8(
x: torch.Tensor,
dtype: torch.dtype = torch.float8_e4m3fn,
):
assert x.is_contiguous(), "`x` is not contiguous"
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
x_s = torch.empty(
1,
device=x.device,
dtype=torch.float32,
)
sgl_per_tensor_quant_fp8(x, x_q, x_s, is_static=False)
return x_q, x_s
@pytest.mark.skipif(
not is_hopper(),
reason="cutlass_w4a8_moe_mm is only supported on sm90",
)
@pytest.mark.parametrize("batch_size", [2, 4, 8, 16, 32])
@pytest.mark.parametrize("k", [512, 1024, 2048, 4096, 7168])
@pytest.mark.parametrize("n", [256, 512, 1024, 2048])
@pytest.mark.parametrize("num_experts", [2, 4, 6, 8])
def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
torch.manual_seed(0)
dtype = torch.bfloat16
device = "cuda"
debug = False
print(
f"\nTesting with batch_size={batch_size}, k={k}, n={n}, num_experts={num_experts}"
)
if debug:
a = torch.ones(batch_size, k, dtype=torch.bfloat16, device=device)
ref_w = torch.ones(num_experts, n, k, dtype=torch.int8, device=device)
ref_w_scale = torch.ones(num_experts, n, k // 128, dtype=dtype, device=device)
else:
a = torch.randn(batch_size, k, dtype=dtype, device=device)
ref_w = torch.randint(
-8, 8, (num_experts, n, k), dtype=torch.int8, device=device
)
affine_coeff = 0.005
ref_w_scale = (
torch.randn(num_experts, n, k // 128, dtype=dtype, device=device)
* affine_coeff
)
w, w_scale = pack_interleave(num_experts, ref_w, ref_w_scale)
# random select experts
experts_selection_result = torch.randint(
0, num_experts, (batch_size,), device=device
)
permutation = torch.argsort(experts_selection_result)
expert_token_counts = torch.bincount(
experts_selection_result, minlength=num_experts
)
# Create problem sizes and offsets for active experts
problem_sizes = []
for i in range(num_experts):
problem_sizes.append([n, expert_token_counts[i].item(), k])
problem_sizes = torch.tensor(problem_sizes, dtype=torch.int32, device=device)
expert_offsets = []
offset = 0
for i in range(num_experts):
expert_offsets.append(offset)
offset += problem_sizes[i][1].item()
expert_offsets = torch.tensor(expert_offsets, dtype=torch.int32, device=device)
# Permute input and quantize
a_q, a_scale = _per_tensor_quant_fp8(a)
a_q_perm = a_q[permutation]
# Create stride tensors
a_strides = torch.full((num_experts, 3), k, device=device, dtype=torch.int64)
c_strides = torch.full((num_experts, 3), n, device=device, dtype=torch.int64)
b_strides = a_strides
s_strides = c_strides
c_perm = torch.empty((batch_size, n), dtype=torch.bfloat16, device=device)
cutlass_w4a8_moe_mm(
c_perm,
a_q_perm,
w,
a_scale,
w_scale,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
c_strides,
s_strides,
128,
8,
)
# Un-permute the result
c = torch.empty_like(c_perm)
c[permutation] = c_perm
c = c.to(dtype)
c_ref = ref_grouped_gemm(
c, a_q, a_scale, ref_w, ref_w_scale, num_experts, experts_selection_result
)
# Compare results
try:
torch.testing.assert_close(c, c_ref, rtol=1e-2, atol=0.1)
except AssertionError as e:
print(f" FAILURE: tensors are NOT close.")
print(
f" Max absolute difference: {torch.max(torch.abs(c.to(c_ref.dtype) - c_ref))}"
)
print(
f" Mean absolute difference: {torch.mean(torch.abs(c.to(c_ref.dtype) - c_ref))}"
)
print(f" AssertionError: {e}")
raise
def ref_grouped_gemm(
c, a_q, a_scale, w, w_scale, num_experts, experts_selection_result
):
dtype = torch.bfloat16
c_ref = torch.zeros_like(c)
for i in range(num_experts):
token_idx = torch.where(experts_selection_result == i)[0]
if len(token_idx) == 0:
continue
a = a_q[token_idx]
ref_w_scale_repeat = w_scale[i].repeat_interleave(128, dim=1).to(torch.float32)
ref_w = w[i].to(torch.float32) * ref_w_scale_repeat
c = torch.matmul(a.to(torch.float32), ref_w.t()) * a_scale
c_ref[token_idx] = c.to(dtype)
return c_ref
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,32 @@
import pytest
import torch
import torch.nn.functional as F
from sgl_kernel import dsv3_fused_a_gemm
@pytest.mark.parametrize("num_tokens", [i + 1 for i in range(16)])
def test_dsv3_fused_a_gemm(num_tokens):
kHdIn = 7168
kHdOut = 2112
mat_a = torch.randn(
(num_tokens, kHdIn), dtype=torch.bfloat16, device="cuda"
).contiguous()
mat_b = torch.randn((kHdOut, kHdIn), dtype=torch.bfloat16, device="cuda").transpose(
0, 1
)
output = torch.empty(
(num_tokens, kHdOut), dtype=torch.bfloat16, device="cuda"
).contiguous()
ref = F.linear(mat_a, mat_b.T)
output = dsv3_fused_a_gemm(mat_a, mat_b)
assert torch.allclose(
output, ref, rtol=1e-2, atol=1e-3
), "Fused GEMM output mismatch with torch.nn.functional.linear reference"
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,35 @@
import pytest
import torch
import torch.nn.functional as F
from sgl_kernel import dsv3_router_gemm
@pytest.mark.parametrize("num_tokens", [i + 1 for i in range(16)])
@pytest.mark.parametrize("num_experts", [256, 384])
def test_dsv3_router_gemm(num_tokens, num_experts):
hidden_dim = 7168
mat_a = torch.randn(
(num_tokens, hidden_dim), dtype=torch.bfloat16, device="cuda"
).contiguous()
mat_b = torch.randn(
(num_experts, hidden_dim), dtype=torch.bfloat16, device="cuda"
).contiguous()
bf16_ref = F.linear(mat_a, mat_b)
float_ref = bf16_ref.to(torch.float32)
bf16_output = dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.bfloat16)
float_output = dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.float32)
assert torch.allclose(
bf16_output, bf16_ref, rtol=1e-2, atol=1e-3
), "Router GEMM output in bf16 dtype mismatch with torch.nn.functional.linear reference"
assert torch.allclose(
float_output, float_ref, rtol=1e-2, atol=1e-3
), "Router GEMM output in float32 dtype mismatch with torch.nn.functional.linear reference"
if __name__ == "__main__":
pytest.main([__file__])

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,877 @@
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/b31ae1e4cd22cf5f820a2995b74b7cd3bd54355a/tests/cute/test_flash_attn.py
# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
import itertools
import math
from functools import partial
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from sgl_kernel.flash_attn import flash_attn_varlen_func
from utils import is_hopper
flash_attn_varlen_func = partial(flash_attn_varlen_func, ver=4)
def unpad_input(hidden_states, attention_mask, unused_mask=None):
"""
Arguments:
hidden_states: (batch, seqlen, ...)
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
Return:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
max_seqlen_in_batch: int
seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
"""
all_masks = (
(attention_mask + unused_mask) if unused_mask is not None else attention_mask
)
seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
# index with integer indices.
return (
rearrange(hidden_states, "b s ... -> (b s) ...")[indices],
indices,
cu_seqlens,
max_seqlen_in_batch,
used_seqlens_in_batch,
)
def pad_input(hidden_states, indices, batch, seqlen):
"""
Arguments:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
batch: int, batch size for the padded sequence.
seqlen: int, maximum sequence length for the padded sequence.
Return:
hidden_states: (batch, seqlen, ...)
"""
dim = hidden_states.shape[1:]
output = torch.zeros(
(batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype
)
output[indices] = hidden_states
return rearrange(output, "(b s) ... -> b s ...", b=batch)
def generate_random_padding_mask(
max_seqlen, batch_size, device, mode="random", zero_lengths=False
):
assert mode in ["full", "random", "third"]
if mode == "full":
lengths = torch.full(
(batch_size, 1), max_seqlen, device=device, dtype=torch.int32
)
elif mode == "random":
lengths = torch.randint(
max(0 if zero_lengths else 1, max_seqlen - 20),
max_seqlen + 1,
(batch_size, 1),
device=device,
)
elif mode == "third":
lengths = torch.randint(
max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device
)
if zero_lengths:
# Generate zero-lengths every 5 batches and the last batch.
for i in range(batch_size):
if i % 5 == 0:
lengths[i] = 0
lengths[-1] = 0
padding_mask = (
repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size)
< lengths
)
return padding_mask
def generate_qkv(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
qv=None,
kvpacked=False,
qkvpacked=False,
query_unused_mask=None,
key_unused_mask=None,
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, d)
k: (batch_size, seqlen_k, nheads_k, d)
v: (batch_size, seqlen_k, nheads_k, d_v)
query_padding_mask: (batch_size, seqlen), bool
key_padding_mask: (batch_size, seqlen), bool
"""
assert not (kvpacked and qkvpacked)
batch_size, seqlen_q, nheads, d = q.shape
d_v = v.shape[-1]
_, seqlen_k, nheads_k, _ = k.shape
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
assert v.shape == (batch_size, seqlen_k, nheads_k, d_v)
if query_unused_mask is not None or key_unused_mask is not None:
assert not kvpacked
assert not qkvpacked
if query_padding_mask is not None:
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input(
q, query_padding_mask, query_unused_mask
)
output_pad_fn = lambda output_unpad: pad_input(
output_unpad, indices_q, batch_size, seqlen_q
)
qv_unpad = (
rearrange(qv, "b s ... -> (b s) ...")[indices_q] if qv is not None else None
)
else:
q_unpad = rearrange(q, "b s h d -> (b s) h d")
cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * seqlen_q,
step=seqlen_q,
dtype=torch.int32,
device=q_unpad.device,
)
seqused_q = None
max_seqlen_q = seqlen_q
output_pad_fn = lambda output_unpad: rearrange(
output_unpad, "(b s) h d -> b s h d", b=batch_size
)
qv_unpad = rearrange(qv, "b s ... -> (b s) ...") if qv is not None else None
if key_padding_mask is not None:
k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input(
k, key_padding_mask, key_unused_mask
)
v_unpad, *rest = unpad_input(v, key_padding_mask, key_unused_mask)
else:
k_unpad = rearrange(k, "b s h d -> (b s) h d")
v_unpad = rearrange(v, "b s h d -> (b s) h d")
cu_seqlens_k = torch.arange(
0,
(batch_size + 1) * seqlen_k,
step=seqlen_k,
dtype=torch.int32,
device=k_unpad.device,
)
seqused_k = None
max_seqlen_k = seqlen_k
if qkvpacked:
assert (query_padding_mask == key_padding_mask).all()
assert nheads == nheads_k
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
qkv = torch.stack([q, k, v], dim=2)
if query_padding_mask is not None:
dqkv_pad_fn = lambda dqkv_unpad: pad_input(
dqkv_unpad, indices_q, batch_size, seqlen_q
)
else:
dqkv_pad_fn = lambda dqkv_unpad: rearrange(
dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
)
return (
qkv_unpad.detach().requires_grad_(),
cu_seqlens_q,
max_seqlen_q,
qkv.detach().requires_grad_(),
output_pad_fn,
dqkv_pad_fn,
)
elif kvpacked:
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
kv = torch.stack([k, v], dim=2)
dq_pad_fn = output_pad_fn
if key_padding_mask is not None:
dkv_pad_fn = lambda dkv_unpad: pad_input(
dkv_unpad, indices_k, batch_size, seqlen_k
)
else:
dkv_pad_fn = lambda dkv_unpad: rearrange(
dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
)
return (
q_unpad.detach().requires_grad_(),
kv_unpad.detach().requires_grad_(),
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q.detach().requires_grad_(),
kv.detach().requires_grad_(),
output_pad_fn,
dq_pad_fn,
dkv_pad_fn,
)
else:
dq_pad_fn = output_pad_fn
if key_padding_mask is not None:
dk_pad_fn = lambda dk_unpad: pad_input(
dk_unpad, indices_k, batch_size, seqlen_k
)
else:
dk_pad_fn = lambda dk_unpad: rearrange(
dk_unpad, "(b s) h d -> b s h d", b=batch_size
)
return (
q_unpad.detach().requires_grad_(),
k_unpad.detach().requires_grad_(),
v_unpad.detach().requires_grad_(),
qv_unpad.detach() if qv is not None else None,
cu_seqlens_q,
cu_seqlens_k,
seqused_q,
seqused_k,
max_seqlen_q,
max_seqlen_k,
q.detach().requires_grad_(),
k.detach().requires_grad_(),
v.detach().requires_grad_(),
qv.detach() if qv is not None else None,
output_pad_fn,
dq_pad_fn,
dk_pad_fn,
)
def construct_local_mask(
seqlen_q,
seqlen_k,
window_size=(None, None),
sink_token_length=0,
query_padding_mask=None,
key_padding_mask=None,
key_leftpad=None,
device=None,
):
row_idx = rearrange(
torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1"
)
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
if key_leftpad is not None:
key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1")
col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0])
col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)
sk = (
seqlen_k
if key_padding_mask is None
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
)
sq = (
seqlen_q
if query_padding_mask is None
else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
)
if window_size[0] is None:
return col_idx > row_idx + sk - sq + window_size[1]
else:
sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
return torch.logical_or(
col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
torch.logical_and(
col_idx < row_idx + sk - sq - window_size[0],
col_idx >= sink_token_length,
),
)
def construct_chunk_mask(
seqlen_q,
seqlen_k,
attention_chunk,
query_padding_mask=None,
key_padding_mask=None,
key_leftpad=None,
device=None,
):
row_idx = rearrange(
torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1"
)
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
if key_leftpad is not None:
key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1")
col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0])
col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)
sk = (
seqlen_k
if key_padding_mask is None
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
)
sq = (
seqlen_q
if query_padding_mask is None
else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
)
sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
# Subtract remainder instead of divide and then multiply to take care of negative values
col_limit_left_chunk = row_idx + sk - sq - (row_idx + sk - sq) % attention_chunk
return torch.logical_or(
col_idx < col_limit_left_chunk,
col_idx >= col_limit_left_chunk + attention_chunk,
)
def attention_ref(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
key_leftpad=None,
attn_bias=None,
dropout_p=0.0,
dropout_mask=None,
causal=False,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=(None, None),
attention_chunk=0,
sink_token_length=0,
learnable_sink=None,
softcap=0.0,
upcast=True,
reorder_ops=False,
intermediate_dtype=None,
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, head_dim)
k: (batch_size, seqlen_k, nheads, head_dim)
v: (batch_size, seqlen_k, nheads, head_dim_v)
qv: (batch_size, seqlen_q, nheads, head_dim_v)
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
dropout_p: float
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
causal: whether to apply causal masking
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
output back to fp16/bf16.
reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)
without changing the math. This is to estimate the numerical error from operation
reordering.
Output:
output: (batch_size, seqlen_q, nheads, head_dim_v)
attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
"""
if causal:
window_size = (window_size[0], 0)
dtype_og = q.dtype
if upcast:
q, k, v = q.float(), k.float(), v.float()
qv = qv.float() if qv is not None else None
if q_descale is not None:
q_descale = repeat(q_descale, "b h -> b 1 (h g) 1", g=q.shape[2] // k.shape[2])
q = (q.float() * q_descale).to(q.dtype)
qv = (qv.float() * q_descale).to(qv.dtype) if qv is not None else None
if k_descale is not None:
k = (k.float() * rearrange(k_descale, "b h -> b 1 h 1")).to(dtype=k.dtype)
if v_descale is not None:
v = (v.float() * rearrange(v_descale, "b h -> b 1 h 1")).to(dtype=v.dtype)
seqlen_q, seqlen_k = q.shape[1], k.shape[1]
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
d = q.shape[-1]
dv = v.shape[-1]
softmax_scale = 1.0 / math.sqrt(d if qv is None else d + dv)
if not reorder_ops:
scores = torch.einsum("bthd,bshd->bhts", q * softmax_scale, k)
else:
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
if qv is not None:
scores = scores + torch.einsum("bthd,bshd->bhts", qv * softmax_scale, v)
if softcap > 0:
scores = torch.tanh(scores / softcap) * softcap
if key_padding_mask is not None:
scores.masked_fill_(
rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")
)
local_mask = None
if window_size[0] is not None or window_size[1] is not None:
local_mask = construct_local_mask(
seqlen_q,
seqlen_k,
window_size,
sink_token_length,
query_padding_mask,
key_padding_mask,
key_leftpad=key_leftpad,
device=q.device,
)
if attention_chunk > 0:
chunk_mask = construct_chunk_mask(
seqlen_q,
seqlen_k,
attention_chunk,
query_padding_mask,
key_padding_mask,
key_leftpad=key_leftpad,
device=q.device,
)
local_mask = (
torch.logical_or(local_mask, chunk_mask)
if local_mask is not None
else chunk_mask
)
if local_mask is not None:
scores.masked_fill_(local_mask, float("-inf"))
if attn_bias is not None:
scores = scores + attn_bias
if learnable_sink is None:
attention = torch.softmax(scores, dim=-1).to(v.dtype)
else:
scores_fp32 = scores.to(torch.float32)
logits_max = torch.amax(scores_fp32, dim=-1, keepdim=True)
learnable_sink = rearrange(learnable_sink, "h -> h 1 1")
logits_or_sinks_max = torch.maximum(learnable_sink, logits_max)
unnormalized_scores = torch.exp(scores_fp32 - logits_or_sinks_max)
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + torch.exp(
learnable_sink - logits_or_sinks_max
)
attention = (unnormalized_scores / normalizer).to(v.dtype)
# We want to mask here so that the attention matrix doesn't have any NaNs
# Otherwise we'll get NaN in dV
if query_padding_mask is not None:
attention = attention.masked_fill(
rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0
)
# Without this we might get NaN in dv
if key_padding_mask is not None:
attention = attention.masked_fill(
rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0
)
# Some rows might be completely masked out so we fill them with zero instead of NaN
if local_mask is not None:
attention = attention.masked_fill(
torch.all(local_mask, dim=-1, keepdim=True), 0.0
)
dropout_scaling = 1.0 / (1 - dropout_p)
# attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
# output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
if dropout_mask is not None:
attention_drop = attention.masked_fill(~dropout_mask, 0.0)
else:
attention_drop = attention
if intermediate_dtype is not None:
attention_drop = attention_drop.to(intermediate_dtype).to(attention_drop.dtype)
output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
if query_padding_mask is not None:
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
@pytest.mark.skipif(
is_hopper(),
reason="skip on hopper",
)
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize("mha_type", ["mqa"])
@pytest.mark.parametrize("has_learnable_sink", [False, True])
# @pytest.mark.parametrize("has_learnable_sink", [False])
# @pytest.mark.parametrize("has_qv", [False, True])
@pytest.mark.parametrize("has_qv", [False])
# @pytest.mark.parametrize("deterministic", [False, True])
@pytest.mark.parametrize("deterministic", [False])
# @pytest.mark.parametrize("softcap", [0.0, 15.0])
@pytest.mark.parametrize("softcap", [0.0])
@pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [False])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [False])
# @pytest.mark.parametrize("add_unused_qkv", [False, True])
@pytest.mark.parametrize("add_unused_qkv", [False])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128])
# @pytest.mark.parametrize("d", [64, 96, 128])
@pytest.mark.parametrize("d", [128, 192])
# @pytest.mark.parametrize("d", [192])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
# (1, 1),
# (1, 3),
# (2, 1),
(511, 1),
(3, 513),
(64, 128),
(128, 128),
(256, 256),
(113, 203),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(307, 256),
(640, 128),
(512, 256),
(1024, 1024),
(1023, 1024),
(1024, 1023),
(2048, 2048),
],
)
def test_flash_attn_varlen_output(
seqlen_q,
seqlen_k,
d,
add_unused_qkv,
causal,
local,
softcap,
deterministic,
has_qv,
has_learnable_sink,
mha_type,
dtype,
):
if (
causal or local
): # Right now we only support causal attention with seqlen_k == seqlen_q
seqlen_k = seqlen_q
device = "cuda"
# set seed
torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local))
batch_size = 49 if seqlen_q <= 1024 else 7
nheads = 6
# batch_size = 1
# nheads = 1
nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1)
dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
# dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])
dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d])
if dtype == torch.float8_e4m3fn:
dv_vals = [d]
# attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k else [0]
attention_chunk_vals = [0]
for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals):
q_ref = torch.randn(
batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref
)
if softcap > 0.0:
# Ensure the values of qk are at least within softcap range.
q_ref = (q_ref * softcap / 4).detach().requires_grad_()
q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_()
k_ref = (
torch.randn(
batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref
)
.to(dtype)
.to(dtype_ref)
.requires_grad_()
)
v_ref = (
torch.randn(
batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref
)
.to(dtype)
.to(dtype_ref)
.requires_grad_()
)
if has_qv:
qv_ref = (
torch.randn(
batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref
)
.to(dtype)
.to(dtype_ref)
)
else:
qv_ref = None
# Put window_size after QKV randn so that window_size changes from test to test
window_size = (
(None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist()
)
if has_learnable_sink:
learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device)
else:
learnable_sink = None
if dtype == torch.float8_e4m3fn:
q_descale, k_descale, v_descale = [
torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32)
* 2
for _ in range(3)
]
else:
q_descale, k_descale, v_descale = None, None, None
q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)]
qv = qv_ref.detach() if has_qv else None
query_padding_mask = generate_random_padding_mask(
seqlen_q, batch_size, device, mode="random", zero_lengths=False
)
# TODO: test zero_lengths
key_padding_mask = generate_random_padding_mask(
# seqlen_k, batch_size, device, mode="random", zero_lengths=True
seqlen_k,
batch_size,
device,
mode="random",
zero_lengths=False,
)
def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):
if add_unused:
another_mask = generate_random_padding_mask(max_seq_len, bs, device)
attn_mask = torch.logical_and(padding_mask, another_mask)
unused_mask = torch.logical_xor(
torch.logical_or(padding_mask, another_mask), attn_mask
)
else:
attn_mask = padding_mask
unused_mask = None
return attn_mask, unused_mask
query_padding_mask, query_unused_mask = _gen_unused_masks(
query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device
)
# query_padding_mask[:] = True
# query_unused_mask = None
key_padding_mask, key_unused_mask = _gen_unused_masks(
key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device
)
if causal or local:
key_padding_mask = query_padding_mask
(
q_unpad,
k_unpad,
v_unpad,
qv_unpad,
cu_seqlens_q,
cu_seqlens_k,
seqused_q,
seqused_k,
max_seqlen_q,
max_seqlen_k,
q,
k,
v,
qv,
output_pad_fn,
dq_pad_fn,
dk_pad_fn,
) = generate_qkv(
q,
k,
v,
query_padding_mask,
key_padding_mask,
qv=qv,
kvpacked=False,
query_unused_mask=query_unused_mask,
key_unused_mask=key_unused_mask,
)
q_unpad, k_unpad, v_unpad = [
x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)
]
out_ref, attn_ref = attention_ref(
q_ref,
k_ref,
v_ref,
query_padding_mask,
key_padding_mask,
causal=causal,
qv=qv_ref,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
window_size=window_size,
attention_chunk=attention_chunk,
learnable_sink=learnable_sink,
softcap=softcap,
)
out_pt, attn_pt = attention_ref(
q_ref,
k_ref,
v_ref,
query_padding_mask,
key_padding_mask,
causal=causal,
qv=qv_ref,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
window_size=window_size,
attention_chunk=attention_chunk,
learnable_sink=learnable_sink,
softcap=softcap,
upcast=False,
reorder_ops=True,
intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None,
)
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
if query_unused_mask is not None:
q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1")
# Numerical error if we just do any arithmetic on out_ref
fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item()
rtol = 2 if softcap == 0.0 else 3
pack_gqa_vals = [False, True, None]
# num_splits_vals = [1, 3]
num_splits_vals = [1]
for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals):
out_unpad, lse = flash_attn_varlen_func(
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=None,
max_seqlen_k=None,
# seqused_q=seqused_q,
# seqused_k=seqused_k,
causal=causal,
# qv=qv_unpad,
# q_descale=q_descale,
# k_descale=k_descale, v_descale=v_descale,
window_size=window_size,
# attention_chunk=attention_chunk,
sinks=learnable_sink,
softcap=softcap,
pack_gqa=pack_gqa,
return_softmax_lse=True,
)
out = output_pad_fn(out_unpad)
if query_unused_mask is not None:
out.masked_fill_(q_zero_masking, 0.0)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
# if not causal:
# print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}")
# breakpoint()
# Check that FlashAttention's numerical error is at most 3x the numerical error
# of a Pytorch implementation.
assert (out - out_ref).abs().max().item() <= rtol * (
out_pt - out_ref
).abs().max().item() + fwd_atol
if (
dtype != torch.float8_e4m3fn
and not has_qv
and not dv > 256
and not attention_chunk != 0
and dv == d
and not has_learnable_sink
and False
):
g_unpad = torch.randn_like(out_unpad)
do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2)
# import flash_attn_3_cuda
# dq_unpad, dk_unpad, dv_unpad, softmax_d, dq_accum, lse_log2 = flash_attn_3_cuda.bwd_varlen(
# g_unpad,
# q_unpad,
# k_unpad,
# v_unpad,
# out_unpad,
# lse,
# None,
# None,
# None,
# cu_seqlens_q,
# cu_seqlens_k,
# None, None,
# max_seqlen_q,
# max_seqlen_k,
# d ** (-0.5),
# causal,
# window_size[0], window_size[1],
# softcap,
# deterministic,
# 0, # sm_margin
# )
dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(
out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad
)
dq = dq_pad_fn(dq_unpad)
dk = dk_pad_fn(dk_unpad)
dv = dk_pad_fn(dv_unpad)
if key_unused_mask is not None:
k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1")
dk.masked_fill_(k_zero_masking, 0.0)
dv.masked_fill_(k_zero_masking, 0.0)
if query_unused_mask is not None:
dq.masked_fill_(q_zero_masking, 0.0)
# print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}")
# assert (softmax_d - do_o).abs().max().item() <= 1e-5
# assert dq_accum.abs().max().item() == 0.0
g = output_pad_fn(g_unpad)
# qk = torch.einsum('bthd,bshd->bhts', q / (d ** 0.5), k).float()
# qk = torch.masked_fill(qk, rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
# dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float())
# P = torch.softmax(qk, -1)
# dP = P * (dS - (g.float() * out.float()).sum(-1).transpose(1, 2).unsqueeze(-1))
# dQ = torch.einsum('bhts,bshd->bthd', dP, k.float())
# dV = torch.einsum('bhts,bthd->bshd', P, g.float())
# dK = torch.einsum('bhts,bthd->bshd', dP, q.float())
# dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)
dq_ref, dk_ref, dv_ref = torch.autograd.grad(
out_ref, (q_ref, k_ref, v_ref), g
)
dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g)
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
# breakpoint()
dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (
0 if softcap == 0 else 3e-4
)
assert (dq - dq_ref).abs().max().item() <= rtol * (
dq_pt - dq_ref
).abs().max().item() + dq_atol
dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (
0 if softcap == 0 else 3e-4
)
assert (dk - dk_ref).abs().max().item() <= rtol * (
dk_pt - dk_ref
).abs().max().item() + dk_atol
dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (
0 if softcap == 0 else 3e-4
)
assert (dv - dv_ref).abs().max().item() <= rtol * (
dv_pt - dv_ref
).abs().max().item() + dv_atol
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,154 @@
import pytest
import torch
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
skip_condition = torch.cuda.get_device_capability() < (10, 0)
DTYPES = [torch.float16, torch.bfloat16]
# m, n, k
SHAPES = [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)]
PAD_SHAPES = [(150, 128, 64), (128, 128, 96)]
SHAPES.extend(PAD_SHAPES)
FLOAT4_E2M1_MAX = 6.0
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
kE2M1ToFloatArray = [
0.0,
0.5,
1.0,
1.5,
2.0,
3.0,
4.0,
6.0,
]
def e2m1_to_fp32(int4_value):
signBit = int4_value & 0x8
int4_absValue = int4_value & 0x7
float_result = kE2M1ToFloatArray[int4_absValue]
if signBit:
float_result = -float_result
return float_result
def break_fp4_bytes(a, dtype):
assert a.dtype == torch.uint8
m, n = a.shape
a = a.flatten()
# Get upper 4 bits
highHalfByte = (a & 0xF0) >> 4
# Get lower 4 bits
lowHalfByte = a & 0x0F
fH = torch.tensor([e2m1_to_fp32(x) for x in highHalfByte]).to(a.device)
fL = torch.tensor([e2m1_to_fp32(x) for x in lowHalfByte]).to(a.device)
# [0xAB, 0xCD] -> [0xB, 0xA, 0xD, 0xC]
out = torch.stack((fL, fH), dim=-1).reshape(m, n * 2)
return out
def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
sf_m, sf_k = a_sf_swizzled.shape
m_tiles = (m + 128 - 1) // 128
f = block_size * 4
k_tiles = (k + f - 1) // f
tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4))
tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))
out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size)
return out[0:m, 0:k]
def dequantize_to_dtype(
tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16
):
"""Dequantize the fp4 tensor back to high precision."""
# Two fp4 values are packed into one uint8.
assert tensor_fp4.dtype == torch.uint8
m, packed_k = tensor_fp4.shape
k = packed_k * 2
tensor_f32 = break_fp4_bytes(tensor_fp4, dtype)
tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale
# scale the tensor
out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
return out
def get_ref_results(
a_fp4,
b_fp4,
a_sf,
b_sf,
a_global_scale,
b_global_scale,
m,
n,
dtype,
block_size,
device,
):
_, m_k = a_fp4.shape
_, n_k = b_fp4.shape
assert m_k == n_k
a_in_dtype = dequantize_to_dtype(
a_fp4, a_sf, a_global_scale, dtype=dtype, device=device, block_size=block_size
)
b_in_dtype = dequantize_to_dtype(
b_fp4, b_sf, b_global_scale, dtype=dtype, device=device, block_size=block_size
)
return torch.matmul(a_in_dtype, b_in_dtype.t())
@pytest.mark.skipif(
skip_condition, reason="Nvfp4 Requires compute capability of 10 or above."
)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("shape", SHAPES)
@torch.inference_mode()
def test_nvfp4_gemm(
dtype: torch.dtype,
shape: tuple[int, int],
) -> None:
m, n, packed_k = shape
k = packed_k * 2
block_size = 16
a_dtype = torch.randn((m, k), dtype=dtype, device="cuda")
b_dtype = torch.randn((n, k), dtype=dtype, device="cuda")
a_global_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a_dtype.flatten(), dim=-1)
).to(torch.float32)
b_global_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1)
).to(torch.float32)
alpha = 1.0 / (a_global_scale * b_global_scale)
a_fp4, a_scale_interleaved = scaled_fp4_quant(a_dtype, a_global_scale)
b_fp4, b_scale_interleaved = scaled_fp4_quant(b_dtype, b_global_scale)
expected_out = get_ref_results(
a_fp4,
b_fp4,
a_scale_interleaved,
b_scale_interleaved,
a_global_scale,
b_global_scale,
m,
n,
dtype,
block_size,
"cuda",
)
out = cutlass_scaled_fp4_mm(
a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype
)
torch.testing.assert_close(out, expected_out.to(dtype=dtype), atol=1e-1, rtol=1e-1)
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,261 @@
import pytest
import torch
from sgl_kernel import (
scaled_fp4_grouped_quant,
scaled_fp4_quant,
silu_and_mul,
silu_and_mul_scaled_fp4_grouped_quant,
)
skip_condition = torch.cuda.get_device_capability() < (10, 0)
DTYPES = [torch.float16, torch.bfloat16]
SHAPES = [(128, 64), (128, 128), (256, 64), (256, 128)]
PAD_SHAPES = [
(90, 64),
(150, 64),
(128, 48),
(128, 80),
(150, 80),
(90, 48),
(90, 128),
(150, 128),
(150, 48),
(90, 80),
]
FLOAT4_E2M1_MAX = 6.0
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
# E2M1 to float
# 0111 -> 6
# 0110 -> 4
# 0101 -> 3
# 0100 -> 2
# 0011 -> 1.5
# 0010 -> 1
# 0001 -> 0.5
# 0000 -> 0
E2M1_TO_FLOAT32 = [
0.0,
0.5,
1.0,
1.5,
2.0,
3.0,
4.0,
6.0,
0.0,
-0.5,
-1.0,
-1.5,
-2.0,
-3.0,
-4.0,
-6.0,
]
BLOCK_SIZE = 16
def cast_from_fp4(x, m, n):
# The fp4 values are packed in uint8 as [v_1st | v_2nd]
v_2nd = x & 0xF
v_1st = (x >> 4) & 0xF
c = torch.stack((v_2nd, v_1st), dim=-1)
out = torch.tensor([E2M1_TO_FLOAT32[x] for x in c.flatten()])
out = out.reshape(m, n).to(torch.float32)
return out
def cast_to_fp4(x):
sign = torch.sign(x)
x = torch.abs(x)
x[(x >= 0.0) & (x <= 0.25)] = 0.0
x[(x > 0.25) & (x < 0.75)] = 0.5
x[(x >= 0.75) & (x <= 1.25)] = 1.0
x[(x > 1.25) & (x < 1.75)] = 1.5
x[(x >= 1.75) & (x <= 2.5)] = 2.0
x[(x > 2.5) & (x < 3.5)] = 3.0
x[(x >= 3.5) & (x <= 5.0)] = 4.0
x[x > 5.0] = 6.0
return x * sign
def get_reciprocal(x):
if isinstance(x, torch.Tensor):
return torch.where(x == 0, torch.tensor(0.0, dtype=x.dtype), 1.0 / x)
elif isinstance(x, (float, int)):
return 0.0 if x == 0 else 1.0 / x
else:
raise TypeError("Input must be a float, int, or a torch.Tensor.")
def ref_nvfp4_quant(x, global_scale):
assert global_scale.dtype == torch.float32
assert x.ndim == 2
m, n = x.shape
x = torch.reshape(x, (m, n // BLOCK_SIZE, BLOCK_SIZE))
vec_max = torch.max(torch.abs(x), dim=-1, keepdim=True)[0].to(torch.float32)
scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX))
scale = scale.to(torch.float8_e4m3fn).to(torch.float32)
output_scale = get_reciprocal(scale * get_reciprocal(global_scale))
scaled_x = x.to(torch.float32) * output_scale
clipped_x = torch.clamp(scaled_x, -6.0, 6.0).reshape(m, n)
return cast_to_fp4(clipped_x), scale.squeeze(-1)
def recover_swizzled_scales(scale, m, n):
rounded_m = ((m + 128 - 1) // 128) * 128
scale_n = n // BLOCK_SIZE
rounded_n = ((scale_n + 4 - 1) // 4) * 4
# Recover the swizzled scaling factor to linear layout
tmp = torch.reshape(scale, (1, rounded_m // 128, rounded_n // 4, 32, 4, 4))
tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))
result = torch.reshape(tmp, (rounded_m, rounded_n)).to(torch.float32)
return result[:m, :scale_n]
@pytest.mark.skipif(
skip_condition, reason="Nvfp4 Requires compute capability of 10 or above."
)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("shape", SHAPES)
@torch.inference_mode()
def test_quantize_to_fp4(
dtype: torch.dtype,
shape: tuple[int, int],
) -> None:
torch.manual_seed(42)
torch.set_default_device("cuda:0")
m, n = shape
x = torch.randn((m, n), dtype=dtype)
tensor_amax = torch.abs(x).max().to(torch.float32)
global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
out_ref, scale_ref = ref_nvfp4_quant(x, global_scale)
out, out_scale = scaled_fp4_quant(x, global_scale)
scale_ans = recover_swizzled_scales(out_scale, m, n)
out_ans = cast_from_fp4(out, m, n)
torch.testing.assert_close(out_ans, out_ref)
torch.testing.assert_close(scale_ans, scale_ref)
@pytest.mark.skipif(
skip_condition, reason="Nvfp4 Requires compute capability of 10 or above."
)
@pytest.mark.parametrize("pad_shape", PAD_SHAPES)
@torch.inference_mode()
def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None:
torch.manual_seed(42)
dtype = torch.float16
torch.set_default_device("cuda:0")
m, n = pad_shape
x = torch.randn((m, n), dtype=dtype)
tensor_amax = torch.abs(x).max().to(torch.float32)
global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
out_ref, scale_ref = ref_nvfp4_quant(x, global_scale)
out, out_scale = scaled_fp4_quant(x, global_scale)
scale_ans = recover_swizzled_scales(out_scale, m, n)
out_ans = cast_from_fp4(out, m, n)
torch.testing.assert_close(out_ans, out_ref)
torch.testing.assert_close(scale_ans, scale_ref)
@pytest.mark.skipif(
skip_condition, reason="Nvfp4 Requires compute capability of 10 or above."
)
@pytest.mark.parametrize("shape", [(2, 512, 2048), (2, 100, 128), (2, 128, 96)])
def test_quantize_to_fp4_grouped(shape):
torch.manual_seed(42)
torch.set_default_device("cuda:0")
l, m, k = shape
x = torch.randn((l, m, k), dtype=torch.bfloat16)
max_m = m // 2
assert max_m <= m
mask = torch.randint(1, max_m, (l,), dtype=torch.int32)
tensor_amax = x.abs().amax(dim=(1, 2)).to(torch.float32)
x_sf_global = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
output, output_scales = scaled_fp4_grouped_quant(
x,
x_sf_global,
mask,
)
# output in logical (m, k, l), but its physical layout is (l, m, k).
# So permute first to (l, m, k).
output = output.permute(2, 0, 1)
# output_scale in logical (32, 4, rm, 4, rk, l), but its physical layout is (l, rm, rk, 32, 4, 4).
# So permute first to (l, rm, rk, 32, 4, 4).
padded_m = ((m + 128 - 1) // 128) * 128
output_scales = output_scales.permute(5, 2, 4, 0, 1, 3).view(l, padded_m, -1)
for i in range(l):
a_fp4, a_scale_interleaved = scaled_fp4_quant(x[i], x_sf_global[i])
torch.testing.assert_close(a_fp4[: mask[i]], output[i][: mask[i]])
# Recover swizzled scales to linear layout and drop padded values, so
# no extra checks on padding are needed.
scale_ref = recover_swizzled_scales(a_scale_interleaved, m, k)
scale_ans = recover_swizzled_scales(output_scales[i], m, k)
torch.testing.assert_close(scale_ref[: mask[i]], scale_ans[: mask[i]])
@pytest.mark.skipif(
skip_condition, reason="Nvfp4 Requires compute capability of 10 or above."
)
@pytest.mark.parametrize("shape", [(32, 100, 2048), (32, 512, 2048), (6, 6144, 2048)])
def test_silu_and_mul_quantize_to_fp4_grouped(shape):
torch.manual_seed(42)
torch.set_default_device("cuda:0")
l, m, k = shape
x = torch.randn((l, m, k * 2), dtype=torch.bfloat16)
max_m = m // 2
assert max_m <= m
mask = torch.randint(1, max_m, (l,), dtype=torch.int32)
ref_y = silu_and_mul(x)
tensor_amax = ref_y.abs().amax(dim=(1, 2)).to(torch.float32)
y_sf_global = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
ref_output, ref_output_scales = scaled_fp4_grouped_quant(
ref_y,
y_sf_global,
mask,
)
output, output_scales = silu_and_mul_scaled_fp4_grouped_quant(
x,
y_sf_global,
mask,
)
# output in logical (m, k, l), but its physical layout is (l, m, k).
# So permute first to (l, m, k).
output = output.permute(2, 0, 1)
ref_output = ref_output.permute(2, 0, 1)
# output_scale in logical (32, 4, rm, 4, rk, l), but its physical layout is (l, rm, rk, 32, 4, 4).
# So permute first to (l, rm, rk, 32, 4, 4).
padded_m = ((m + 128 - 1) // 128) * 128
output_scales = output_scales.permute(5, 2, 4, 0, 1, 3).view(l, padded_m, -1)
ref_output_scales = ref_output_scales.permute(5, 2, 4, 0, 1, 3).view(
l, padded_m, -1
)
for i in range(l):
torch.testing.assert_close(ref_output[i, : mask[i]], output[i, : mask[i]])
# We need to recover the swizzled scales to linear layout before applying mask slice.
scale_ref = recover_swizzled_scales(ref_output_scales[i], m, k)
scale_ans = recover_swizzled_scales(output_scales[i], m, k)
torch.testing.assert_close(scale_ref[: mask[i]], scale_ans[: mask[i]])
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,93 @@
import os
import random
from typing import Optional, Type
import pytest
import torch
from sgl_kernel import fp8_blockwise_scaled_mm
def cdiv(a: int, b: int) -> int:
return -(a // -b)
def scale_shape(shape, group_shape):
assert len(shape) == len(group_shape)
return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))
def baseline_scaled_mm(
a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: Type[torch.dtype],
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# We treat N-dimensional group scaling as extended numpy-style broadcasting
# in numpy simply stretches dimensions with an extent of 1 to match the
# the target shape by repeating the data along that dimension (broadcasting)
# , we extend these semantics to say if the extent of a dimension in the
# source shape is not 1 and does not match the target shape we repeat each
# element along that dimension src_shape[dim] // target_shape[dim] times
# example if we have:
# a = [[1, 2], and target_shape = (2, 4)
# [3, 4]]
# then we would expand a to:
# a = [[1, 1, 2, 2],
# [3, 3, 4, 4]]
# NOTE this function this function does not explicitly broadcast dimensions
# with an extent of 1, since this can be done implicitly by pytorch
def group_broadcast(t, shape):
for i, s in enumerate(shape):
if t.shape[i] != s and t.shape[i] != 1:
assert s % t.shape[i] == 0
t = (
t.unsqueeze(i + 1)
.expand(*t.shape[: i + 1], s // t.shape[i], *t.shape[i + 1 :])
.flatten(i, i + 1)
)
return t
scale_a = group_broadcast(scale_a, a.shape)
scale_b = group_broadcast(scale_b, b.shape)
output = torch.mm(
(scale_a * a.to(dtype=torch.float32)), (scale_b * b.to(dtype=torch.float32))
).to(out_dtype)
if bias is not None:
output = output + bias
return output
def _test_accuracy_once(M, N, K, out_dtype, device):
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
a_fp32 = (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
b_fp32 = (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn).t()
scale_a_group_shape = (1, 128)
scale_b_group_shape = (128, 128)
scale_a_shape = scale_shape(a_fp8.shape, scale_a_group_shape)
scale_b_shape = scale_shape(b_fp8.shape, scale_b_group_shape)
scale_a = torch.randn(scale_a_shape, device=device, dtype=torch.float32) * 0.001
scale_b = torch.randn(scale_b_shape, device=device, dtype=torch.float32) * 0.001
scale_a = scale_a.t().contiguous().t()
scale_b = scale_b.t().contiguous().t()
o = baseline_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype)
o1 = fp8_blockwise_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype)
rtol = 0.02
atol = 1
torch.testing.assert_close(o, o1, rtol=rtol, atol=atol)
@pytest.mark.parametrize("M", [1, 3, 5, 127, 128, 512, 1024, 4096])
@pytest.mark.parametrize("N", [128, 512, 1024, 4096, 8192, 14080])
@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 14080, 16384])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
def test_accuracy(M, N, K, out_dtype):
_test_accuracy_once(M, N, K, out_dtype, "cuda")
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,221 @@
import random
from typing import Tuple
import pytest
import torch
from sgl_kernel import fp8_blockwise_scaled_grouped_mm
def cdiv(a: int, b: int) -> int:
return -(a // -b)
def scale_shape(shape, group_shape):
return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))
def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
finfo = torch.finfo(torch.float8_e4m3fn)
return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to(
dtype=torch.float8_e4m3fn
)
# Copy from: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/utils.py
def calc_diff(x, y):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim
def ceil_div(x: int, y: int) -> int:
return (x + y - 1) // y
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
pad_size = (128 - (n % 128)) % 128
x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros(
(ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device
)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(
x_view.size(0), x_view.size(2)
)
def baseline_scaled_mm(
a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: type[torch.dtype],
) -> torch.Tensor:
def group_broadcast(t, shape):
for i, s in enumerate(shape):
if t.shape[i] != s and t.shape[i] != 1:
assert s % t.shape[i] == 0
t = (
t.unsqueeze(i + 1)
.expand(*t.shape[: i + 1], s // t.shape[i], *t.shape[i + 1 :])
.flatten(i, i + 1)
)
return t
scale_a = group_broadcast(scale_a, a.shape)
scale_b = group_broadcast(scale_b, b.shape)
return torch.mm(
(scale_a * a.to(dtype=torch.float32)), (scale_b * b.to(dtype=torch.float32))
).to(out_dtype)
def is_sm100_supported(device=None) -> bool:
return (torch.cuda.get_device_capability(device)[0] == 10) and (
torch.version.cuda >= "12.8"
)
def is_sm90_supported(device=None) -> bool:
return (torch.cuda.get_device_capability(device)[0] == 9) and (
torch.version.cuda >= "12.3"
)
@pytest.mark.skipif(
not (is_sm100_supported() or is_sm90_supported()),
reason="fp8_blockwise_scaled_grouped_mm at sgl-kernel is only supported on sm100 or sm90",
)
@pytest.mark.parametrize("num_experts", [8, 16, 32, 64, 128])
@pytest.mark.parametrize("out_dtype", [torch.half, torch.bfloat16])
def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
device = "cuda"
alignment = 128
n_g = random.randint(1, 64) * 128
k_g = random.randint(1, 64) * 128
expert_offsets = torch.zeros((num_experts + 1), device=device, dtype=torch.int32)
problem_sizes = torch.zeros((num_experts, 3), device=device, dtype=torch.int32)
layout_sfa = torch.zeros((num_experts, 5), device=device, dtype=torch.int32)
layout_sfb = torch.zeros((num_experts, 5), device=device, dtype=torch.int32)
a_tensors = []
b_tensors = []
a_scales_tensors = []
b_scales_tensors = []
baseline_tensors = []
for g in range(num_experts):
m_g = random.randint(1, 256)
expert_offsets[g + 1] = expert_offsets[g] + m_g
problem_sizes[g][:] = torch.tensor([m_g, n_g, k_g], device=device)
a = torch.randn((m_g, k_g), device=device, dtype=out_dtype) # (M, K):(K, 1)
b = torch.randn((n_g, k_g), device=device, dtype=out_dtype).t() # (K, N):(1, K)
a_g, a_scale = per_token_cast_to_fp8(
a
) # ag -- (M, K):(K, 1), a_scale() -- (M, k):(k, 1)
b_g, b_scale = per_block_cast_to_fp8(
b
) # bg -- (K, N):(N, 1), b_scale() -- (k, n):(n, 1)
a_tensors.append(a_g)
b_tensors.append(b_g)
a_scales_tensors.append(a_scale)
b_scales_tensors.append(b_scale)
baseline = torch.mm(a, b)
baseline_tensors.append(baseline)
a_stack = torch.empty(
(expert_offsets[-1], k_g), device=device, dtype=torch.float8_e4m3fn
)
b_stack = torch.empty(
(num_experts, n_g, k_g), device=device, dtype=torch.float8_e4m3fn
)
a_scale_stack = torch.empty(
(expert_offsets[-1], (k_g // 128)), device=device, dtype=torch.float32
)
b_scale_stack = torch.empty(
(num_experts, n_g // 128, k_g // 128), device=device, dtype=torch.float32
)
for g in range(num_experts):
# Matrix A is Row-Major.
a_stack[expert_offsets[g] : expert_offsets[g + 1], :] = a_tensors[
g
] # a_stack[expert_offsets[g] : expert_offsets[g + 1], :] -- (M, K):(K, 1)
b_stack[g] = b_tensors[g].t() # b_stack[g] -- (N, K):(K, 1)
# We need K-Major scale factor
a_scale_stack[expert_offsets[g] : expert_offsets[g + 1], :] = a_scales_tensors[
g
]
b_scale_stack[g] = b_scales_tensors[
g
].t() # b_scale_stack[g] -- (k, n):(n, 1), we need transpose & contiguous later
b_stack = b_stack.transpose(1, 2) # Transpose Matrix B to Column-Major.
b_scale_stack = b_scale_stack.transpose(1, 2)
c_out = torch.empty((expert_offsets[-1], n_g), device=device, dtype=out_dtype)
a_strides = torch.full(
(num_experts,), a_stack.stride(0), device=device, dtype=torch.int64
)
c_strides = torch.full(
(num_experts,), c_out.stride(0), device=device, dtype=torch.int64
)
workspace = torch.empty((1024 * 1024 * 1024), device=device, dtype=torch.uint8)
a_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
b_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
out_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
a_scales_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
b_scales_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
fp8_blockwise_scaled_grouped_mm(
c_out,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a_stack,
b_stack,
a_scale_stack,
b_scale_stack,
a_strides,
a_strides,
c_strides,
layout_sfa,
layout_sfb,
problem_sizes,
expert_offsets[:-1],
workspace,
)
for g in range(num_experts):
baseline = baseline_tensors[g]
actual = c_out[expert_offsets[g] : expert_offsets[g + 1]]
diff = calc_diff(actual, baseline)
assert diff < 0.001
print(
f"m_g={baseline.shape[0]} n_g={n_g} k_g={k_g} num_experts={num_experts}, out_dtype={out_dtype}, diff={diff:.5f}: OK"
)
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,49 @@
import pytest
import torch
from sgl_kernel import fp8_scaled_mm
def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias):
o = torch.matmul(a.to(torch.float32), b.to(torch.float32))
o = o.to(torch.float32)
temp1 = o * scale_a.view(-1, 1)
temp2 = temp1 * scale_b.view(1, -1)
final = temp2.to(out_dtype)
if bias is not None:
final = final + bias.view(1, -1)
return final
def _test_accuracy_once(M, N, K, with_bias, out_dtype, device):
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
a_fp32 = (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
b_fp32 = (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
scale_a = torch.randn((M,), device=device, dtype=torch.float32) * 0.001
scale_b = torch.randn((N,), device=device, dtype=torch.float32) * 0.001
if with_bias:
bias = torch.randn((N,), device=device, dtype=out_dtype)
else:
bias = None
b_fp8 = b_fp8.t()
o = torch_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias)
o1 = fp8_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias)
rtol = 0.02
atol = 1
torch.testing.assert_close(o, o1, rtol=rtol, atol=atol)
print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK")
@pytest.mark.parametrize("M", [1, 128, 512, 1024, 4096])
@pytest.mark.parametrize("N", [16, 128, 512, 1024, 4096])
@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384])
@pytest.mark.parametrize("with_bias", [True, False])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
def test_accuracy(M, N, K, with_bias, out_dtype):
_test_accuracy_once(M, N, K, with_bias, out_dtype, "cuda")
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,131 @@
import pytest
import torch
from sgl_kernel import gptq_gemm
from sglang.srt.layers.quantization.utils import pack_cols, pack_rows
def torch_dequantize(q_weight, q_zeros, scales, g_idx, use_shuffle, bit, K, N):
assert bit == 4, "Reference dequantization only supports 4-bit"
group_size = K // scales.shape[0]
pack_factor = 32 // bit
# unpack q_weight: (K//pack_factor, N) -> (K, N)
unpacked_q_weight = torch.empty(
q_weight.shape[0] * pack_factor,
q_weight.shape[1],
dtype=torch.uint8,
device=q_weight.device,
)
for i in range(pack_factor):
unpacked_q_weight[i::pack_factor, :] = (q_weight >> (i * 4)) & 0x0F
# unpack q_zeros: (num_groups, N//pack_factor) -> (num_groups, N)
unpacked_q_zeros = torch.empty(
q_zeros.shape[0],
q_zeros.shape[1] * pack_factor,
dtype=torch.uint8,
device=q_zeros.device,
)
for i in range(pack_factor):
unpacked_q_zeros[:, i::pack_factor] = (q_zeros >> (i * 4)) & 0x0F
unpacked_q_zeros += 1
unpacked_q_zeros = unpacked_q_zeros.to(scales.dtype)
scale_zeros = unpacked_q_zeros * scales # (num_groups, N)
current_g_idx = torch.tensor(
[i // group_size for i in range(K)], dtype=torch.int32, device=q_weight.device
)
scale_mat = scales[current_g_idx] # (K, N)
scale_zeros_mat = scale_zeros[current_g_idx] # (K, N)
# dequant: weight * scale - scale_zeros
dequantized_b = unpacked_q_weight.to(scales.dtype) * scale_mat - scale_zeros_mat
return dequantized_b.reshape(K, N)
def torch_gptq_gemm(
a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_shuffle, bit
):
K, N = a.shape[1], b_q_weight.shape[1]
b_dequant = torch_dequantize(
b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_shuffle, bit, K, N
)
c = torch.matmul(a, b_dequant)
return c
def _test_gptq_gemm_once(M, N, K, bit, group_size, use_shuffle, dtype, device="cuda"):
b_fp = torch.randn(K, N, dtype=dtype, device=device)
assert K % group_size == 0, "K must be divisible by group_size"
num_groups = K // group_size
if use_shuffle:
return
else:
g_idx = torch.tensor(
[i // group_size for i in range(K)], dtype=torch.int32, device=device
)
b_shuffled = b_fp[g_idx]
b_grouped = b_shuffled.reshape(num_groups, group_size, N)
b_max = torch.max(b_grouped, dim=1, keepdim=True)[0]
b_min = torch.min(b_grouped, dim=1, keepdim=True)[0]
scales = (b_max - b_min) / (2**bit - 1)
scales = scales.clamp(min=1e-6)
zeros_float = (-b_min / scales).round()
q_b = (
(b_grouped / scales + zeros_float).round().clamp(0, 2**bit - 1).to(torch.uint8)
)
q_zeros_unpacked = zeros_float.to(torch.uint8) - 1
b_q_weight = pack_rows(q_b.reshape(K, N), bit, K, N)
q_zeros_unpacked = q_zeros_unpacked.reshape(num_groups, N)
b_gptq_qzeros = pack_cols(q_zeros_unpacked, bit, num_groups, N)
b_gptq_scales = scales.squeeze(1)
a = torch.randn(M, K, dtype=dtype, device=device)
c_ref = torch_gptq_gemm(
a, b_q_weight, b_gptq_qzeros, b_gptq_scales, g_idx, use_shuffle, bit
)
c_out = gptq_gemm(
a, b_q_weight, b_gptq_qzeros, b_gptq_scales, g_idx, use_shuffle, bit
)
rtol = 4e-2
atol = 4e-2
torch.testing.assert_close(c_ref, c_out, rtol=rtol, atol=atol)
print(
f"✅ Test passed: M={M}, N={N}, K={K}, bit={bit}, group_size={group_size}, use_shuffle={use_shuffle}, dtype={dtype}"
)
@pytest.mark.parametrize("M", [1, 8, 128])
@pytest.mark.parametrize("N", [2048, 4096])
@pytest.mark.parametrize("K", [2048, 4096])
@pytest.mark.parametrize("bit", [4])
@pytest.mark.parametrize("group_size", [128])
@pytest.mark.parametrize("use_shuffle", [False])
@pytest.mark.parametrize("dtype", [torch.float16])
def test_gptq_gemm(M, N, K, bit, group_size, use_shuffle, dtype):
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
_test_gptq_gemm_once(M, N, K, bit, group_size, use_shuffle, dtype, "cuda")
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,48 @@
import pytest
import torch
from sgl_kernel import int8_scaled_mm
from utils import is_sm10x
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias):
o = torch.matmul(a.to(torch.float32), b.to(torch.float32))
if bias is not None:
o = o.to(torch.float32) * scale_a.view(-1, 1) * scale_b.view(1, -1) + bias
else:
o = o.to(torch.float32) * scale_a.view(-1, 1) * scale_b.view(1, -1)
return o.to(out_dtype)
def _test_accuracy_once(M, N, K, with_bias, out_dtype, device):
a = to_int8(torch.randn((M, K), device=device) * 5)
b = to_int8(torch.randn((N, K), device=device).t() * 5)
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
if with_bias:
bias = torch.randn((N,), device="cuda", dtype=out_dtype) * 10
else:
bias = None
o = int8_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
o1 = torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
torch.testing.assert_close(o, o1)
@pytest.mark.skipif(
is_sm10x(),
reason="int8_scaled_mm is only supported on sm90 and lower",
)
@pytest.mark.parametrize("M", [1, 16, 32, 64, 128, 512, 1024, 4096, 8192])
@pytest.mark.parametrize("N", [16, 128, 512, 1024, 4096, 8192, 16384])
@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384])
@pytest.mark.parametrize("with_bias", [True, False])
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16])
def test_accuracy(M, N, K, with_bias, out_dtype):
_test_accuracy_once(M, N, K, with_bias, out_dtype, "cuda")
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,485 @@
import pytest
import torch
from sgl_kernel.kvcacheio import (
transfer_kv_all_layer,
transfer_kv_all_layer_direct_lf_pf,
transfer_kv_all_layer_mla,
transfer_kv_direct,
transfer_kv_per_layer,
transfer_kv_per_layer_direct_pf_lf,
transfer_kv_per_layer_mla,
)
def ref_copy_with_indices(src_pool, dst_pool, src_indices, dst_indices):
dst_pool[dst_indices] = src_pool[src_indices].to(dst_pool.device)
def ref_copy_with_indices_pf_direct(
src_pool, dst_pool, src_indices, dst_indices, page_size, layer_id, lf_to_pf=False
):
if lf_to_pf:
for i in range(0, len(src_indices), page_size):
dst_pool[dst_indices[i] // page_size][layer_id] = src_pool[layer_id][
src_indices[i : i + page_size]
].to(dst_pool.device)
else:
for i in range(0, len(src_indices), page_size):
dst_pool[layer_id][dst_indices[i : i + page_size]] = src_pool[
src_indices[i] // page_size
][layer_id].to(dst_pool.device)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("num_items_to_transfer", [1, 128, 1024])
@pytest.mark.parametrize("page_size", [1, 16, 64])
@pytest.mark.parametrize("item_size", [256])
@pytest.mark.parametrize("total_items_in_pool", [10240])
@pytest.mark.parametrize("is_mla", [False, True])
@pytest.mark.parametrize("all_layers", [False, True])
def test_transfer_kv(
dtype: torch.dtype,
num_items_to_transfer: int,
item_size: int,
page_size: int,
total_items_in_pool: int,
is_mla: bool,
all_layers: bool,
):
"""
Tests the per-layer transfer functions, treating tensors as memory pools.
"""
original_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
device = "cuda"
torch.cuda.manual_seed(42)
num_layers = 4 # A small number of layers for pool creation
total_pages_in_pool = total_items_in_pool // page_size
num_pages_to_transfer = num_items_to_transfer // page_size
if num_pages_to_transfer == 0:
torch.set_default_dtype(original_dtype)
return
page_indices = torch.randperm(total_pages_in_pool, dtype=torch.int64)
src_indices_host = torch.cat(
[
torch.arange(p * page_size, (p + 1) * page_size)
for p in page_indices[:num_pages_to_transfer]
]
)
src_indices_device = src_indices_host.to(device)
dst_indices_host = torch.cat(
[
torch.arange(p * page_size, (p + 1) * page_size)
for p in page_indices[num_pages_to_transfer : 2 * num_pages_to_transfer]
]
)
dst_indices_device = dst_indices_host.to(device)
# Prepare memory pools based on whether it's an MLA case.
if is_mla:
src_pool_host = torch.randn(
num_layers, total_items_in_pool, item_size
).pin_memory()
dst_pool_ref = torch.zeros_like(src_pool_host).to(device)
dst_pool_kernel = torch.zeros_like(dst_pool_ref)
dst_pool_direct = torch.zeros_like(dst_pool_ref)
else:
src_k_pool = torch.randn(
num_layers, total_items_in_pool, item_size
).pin_memory()
src_v_pool = torch.randn(
num_layers, total_items_in_pool, item_size
).pin_memory()
dst_k_pool_ref = torch.zeros_like(src_k_pool).to(device)
dst_v_pool_ref = torch.zeros_like(src_v_pool).to(device)
dst_k_pool_kernel = torch.zeros_like(dst_k_pool_ref)
dst_v_pool_kernel = torch.zeros_like(dst_v_pool_ref)
dst_k_pool_direct = torch.zeros_like(dst_k_pool_ref)
dst_v_pool_direct = torch.zeros_like(dst_v_pool_ref)
torch.cuda.synchronize()
# We will test the per-layer function on the first layer (index 0) of the pool.
layer_idx_to_test = 0
if is_mla:
if not all_layers:
ref_copy_with_indices(
src_pool_host[layer_idx_to_test],
dst_pool_ref[layer_idx_to_test],
src_indices_host,
dst_indices_device,
)
transfer_kv_per_layer_mla(
src_pool_host[layer_idx_to_test],
dst_pool_kernel[layer_idx_to_test],
src_indices_device,
dst_indices_device,
item_size=item_size * dtype.itemsize,
)
transfer_kv_direct(
[src_pool_host[layer_idx_to_test]],
[dst_pool_direct[layer_idx_to_test]],
src_indices_host,
dst_indices_device,
page_size=page_size,
)
else:
for layer_id in range(num_layers):
ref_copy_with_indices(
src_pool_host[layer_id],
dst_pool_ref[layer_id],
src_indices_host,
dst_indices_device,
)
src_layers_device = torch.tensor(
[src_pool_host[layer_id].data_ptr() for layer_id in range(num_layers)],
dtype=torch.uint64,
device=device,
)
dst_layers_device = torch.tensor(
[
dst_pool_kernel[layer_id].data_ptr()
for layer_id in range(num_layers)
],
dtype=torch.uint64,
device=device,
)
transfer_kv_all_layer_mla(
src_layers_device,
dst_layers_device,
src_indices_device,
dst_indices_device,
item_size=item_size * dtype.itemsize,
num_layers=num_layers,
)
transfer_kv_direct(
[src_pool_host[layer_id] for layer_id in range(num_layers)],
[dst_pool_direct[layer_id] for layer_id in range(num_layers)],
src_indices_host,
dst_indices_device,
page_size=page_size,
)
torch.cuda.synchronize()
torch.testing.assert_close(dst_pool_kernel, dst_pool_ref)
torch.testing.assert_close(dst_pool_direct, dst_pool_ref)
else:
if not all_layers:
ref_copy_with_indices(
src_k_pool[layer_idx_to_test],
dst_k_pool_ref[layer_idx_to_test],
src_indices_host,
dst_indices_device,
)
ref_copy_with_indices(
src_v_pool[layer_idx_to_test],
dst_v_pool_ref[layer_idx_to_test],
src_indices_host,
dst_indices_device,
)
transfer_kv_per_layer(
src_k_pool[layer_idx_to_test],
dst_k_pool_kernel[layer_idx_to_test],
src_v_pool[layer_idx_to_test],
dst_v_pool_kernel[layer_idx_to_test],
src_indices_device,
dst_indices_device,
item_size=item_size * dtype.itemsize,
)
transfer_kv_direct(
[src_k_pool[layer_idx_to_test], src_v_pool[layer_idx_to_test]],
[
dst_k_pool_direct[layer_idx_to_test],
dst_v_pool_direct[layer_idx_to_test],
],
src_indices_host,
dst_indices_device,
page_size=page_size,
)
else:
for layer_id in range(num_layers):
ref_copy_with_indices(
src_k_pool[layer_id],
dst_k_pool_ref[layer_id],
src_indices_host,
dst_indices_device,
)
ref_copy_with_indices(
src_v_pool[layer_id],
dst_v_pool_ref[layer_id],
src_indices_host,
dst_indices_device,
)
src_k_layers_device = torch.tensor(
[src_k_pool[layer_id].data_ptr() for layer_id in range(num_layers)],
dtype=torch.uint64,
device=device,
)
src_v_layers_device = torch.tensor(
[src_v_pool[layer_id].data_ptr() for layer_id in range(num_layers)],
dtype=torch.uint64,
device=device,
)
dst_k_layers_device = torch.tensor(
[
dst_k_pool_kernel[layer_id].data_ptr()
for layer_id in range(num_layers)
],
dtype=torch.uint64,
device=device,
)
dst_v_layers_device = torch.tensor(
[
dst_v_pool_kernel[layer_id].data_ptr()
for layer_id in range(num_layers)
],
dtype=torch.uint64,
device=device,
)
transfer_kv_all_layer(
src_k_layers_device,
dst_k_layers_device,
src_v_layers_device,
dst_v_layers_device,
src_indices_device,
dst_indices_device,
item_size=item_size * dtype.itemsize,
num_layers=num_layers,
)
transfer_kv_direct(
[src_k_pool[layer_id] for layer_id in range(num_layers)]
+ [src_v_pool[layer_id] for layer_id in range(num_layers)],
[dst_k_pool_direct[layer_id] for layer_id in range(num_layers)]
+ [dst_v_pool_direct[layer_id] for layer_id in range(num_layers)],
src_indices_host,
dst_indices_device,
page_size=page_size,
)
torch.cuda.synchronize()
torch.testing.assert_close(dst_k_pool_kernel, dst_k_pool_ref)
torch.testing.assert_close(dst_v_pool_kernel, dst_v_pool_ref)
torch.testing.assert_close(dst_k_pool_direct, dst_k_pool_ref)
torch.testing.assert_close(dst_v_pool_direct, dst_v_pool_ref)
torch.set_default_dtype(original_dtype)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("num_items_to_transfer", [128, 1024, 8192])
@pytest.mark.parametrize("page_size", [16, 64, 128])
@pytest.mark.parametrize("item_size", [256])
@pytest.mark.parametrize("total_items_in_pool", [20480])
@pytest.mark.parametrize("is_mla", [False, True])
@pytest.mark.parametrize("lf_to_pf", [False, True])
def test_transfer_kv_pf_direct(
dtype: torch.dtype,
num_items_to_transfer: int,
item_size: int,
page_size: int,
total_items_in_pool: int,
is_mla: bool,
lf_to_pf: bool,
):
original_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
device = "cuda"
torch.cuda.manual_seed(42)
num_layers = 4
total_pages_in_pool = total_items_in_pool // page_size
num_pages_to_transfer = num_items_to_transfer // page_size
if num_pages_to_transfer == 0:
torch.set_default_dtype(original_dtype)
return
page_indices = torch.randperm(total_pages_in_pool, dtype=torch.int64)
src_indices_host = torch.cat(
[
torch.arange(p * page_size, (p + 1) * page_size)
for p in page_indices[:num_pages_to_transfer]
]
)
src_indices_device = src_indices_host.to(device)
dst_indices_host = torch.cat(
[
torch.arange(p * page_size, (p + 1) * page_size)
for p in page_indices[num_pages_to_transfer : 2 * num_pages_to_transfer]
]
)
dst_indices_device = dst_indices_host.to(device)
# We will test the per-layer function on the first layer (index 0) of the pool.
layer_idx_to_test = 0
if lf_to_pf:
if is_mla:
src_pool = torch.randn(num_layers, total_items_in_pool, item_size).to(
device
)
src_pool_ptrs = [src_pool[i] for i in range(num_layers)]
dst_pool_ref = torch.zeros(
total_pages_in_pool, num_layers, page_size, item_size
).pin_memory()
dst_pool_direct = torch.zeros_like(dst_pool_ref)
torch.cuda.synchronize()
transfer_kv_all_layer_direct_lf_pf(
src_pool_ptrs,
[dst_pool_direct],
src_indices_host,
dst_indices_host,
page_size,
)
for i in range(num_layers):
ref_copy_with_indices_pf_direct(
src_pool,
dst_pool_ref,
src_indices_device,
dst_indices_host,
page_size,
i,
lf_to_pf=True,
)
torch.cuda.synchronize()
torch.testing.assert_close(dst_pool_direct, dst_pool_ref)
else:
src_k_pool = torch.randn(num_layers, total_items_in_pool, item_size).to(
device
)
src_k_pool_ptrs = [src_k_pool[i] for i in range(num_layers)]
src_v_pool = torch.randn(num_layers, total_items_in_pool, item_size).to(
device
)
src_v_pool_ptrs = [src_v_pool[i] for i in range(num_layers)]
dst_k_pool_ref = torch.zeros(
total_pages_in_pool, num_layers, page_size, item_size
).pin_memory()
dst_v_pool_ref = torch.zeros_like(dst_k_pool_ref)
dst_k_pool_direct = torch.zeros_like(dst_k_pool_ref)
dst_v_pool_direct = torch.zeros_like(dst_v_pool_ref)
torch.cuda.synchronize()
transfer_kv_all_layer_direct_lf_pf(
src_k_pool_ptrs + src_v_pool_ptrs,
[dst_k_pool_direct, dst_v_pool_direct],
src_indices_host,
dst_indices_host,
page_size,
)
for i in range(num_layers):
ref_copy_with_indices_pf_direct(
src_k_pool,
dst_k_pool_ref,
src_indices_device,
dst_indices_host,
page_size,
i,
lf_to_pf=True,
)
ref_copy_with_indices_pf_direct(
src_v_pool,
dst_v_pool_ref,
src_indices_device,
dst_indices_host,
page_size,
i,
lf_to_pf=True,
)
torch.cuda.synchronize()
torch.testing.assert_close(dst_k_pool_direct, dst_k_pool_ref)
torch.testing.assert_close(dst_v_pool_direct, dst_v_pool_ref)
else:
if is_mla:
src_pool = torch.randn(
total_pages_in_pool, num_layers, page_size, item_size
).pin_memory()
dst_pool_ref = torch.zeros(num_layers, total_items_in_pool, item_size).to(
device
)
dst_pool_direct = torch.zeros_like(dst_pool_ref)
dst_pool_direct_ptrs = [dst_pool_direct[i] for i in range(num_layers)]
torch.cuda.synchronize()
transfer_kv_per_layer_direct_pf_lf(
[src_pool],
[dst_pool_direct_ptrs[layer_idx_to_test]],
src_indices_host,
dst_indices_host,
layer_idx_to_test,
page_size,
)
ref_copy_with_indices_pf_direct(
src_pool,
dst_pool_ref,
src_indices_host,
dst_indices_device,
page_size,
layer_idx_to_test,
lf_to_pf=False,
)
torch.cuda.synchronize()
torch.testing.assert_close(dst_pool_direct, dst_pool_ref)
else:
src_k_pool = torch.randn(
total_pages_in_pool, num_layers, page_size, item_size
).pin_memory()
src_v_pool = torch.randn(
total_pages_in_pool, num_layers, page_size, item_size
).pin_memory()
dst_k_pool_ref = torch.zeros(num_layers, total_items_in_pool, item_size).to(
device
)
dst_k_pool_direct = torch.zeros_like(dst_k_pool_ref)
dst_k_pool_direct_ptrs = [dst_k_pool_direct[i] for i in range(num_layers)]
dst_v_pool_ref = torch.zeros_like(dst_k_pool_ref)
dst_v_pool_direct = torch.zeros_like(dst_v_pool_ref)
dst_v_pool_direct_ptrs = [dst_v_pool_direct[i] for i in range(num_layers)]
torch.cuda.synchronize()
transfer_kv_per_layer_direct_pf_lf(
[src_k_pool, src_v_pool],
[
dst_k_pool_direct_ptrs[layer_idx_to_test],
dst_v_pool_direct_ptrs[layer_idx_to_test],
],
src_indices_host,
dst_indices_host,
layer_idx_to_test,
page_size,
)
ref_copy_with_indices_pf_direct(
src_k_pool,
dst_k_pool_ref,
src_indices_host,
dst_indices_device,
page_size,
layer_idx_to_test,
lf_to_pf=False,
)
ref_copy_with_indices_pf_direct(
src_v_pool,
dst_v_pool_ref,
src_indices_host,
dst_indices_device,
page_size,
layer_idx_to_test,
lf_to_pf=False,
)
torch.cuda.synchronize()
torch.testing.assert_close(dst_k_pool_direct, dst_k_pool_ref)
torch.testing.assert_close(dst_v_pool_direct, dst_v_pool_ref)
torch.set_default_dtype(original_dtype)
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,84 @@
import pytest
import torch
from sgl_kernel import lightning_attention_decode
def naive_lightning_attention_decode(q, k, v, past_kv, slope):
"""Naive implementation of lightning attention decode"""
original_dtype = q.dtype
ratio = torch.exp(-slope) # [h, 1, 1]
kv = past_kv
b, h, n, d = q.shape
output = []
for i in range(n):
kv = ratio * kv.to(torch.float32) + torch.einsum(
"... n d, ... n e -> ... d e",
k[:, :, i : i + 1],
v[:, :, i : i + 1],
)
qkv = torch.einsum(
"... n e, ... e d -> ... n d",
q[:, :, i : i + 1].to(torch.float32),
kv.to(torch.float32),
)
output.append(qkv)
output = torch.cat(output, dim=-2)
return output.to(original_dtype), kv
configs = [
# (batch_size, num_heads, dim, embed_dim)
(1, 8, 64, 64),
(2, 8, 64, 64),
(1, 32, 32, 64),
(2, 32, 32, 64),
(4, 32, 64, 64),
(4, 32, 64, 64),
(16, 64, 96, 96),
(64, 64, 96, 96),
]
dtypes = [torch.float32, torch.float16, torch.bfloat16]
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("batch_size,num_heads,dim,embed_dim", configs)
def test_lightning_attention_decode(dtype, batch_size, num_heads, dim, embed_dim):
device = torch.device("cuda")
q = torch.randn(batch_size, num_heads, 1, dim, device=device, dtype=dtype)
k = torch.randn(batch_size, num_heads, 1, dim, device=device, dtype=dtype)
v = torch.randn(batch_size, num_heads, 1, embed_dim, device=device, dtype=dtype)
past_kv = torch.randn(batch_size, num_heads, dim, embed_dim, device=device)
slope = torch.randn(num_heads, 1, 1, device=device)
ref_output, ref_new_kv = naive_lightning_attention_decode(q, k, v, past_kv, slope)
output = torch.empty_like(ref_output)
new_kv = torch.empty_like(ref_new_kv)
lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv)
rtol = 1e-2
atol = 1e-2
torch.testing.assert_close(
output,
ref_output,
rtol=rtol,
atol=atol,
)
torch.testing.assert_close(
new_kv,
ref_new_kv,
rtol=rtol,
atol=atol,
)
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,121 @@
import pytest
import torch
from sgl_kernel import gptq_marlin_gemm
from sgl_kernel.scalar_type import scalar_types
from sglang.srt.layers.quantization.marlin_utils import marlin_make_workspace
from sglang.test.test_marlin_utils import awq_marlin_quantize, marlin_quantize
MNK_FACTORS = [
(1, 1, 1),
(1, 4, 8),
(1, 7, 5),
(13, 17, 67),
(26, 37, 13),
(67, 13, 11),
(257, 13, 11),
(658, 13, 11),
]
# uint4 for awq
# uint4b8 for gptq
@pytest.mark.parametrize("k_chunk", [128])
@pytest.mark.parametrize("n_chunk", [64, 256])
@pytest.mark.parametrize("quant_type", [scalar_types.uint4, scalar_types.uint4b8])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("act_order", [False, True])
@pytest.mark.parametrize("is_k_full", [False, True])
@pytest.mark.parametrize("use_atomic_add", [False, True])
@pytest.mark.parametrize("use_fp32_reduce", [False, True])
def test_gptq_marlin_gemm(
k_chunk,
n_chunk,
quant_type,
group_size,
mnk_factors,
act_order,
is_k_full,
use_atomic_add,
use_fp32_reduce,
):
m_factor, n_factor, k_factor = mnk_factors
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
if act_order:
if group_size == -1:
return
if group_size == size_k:
return
if has_zp:
return
if size_k % group_size != 0:
return
a_input = torch.randn((size_m, size_k), dtype=torch.float16, device="cuda")
b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device="cuda")
if has_zp:
# AWQ style, unsigned + runtime zero-point
if group_size == 16:
return
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
b_weight, quant_type, group_size
)
g_idx = None
sort_indices = None
marlin_s2 = None
else:
# GPTQ style, unsigned + symmetric bias
if group_size == 16:
return
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
b_weight, quant_type, group_size, act_order
)
marlin_zp = None
marlin_s2 = None
workspace = marlin_make_workspace(w_ref.device)
# marlin gemm
output = gptq_marlin_gemm(
a_input,
None,
marlin_q_w,
marlin_s,
marlin_s2,
marlin_zp,
g_idx,
sort_indices,
workspace,
quant_type,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)
# ref gemm
output_ref = torch.matmul(a_input, w_ref)
torch.cuda.synchronize()
max_diff = torch.mean(torch.abs(output - output_ref)) / torch.mean(
torch.abs(output_ref)
)
assert max_diff < 0.04
if __name__ == "__main__":
import subprocess
subprocess.call(["pytest", "--tb=short", str(__file__)])

View File

@@ -0,0 +1,148 @@
import numpy as np
import pytest
import torch
from sgl_kernel import awq_marlin_repack, gptq_marlin_repack
from sgl_kernel.scalar_type import scalar_types
from sglang.srt.layers.quantization.utils import (
gptq_quantize_weights,
pack_cols,
pack_rows,
quantize_weights,
sort_weights,
)
from sglang.test.test_marlin_utils import get_weight_perm, marlin_weights
GPTQ_MARLIN_TILE = 16
MARLIN_K_CHUNKS = [128]
MARLIN_N_CHUNKS = [64, 256]
MNK_FACTORS = [
(1, 1, 1),
(1, 4, 8),
(1, 7, 5),
(13, 17, 67),
(26, 37, 13),
(67, 13, 11),
(257, 13, 11),
(658, 13, 11),
]
def awq_pack(
q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
assert q_w.shape == (size_k, size_n)
# Interleave column dim (for the dequantize code) and pack it to int32
if num_bits == 4:
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
elif num_bits == 8:
interleave = np.array([0, 2, 1, 3])
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel()
q_w = q_w.reshape((-1, size_n)).contiguous()
return pack_cols(q_w, num_bits, size_k, size_n)
@pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("k_tiles,n_tiles", [(1, 1), (2, 2)])
@pytest.mark.parametrize("group_size", [16, 32])
def test_awq_marlin_repack_correct(num_bits, k_tiles, n_tiles, group_size):
tile_k, tile_n = 16, 64
size_k = k_tiles * tile_k
size_n = n_tiles * tile_n
pack_factor = 32 // num_bits
b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device="cuda")
w_ref, q_w, s, zp = quantize_weights(
b_weight, scalar_types.uint4, group_size, zero_points=True
)
q_w_awq = awq_pack(q_w, num_bits, size_k, size_n)
weight_perm = get_weight_perm(num_bits)
q_w_marlin = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
out_gpu = awq_marlin_repack(q_w_awq, size_k, size_n, num_bits)
assert out_gpu.is_cuda and out_gpu.dtype == torch.int32
expected_cols = size_n * tile_k // pack_factor
assert list(out_gpu.shape) == [size_k // tile_k, expected_cols]
torch.cuda.synchronize()
torch.testing.assert_close(out_gpu, q_w_marlin)
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("quant_type", [scalar_types.uint4b8])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [False, True])
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_gptq_marlin_repack(
k_chunk, n_chunk, quant_type, group_size, act_order, mnk_factors
):
m_factor, n_factor, k_factor = mnk_factors
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
# Filter act_order
if act_order:
if group_size == -1:
return
if group_size == size_k:
return
# Normalize group_size
if group_size == -1:
group_size = size_k
assert group_size <= size_k
if size_k % group_size != 0:
pytest.skip("size_k must be divisible by group_size")
# Create input
b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device="cuda")
# Quantize (and apply act_order if provided)
w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
b_weight, quant_type, group_size, act_order
)
q_w_gptq = pack_rows(q_w, quant_type.size_bits, size_k, size_n)
# For act_order, sort the "weights" and "g_idx" so that group ids are
# increasing
sort_indices = torch.empty(0, dtype=torch.int, device=b_weight.device)
if act_order:
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
marlin_layout_perm = get_weight_perm(quant_type.size_bits)
q_w_marlin_ref = marlin_weights(
q_w, size_k, size_n, quant_type.size_bits, marlin_layout_perm
)
# Run Marlin repack GPU kernel
q_w_marlin = gptq_marlin_repack(
q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits
)
torch.cuda.synchronize()
torch.testing.assert_close(q_w_marlin, q_w_marlin_ref)
if __name__ == "__main__":
import subprocess
subprocess.call(["pytest", "--tb=short", str(__file__)])

View File

@@ -0,0 +1,142 @@
# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/55576c626421b5ee7e7ebe74afd26465c8ae863f/flashinfer/triton/kernels/cascade.py
from typing import List
import pytest
import torch
import triton
import triton.language as tl
from sgl_kernel import merge_state
def check_input(x: torch.Tensor):
assert x.is_cuda, f"{str(x)} must be a CUDA Tensor"
assert x.is_contiguous(), f"{str(x)} must be contiguous"
def check_dim(d, x: torch.Tensor):
assert x.dim() == d, f"{str(x)} must be a {d}D tensor"
def check_shape(a: torch.Tensor, b: torch.Tensor):
assert a.dim() == b.dim(), "tensors should have same dim"
for i in range(a.dim()):
assert a.size(i) == b.size(
i
), f"tensors shape mismatch, {a.size()} and {b.size()}"
def check_device(tensors: List[torch.Tensor]):
device = tensors[0].device
for t in tensors:
assert (
t.device == device
), f"All tensors should be on the same device, but got {device} and {t.device}"
@triton.jit
def state_merge(o, m, d, other_o, other_m, other_d):
m_max = tl.maximum(m, other_m)
d = d * tl.exp2(m - m_max) + other_d * tl.exp2(other_m - m_max)
o = o * tl.exp2(m - m_max) + other_o * tl.exp2(other_m - m_max)
return o, m_max, d
@triton.jit
def state_normalize(o, m, d):
o = o / d
return o, m, d
@triton.jit
def state_get_lse(o, m, d):
return m + tl.log2(d)
@triton.jit
def merge_state_kernel(
v_a_ptr,
s_a_ptr,
v_b_ptr,
s_b_ptr,
v_merged_ptr,
s_merged_ptr,
num_heads,
head_dim,
bdx: tl.constexpr,
bdy: tl.constexpr,
):
pos = tl.program_id(axis=0)
for tx in tl.range(bdx):
for head_idx in tl.range(bdy):
s_a_val = tl.load(s_a_ptr + pos * num_heads + head_idx)
s_b_val = tl.load(s_b_ptr + pos * num_heads + head_idx)
offsets = (pos * num_heads + head_idx) * head_dim + tx
v_a = tl.load(v_a_ptr + offsets)
v_b = tl.load(v_b_ptr + offsets)
v_merged, s_max, d = state_merge(
o=v_a, m=s_a_val, d=1, other_o=v_b, other_m=s_b_val, other_d=1
)
v_merged, s_max, d = state_normalize(v_merged, s_max, d)
v_merged_offset = (pos * num_heads + head_idx) * head_dim + tx
tl.store(v_merged_ptr + v_merged_offset, v_merged)
if s_merged_ptr:
tl.store(
s_merged_ptr + pos * num_heads + head_idx,
tl.log2(d) + s_max,
)
def merge_state_triton(
v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor
):
check_input(v_a)
check_input(s_a)
check_input(v_b)
check_input(s_b)
check_device([v_a, s_a, v_b, s_b])
check_dim(3, v_a)
check_dim(2, s_a)
check_dim(3, v_b)
check_dim(2, s_b)
check_shape(v_a, v_b)
check_shape(s_a, s_b)
assert v_a.size(0) == s_a.size(0)
assert v_a.size(1) == s_b.size(1)
s_a = s_a.to(torch.float32)
s_b = s_b.to(torch.float32)
seq_len = v_a.size(0)
num_heads = v_a.size(1)
head_dim = v_a.size(2)
v_merged = torch.empty_like(v_a).to(s_a.device)
s_merged = torch.empty((seq_len, num_heads)).to(s_a.device)
bdx = head_dim
bdy = num_heads
merge_state_kernel[lambda meta: (seq_len,)](
v_a, s_a, v_b, s_b, v_merged, s_merged, num_heads, head_dim, bdx=bdx, bdy=bdy
)
return v_merged, s_merged
@pytest.mark.parametrize("seq_len", [2048])
@pytest.mark.parametrize("num_heads", [32])
@pytest.mark.parametrize("head_dim", [128])
def test_merge_state(seq_len, num_heads, head_dim):
va = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
sa = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
vb = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
sb = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
v_merged, s_merged = merge_state_triton(va, sa, vb, sb)
v_merged_std, s_merged_std = merge_state(va, sa, vb, sb)
assert torch.allclose(v_merged, v_merged_std, atol=1e-2)
assert torch.allclose(s_merged, s_merged_std, atol=1e-2)
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,400 @@
from typing import Optional
import pytest
import torch
import triton
import triton.language as tl
from sgl_kernel import merge_state, merge_state_v2
@triton.jit
def merge_state_kernel(
output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_merged
output_lse, # [NUM_TOKENS, NUM_HEADS] s_merged
prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_a
prefix_lse, # [NUM_TOKENS, NUM_HEADS] s_a
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_b
suffix_lse, # [NUM_TOKENS, NUM_HEADS] s_b
HEAD_SIZE: tl.constexpr,
PADDED_HEAD_SIZE: tl.constexpr,
OUTPUT_LSE: tl.constexpr,
):
token_idx = tl.program_id(0)
num_tokens = tl.num_programs(0)
head_idx = tl.program_id(1)
num_heads = tl.num_programs(1)
p_lse = tl.load(prefix_lse + token_idx * num_heads + head_idx)
s_lse = tl.load(suffix_lse + token_idx * num_heads + head_idx)
p_lse = float("-inf") if p_lse == float("inf") else p_lse
s_lse = float("-inf") if s_lse == float("inf") else s_lse
max_lse = tl.maximum(p_lse, s_lse)
p_lse = p_lse - max_lse
s_lse = s_lse - max_lse
out_se = tl.exp(p_lse) + tl.exp(s_lse)
if OUTPUT_LSE:
out_lse = tl.log(out_se) + max_lse
tl.store(output_lse + token_idx * num_heads + head_idx, out_lse)
head_arange = tl.arange(0, PADDED_HEAD_SIZE)
head_mask = head_arange < HEAD_SIZE
p_out = tl.load(
prefix_output
+ token_idx * num_heads * HEAD_SIZE
+ head_idx * HEAD_SIZE
+ head_arange,
mask=head_mask,
)
s_out = tl.load(
suffix_output
+ token_idx * num_heads * HEAD_SIZE
+ head_idx * HEAD_SIZE
+ head_arange,
mask=head_mask,
)
p_scale = tl.exp(p_lse) / out_se
s_scale = tl.exp(s_lse) / out_se
out = p_out * p_scale + s_out * s_scale
tl.store(
output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange,
out,
mask=head_mask,
)
def merge_state_triton(
prefix_output: torch.Tensor,
prefix_lse: torch.Tensor,
suffix_output: torch.Tensor,
suffix_lse: torch.Tensor,
output: Optional[torch.Tensor] = None,
output_lse: Optional[torch.Tensor] = None,
) -> None:
num_tokens = output.shape[0]
num_query_heads = output.shape[1]
head_size = output.shape[2]
padded_head_size = triton.next_power_of_2(head_size)
# Avoid creating new tensors if they are already provided
if output is None:
output = torch.empty_like(prefix_output)
if output_lse is None:
output_lse = torch.empty_like(prefix_lse)
merge_state_kernel[(num_tokens, num_query_heads)](
output,
output_lse,
prefix_output,
prefix_lse,
suffix_output,
suffix_lse,
head_size,
padded_head_size,
output_lse is not None,
)
return output, output_lse
# Naive PyTorch Implements of Merge Attn States
def merge_state_torch(
prefix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
prefix_lse: torch.Tensor, # [NUM_TOKENS, NUM_HEADS]
suffix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
suffix_lse: torch.Tensor, # [NUM_TOKENS, NUM_HEADS]
output: Optional[torch.Tensor] = None, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
output_lse: Optional[torch.Tensor] = None, # [NUM_TOKENS, NUM_HEADS]
):
# Avoid creating new tensors if they are already provided
if output is None:
output = torch.empty_like(prefix_output)
if output_lse is None:
output_lse = torch.empty_like(prefix_lse)
p_lse = prefix_lse
s_lse = suffix_lse
# inf -> -inf
p_lse[p_lse == torch.inf] = -torch.inf
s_lse[s_lse == torch.inf] = -torch.inf
# max_lse [NUM_HEADS, NUM_TOKENS]
max_lse = torch.maximum(p_lse, s_lse)
p_lse = p_lse - max_lse
s_lse = s_lse - max_lse
p_lse_exp = torch.exp(p_lse)
s_lse_exp = torch.exp(s_lse)
out_se = p_lse_exp + s_lse_exp
if output_lse is not None:
output_lse = torch.log(out_se) + max_lse
p_scale = p_lse_exp / out_se
s_scale = s_lse_exp / out_se
p_scale = p_scale.unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
s_scale = s_scale.unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
output = prefix_output * p_scale + suffix_output * s_scale
return output, output_lse
NUM_BATCH_TOKENS = [256, 512, 613, 1024, 1536]
NUM_QUERY_HEADS = [8, 16, 32]
HEAD_SIZES = [32, 48, 64, 128, 256]
DTYPES = [torch.half, torch.bfloat16]
all_case_info: list[tuple] = []
def generate_markdown_table():
global all_case_info
table_header = (
"| tokens | heads | headsize | dtype "
"| device | torch | triton | v1 | v2 | speedup(vs triton) | speedup(vs v1)|"
)
table_separator = (
"| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |"
)
def shortly_dtype(dtype: torch.dtype) -> str:
return str(dtype).removeprefix("torch.")
def shortly_device(device: str) -> str:
return device.removeprefix("NVIDIA").strip()
print(table_header)
print(table_separator)
for info in all_case_info:
(
num_tokens,
num_heads,
head_size,
dtype,
device,
time_torch,
time_triton,
time_v1,
time_v2,
) = info
dtype = shortly_dtype(dtype)
device = shortly_device(device)
improved_triton = time_triton / time_v2
improved_v1 = time_v1 / time_v2
print(
f"| {num_tokens} | {num_heads} | {head_size} "
f"| {dtype} | {device} | {time_torch:.4f}ms "
f"| {time_triton:.4f}ms "
f"| {time_v1:.4f}ms "
f"| {time_v2:.4f}ms "
f"| {improved_triton:.4f}x "
f"| {improved_v1:.4f}x |"
)
@pytest.mark.parametrize("num_tokens", NUM_BATCH_TOKENS)
@pytest.mark.parametrize("num_query_heads", NUM_QUERY_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("output_dtype", DTYPES)
@torch.inference_mode()
def test_merge_attn_states(
num_tokens: int, num_query_heads: int, head_size: int, output_dtype: torch.dtype
):
if not torch.cuda.is_available():
pytest.skip(
"Currently only support compare triton merge_attn_states "
"with custom cuda merge_attn_states kernel"
)
NUM_TOKENS = num_tokens
NUM_HEADS = num_query_heads
HEAD_SIZE = head_size
print(
f"\nNUM_TOKENS:{NUM_TOKENS}, NUM_HEADS:{NUM_HEADS}, "
f"HEAD_SIZE:{HEAD_SIZE}, DTYPE: {output_dtype}, "
f"Device: {torch.cuda.get_device_name()}"
)
# prefix_lse and suffix_lse contain inf and normal values
prefix_lse = torch.randn(NUM_TOKENS, NUM_HEADS, dtype=torch.float32, device="cuda")
suffix_lse = torch.randn(NUM_TOKENS, NUM_HEADS, dtype=torch.float32, device="cuda")
# Generate boolean masks
mask_prefix = torch.rand(NUM_TOKENS, NUM_HEADS) < 0.1
mask_suffix = torch.rand(NUM_TOKENS, NUM_HEADS) < 0.1
# Ensure that the same position is not True at the same time
combined_mask = torch.logical_and(mask_prefix, mask_suffix)
mask_prefix = torch.logical_and(mask_prefix, ~combined_mask)
mask_suffix = torch.logical_and(mask_suffix, ~combined_mask)
prefix_lse[mask_prefix] = float("inf")
suffix_lse[mask_suffix] = float("inf")
# Other input tensors (need to be initialized but
# no actual calculation needed)
output = torch.zeros(
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda"
)
output_lse = torch.zeros(
(NUM_TOKENS, NUM_HEADS), dtype=torch.float32, device="cuda"
)
prefix_output = torch.randn(
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda"
)
suffix_output = torch.randn(
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda"
)
warmup_times = 2
repeat_times = 20
def perf_kernel_fn(
output_fn: torch.Tensor,
output_lse_fn: torch.Tensor,
kernel_fn: callable,
fn_type: str = "torch",
):
# Avoid inplace inf -> -inf, we have to use prefix_lse
# and suffix_lse for other kernel.
if fn_type == "torch":
prefix_lse_ = prefix_lse.clone()
suffix_lse_ = suffix_lse.clone()
else:
prefix_lse_ = prefix_lse
suffix_lse_ = suffix_lse
if fn_type == "cuda_v1":
# merge_state v1 kernel not support float32
if output_dtype not in (torch.half, torch.bfloat16):
return 0, output_fn, output_lse_fn
total_time = 0
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
try:
for _ in range(warmup_times):
output_fn, output_lse_fn = kernel_fn(
prefix_output,
prefix_lse_,
suffix_output,
suffix_lse_,
output_fn,
output_lse_fn,
)
torch.cuda.synchronize()
for _ in range(repeat_times):
start.record()
output_fn, output_lse_fn = kernel_fn(
prefix_output,
prefix_lse_,
suffix_output,
suffix_lse_,
output_fn,
output_lse_fn,
)
end.record()
torch.cuda.synchronize()
total_time += start.elapsed_time(end)
avg_time = total_time / repeat_times
return avg_time, output_fn, output_lse_fn
except Exception as e:
return 0, output_fn, output_lse_fn
# 0. Run the Torch kernel
output_torch = output.clone()
output_lse_torch = output_lse.clone()
time_torch, output_torch, output_lse_torch = perf_kernel_fn(
output_torch, output_lse_torch, merge_state_torch, fn_type="torch"
)
# 1. Run the Triton kernel
output_ref_triton = output.clone()
output_lse_ref_triton = output_lse.clone()
time_triton, output_ref_triton, output_lse_ref_triton = perf_kernel_fn(
output_ref_triton,
output_lse_ref_triton,
merge_state_triton,
fn_type="triton",
)
# 2. Run the merge_state V1 kernel
output_v1 = output.clone()
output_lse_v1 = output_lse.clone()
time_v1, output_v1, output_lse_v1 = perf_kernel_fn(
output_v1, output_lse_v1, merge_state, fn_type="cuda_v1"
)
# 3. Run the merge_state V2 kernel
output_v2 = output.clone()
output_lse_v2 = output_lse.clone()
time_v2, output_v2, output_lse_v2 = perf_kernel_fn(
output_v2, output_lse_v2, merge_state_v2, fn_type="cuda_v2"
)
# 4. Performance compare
improved = time_triton / time_v2
print(f" Torch time: {time_torch:.6f}ms")
print(f" Triton time: {time_triton:.6f}ms")
print(f"CUDA v1 time: {time_v1:.6f}ms")
print(f"CUDA v2 time: {time_v2:.6f}ms, Performance: {improved:.5f}x")
print("-" * 100)
# 5. Correctness compare
# Liger Kernel: Efficient Triton Kernels for LLM Training
# https://arxiv.org/pdf/2410.10989, 3.3 Correctness
# use rtol = 1e-2 for bfloat16.
rtol = 1e-2 if output_dtype == torch.bfloat16 else 1e-3
def diff(a: torch.Tensor, b: torch.Tensor):
max_diff = torch.max(torch.abs(a.float() - b.float()))
return max_diff
# Use Triton output as reference because we want to replace
# the Triton kernel with custom CUDA kernel for merge attn
# states operation.
output_ref = output_ref_triton
output_lse_ref = output_lse_ref_triton
torch.testing.assert_close(
output_v2.float(), output_ref.float(), atol=1e-3, rtol=rtol
)
print("Output all match, max abs diff:")
print(f"(Triton vs Torch) : {diff(output_torch, output_ref)}")
print(f"(CUDA v2 vs Torch) : {diff(output_torch, output_v2)}")
print(f"(CUDA v2 vs Triton): {diff(output_ref, output_v2)}")
print("-" * 100)
torch.testing.assert_close(
output_lse_v2.float(), output_lse_ref.float(), atol=1e-3, rtol=rtol
)
print("Output LSE all match, max abs diff:")
print(f"(Triton vs Torch) : {diff(output_lse_torch, output_lse_ref)}")
print(f"(CUDA v2 vs Torch) : {diff(output_lse_torch, output_lse_v2)}")
print(f"(CUDA v2 vs Triton): {diff(output_lse_ref, output_lse_v2)}")
print("-" * 100)
print(
"All output values test passed! All inf values "
"are correctly replaced with -inf."
)
print("-" * 100)
device = torch.cuda.get_device_name()
all_case_info.append(
(
NUM_TOKENS,
NUM_HEADS,
HEAD_SIZE,
output_dtype,
device,
time_torch,
time_triton,
time_v1,
time_v2,
)
)
if len(all_case_info) == (
len(NUM_BATCH_TOKENS) * len(HEAD_SIZES) * len(NUM_QUERY_HEADS) * len(DTYPES)
):
generate_markdown_table()
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,250 @@
import itertools
import pytest
import torch
import triton
import triton.language as tl
from sgl_kernel import moe_align_block_size
def ceil_div(a, b):
return (a + b - 1) // b
@triton.jit
def moe_align_block_size_stage1(
topk_ids_ptr,
tokens_cnts_ptr,
num_experts: tl.constexpr,
numel: tl.constexpr,
tokens_per_thread: tl.constexpr,
):
pid = tl.program_id(0)
start_idx = pid * tokens_per_thread
off_c = (pid + 1) * num_experts
for i in range(tokens_per_thread):
if start_idx + i < numel:
idx = tl.load(topk_ids_ptr + start_idx + i)
token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
@triton.jit
def moe_align_block_size_stage2(
tokens_cnts_ptr,
num_experts: tl.constexpr,
):
pid = tl.program_id(0)
last_cnt = 0
for i in range(1, num_experts + 1):
token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
last_cnt = last_cnt + token_cnt
tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
@triton.jit
def moe_align_block_size_stage3(
total_tokens_post_pad_ptr,
tokens_cnts_ptr,
cumsum_ptr,
num_experts: tl.constexpr,
block_size: tl.constexpr,
):
last_cumsum = 0
off_cnt = num_experts * num_experts
for i in range(1, num_experts + 1):
token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
tl.store(cumsum_ptr + i, last_cumsum)
tl.store(total_tokens_post_pad_ptr, last_cumsum)
@triton.jit
def moe_align_block_size_stage4(
topk_ids_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
tokens_cnts_ptr,
cumsum_ptr,
num_experts: tl.constexpr,
block_size: tl.constexpr,
numel: tl.constexpr,
tokens_per_thread: tl.constexpr,
):
pid = tl.program_id(0)
start_idx = tl.load(cumsum_ptr + pid)
end_idx = tl.load(cumsum_ptr + pid + 1)
for i in range(start_idx, end_idx, block_size):
tl.store(expert_ids_ptr + i // block_size, pid)
start_idx = pid * tokens_per_thread
off_t = pid * num_experts
for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)):
expert_id = tl.load(topk_ids_ptr + i)
token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
tl.store(sorted_token_ids_ptr + rank_post_pad, i)
tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
def moe_align_block_size_triton(
topk_ids: torch.Tensor,
num_experts: int,
block_size: int,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor,
) -> None:
numel = topk_ids.numel()
grid = (num_experts,)
tokens_cnts = torch.zeros(
(num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
)
cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
tokens_per_thread = ceil_div(numel, num_experts)
moe_align_block_size_stage1[grid](
topk_ids,
tokens_cnts,
num_experts,
numel,
tokens_per_thread,
)
moe_align_block_size_stage2[grid](
tokens_cnts,
num_experts,
)
moe_align_block_size_stage3[(1,)](
num_tokens_post_pad,
tokens_cnts,
cumsum,
num_experts,
block_size,
)
moe_align_block_size_stage4[grid](
topk_ids,
sorted_token_ids,
expert_ids,
tokens_cnts,
cumsum,
num_experts,
block_size,
numel,
tokens_per_thread,
)
@pytest.mark.parametrize(
"block_size,num_tokens,topk,num_experts,pad_sorted_token_ids",
list(
itertools.product(
[32, 64, 128, 256], # block_size
[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], # num_tokens
[1, 2, 4, 8, 16, 32, 64], # topk
[64, 160, 256, 257, 260, 264], # num_experts
[True, False], # pad_sorted_token_ids
)
),
)
def test_moe_align_block_size_compare_implementations(
block_size, num_tokens, topk, num_experts, pad_sorted_token_ids
):
topk_ids = torch.argsort(torch.rand(num_tokens, num_experts, device="cuda"), dim=1)[
:, :topk
]
max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1)
sorted_ids_cuda = torch.empty(
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
)
if not pad_sorted_token_ids:
sorted_ids_cuda.fill_(topk_ids.numel())
max_num_m_blocks = max_num_tokens_padded // block_size
expert_ids_cuda = torch.zeros(
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
)
num_tokens_post_pad_cuda = torch.empty(
(1), dtype=torch.int32, device=topk_ids.device
)
cumsum_buffer = torch.empty(
num_experts + 2, dtype=torch.int32, device=topk_ids.device
)
sorted_ids_triton = torch.empty_like(sorted_ids_cuda)
sorted_ids_triton.fill_(topk_ids.numel())
expert_ids_triton = torch.zeros_like(expert_ids_cuda)
num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda)
moe_align_block_size(
topk_ids,
num_experts + 1,
block_size,
sorted_ids_cuda,
expert_ids_cuda,
num_tokens_post_pad_cuda,
cumsum_buffer,
pad_sorted_token_ids,
)
moe_align_block_size_triton(
topk_ids,
num_experts + 1,
block_size,
sorted_ids_triton,
expert_ids_triton,
num_tokens_post_pad_triton,
)
assert torch.allclose(expert_ids_cuda, expert_ids_triton, atol=0, rtol=0), (
f"Expert IDs mismatch for block_size={block_size}, "
f"num_tokens={num_tokens}, topk={topk}\n"
f"CUDA expert_ids: {expert_ids_cuda}\n"
f"Triton expert_ids: {expert_ids_triton}"
)
assert torch.allclose(
num_tokens_post_pad_cuda, num_tokens_post_pad_triton, atol=0, rtol=0
), (
f"Num tokens post pad mismatch for block_size={block_size}, "
f"num_tokens={num_tokens}, topk={topk}\n"
f"CUDA num_tokens_post_pad: {num_tokens_post_pad_cuda}\n"
f"Triton num_tokens_post_pad: {num_tokens_post_pad_triton}"
)
# Select an expert to check
expert_idx = expert_ids_cuda.max().item()
# Get the first and last block id where expert_ids_cuda == expert_idx
matching_indices = torch.where(expert_ids_cuda == expert_idx)[0]
block_sorted_start = matching_indices[0].item() * block_size
block_sorted_end = min(
(matching_indices[-1].item() + 1) * block_size, num_tokens_post_pad_cuda.item()
)
selected_sorted_ids_cuda = sorted_ids_cuda[
block_sorted_start:block_sorted_end
].sort()[0]
selected_sorted_ids_triton = sorted_ids_triton[
block_sorted_start:block_sorted_end
].sort()[0]
assert torch.allclose(
selected_sorted_ids_cuda,
selected_sorted_ids_triton,
atol=0,
rtol=0,
), (
f"Sorted IDs mismatch for block_size={block_size}, "
f"num_tokens={num_tokens}, topk={topk}\n"
f"CUDA sorted_ids: {selected_sorted_ids_cuda}\n"
f"Triton sorted_ids: {selected_sorted_ids_triton}"
)
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,103 @@
import pytest
import torch
from sgl_kernel import moe_fused_gate
from sglang.srt.layers.moe.topk import biased_grouped_topk
@pytest.mark.parametrize(
"seq_length",
list(range(1, 10))
+ [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536],
)
@pytest.mark.parametrize(
"params",
[
(128, 4, 2, 4),
(256, 8, 4, 8), # deepseek v3
(512, 16, 8, 16),
],
)
@pytest.mark.parametrize("num_fused_shared_experts", [0, 1, 2])
@pytest.mark.parametrize("apply_routed_scaling_factor_on_output", [False, True])
def test_moe_fused_gate_combined(
seq_length, params, num_fused_shared_experts, apply_routed_scaling_factor_on_output
):
num_experts, num_expert_group, topk_group, topk = params
dtype = torch.float32
torch.manual_seed(seq_length)
tensor = torch.rand((seq_length, num_experts), dtype=dtype, device="cuda")
scores = tensor.clone()
bias = torch.rand(num_experts, dtype=dtype, device="cuda")
topk = topk + num_fused_shared_experts
output, indices = moe_fused_gate(
tensor,
bias,
num_expert_group=num_expert_group,
topk_group=topk_group,
topk=topk,
num_fused_shared_experts=num_fused_shared_experts,
routed_scaling_factor=2.5,
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
)
ref_output, ref_indices = biased_grouped_topk(
scores,
scores,
bias,
topk=topk,
renormalize=True,
num_expert_group=num_expert_group,
topk_group=topk_group,
num_fused_shared_experts=num_fused_shared_experts,
routed_scaling_factor=2.5,
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
)
# When num_fused_shared_experts > 0, ignore the comparison of the last topk dimension
if num_fused_shared_experts > 0:
original_indices = indices.clone()
original_ref_indices = ref_indices.clone()
indices = indices[:, :-1]
ref_indices = ref_indices[:, :-1]
valid_min = num_experts
valid_max = num_experts + num_fused_shared_experts
shared_indices = original_indices[:, -1]
shared_ref_indices = original_ref_indices[:, -1]
if shared_indices is not None:
assert torch.all(
(shared_indices >= valid_min) & (shared_indices < valid_max)
), f"Shared expert indices out of range: found values outside [{valid_min}, {valid_max})"
if shared_ref_indices is not None:
assert torch.all(
(shared_ref_indices >= valid_min) & (shared_ref_indices < valid_max)
), f"Shared expert reference indices out of range: found values outside [{valid_min}, {valid_max})"
idx_check = torch.allclose(
ref_indices.sort()[0].to(torch.int32),
indices.sort()[0].to(torch.int32),
rtol=1e-04,
atol=1e-05,
)
output_check = torch.allclose(
ref_output.sort()[0].to(torch.float32),
output.sort()[0].to(torch.float32),
rtol=1e-02,
atol=1e-03,
)
assert idx_check, (
f"Indices mismatch at seq_length {seq_length}, dtype {dtype}, "
f"params {params}, num_fused_shared_experts {num_fused_shared_experts}"
)
assert output_check, (
f"Output mismatch at seq_length {seq_length}, dtype {dtype}, "
f"params {params}, num_fused_shared_experts {num_fused_shared_experts}"
)
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,139 @@
import itertools
import pytest
import torch
from sgl_kernel import topk_softmax
@pytest.mark.parametrize(
"num_tokens, num_experts, topk",
list(
itertools.product(
[1, 16, 128, 512, 1024, 2048], # num_tokens
[4, 8, 16, 32, 64, 128, 256], # num_experts
[1, 2, 4], # topk
)
),
)
def test_topk_softmax(num_tokens, num_experts, topk):
gating_output = torch.randn(
(num_tokens, num_experts), dtype=torch.float32, device="cuda"
)
topk_weights = torch.empty((num_tokens, topk), dtype=torch.float32, device="cuda")
topk_indices = torch.empty((num_tokens, topk), dtype=torch.int32, device="cuda")
topk_softmax(
topk_weights,
topk_indices,
gating_output,
)
# Native torch implementation
softmax_output = torch.softmax(gating_output, dim=-1)
topk_weights_ref, topk_indices_ref = torch.topk(softmax_output, topk, dim=-1)
# Verify the top-k weights and indices match the torch native ones
assert torch.allclose(
topk_weights_ref, topk_weights, atol=1e-3, rtol=1e-3
), f"Weights mismatch: torch={topk_indices_ref} vs SGLang={topk_weights}"
assert torch.allclose(
topk_indices_ref.int(), topk_indices, atol=0, rtol=0
), f"Indices mismatch: torch={topk_indices_ref}, SGLang={topk_indices}"
@pytest.mark.parametrize(
"num_tokens, num_experts, topk, dtype",
list(
itertools.product(
[1, 16, 128, 512, 1024, 2048], # num_tokens
[4, 8, 16, 32, 64, 128, 256], # num_experts
[1, 2, 4], # topk
[torch.float16, torch.bfloat16, torch.float32], # dtype
)
),
)
def test_topk_softmax_dtype_regression(num_tokens, num_experts, topk, dtype):
gating_output = torch.randn((num_tokens, num_experts), dtype=dtype, device="cuda")
topk_weights = torch.empty((num_tokens, topk), dtype=torch.float32, device="cuda")
topk_indices = torch.empty((num_tokens, topk), dtype=torch.int32, device="cuda")
topk_softmax(
topk_weights,
topk_indices,
gating_output,
)
topk_weights_ref = torch.empty(
(num_tokens, topk), dtype=torch.float32, device="cuda"
)
topk_indices_ref = torch.empty((num_tokens, topk), dtype=torch.int32, device="cuda")
topk_softmax(
topk_weights_ref,
topk_indices_ref,
gating_output.float(),
)
assert torch.allclose(
topk_weights_ref, topk_weights, atol=1e-3, rtol=1e-3
), f"Weights mismatch: SGLang old interface={topk_indices_ref} vs SGLang new interface={topk_weights}"
assert torch.allclose(
topk_indices_ref.int(), topk_indices, atol=0, rtol=0
), f"Indices mismatch: SGLang old interface={topk_indices_ref}, SGLang new interface={topk_indices}"
@pytest.mark.parametrize(
"num_tokens, num_experts, topk",
list(
itertools.product(
[1, 16, 128, 512, 1024, 2048], # num_tokens
[4, 8, 16, 32, 64, 128, 256], # num_experts
[1, 2, 4], # topk
)
),
)
def test_topk_softmax_renormalize(num_tokens, num_experts, topk):
gating_output = torch.randn(
(num_tokens, num_experts), dtype=torch.bfloat16, device="cuda"
)
topk_weights = torch.empty((num_tokens, topk), dtype=torch.float32, device="cuda")
topk_indices = torch.empty((num_tokens, topk), dtype=torch.int32, device="cuda")
topk_softmax(
topk_weights,
topk_indices,
gating_output,
renormalize=True,
)
topk_weights_ref = torch.empty(
(num_tokens, topk), dtype=torch.float32, device="cuda"
)
topk_indices_ref = torch.empty((num_tokens, topk), dtype=torch.int32, device="cuda")
token_expert_indices_ref = torch.empty(
(num_tokens, topk), dtype=torch.int32, device="cuda"
)
topk_softmax(
topk_weights_ref,
topk_indices_ref,
gating_output,
)
topk_weights_ref = topk_weights_ref / topk_weights_ref.sum(dim=-1, keepdim=True)
assert torch.allclose(
topk_weights_ref, topk_weights, atol=1e-3, rtol=1e-3
), f"Weights mismatch: SGLang w/o fused renormalize={topk_indices_ref} vs SGLang w/ fused renormalize={topk_weights}"
assert torch.allclose(
topk_indices_ref.int(), topk_indices, atol=0, rtol=0
), f"Indices mismatch: SGLang w/o fused renormalize={topk_indices_ref}, SGLang w/ fused renormalize={topk_indices}"
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,146 @@
import multiprocessing as mp
import os
import socket
import unittest
from enum import IntEnum
from typing import Any
import sgl_kernel.allreduce as custom_ops
import torch
import torch.distributed as dist
class MscclContextSelection(IntEnum):
MSCCL1SHOT1NODELL = 1
MSCCL1SHOT2NODELL = 2
def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes):
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
torch.cuda.set_device(device)
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
dist.init_process_group(
backend="nccl",
init_method=distributed_init_method,
rank=rank,
world_size=world_size,
)
group = dist.group.WORLD
cpu_group = torch.distributed.new_group(list(range(world_size)), backend="gloo")
if rank == 0:
unique_id = [custom_ops.mscclpp_generate_unique_id()]
else:
unique_id = [None]
dist.broadcast_object_list(
unique_id, src=0, device=torch.device("cpu"), group=cpu_group
)
unique_id = unique_id[0]
rank_to_node, rank_to_ib = list(range(world_size)), list(range(world_size))
for r in range(world_size):
rank_to_node[r] = r // 8
rank_to_ib[r] = rank % 8
MAX_BYTES = 2**20
scratch = torch.empty(
MAX_BYTES * 8, dtype=torch.bfloat16, device=torch.cuda.current_device()
)
put_buffer = torch.empty(
MAX_BYTES, dtype=torch.bfloat16, device=torch.cuda.current_device()
)
print(f"[{rank}] start mscclpp_context init")
nranks_per_node = torch.cuda.device_count()
selection = int(MscclContextSelection.MSCCL1SHOT1NODELL)
mscclpp_context = custom_ops.mscclpp_init_context(
unique_id,
rank,
world_size,
scratch,
put_buffer,
nranks_per_node,
rank_to_node,
rank_to_ib,
selection,
)
try:
test_loop = 10
for sz in test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
if sz * dtype.itemsize > MAX_BYTES:
continue
if rank == 0:
print(f"mscclpp allreduce test sz {sz}, dtype {dtype}")
for _ in range(test_loop):
inp1 = torch.randint(1, 16, (sz,), dtype=dtype, device=device)
inp1_ref = inp1.clone()
out1 = torch.empty_like(inp1)
custom_ops.mscclpp_allreduce(
mscclpp_context, inp1, out1, nthreads=512, nblocks=21
)
dist.all_reduce(inp1_ref, group=group)
torch.testing.assert_close(out1, inp1_ref)
finally:
dist.barrier(group=group)
dist.destroy_process_group(group=group)
def get_open_port() -> int:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
except OSError:
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
s.bind(("::1", 0))
return s.getsockname()[1]
def multi_process_parallel(
world_size: int, test_target: Any, target_args: tuple = ()
) -> None:
mp.set_start_method("spawn", force=True)
procs = []
distributed_init_port = get_open_port()
for i in range(world_size):
proc_args = (world_size, i, distributed_init_port) + target_args
proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}")
proc.start()
procs.append(proc)
for i in range(world_size):
procs[i].join()
assert (
procs[i].exitcode == 0
), f"Process {i} failed with exit code {procs[i].exitcode}"
class TestMSCCLAllReduce(unittest.TestCase):
test_sizes = [
512,
2560,
4096,
5120,
7680,
32768,
262144,
524288,
]
world_sizes = [8]
def test_correctness(self):
for world_size in self.world_sizes:
available_gpus = torch.cuda.device_count()
if world_size > available_gpus:
print(
f"Skipping world_size={world_size}, found {available_gpus} and now ray is not supported here"
)
continue
print(f"Running test for world_size={world_size}")
multi_process_parallel(
world_size, _run_correctness_worker, target_args=(self.test_sizes,)
)
print(f"custom allreduce tp = {world_size}: OK")
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,133 @@
# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/tests/test_norm.py
import pytest
import sgl_kernel
import torch
def llama_rms_norm(x, w, eps=1e-6):
orig_dtype = x.dtype
x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
x = x * w.float()
x = x.to(orig_dtype)
return x
def gemma_rms_norm(x, w, eps=1e-6):
orig_dtype = x.dtype
x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
x = x * (1.0 + w.float())
x = x.to(orig_dtype)
return x
def gemma_fused_add_rms_norm(x, residual, w, eps=1e-6):
orig_dtype = x.dtype
x = x + residual
residual = x
x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
x = x * (1.0 + w.float())
x = x.to(orig_dtype)
return x, residual
def fused_add_rms_norm(x, residual, weight, eps):
orig_dtype = x.dtype
x = x.to(torch.float32)
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
x = (x * weight.float()).to(orig_dtype)
return x, residual
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("specify_out", [True, False])
def test_norm(batch_size, hidden_size, dtype, specify_out):
x = torch.randn(batch_size, hidden_size).to(0).to(dtype)
w = torch.randn(hidden_size).to(0).to(dtype)
y_ref = llama_rms_norm(x, w)
if specify_out:
y = torch.empty_like(x)
sgl_kernel.rmsnorm(x, w, out=y)
else:
y = sgl_kernel.rmsnorm(x, w)
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_fused_add_rmsnorm(batch_size, hidden_size, dtype):
eps = 1e-6
x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda")
residual = torch.randn_like(x)
weight = torch.randn(hidden_size, dtype=dtype, device="cuda")
x_native, residual_native = fused_add_rms_norm(
x.clone(), residual.clone(), weight, eps
)
x_fused = x.clone()
residual_fused = residual.clone()
sgl_kernel.fused_add_rmsnorm(x_fused, residual_fused, weight, eps)
torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("specify_out", [True, False])
def test_gemma_norm(batch_size, hidden_size, dtype, specify_out):
x = torch.randn(batch_size, hidden_size).to(0).to(dtype)
w = torch.randn(hidden_size).to(0).to(dtype)
y_ref = gemma_rms_norm(x, w)
if specify_out:
y = torch.empty_like(x)
sgl_kernel.gemma_rmsnorm(x, w, out=y)
else:
y = sgl_kernel.gemma_rmsnorm(x, w)
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
@pytest.mark.parametrize("dtype", [torch.float16])
def test_gemma_fused_add_rmsnorm(batch_size, hidden_size, dtype):
eps = 1e-6
x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda")
residual = torch.randn_like(x)
weight = torch.randn(hidden_size, dtype=dtype, device="cuda")
x_native, residual_native = gemma_fused_add_rms_norm(
x.clone(), residual.clone(), weight, eps
)
x_fused = x.clone()
residual_fused = residual.clone()
sgl_kernel.gemma_fused_add_rmsnorm(x_fused, residual_fused, weight, eps)
torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,67 @@
import itertools
from typing import Optional, Tuple
import pytest
import torch
from sgl_kernel import sgl_per_tensor_quant_fp8
from sglang.srt.utils import is_hip
_is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
def sglang_scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
fp8_type_: torch.dtype = torch.float8_e4m3fn
output = torch.empty_like(input, device=input.device, dtype=fp8_type_)
is_static = True
if scale is None:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
is_static = False
sgl_per_tensor_quant_fp8(input, output, scale, is_static)
return output, scale
def torch_scaled_fp8_quant(tensor, inv_scale):
# The reference implementation that fully aligns to
# the kernel being tested.
finfo = torch.finfo(torch.float8_e4m3fn)
scale = inv_scale.reciprocal()
qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max)
qweight = qweight.to(torch.float8_e4m3fn)
return qweight
@pytest.mark.parametrize(
"num_tokens,hidden_dim",
list(itertools.product([128, 256, 512], [512, 2048, 4096])),
)
def test_per_tensor_quant_compare_implementations(
num_tokens: int,
hidden_dim: int,
):
device = torch.device("cuda")
x = torch.rand((num_tokens, hidden_dim), dtype=torch.float16, device=device)
sglang_out, sglang_scale = sglang_scaled_fp8_quant(x)
torch_out = torch_scaled_fp8_quant(x, sglang_scale)
torch.testing.assert_close(
sglang_out.float(), torch_out.float(), rtol=1e-3, atol=1e-3
)
scale = torch.rand(1, dtype=torch.float32, device=device)
sglang_out, sglang_scale = sglang_scaled_fp8_quant(x, scale)
torch_out = torch_scaled_fp8_quant(x, scale)
torch.testing.assert_close(
sglang_out.float(), torch_out.float(), rtol=1e-3, atol=1e-3
)
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,97 @@
import itertools
import pytest
import torch
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_8bit as triton_per_token_group_quant_8bit,
)
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit
from sglang.srt.layers.quantization.utils import assert_fp8_all_close
from sglang.srt.utils import is_hip
_is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
@pytest.mark.parametrize(
"num_tokens, hidden_dim, group_size, dst_dtype, flags",
list(
itertools.product(
[127, 128, 512, 1024, 4096, 8192], # num_tokens
[256, 512, 1024, 2048, 4096], # hidden_dim
[8, 16, 32, 64, 128], # group_size
# TODO test int8
[fp8_type_], # dtype
[
dict(
column_major_scales=False,
scale_tma_aligned=False,
scale_ue8m0=False,
),
dict(
column_major_scales=True,
scale_tma_aligned=False,
scale_ue8m0=False,
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=False,
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
),
],
)
),
)
def test_per_token_group_quant_with_column_major(
num_tokens,
hidden_dim,
group_size,
dst_dtype,
flags,
):
if flags["scale_ue8m0"] and ((group_size != 128) or (hidden_dim % 512 != 0)):
pytest.skip()
return
if flags["scale_ue8m0"] and not deep_gemm_wrapper.DEEPGEMM_BLACKWELL:
pytest.skip("scale_ue8m0 only supported on Blackwell")
return
x = torch.randn(num_tokens, hidden_dim, device="cuda", dtype=torch.bfloat16)
execute_kwargs = dict(
x=x,
group_size=group_size,
eps=1e-10,
dst_dtype=dst_dtype,
**flags,
)
x_q_triton, x_s_triton = triton_per_token_group_quant_8bit(**execute_kwargs)
x_q_sglang, x_s_sglang = sglang_per_token_group_quant_8bit(**execute_kwargs)
# torch.set_printoptions(profile="full")
# print(f"{x_q_triton=}")
# print(f"{x_s_triton=}")
# print(f"{x_q_sglang=}")
# print(f"{x_s_sglang=}")
# torch.set_printoptions(profile="default")
assert_fp8_all_close(x_q_triton, x_q_sglang)
torch.testing.assert_close(
x_s_triton.contiguous(),
x_s_sglang.contiguous(),
rtol=1e-3,
atol=1e-5,
msg=lambda message: message + f" {x_s_triton=} {x_s_sglang=}",
)
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,57 @@
import itertools
from typing import Optional, Tuple
import pytest
import torch
from sgl_kernel import sgl_per_token_quant_fp8
from sglang.srt.utils import is_hip
_is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
def torch_per_token_quant_fp8(tensor, inv_scale):
# The reference implementation that fully aligns to
# the kernel being tested.
finfo = torch.finfo(torch.float8_e4m3fn)
inv_scale = inv_scale.view(-1, 1)
scale = inv_scale.reciprocal()
qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max)
qweight = qweight.to(torch.float8_e4m3fn)
return qweight
def sglang_per_token_quant_fp8(
input: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
scale = torch.zeros(input.size(0), device=input.device, dtype=torch.float32)
output = torch.empty_like(input, device=input.device, dtype=fp8_type_)
sgl_per_token_quant_fp8(input, output, scale)
scale = scale.reshape(-1, 1)
return output, scale
@pytest.mark.parametrize(
"num_tokens,hidden_dim",
list(itertools.product([128, 256, 512], [512, 1368, 2048, 4096])),
)
def test_per_token_quant_compare_implementations(
num_tokens: int,
hidden_dim: int,
):
device = torch.device("cuda")
x = torch.rand((num_tokens, hidden_dim), dtype=torch.float16, device=device)
sglang_out, sglang_scale = sglang_per_token_quant_fp8(x)
torch_out = torch_per_token_quant_fp8(x, sglang_scale)
torch.testing.assert_close(
sglang_out.float(), torch_out.float(), rtol=1e-3, atol=1e-3
)
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,118 @@
import pytest
import torch
from sgl_kernel import qserve_w4a8_per_chn_gemm
# Adapted from https://github.com/mit-han-lab/omniserve/blob/main/omniserve/modeling/layers/quantized_linear/w4a8_linear.py
def convert_to_qserve_format(qweight, scale, zero):
assert qweight.min() >= 0 and qweight.max() <= 15, "Quantized weight out of range"
in_features = qweight.shape[1]
out_features = qweight.shape[0]
assert in_features % 32 == 0, "Input features must be divisible by 32"
assert out_features % 32 == 0, "Output features must be divisible by 32"
# ---- Repack the weight ---- #
# pack to M // 32, K // 32, (8, 4), ([2], 2, 2, 4)
qweight_unpack_reorder = (
qweight.reshape(
out_features // 32,
2,
2,
8,
in_features // 32,
2,
4,
4,
)
.permute(0, 4, 3, 6, 1, 5, 2, 7)
.contiguous()
)
qweight_unpack_reorder = (
qweight_unpack_reorder.permute(0, 1, 2, 3, 5, 6, 7, 4)
.contiguous()
.to(torch.int8)
)
# B_fp16_reorder = B_fp16_reorder[:, :, :, :, :, :, [3, 2, 1, 0]].contiguous()
# [16, 0, 17, 1, ...]
qweight_unpack_repacked = (
qweight_unpack_reorder[..., 1] << 4
) + qweight_unpack_reorder[..., 0]
qweight_unpack_repacked = qweight_unpack_repacked.reshape(
out_features // 32, in_features // 32, 32, 16
)
qweight_unpack_repacked = qweight_unpack_repacked.reshape(
out_features, in_features // 2
).contiguous()
# ---- Pack the scales ---- #
scale = scale.reshape(out_features).to(torch.float16).contiguous()
szero = zero.reshape(out_features).to(torch.float16).contiguous() * scale
return qweight_unpack_repacked, scale, szero
# INT4 Quantization
def asym_quantize_tensor(tensor):
tensor_min = tensor.min(dim=-1, keepdim=True)[0]
tensor_max = tensor.max(dim=-1, keepdim=True)[0]
q_min = 0
q_max = 15
tensor_scale = (tensor_max - tensor_min) / (q_max - q_min)
tensor_zero = q_min - torch.round(tensor_min / tensor_scale)
tensor_q = torch.clamp(
torch.round(tensor / tensor_scale) + tensor_zero, q_min, q_max
).to(torch.int8)
return tensor_q, tensor_scale.to(torch.float16), tensor_zero.to(torch.int8)
# INT8 Quantization
def sym_quantize_tensor(tensor):
tensor_scale = tensor.abs().max(dim=-1, keepdim=True)[0] / 127
tensor_q = torch.clamp(torch.round(tensor / tensor_scale), -128, 127).to(torch.int8)
return tensor_q, tensor_scale.to(torch.float16)
def torch_w4a8_per_chn_gemm(a, b, a_scale, b_scale, b_zero, out_dtype):
print(a.shape)
print(b.shape)
print(b_zero.shape)
o = torch.matmul(
a.to(torch.float16), (b.to(torch.float16) - b_zero.to(torch.float16)).t()
)
o = o * a_scale.view(-1, 1) * b_scale.view(1, -1)
return o.to(out_dtype)
def _test_accuracy_once(M, N, K, out_dtype, device):
# to avoid overflow, multiply 0.01
a = torch.randn((M, K), device=device, dtype=torch.float32) * 0.01
b = torch.randn((N, K), device=device, dtype=torch.float32) * 0.01
# symmetric quantize a
a_q, a_scale = sym_quantize_tensor(a)
# asymmetric quantize b
b_q, b_scale, b_zero = asym_quantize_tensor(b)
# convert to qserve format
b_q_format, b_scale_format, b_szero_format = convert_to_qserve_format(
b_q, b_scale, b_zero
)
# cal sum of every row of a
a_sum = a.sum(dim=-1, keepdim=True).to(torch.float16)
out = qserve_w4a8_per_chn_gemm(
a_q, b_q_format, b_scale_format, a_scale, b_szero_format, a_sum
)
ref_out = torch_w4a8_per_chn_gemm(a_q, b_q, a_scale, b_scale, b_zero, out_dtype)
torch.testing.assert_close(out, ref_out, rtol=1e-3, atol=1e-2)
@pytest.mark.parametrize("M", [1, 16, 32, 64, 128, 512, 1024, 4096, 8192])
@pytest.mark.parametrize("N", [128, 512, 1024, 4096, 8192, 16384])
@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384])
@pytest.mark.parametrize("out_dtype", [torch.float16])
def test_accuracy(M, N, K, out_dtype):
_test_accuracy_once(M, N, K, out_dtype, "cuda")
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,183 @@
import pytest
import torch
from sgl_kernel import qserve_w4a8_per_group_gemm
# Adapted from https://github.com/mit-han-lab/omniserve/blob/main/omniserve/modeling/layers/quantized_linear/w4a8_linear.py
def convert_to_qserve_format(qweight, chn_scale, scale_i8, zero_i8, group_size):
assert qweight.min() >= 0 and qweight.max() <= 15, "Quantized weight out of range"
in_features = qweight.shape[1]
out_features = qweight.shape[0]
assert in_features % 32 == 0, "Input features must be divisible by 32"
assert out_features % 32 == 0, "Output features must be divisible by 32"
assert group_size == 128, "Group size must be 128"
assert (
in_features % group_size == 0
), "Input features must be divisible by group_size"
# ---- Repack the weight ---- #
# pack to M // 32, K // 32, (8, 4), ([2], 2, 2, 4)
qweight_unpack_reorder = (
qweight.reshape(
out_features // 32,
2,
2,
8,
in_features // 32,
2,
4,
4,
)
.permute(0, 4, 3, 6, 1, 5, 2, 7)
.contiguous()
)
qweight_unpack_reorder = (
qweight_unpack_reorder.permute(0, 1, 2, 3, 5, 6, 7, 4)
.contiguous()
.to(torch.int8)
)
# B_fp16_reorder = B_fp16_reorder[:, :, :, :, :, :, [3, 2, 1, 0]].contiguous()
# [16, 0, 17, 1, ...]
qweigth_unpack_repacked = (
qweight_unpack_reorder[..., 1] << 4
) + qweight_unpack_reorder[..., 0]
qweigth_unpack_repacked = qweigth_unpack_repacked.reshape(
out_features // 32, in_features // 32, 32, 16
)
qweigth_unpack_repacked = qweigth_unpack_repacked.reshape(
out_features, in_features // 2
)
# ---- Pack the scales ---- #
chn_scale = chn_scale.reshape(out_features)
scale_i8 = (
scale_i8.reshape(out_features, in_features // group_size)
.transpose(0, 1)
.contiguous()
)
scale_i8 = scale_i8.reshape(in_features // group_size, out_features // 32, 32)
scale_i8 = (
scale_i8.reshape(in_features // group_size, out_features // 32, 4, 8)
.transpose(-2, -1)
.contiguous()
)
scale_i8 = scale_i8.reshape(in_features // group_size, out_features).contiguous()
# ---- Pack the zeros ---- #
zero_i8 = -zero_i8
# zero_i8 = zero_i8.int() # convert to 2-complement
zero_i8 = (
zero_i8.reshape(out_features, in_features // group_size)
.transpose(0, 1)
.contiguous()
)
zero_i8 = zero_i8.reshape(in_features // group_size, out_features // 32, 32)
# for the last dimension, organize as 0, 8, 16, 24, 1, 9, 17, 25, ... following the requirement of tensor core gemm
zero_i8 = (
zero_i8.reshape(in_features // group_size, out_features // 32, 4, 8)
.transpose(-2, -1)
.contiguous()
)
zero_i8 = (
zero_i8.reshape(in_features // group_size, out_features).contiguous() * scale_i8
)
return qweigth_unpack_repacked, chn_scale, scale_i8, zero_i8
# Progressive Group INT4 Quantization
def progressive_group_quantize_tensor(tensor, group_size):
assert group_size == 128, "Group size must be 128"
assert (
tensor.shape[-1] % group_size == 0
), "Input features must be divisible by group_size"
# Channel scale
# NOTE(HandH1998): use protective quantization range
chn_scale = tensor.abs().max(dim=-1, keepdim=True)[0] / 119
tensor_i8 = torch.clamp(torch.round(tensor / chn_scale), -119, 119)
# Group scale
tensor_i8 = tensor_i8.reshape(-1, group_size)
tensor_i8_min = tensor_i8.min(dim=-1, keepdim=True)[0]
tensor_i8_max = tensor_i8.max(dim=-1, keepdim=True)[0]
q_min = 0
q_max = 15
scale_i8 = torch.round((tensor_i8_max - tensor_i8_min) / (q_max - q_min))
zero_i8 = q_min - torch.round(tensor_i8_min / scale_i8)
tensor_q = (
torch.clamp(torch.round(tensor_i8 / scale_i8) + zero_i8, q_min, q_max)
.reshape(tensor.shape[0], -1)
.to(torch.int8)
)
return (
tensor_q,
chn_scale.to(torch.float16),
scale_i8.reshape(tensor.shape[0], -1).to(torch.int8),
zero_i8.reshape(tensor.shape[0], -1).to(torch.int8),
)
# INT8 Quantization
def sym_quantize_tensor(tensor):
tensor_scale = tensor.abs().max(dim=-1, keepdim=True)[0] / 127
tensor_q = torch.clamp(torch.round(tensor / tensor_scale), -128, 127).to(torch.int8)
return tensor_q, tensor_scale.to(torch.float16)
def torch_w4a8_per_group_gemm(
a, b, a_scale, b_chn_scale, b_scale_i8, b_zero_i8, group_size, out_dtype
):
assert group_size == 128, "Group size must be 128"
b_dq = (
b.reshape(-1, group_size).to(torch.float32)
- b_zero_i8.reshape(-1, 1).to(torch.float32)
) * b_scale_i8.reshape(-1, 1).to(torch.float32)
b_dq = b_dq.reshape(b.shape[0], b.shape[1])
o = torch.matmul(a.to(torch.float32), b_dq.t())
o = o * a_scale.view(-1, 1) * b_chn_scale.view(1, -1)
return o.to(out_dtype)
def _test_accuracy_once(M, N, K, group_size, out_dtype, device):
# to avoid overflow, multiply 0.01
a = torch.randn((M, K), device=device, dtype=torch.float32) * 0.01
b = torch.randn((N, K), device=device, dtype=torch.float32) * 0.01
# symmetric quantize a
a_q, a_scale = sym_quantize_tensor(a)
# asymmetric quantize b
b_q, b_chn_scale, b_scale_i8, b_zero_i8 = progressive_group_quantize_tensor(
b, group_size
)
# convert to qserve format
b_q_format, b_chn_scale_format, b_scale_i8_format, b_zero_i8_format = (
convert_to_qserve_format(b_q, b_chn_scale, b_scale_i8, b_zero_i8, group_size)
)
out = qserve_w4a8_per_group_gemm(
a_q,
b_q_format,
b_zero_i8_format,
b_scale_i8_format,
b_chn_scale_format,
a_scale,
)
ref_out = torch_w4a8_per_group_gemm(
a_q, b_q, a_scale, b_chn_scale, b_scale_i8, b_zero_i8, group_size, out_dtype
)
torch.testing.assert_close(out, ref_out, rtol=1e-3, atol=1e-5)
@pytest.mark.parametrize("M", [1, 16, 32, 64, 128, 512, 1024, 4096, 8192])
@pytest.mark.parametrize("N", [128, 512, 1024, 4096, 8192, 16384])
@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384])
@pytest.mark.parametrize("group_size", [128])
@pytest.mark.parametrize("out_dtype", [torch.float16])
def test_accuracy(M, N, K, group_size, out_dtype):
_test_accuracy_once(M, N, K, group_size, out_dtype, "cuda")
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,138 @@
from typing import Any, Dict, List, Optional, Tuple, Union
import pytest
import torch
from sgl_kernel import FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace
from sgl_kernel.testing.rotary_embedding import (
FlashInferRotaryEmbedding,
MHATokenToKVPool,
RotaryEmbedding,
create_inputs,
)
@pytest.mark.parametrize(
"head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device, batch_size, seq_len, num_q_heads, num_kv_heads, save_kv_cache",
[
# GPT-OSS cases
*[
(
64,
64,
4096,
8000,
True,
torch.bfloat16,
"cuda",
batch_size,
seq_len,
64,
8,
save_kv_cache,
)
for batch_size, seq_len in (
(1, 1),
(32, 1),
(128, 1),
(512, 1),
(2, 512),
(4, 4096),
)
for save_kv_cache in (False, True)
],
# Other cases
(64, 64, 32, 8000, True, torch.bfloat16, "cuda", 32, 32, 1, 1, False),
(256, 128, 4096, 10000, True, torch.bfloat16, "cuda", 2, 512, 4, 2, False),
(512, 128, 311, 10000, True, torch.bfloat16, "cuda", 3, 39, 4, 2, False),
(128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 32, 8, False),
(128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 16, 4, False),
(512, 128, 311, 10000, False, torch.bfloat16, "cuda", 3, 39, 4, 2, False),
],
)
def test_correctness(
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
device: str,
batch_size: int,
seq_len: int,
num_q_heads: int,
num_kv_heads: int,
save_kv_cache: bool,
):
config = dict(
head_size=head_size,
rotary_dim=rotary_dim,
max_position_embeddings=max_position_embeddings,
base=base,
is_neox_style=is_neox_style,
dtype=dtype,
)
rope_ref = RotaryEmbedding(**config).to(device)
rope_flashinfer = FlashInferRotaryEmbedding(**config).to(device)
inputs = create_inputs(
head_size=head_size,
batch_size=batch_size,
seq_len=seq_len,
device=device,
dtype=dtype,
num_q_heads=num_q_heads,
num_kv_heads=num_kv_heads,
)
if save_kv_cache:
pool_ref = MHATokenToKVPool(head_num=num_kv_heads, head_dim=head_size)
pool_flashinfer = MHATokenToKVPool(head_num=num_kv_heads, head_dim=head_size)
query_ref, key_ref = inputs["query"].clone(), inputs["key"].clone()
query_flashinfer, key_flashinfer = inputs["query"].clone(), inputs["key"].clone()
query_ref_out, key_ref_out = rope_ref.forward_native(
inputs["pos_ids"], query_ref, key_ref
)
if save_kv_cache:
pool_ref.set_kv_buffer(
loc=inputs["out_cache_loc"],
cache_k=key_ref_out.view(-1, num_kv_heads, head_size),
cache_v=inputs["value"].view(-1, num_kv_heads, head_size),
)
query_flashinfer_out, key_flashinfer_out = rope_flashinfer.forward_cuda(
inputs["pos_ids"],
query_flashinfer,
key_flashinfer,
fused_set_kv_buffer_arg=(
FusedSetKVBufferArg(
value=inputs["value"],
k_buffer=pool_flashinfer.k_buffer[0].view(-1, num_kv_heads * head_size),
v_buffer=pool_flashinfer.v_buffer[0].view(-1, num_kv_heads * head_size),
k_scale=None,
v_scale=None,
cache_loc=inputs["out_cache_loc"],
)
if save_kv_cache
else None
),
)
torch.testing.assert_close(
query_ref_out, query_flashinfer_out, atol=1e-2, rtol=1e-2
)
torch.testing.assert_close(key_ref_out, key_flashinfer_out, atol=1e-2, rtol=1e-2)
if save_kv_cache:
for field in ["k_buffer", "v_buffer"]:
x_ref = getattr(pool_ref, field)[0]
x_flashinfer = getattr(pool_flashinfer, field)[0]
torch.testing.assert_close(x_ref, x_flashinfer, atol=1e-2, rtol=1e-2)
nonzero_ref = x_ref != 0
nonzero_flashinfer = x_ref != 0
assert torch.all(nonzero_ref == nonzero_flashinfer)
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,185 @@
# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/93e1a2634e22355b0856246b032b285ad1d1da6b/tests/test_sampling.py
import pytest
import sgl_kernel
import torch
@pytest.mark.parametrize("batch_size", [1, 99, 989])
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
@pytest.mark.parametrize("k", [100])
@pytest.mark.parametrize("p", [0.1, 0.5])
def test_top_k_top_p_sampling_from_probs_logits_top_k_first_alignment(
batch_size, vocab_size, k, p
):
torch.manual_seed(42)
logits = torch.randn(batch_size, vocab_size, device="cuda:0") * 5
generator_logits = torch.Generator("cuda:0")
generator_probs = generator_logits.clone_state()
samples = sgl_kernel.sampling.top_k_top_p_sampling_from_logits(
logits, k, p, filter_apply_order="top_k_first", generator=generator_logits
)
samples_ref = sgl_kernel.sampling.top_k_top_p_sampling_from_probs(
torch.softmax(logits, dim=-1),
k,
p,
filter_apply_order="top_k_first",
generator=generator_probs,
)
assert torch.all(samples == samples_ref)
@pytest.mark.parametrize("batch_size", [1, 99, 989])
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
@pytest.mark.parametrize("k", [100])
@pytest.mark.parametrize("p", [0.1, 0.5])
def test_top_k_top_p_sampling_from_probs_logits_joint_alignment(
batch_size, vocab_size, k, p
):
torch.manual_seed(42)
logits = torch.randn(batch_size, vocab_size, device="cuda:0") * 5
generator_logits = torch.Generator("cuda:0")
generator_probs = generator_logits.clone_state()
samples = sgl_kernel.sampling.top_k_top_p_sampling_from_logits(
logits, k, p, filter_apply_order="joint", generator=generator_logits
)
samples_ref = sgl_kernel.sampling.top_k_top_p_sampling_from_probs(
torch.softmax(logits, dim=-1),
k,
p,
filter_apply_order="joint",
generator=generator_probs,
)
assert torch.all(samples == samples_ref)
@pytest.mark.parametrize("batch_size", [1, 99, 989])
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
@pytest.mark.parametrize("p", [0.1, 0.5])
def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p):
torch.manual_seed(42)
if p == 0.1:
k = int(vocab_size * 0.5)
elif p == 0.5:
k = int(vocab_size * 0.1)
else:
raise ValueError("p not recognized")
eps = 1e-4
pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0")
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
# top-p mask
sorted_prob, indices = torch.sort(normalized_prob, descending=False)
cdf = torch.cumsum(sorted_prob, dim=-1)
mask_top_p = torch.zeros(batch_size, vocab_size, dtype=torch.int32, device="cuda:0")
mask_top_p.scatter_add_(1, indices, (cdf > (1 - p) - eps).int())
# top-k mask
sorted_prob, _ = torch.sort(normalized_prob, descending=True)
pivot = sorted_prob[:, k - 1]
mask_top_k = (normalized_prob >= pivot.unsqueeze(-1)).int()
# overall mask
mask = torch.minimum(mask_top_p, mask_top_k)
top_p_tensor = torch.full((batch_size,), p, device="cuda:0")
top_k_tensor = torch.full((batch_size,), k, device="cuda:0")
num_trails = 1000
for _ in range(num_trails):
samples = sgl_kernel.top_k_top_p_sampling_from_probs(
normalized_prob,
top_k_tensor,
top_p_tensor,
filter_apply_order="joint",
)
assert torch.all(samples < vocab_size) and torch.all(samples >= 0)
assert torch.all(mask[torch.arange(batch_size), samples] == 1), normalized_prob[
torch.arange(batch_size), samples
]
@pytest.mark.parametrize("batch_size", [1, 99, 989])
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
@pytest.mark.parametrize("p", [0.1, 0.5, 0.9])
def test_top_p_renorm_probs(batch_size, vocab_size, p):
torch.manual_seed(42)
pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0")
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
sorted_prob, indices = torch.sort(normalized_prob, descending=False)
cdf = torch.cumsum(sorted_prob, dim=-1)
mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32, device="cuda:0")
mask.scatter_add_(1, indices, (cdf >= (1 - p)).int())
renorm_prob_ground_truth = normalized_prob.clone()
renorm_prob_ground_truth[mask == 0] = 0
renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum(
dim=-1, keepdim=True
)
renorm_prob = sgl_kernel.top_p_renorm_prob(normalized_prob, p)
torch.testing.assert_close(
renorm_prob_ground_truth,
renorm_prob,
rtol=1e-3,
atol=1e-3,
)
@pytest.mark.parametrize("batch_size", [1, 99, 989])
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
@pytest.mark.parametrize("k", [10, 100, 500])
def test_top_k_renorm_probs(batch_size, vocab_size, k):
if k > vocab_size:
pytest.skip("k should be less than vocab_size")
torch.manual_seed(42)
pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0")
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
sorted_prob, _ = torch.sort(normalized_prob, descending=True)
pivot = sorted_prob[:, k - 1]
mask = (normalized_prob >= pivot.unsqueeze(-1)).int()
renorm_prob_ground_truth = normalized_prob.clone()
renorm_prob_ground_truth[mask == 0] = 0
renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum(
dim=-1, keepdim=True
)
renorm_prob = sgl_kernel.top_k_renorm_prob(normalized_prob, k)
for i in range(batch_size):
torch.testing.assert_close(
renorm_prob_ground_truth[i],
renorm_prob[i],
rtol=1e-3,
atol=1e-3,
)
@pytest.mark.parametrize("batch_size", [1, 99, 989])
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
@pytest.mark.parametrize("p", [0.05, 0.1, 0.2, 0.7, 1])
def test_min_p_sampling(batch_size, vocab_size, p):
torch.manual_seed(42)
pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0")
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
sorted_prob, indices = torch.sort(normalized_prob, descending=False)
# scale min-p
top_probs = sorted_prob[:, -1].unsqueeze(-1)
scaled_p = p * top_probs
# min-p mask
mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32, device="cuda:0")
mask.scatter_add_(1, indices, (sorted_prob >= scaled_p).int())
min_p_tensor = torch.full((batch_size,), p, device="cuda:0")
num_trails = 1000
for _ in range(num_trails):
samples = sgl_kernel.min_p_sampling_from_probs(
normalized_prob,
min_p_tensor,
)
assert torch.all(mask[torch.arange(batch_size), samples] == 1), samples[
torch.nonzero(mask[torch.arange(batch_size), samples] == 0)
]
assert torch.all(mask[torch.arange(batch_size), samples] == 1), samples[
torch.nonzero(mask[torch.arange(batch_size), samples] == 0)
]
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,492 @@
import math
from typing import List, Optional, Tuple
import pytest
import torch
from einops import rearrange, repeat
from sgl_kernel.sparse_flash_attn import (
convert_vertical_slash_indexes,
convert_vertical_slash_indexes_mergehead,
sparse_attn_func,
)
from test_flash_attention import construct_local_mask, is_fa3_supported
def ref_attn(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
attn_bias=None,
dropout_p=0.0,
dropout_mask=None,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
softcap=0.0,
upcast=True,
reorder_ops=False,
key_leftpad=None,
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, head_dim)
k: (batch_size, seqlen_k, nheads_k, head_dim)
v: (batch_size, seqlen_k, nheads_k, head_dim)
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
dropout_p: float
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
causal: whether to apply causal masking
window_size: (int, int), left and right window size
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
output back to fp16/bf16.
reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.)
without changing the math. This is to estimate the numerical error from operation
reordering.
Output:
output: (batch_size, seqlen_q, nheads, head_dim)
lse: (batch_size, nheads, seqlen_q)
"""
if causal:
window_size = (window_size[0], 0)
dtype_og = q.dtype
if upcast:
q, k, v = q.float(), k.float(), v.float()
seqlen_q, seqlen_k = q.shape[1], k.shape[1]
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
d = q.shape[-1]
if not reorder_ops:
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
else:
scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
lse_ref = scores.logsumexp(dim=-1)
if softcap > 0:
scores = scores / softcap
scores = scores.tanh()
scores = scores * softcap
if key_padding_mask is not None:
scores.masked_fill_(
rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")
)
if window_size[0] >= 0 or window_size[1] >= 0:
local_mask = construct_local_mask(
seqlen_q,
seqlen_k,
window_size,
query_padding_mask,
key_padding_mask,
q.device,
key_leftpad=key_leftpad,
)
scores.masked_fill_(local_mask, float("-inf"))
if attn_bias is not None:
scores = scores + attn_bias
attention = torch.softmax(scores, dim=-1).to(v.dtype)
# Some rows might be completely masked out so we fill them with zero instead of NaN
if window_size[0] >= 0 or window_size[1] >= 0:
attention = attention.masked_fill(
torch.all(local_mask, dim=-1, keepdim=True), 0.0
)
# We want to mask here so that the attention matrix doesn't have any NaNs
# Otherwise we'll get NaN in dV
if query_padding_mask is not None:
attention = attention.masked_fill(
rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0
)
dropout_scaling = 1.0 / (1 - dropout_p)
# attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
# output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
if dropout_mask is not None:
attention_drop = attention.masked_fill(~dropout_mask, 0.0)
else:
attention_drop = attention
output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
if query_padding_mask is not None:
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
return output.to(dtype=dtype_og), lse_ref
def ref_paged_attn(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
query_lens: List[int],
kv_lens: List[int],
block_tables: torch.Tensor,
scale: float,
sliding_window: Optional[int] = None,
soft_cap: Optional[float] = None,
) -> torch.Tensor:
num_seqs = len(query_lens)
block_tables = block_tables.cpu().numpy()
_, block_size, num_kv_heads, head_size = key_cache.shape
outputs: List[torch.Tensor] = []
start_idx = 0
for i in range(num_seqs):
query_len = query_lens[i]
kv_len = kv_lens[i]
# clone to avoid clobbering the query tensor
q = query[start_idx : start_idx + query_len].clone()
q *= scale
num_kv_blocks = (kv_len + block_size - 1) // block_size
block_indices = block_tables[i, :num_kv_blocks]
k = key_cache[block_indices].view(-1, num_kv_heads, head_size)
k = k[:kv_len]
v = value_cache[block_indices].view(-1, num_kv_heads, head_size)
v = v[:kv_len]
if q.shape[1] != k.shape[1]:
k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
attn = torch.einsum("qhd,khd->hqk", q, k).float()
empty_mask = torch.ones(query_len, kv_len)
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
if sliding_window is not None:
sliding_window_mask = (
torch.triu(
empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1
)
.bool()
.logical_not()
)
mask |= sliding_window_mask
if soft_cap is not None:
attn = soft_cap * torch.tanh(attn / soft_cap)
attn.masked_fill_(mask, float("-inf"))
attn = torch.softmax(attn, dim=-1).to(v.dtype)
out = torch.einsum("hqk,khd->qhd", attn, v)
outputs.append(out)
start_idx += query_len
return torch.cat(outputs, dim=0)
@pytest.mark.skipif(
not is_fa3_supported(),
reason="flash_attn at sgl-kernel is only supported on sm90 or sm80",
)
@pytest.mark.parametrize("batch_size", [1, 2])
@pytest.mark.parametrize(
"seq_lens",
[
(1, 1),
(1, 1024),
(1, 2048),
(1023, 2049),
(1023, 1023),
(32, 32),
(65, 65),
(129, 129),
],
)
@pytest.mark.parametrize("num_heads", [1, 2, 4])
@pytest.mark.parametrize("head_size", [128])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("NNZ_S", [0, 1, 2, 3, 7, 15, 32])
@torch.inference_mode()
def test_sparse_attention(
batch_size,
seq_lens,
num_heads,
head_size,
dtype,
NNZ_S,
) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
block_size_M = 64
block_size_N = 64
seqlen_q, seqlen_k = seq_lens
q = torch.randn(
batch_size, seqlen_q, num_heads, head_size, dtype=dtype, requires_grad=False
)
k = torch.randn(
batch_size, seqlen_k, num_heads, head_size, dtype=dtype, requires_grad=False
)
v = torch.randn(
batch_size, seqlen_k, num_heads, head_size, dtype=dtype, requires_grad=False
)
NUM_ROWS = (seqlen_q + block_size_M - 1) // block_size_M
if NNZ_S * block_size_N > seqlen_k:
return
NNZ_V = seqlen_k - NNZ_S * block_size_N
block_count = torch.tensor(
[NNZ_S] * batch_size * NUM_ROWS * num_heads, dtype=torch.int32
).reshape(batch_size, num_heads, NUM_ROWS)
column_count = torch.tensor(
[NNZ_V] * batch_size * NUM_ROWS * num_heads, dtype=torch.int32
).reshape(batch_size, num_heads, NUM_ROWS)
block_offset = torch.tensor(
[[i * block_size_N for i in range(NNZ_S)]] * batch_size * NUM_ROWS * num_heads,
dtype=torch.int32,
).reshape(batch_size, num_heads, NUM_ROWS, NNZ_S)
column_index = torch.tensor(
[[NNZ_S * block_size_N + i for i in range(NNZ_V)]]
* batch_size
* NUM_ROWS
* num_heads,
dtype=torch.int32,
).reshape(batch_size, num_heads, NUM_ROWS, NNZ_V)
out, lse = sparse_attn_func(
q,
k,
v,
block_count,
block_offset,
column_count,
column_index,
return_softmax_lse=True,
)
ref_out, ref_lse = ref_attn(q, k, v)
torch.testing.assert_close(
out, ref_out, atol=2e-2, rtol=1e-2
), f"{torch.max(torch.abs(out - ref_out))}"
torch.testing.assert_close(
lse, ref_lse, atol=2e-2, rtol=1e-2
), f"{torch.max(torch.abs(lse - ref_lse))}"
# sparse attention utils
# origin
@pytest.mark.skipif(
not is_fa3_supported(),
reason="flash_attn at sgl-kernel is only supported on sm90 or sm80",
)
@pytest.mark.parametrize("causal", [True, False])
def test_convert_vertical_slash_indexes(causal):
# Prepare small, hand-checkable inputs
q_seqlens = torch.tensor([4], dtype=torch.int32, device="cuda") # [BATCH]
kv_seqlens = torch.tensor([4], dtype=torch.int32, device="cuda")
vertical_indexes = torch.tensor(
[[[1, 3]]], dtype=torch.int32, device="cuda"
) # [BATCH, N_HEADS, NNZ_V]
slash_indexes = torch.tensor(
[[[2]]], dtype=torch.int32, device="cuda"
) # [BATCH, N_HEADS, NNZ_S]
context_size = 4
block_size_M = 2
block_size_N = 2
# Call your CUDA kernel wrapper
block_count, block_offset, column_count, column_index = (
convert_vertical_slash_indexes(
q_seqlens,
kv_seqlens,
vertical_indexes,
slash_indexes,
context_size,
block_size_M,
block_size_N,
causal=causal,
)
)
# Manually create expected outputs for this input
# There are 2 rows (blocks): row0 (tokens 0-1), row1 (tokens 2-3)
# Fill these expected tensors based on your CUDA kernel's logic
# For demonstration, we assume:
# - block_count: how many slash indices fall into each block
# - block_offset: the value of those indices
# - column_count: number of valid vertical indices per block
# - column_index: the actual vertical indices
expected_column_index = torch.tensor(
[[[[0, 0], [0, 0]]]], dtype=torch.int32, device="cuda"
)
# If causal=False, update these tensors according to expected behavior
if not causal:
# Update these values if your kernel produces different output in non-causal mode
expected_column_index = torch.tensor(
[[[[1, 0], [1, 3]]]], dtype=torch.int32, device="cuda"
)
# Assert that outputs match expectations
assert torch.equal(column_index, expected_column_index)
# mergehead
@pytest.mark.skipif(
not is_fa3_supported(),
reason="flash_attn at sgl-kernel is only supported on sm90 or sm80",
)
@pytest.mark.parametrize("causal", [True, False])
def test_convert_vertical_slash_indexes_mergehead(causal):
# Prepare small, hand-checkable inputs for mergehead version
q_seqlens = torch.tensor([4], dtype=torch.int32, device="cuda")
kv_seqlens = torch.tensor([4], dtype=torch.int32, device="cuda")
vertical_indexes = torch.tensor(
[
[
[1, 3], # head 0
[2, 0], # head 1
]
],
dtype=torch.int32,
device="cuda",
) # [BATCH, N_HEADS, NNZ_V]
slash_indexes = torch.tensor(
[
[
[2, 0], # head 0
[1, 3], # head 1
]
],
dtype=torch.int32,
device="cuda",
) # [BATCH, N_HEADS, NNZ_S]
vertical_indices_count = torch.tensor([2, 1], dtype=torch.int32, device="cuda")
slash_indices_count = torch.tensor([1, 2], dtype=torch.int32, device="cuda")
context_size = 4
block_size_M = 2
block_size_N = 2
# Call your CUDA kernel wrapper
block_count, block_offset, column_count, column_index = (
convert_vertical_slash_indexes_mergehead(
q_seqlens,
kv_seqlens,
vertical_indexes,
slash_indexes,
vertical_indices_count,
slash_indices_count,
context_size,
block_size_M,
block_size_N,
causal=causal,
)
)
# Manually create expected outputs for this input
# For demonstration, assume:
# - batch=1, head=2, num_rows=2, nnz_v=2, nnz_s=2
# Fill these expected tensors according to your kernel's behavior
expected_column_index = torch.tensor(
[[[[1, 0], [1, 3]], [[-1079459945, -1077788999], [-1080050043, -1104625879]]]],
dtype=torch.int32,
device="cuda",
)
if not causal:
# If non-causal mode output is different, update these values
expected_column_index = torch.tensor(
[[[[1, 0], [1, 3]], [[2, -1077788999], [2, -1104625879]]]],
dtype=torch.int32,
device="cuda",
)
# Assert that outputs match expectations
assert torch.equal(column_index, expected_column_index)
# skip cause use fa2 for test
# @pytest.mark.parametrize("seq_lens", [[(1024, 1328)],
# [(1024, 1328), (1, 2048)],
# [(1025, 1328), (2, 2048)],
# [(1025, 2049), (2, 1281)],
# ])
# @pytest.mark.parametrize("head_size", [128])
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
# @torch.inference_mode()
# def test_sparse_attention_varlen(
# seq_lens,
# head_size,
# dtype,
# ) -> None:
# torch.set_default_device("cuda")
# torch.cuda.manual_seed_all(0)
# block_size_M = 64
# block_size_N = 64
# num_seqs = len(seq_lens)
# query_lens = [x[0] for x in seq_lens]
# kv_lens = [x[1] for x in seq_lens]
# num_heads = 1
# query = torch.randn(sum(query_lens),
# num_heads,
# head_size,
# dtype=dtype)
# key = torch.randn(sum(kv_lens),
# num_heads,
# head_size,
# dtype=dtype)
# value = torch.randn_like(key)
# cu_query_lens = torch.tensor([0] + query_lens,
# dtype=torch.int32).cumsum(dim=0,
# dtype=torch.int32)
# cu_kv_lens = torch.tensor([0] + kv_lens,
# dtype=torch.int32).cumsum(dim=0,
# dtype=torch.int32)
# max_query_len = max(query_lens)
# max_kv_len = max(kv_lens)
# NUM_ROWS = (max_query_len + block_size_M - 1) // block_size_M
# NNZ_S = 20
# NNZ_V = 2048
# batch_size = len(query_lens)
# block_counts = []
# column_counts = []
# block_offsets = []
# column_indices = []
# for b in range(batch_size):
# block_counts.append(torch.tensor([NNZ_S] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS))
# columns = kv_lens[b] - NNZ_S * block_size_N
# column_counts.append(torch.tensor([columns] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS))
# block_offsets.append(torch.tensor([[i * block_size_N for i in range(NNZ_S)]] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS, NNZ_S))
# column_indices.append(torch.tensor([[NNZ_S * block_size_N + i for i in range(NNZ_V)]] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS, NNZ_V))
# block_count = torch.concat(block_counts).reshape(batch_size, num_heads, NUM_ROWS)
# column_count = torch.concat(column_counts).reshape(batch_size, num_heads, NUM_ROWS)
# block_offset = torch.concat(block_offsets).reshape(batch_size, num_heads, NUM_ROWS, NNZ_S)
# column_index = torch.concat(column_indices).reshape(batch_size, num_heads, NUM_ROWS, NNZ_V)
# out, lse = sparse_attn_varlen_func(
# query,
# key,
# value,
# block_count,
# block_offset,
# column_count,
# column_index,
# cu_seqlens_q=cu_query_lens,
# cu_seqlens_k=cu_kv_lens,
# max_seqlen_q=max_query_len,
# max_seqlen_k=max_kv_len,
# return_softmax_lse=True,
# )
# max_num_blocks_per_seq = (max_kv_len + 2048 - 1) // 2048
# block_tables = torch.randint(0,
# 2048,
# (len(query_lens), max_num_blocks_per_seq),
# dtype=torch.int32)
# scale = head_size**-0.5
# ref_out, ref_lse, _ = ref_paged_attn(
# query,
# key,
# value,
# query_lens=query_lens,
# kv_lens=kv_lens,
# block_tables=block_tables,
# scale=scale
# )
# torch.testing.assert_close(out, ref_out, atol=2e-2, rtol=1e-2), \
# f"{torch.max(torch.abs(out - ref_out))}"
# torch.testing.assert_close(lse, ref_lse, atol=2e-2, rtol=1e-2), \
# f"{torch.max(torch.abs(lse - ref_lse))}"
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,9 @@
import torch
def is_sm10x():
return torch.cuda.get_device_capability() >= (10, 0)
def is_hopper():
return torch.cuda.get_device_capability() == (9, 0)