sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct

This commit is contained in:
maxiao1
2025-09-13 17:00:20 +08:00
commit 118f1fc726
2037 changed files with 515371 additions and 0 deletions

View File

@@ -0,0 +1,55 @@
import itertools
import unittest
import sgl_kernel
import torch
import torch.nn.functional as F
from utils import GeluAndMul, SiluAndMul, precision
from sglang.test.test_utils import CustomTestCase
torch.manual_seed(1234)
class TestActivation(CustomTestCase):
M = [128, 129, 257]
N = [22016, 22018]
dtype = [torch.float16, torch.bfloat16]
def _silu_and_mul_test(self, m, n, dtype):
x = torch.randn([m, n], dtype=dtype)
out = torch.ops.sgl_kernel.silu_and_mul_cpu(x)
ref_out = SiluAndMul(x)
atol = rtol = precision[ref_out.dtype]
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
def _gelu_and_mul_test(self, m, n, dtype):
x = torch.randn([m, n], dtype=dtype)
out = torch.ops.sgl_kernel.gelu_and_mul_cpu(x)
ref_out = GeluAndMul(x, approximate="none")
atol = rtol = precision[ref_out.dtype]
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
def _gelu_tanh_and_mul_test(self, m, n, dtype):
x = torch.randn([m, n], dtype=dtype)
out = torch.ops.sgl_kernel.gelu_tanh_and_mul_cpu(x)
ref_out = GeluAndMul(x, approximate="tanh")
atol = rtol = precision[ref_out.dtype]
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
def test_activation(self):
for params in itertools.product(self.M, self.N, self.dtype):
with self.subTest(m=params[0], n=params[1], dtype=params[2]):
self._silu_and_mul_test(*params)
self._gelu_and_mul_test(*params)
self._gelu_tanh_and_mul_test(*params)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,28 @@
import re
import unittest
import sgl_kernel
import torch
kernel = torch.ops.sgl_kernel
from sglang.test.test_utils import CustomTestCase
class TestGemm(CustomTestCase):
def test_binding(self):
start_id = 1
n_cpu = 6
expected_cores = list(map(str, range(start_id, start_id + n_cpu)))
cpu_ids = ",".join(expected_cores)
output = kernel.init_cpu_threads_env(cpu_ids)
bindings = re.findall(r"OMP tid: \d+, core (\d+)", output)
self.assertEqual(len(bindings), n_cpu)
self.assertEqual(bindings, expected_cores)
if __name__ == "__main__":
unittest.main()

170
test/srt/cpu/test_decode.py Normal file
View File

@@ -0,0 +1,170 @@
import unittest
import sgl_kernel
import torch
from torch.nn.functional import scaled_dot_product_attention
from sglang.test.test_utils import CustomTestCase
torch.manual_seed(1234)
class TestDecodeAttention(CustomTestCase):
def _run_sdpa_forward_decode(
self,
query: torch.Tensor,
output: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
req_to_token: torch.Tensor,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
scaling=None,
enable_gqa=False,
causal=False,
):
# [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
query = query.movedim(0, query.dim() - 2)
start_q, start_kv = 0, 0
for seq_idx in range(seq_lens.shape[0]):
seq_len_q = 1
seq_len_kv = seq_lens[seq_idx]
end_q = start_q + seq_len_q
end_kv = start_kv + seq_len_kv
per_req_query = query[:, start_q:end_q, :]
# get key and value from cache. per_req_tokens contains the kv cache
# index for each token in the sequence.
req_pool_idx = req_pool_indices[seq_idx]
per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
per_req_out = (
scaled_dot_product_attention(
per_req_query.unsqueeze(0),
per_req_key.unsqueeze(0),
per_req_value.unsqueeze(0),
enable_gqa=enable_gqa,
scale=scaling,
is_causal=causal,
)
.squeeze(0)
.movedim(query.dim() - 2, 0)
)
output[start_q:end_q, :, :] = per_req_out
start_q, start_kv = end_q, end_kv
return output
def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V, device):
dtype = torch.bfloat16
# This represents the number of tokens already in the sequence
seq_len = 1024
total_tokens = B * seq_len
sm_scale = 1.0 / (D**0.5)
logit_cap = 0.0
num_kv_splits = 8
enable_gqa = H_Q != H_KV
# q represents the new token being generated, one per batch
q = torch.randn(B, H_Q, D, dtype=dtype, device=device)
# k_buffer and v_buffer represent all previous tokens
k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device=device)
v_buffer = torch.randn(total_tokens, H_KV, D_V, dtype=dtype, device=device)
key = torch.randn(B, H_KV, D, dtype=dtype)
value = torch.randn(B, H_KV, D_V, dtype=dtype)
loc = torch.randint(0, 10, (B,)).to(torch.int64)
# set kv cache
k_buffer[loc] = key
v_buffer[loc] = value
# o will have the same shape as q
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device=device)
o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype, device=device)
req_to_token = (
torch.arange(total_tokens, device=device)
.reshape(B, seq_len)
.to(torch.int32)
)
b_req_idx = torch.arange(B, device=device).to(torch.int64)
b_seq_len = torch.full((B,), seq_len, device=device).to(torch.int64)
attn_logits = torch.empty(
(B, H_Q, num_kv_splits, D_V + 1),
dtype=torch.float32,
device=device,
)
# k_buffer, v_buffer, query, key and value supports non-contiguous tensors
k_buffer = k_buffer.transpose(0, 1).contiguous().transpose(0, 1)
v_buffer = v_buffer.transpose(0, 1).contiguous().transpose(0, 1)
q = q.transpose(0, 1).contiguous().transpose(0, 1)
key = key.transpose(0, 1).contiguous().transpose(0, 1)
value = value.transpose(0, 1).contiguous().transpose(0, 1)
torch.ops.sgl_kernel.decode_attention_cpu(
q,
k_buffer,
v_buffer,
o,
key,
value,
loc,
attn_logits,
req_to_token,
b_req_idx,
b_seq_len,
sm_scale,
logit_cap,
)
self._run_sdpa_forward_decode(
q,
o_grouped,
k_buffer,
v_buffer,
req_to_token,
b_req_idx,
b_seq_len,
scaling=sm_scale,
enable_gqa=enable_gqa,
)
cos_sim = torch.nn.functional.cosine_similarity(
o.flatten(), o_grouped.flatten(), dim=0
)
self.assertGreater(cos_sim.item(), 0.99)
torch.testing.assert_close(o, o_grouped, atol=3e-2, rtol=1e-6)
def _test_grouped_decode_attention(self, device="cuda"):
configs = [
(2, 16, 16, 64, 64),
(2, 16, 1, 16, 16),
(2, 32, 8, 33, 55),
(2, 16, 1, 64, 64),
(2, 64, 1, 13, 13),
(2, 128, 1, 80, 80),
(2, 128, 2, 512, 512),
(1, 16, 1, 576, 512),
(1, 16, 16, 576, 512),
(1, 22, 1, 576, 512),
(1, 40, 8, 128, 128),
]
for B, H_Q, H_KV, D, D_V in configs:
self._test_grouped_decode_attention_once(
B, H_Q, H_KV, D, D_V, device=device
)
def test_grouped_decode_attention(self):
self._test_grouped_decode_attention("cpu")
if __name__ == "__main__":
unittest.main()

190
test/srt/cpu/test_extend.py Normal file
View File

@@ -0,0 +1,190 @@
import unittest
import sgl_kernel
import torch
from torch.nn.functional import scaled_dot_product_attention
from sglang.test.test_utils import CustomTestCase
torch.manual_seed(1234)
class TestExtendAttention(CustomTestCase):
def _run_sdpa_forward_extend(
self,
query: torch.Tensor,
output: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
req_to_token: torch.Tensor,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
extend_prefix_lens: torch.Tensor,
extend_seq_lens: torch.Tensor,
scaling=None,
enable_gqa=False,
causal=False,
):
assert seq_lens.shape[0] == extend_prefix_lens.shape[0]
assert seq_lens.shape[0] == extend_seq_lens.shape[0]
# [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
query = query.movedim(0, query.dim() - 2)
start_q, start_kv = 0, 0
for seq_idx in range(seq_lens.shape[0]):
extend_seq_len_q = extend_seq_lens[seq_idx]
prefill_seq_len_q = extend_prefix_lens[seq_idx]
seq_len_kv = seq_lens[seq_idx]
end_q = start_q + extend_seq_len_q
end_kv = start_kv + seq_len_kv
per_req_query = query[:, start_q:end_q, :]
per_req_query_redudant = torch.empty(
(per_req_query.shape[0], seq_len_kv, per_req_query.shape[2]),
dtype=per_req_query.dtype,
device=per_req_query.device,
)
per_req_query_redudant[:, prefill_seq_len_q:, :] = per_req_query
# get key and value from cache. per_req_tokens contains the kv cache
# index for each token in the sequence.
req_pool_idx = req_pool_indices[seq_idx]
per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
per_req_out_redudant = (
scaled_dot_product_attention(
per_req_query_redudant.unsqueeze(0),
per_req_key.unsqueeze(0),
per_req_value.unsqueeze(0),
enable_gqa=enable_gqa,
scale=scaling,
is_causal=causal,
)
.squeeze(0)
.movedim(query.dim() - 2, 0)
)
output[start_q:end_q, :, :] = per_req_out_redudant[prefill_seq_len_q:, :, :]
start_q, start_kv = end_q, end_kv
return output
def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D, DV, mla=False):
dtype = torch.bfloat16
b_seq_len_prefix = torch.randint(1, N_CTX // 2, (B,), dtype=torch.int32)
if mla:
b_seq_len_prefix.zero_()
b_seq_len_extend = torch.randint(1, N_CTX // 2, (B,), dtype=torch.int32)
b_seq_len = b_seq_len_prefix + b_seq_len_extend
max_len_in_batch = torch.max(b_seq_len, 0)[0].item()
b_req_idx = torch.arange(B, dtype=torch.int32)
req_to_tokens = torch.empty((B, max_len_in_batch), dtype=torch.int32)
b_start_loc = torch.zeros((B,), dtype=torch.int32)
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
b_start_loc_extend = torch.zeros((B,), dtype=torch.int32)
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
for i in range(B):
req_to_tokens[i, : b_seq_len[i]] = torch.arange(
b_start_loc[i], b_start_loc[i] + b_seq_len[i]
)
total_token_num = torch.sum(b_seq_len).item()
extend_token_num = torch.sum(b_seq_len_extend).item()
H_BUF = 1 if mla else H_KV
k_buffer = torch.randn((total_token_num, H_BUF, D), dtype=dtype)
v_buffer = torch.randn((total_token_num, H_BUF, DV), dtype=dtype)
k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype)
v_extend = torch.empty((extend_token_num, H_KV, DV), dtype=dtype)
q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype)
for i in range(B):
extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
extend_start = b_start_loc_extend[i]
extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
k_extend[extend_start:extend_end] = k_buffer[
extend_start_in_buffer:extend_end_in_buffer
]
v_extend[extend_start:extend_end] = v_buffer[
extend_start_in_buffer:extend_end_in_buffer
]
q_extend[extend_start:extend_end] = torch.randn(
(b_seq_len_extend[i], H_Q, D), dtype=dtype
)
# q_extend, k_extend, v_extend, k_buffer and v_buffer supports non-contiguous tensors
q_extend = q_extend.transpose(0, 1).contiguous().transpose(0, 1)
k_extend = k_extend.transpose(0, 1).contiguous().transpose(0, 1)
v_extend = v_extend.transpose(0, 1).contiguous().transpose(0, 1)
k_buffer = k_buffer.transpose(0, 1).contiguous().transpose(0, 1)
v_buffer = v_buffer.transpose(0, 1).contiguous().transpose(0, 1)
b_seq_len_extend = b_seq_len - b_seq_len_prefix
b_start_loc_extend = torch.zeros_like(b_seq_len)
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
sm_scale = 1.0 / (D**0.5)
logit_cap = 0.0
# handle index type
b_req_idx = b_req_idx.to(torch.int64)
b_seq_len = b_seq_len.to(torch.int64)
enable_gqa = H_Q != H_KV
o_ref = torch.empty((extend_token_num, H_Q, DV), dtype=dtype)
self._run_sdpa_forward_extend(
q_extend,
o_ref,
k_buffer,
v_buffer,
req_to_tokens,
b_req_idx,
b_seq_len,
b_seq_len_prefix,
b_seq_len_extend,
scaling=sm_scale,
enable_gqa=enable_gqa,
causal=True,
)
o_extend = torch.empty((extend_token_num, H_Q, DV), dtype=dtype)
torch.ops.sgl_kernel.extend_attention_cpu(
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
req_to_tokens,
b_req_idx,
b_seq_len,
b_seq_len_extend,
b_start_loc_extend,
max_len_extend,
sm_scale,
logit_cap,
)
torch.testing.assert_close(o_ref, o_extend, atol=1e-2, rtol=1e-2)
def test_extend_attention(self):
for is_mla in [True, False]:
self._test_extend_attention_once(1, 123, 1, 1, 128, 96, is_mla)
self._test_extend_attention_once(1, 123, 16, 1, 128, 96, is_mla)
self._test_extend_attention_once(4, 1230, 16, 4, 128, 96, is_mla)
if __name__ == "__main__":
unittest.main()

189
test/srt/cpu/test_gemm.py Normal file
View File

@@ -0,0 +1,189 @@
import itertools
import unittest
# TODO: use interface in cpu.py
import sgl_kernel
import torch
import torch.nn as nn
from utils import (
convert_weight,
native_w8a8_per_token_matmul,
per_token_quant_int8,
precision,
)
from sglang.test.test_utils import CustomTestCase
torch.manual_seed(1234)
class Mod(nn.Module):
def __init__(self, input_channel, output_channel, has_bias):
super(Mod, self).__init__()
self.linear = torch.nn.Linear(input_channel, output_channel, has_bias)
def forward(self, x):
return self.linear(x)
class TestGemm(CustomTestCase):
M = [1, 101]
N = [16, 32 * 13]
K = [32 * 16]
has_bias = [False, True]
M_int8 = [2, 128]
N_int8 = [32 * 12]
K_int8 = [32 * 17]
M_fp8 = [1, 11]
N_fp8 = [128, 224]
K_fp8 = [512, 576]
def _bf16_gemm(self, M, N, K, has_bias):
mat1 = torch.randn(M, K, dtype=torch.bfloat16)
mat2 = torch.randn(N, K, dtype=torch.bfloat16)
ref = torch.matmul(mat1.float(), mat2.float().t())
if has_bias:
bias = torch.randn(N, dtype=torch.float32)
ref.add_(bias.bfloat16())
ref = ref.bfloat16()
out = torch.ops.sgl_kernel.weight_packed_linear(
mat1, mat2, bias if has_bias else None, False
)
packed_mat2 = torch.ops.sgl_kernel.convert_weight_packed(mat2)
out2 = torch.ops.sgl_kernel.weight_packed_linear(
mat1, packed_mat2, bias if has_bias else None, True
)
atol = rtol = precision[ref.dtype]
torch.testing.assert_close(ref, out, atol=atol, rtol=rtol)
torch.testing.assert_close(ref, out2, atol=atol, rtol=rtol)
def test_bf16_gemm(self):
for params in itertools.product(
self.M,
self.N,
self.K,
self.has_bias,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
has_bias=params[3],
):
self._bf16_gemm(*params)
def _int8_gemm(self, M, N, K, has_bias):
dtype = torch.bfloat16
A = torch.randn((M, K), dtype=dtype) / 10
Aq, As = per_token_quant_int8(A)
factor_for_scale = 1e-2
int8_max = 127
int8_min = -128
B = (torch.rand((N, K), dtype=torch.float32) - 0.5) * 2
Bq = (B * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)
Bs = torch.rand(N) * factor_for_scale
bias = torch.randn(N) if has_bias else None
ref_out = native_w8a8_per_token_matmul(Aq, Bq, As, Bs, bias, dtype)
atol = rtol = precision[ref_out.dtype]
Aq2, As2 = torch.ops.sgl_kernel.per_token_quant_int8_cpu(A)
out = torch.ops.sgl_kernel.int8_scaled_mm_cpu(
Aq2, Bq, As2, Bs, bias if has_bias else None, torch.bfloat16, False
)
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
# test the fused version
fused_out = torch.ops.sgl_kernel.int8_scaled_mm_with_quant(
A, Bq, Bs, bias if has_bias else None, torch.bfloat16, False
)
torch.testing.assert_close(ref_out, fused_out, atol=atol, rtol=rtol)
def test_int8_gemm(self):
for params in itertools.product(
self.M_int8,
self.N_int8,
self.K_int8,
self.has_bias,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
has_bias=params[3],
):
self._int8_gemm(*params)
def _fp8_gemm(self, M, N, K, has_bias):
prepack = True
chunk = False
scale_block_size_N = 64
scale_block_size_K = 128
assert scale_block_size_N <= N
assert scale_block_size_K <= K
A_dtype = torch.bfloat16
model = Mod(K, N, has_bias).eval()
if chunk:
data = torch.randn(M, K + 6, dtype=A_dtype).narrow(1, 0, K)
else:
data = torch.randn(M, K, dtype=A_dtype)
weight = model.linear.weight # (N, K)
if has_bias:
bias = model.linear.bias
fp8_weight, scales, dq_weight = convert_weight(
weight, [scale_block_size_N, scale_block_size_K], A_dtype
)
if has_bias:
ref = torch.matmul(data.to(A_dtype), dq_weight.T) + bias.to(A_dtype)
else:
ref = torch.matmul(data.to(A_dtype), dq_weight.T)
if prepack:
fp8_weight = torch.ops.sgl_kernel.convert_weight_packed(fp8_weight)
opt = torch.ops.sgl_kernel.fp8_scaled_mm_cpu(
data,
fp8_weight,
scales,
[scale_block_size_N, scale_block_size_K],
bias if has_bias else None,
data.dtype,
prepack,
)
atol = rtol = precision[ref.dtype]
torch.testing.assert_close(ref, opt, atol=atol, rtol=rtol)
def test_fp8_gemm(self):
for params in itertools.product(
self.M_fp8,
self.N_fp8,
self.K_fp8,
self.has_bias,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
has_bias=params[3],
):
self._fp8_gemm(*params)
if __name__ == "__main__":
unittest.main()

157
test/srt/cpu/test_mla.py Normal file
View File

@@ -0,0 +1,157 @@
import itertools
import unittest
import sgl_kernel
import torch
from torch.nn.functional import scaled_dot_product_attention
from utils import precision
from sglang.test.test_utils import CustomTestCase
torch.manual_seed(1234)
class TestMLA(CustomTestCase):
def _run_sdpa_forward_decode(
self,
query: torch.Tensor,
output: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
key: torch.Tensor,
loc: torch.Tensor,
req_to_token: torch.Tensor,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
scaling=None,
enable_gqa=False,
causal=False,
):
# set kv cache
k_cache[loc] = key
# [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
query = query.movedim(0, query.dim() - 2)
start_q, start_kv = 0, 0
for seq_idx in range(seq_lens.shape[0]):
seq_len_q = 1
seq_len_kv = seq_lens[seq_idx]
end_q = start_q + seq_len_q
end_kv = start_kv + seq_len_kv
per_req_query = query[:, start_q:end_q, :]
# get key and value from cache. per_req_tokens contains the kv cache
# index for each token in the sequence.
req_pool_idx = req_pool_indices[seq_idx]
per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
per_req_out = (
scaled_dot_product_attention(
per_req_query.unsqueeze(0),
per_req_key.unsqueeze(0),
per_req_value.unsqueeze(0),
enable_gqa=enable_gqa,
scale=scaling,
is_causal=causal,
)
.squeeze(0)
.movedim(query.dim() - 2, 0)
)
output[start_q:end_q, :, :] = per_req_out
start_q, start_kv = end_q, end_kv
return output
def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V, seq_len):
dtype = torch.bfloat16
total_tokens = B * seq_len
sm_scale = 1.0 / (D**0.5)
logit_cap = 0.0
num_kv_splits = 8
enable_gqa = H_Q != H_KV
# q represents the new token being generated, one per batch
q = torch.randn(B, H_Q, D, dtype=dtype)
# k_buffer and v_buffer represent all previous tokens
k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype)
v_buffer = k_buffer.narrow(2, 0, D_V)
key = torch.randn(B, H_KV, D, dtype=dtype)
value = key.narrow(2, 0, D_V)
# make sure no duplicates in loc
loc = torch.randperm(total_tokens)[:B].to(torch.int64)
k_buffer2 = k_buffer.clone()
v_buffer2 = k_buffer2.narrow(2, 0, D_V)
# o will have the same shape as q
o = torch.zeros(B, H_Q, D_V, dtype=dtype)
o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype)
req_to_token = torch.arange(total_tokens).reshape(B, seq_len).to(torch.int32)
b_req_idx = torch.arange(B).to(torch.int64)
b_seq_len = torch.full((B,), seq_len).to(torch.int64)
attn_logits = torch.empty(
(B, H_Q, num_kv_splits, D_V + 1),
dtype=torch.float32,
)
torch.ops.sgl_kernel.decode_attention_cpu(
q,
k_buffer2,
v_buffer2,
o,
key,
value,
loc,
attn_logits,
req_to_token,
b_req_idx,
b_seq_len,
sm_scale,
logit_cap,
)
self._run_sdpa_forward_decode(
q,
o_grouped,
k_buffer,
v_buffer,
key,
loc,
req_to_token,
b_req_idx,
b_seq_len,
scaling=sm_scale,
enable_gqa=enable_gqa,
)
cos_sim = torch.nn.functional.cosine_similarity(
o.flatten(), o_grouped.flatten(), dim=0
)
atol = rtol = precision[q.dtype]
self.assertGreater(cos_sim.item(), 0.99)
torch.testing.assert_close(o, o_grouped, atol=atol, rtol=rtol)
torch.testing.assert_close(k_buffer, k_buffer2, atol=atol, rtol=rtol)
torch.testing.assert_close(v_buffer, v_buffer2, atol=atol, rtol=rtol)
def test_grouped_decode_attention(self):
configs = [
(1, 22, 1, 576, 512, 8 * 111),
(4, 22, 1, 576, 512, 8 * 128),
(40, 22, 1, 576, 512, 8 * 133),
]
for B, H_Q, H_KV, D, D_V, seqlen in configs:
self._test_grouped_decode_attention_once(B, H_Q, H_KV, D, D_V, seqlen)
if __name__ == "__main__":
unittest.main()

265
test/srt/cpu/test_moe.py Normal file
View File

@@ -0,0 +1,265 @@
import itertools
import math
import unittest
# TODO: use interface in cpu.py
import sgl_kernel
import torch
kernel = torch.ops.sgl_kernel
torch.manual_seed(1234)
from utils import (
BLOCK_K,
BLOCK_N,
factor_for_scale,
fp8_max,
fp8_min,
native_fp8_fused_moe,
precision,
scaled_weight,
torch_naive_fused_moe,
torch_w8a8_per_column_fused_moe,
)
from sglang.test.test_utils import CustomTestCase
def fused_moe(a, w1, w2, score, topk, renormalize, prepack):
G = 1
topk_group = 1
B, D = a.shape
topk_weights = torch.empty(B, topk, dtype=torch.float32)
topk_ids = torch.empty(B, topk, dtype=torch.int32)
topk_weights, topk_ids = kernel.grouped_topk_cpu(
a, score, topk, renormalize, G, topk_group, 0, None, None
)
packed_w1 = kernel.convert_weight_packed(w1) if prepack else w1
packed_w2 = kernel.convert_weight_packed(w2) if prepack else w2
inplace = True
return kernel.fused_experts_cpu(
a,
packed_w1,
packed_w2,
topk_weights,
topk_ids,
inplace,
False,
False,
None,
None,
None,
None,
None,
prepack,
)
class TestFusedExperts(CustomTestCase):
M = [2, 114]
N = [32]
K = [32]
E = [4]
topk = [2]
renormalize = [False, True]
M_int8 = [1, 39]
N_int8 = [128]
K_int8 = [256]
E_int8 = [8]
topk_int8 = [3]
M_fp8 = [2, 121]
N_fp8 = [352, 512]
K_fp8 = [256, 320]
E_fp8 = [8]
topk_fp8 = [4]
def _bf16_moe(self, m, n, k, e, topk, renormalize):
dtype = torch.bfloat16
prepack = True
a = torch.randn((m, k), device="cpu", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cpu", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cpu", dtype=dtype) / 10
score = torch.randn((m, e), device="cpu", dtype=dtype)
torch_output = torch_naive_fused_moe(a, w1, w2, score, topk, renormalize)
fused_output = fused_moe(a, w1, w2, score, topk, renormalize, prepack)
atol = rtol = precision[torch_output.dtype]
torch.testing.assert_close(torch_output, fused_output, atol=atol, rtol=rtol)
def test_bf16_moe(self):
for params in itertools.product(
self.M,
self.N,
self.K,
self.E,
self.topk,
self.renormalize,
):
with self.subTest(
m=params[0],
n=params[1],
k=params[2],
e=params[3],
topk=params[4],
renormalize=params[5],
):
self._bf16_moe(*params)
def _int8_moe(self, M, N, K, E, topk):
dtype = torch.bfloat16
prepack = True
# Initialize int8 quantization parameters
int8_factor_for_scale = 1e-2
int8_max = 127
int8_min = -128
# Input tensor
# M * K
a = torch.randn((M, K), dtype=dtype) / math.sqrt(K)
# Generate int8 weights
w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2
w1 = (w1_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)
w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2
w2 = (w2_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)
# Generate scale for each column (per-column quantization)
w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * int8_factor_for_scale
w2_s = torch.rand(E, K, device=w2_fp32.device) * int8_factor_for_scale
# Calculate routing
score = torch.randn((M, E), dtype=dtype)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
ref_out = torch_w8a8_per_column_fused_moe(
a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, topk
)
inplace = True
packed_w1 = kernel.convert_weight_packed(w1) if prepack else w1
packed_w2 = kernel.convert_weight_packed(w2) if prepack else w2
out = kernel.fused_experts_cpu(
a,
packed_w1,
packed_w2,
topk_weight,
topk_ids.to(torch.int32),
inplace,
True,
False,
w1_s,
w2_s,
None,
None,
None,
prepack,
)
atol = rtol = precision[ref_out.dtype]
# Increase the tolerance for large input shapes
if M > 35:
atol = rtol = 0.02
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
def test_int8_moe(self):
for params in itertools.product(
self.M_int8,
self.N_int8,
self.K_int8,
self.E_int8,
self.topk_int8,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
E=params[3],
topk=params[4],
):
self._int8_moe(*params)
def _fp8_moe(self, M, N, K, E, topk):
dtype = torch.bfloat16
a = torch.randn(M, K, dtype=dtype) / math.sqrt(K)
w1_fp32 = torch.randn(E, 2 * N, K)
w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
w2_fp32 = torch.randn(E, K, N)
w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
w1s = (
torch.randn(E, math.ceil(2 * N / BLOCK_N), math.ceil(K / BLOCK_K))
* factor_for_scale
)
w2s = (
torch.randn(E, math.ceil(K / BLOCK_N), math.ceil(N / BLOCK_K))
* factor_for_scale
)
w1_scaled = scaled_weight(w1, w1s)
w2_scaled = scaled_weight(w2, w2s)
score = torch.randn((M, E), dtype=dtype)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
w1 = kernel.convert_weight_packed(w1)
w2 = kernel.convert_weight_packed(w2)
ref_out = native_fp8_fused_moe(
a, w1_scaled, w2_scaled, topk_weight, topk_ids, topk
)
out = kernel.fused_experts_cpu(
a,
w1,
w2,
topk_weight,
topk_ids.to(torch.int32),
False,
False,
True,
w1s,
w2s,
[BLOCK_N, BLOCK_K],
None,
None,
True,
)
atol = rtol = precision[dtype]
torch.testing.assert_close(ref_out.bfloat16(), out, atol=atol, rtol=rtol)
def test_fp8_moe(self):
for params in itertools.product(
self.M_fp8,
self.N_fp8,
self.K_fp8,
self.E_fp8,
self.topk_fp8,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
E=params[3],
topk=params[4],
):
self._fp8_moe(*params)
if __name__ == "__main__":
unittest.main()

90
test/srt/cpu/test_norm.py Normal file
View File

@@ -0,0 +1,90 @@
import itertools
import unittest
from typing import Optional, Tuple, Union
import sgl_kernel
import torch
from utils import make_non_contiguous, precision
from sglang.test.test_utils import CustomTestCase
torch.manual_seed(1234)
class TestNorm(CustomTestCase):
M = [4096, 1024]
N = [4096, 4096 + 13]
dtype = [torch.float16, torch.bfloat16]
def _forward_native(
self,
x: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float = 1e-6,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + variance_epsilon)
x = x.to(orig_dtype) * weight
if residual is None:
return x
else:
return x, residual
def _norm_test(self, m, n, dtype):
x = torch.randn([m, n], dtype=dtype)
x = make_non_contiguous(x)
hidden_size = x.size(-1)
weight = torch.randn(hidden_size, dtype=dtype)
variance_epsilon = 1e-6
out = torch.ops.sgl_kernel.rmsnorm_cpu(x, weight, variance_epsilon)
ref_out = self._forward_native(x, weight, variance_epsilon)
atol = rtol = precision[ref_out.dtype]
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
ref_x = x.clone()
residual = torch.randn([m, hidden_size], dtype=dtype)
ref_residual = residual.clone()
torch.ops.sgl_kernel.fused_add_rmsnorm_cpu(
x, residual, weight, variance_epsilon
)
ref_x, ref_residual = self._forward_native(
ref_x, weight, variance_epsilon, ref_residual
)
torch.testing.assert_close(x, ref_x, atol=atol, rtol=rtol)
torch.testing.assert_close(residual, ref_residual, atol=atol, rtol=rtol)
def _l2norm_test(self, m, n, dtype):
x = torch.randn([m, n], dtype=dtype)
hidden_size = x.size(-1)
fake_ones_weight = torch.ones(hidden_size, dtype=dtype)
variance_epsilon = 1e-6
out = torch.ops.sgl_kernel.l2norm_cpu(x, variance_epsilon)
ref_out = self._forward_native(x, fake_ones_weight, variance_epsilon)
atol = rtol = precision[ref_out.dtype]
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
def test_norm(self):
for params in itertools.product(self.M, self.N, self.dtype):
with self.subTest(m=params[0], n=params[1], dtype=params[2]):
self._norm_test(*params)
self._l2norm_test(*params)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,432 @@
import unittest
import sgl_kernel
import torch
from utils import (
convert_weight,
native_w8a8_per_token_matmul,
per_token_quant_int8,
precision,
)
from sglang.srt.layers.rotary_embedding import _apply_rotary_emb
from sglang.test.test_utils import CustomTestCase
convert_weight_packed = torch.ops.sgl_kernel.convert_weight_packed
qkv_proj_with_rope = torch.ops.sgl_kernel.qkv_proj_with_rope
qkv_proj_with_rope_fused_weight = torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight
torch.manual_seed(1234)
# constants
kv_lora_rank = 512
qk_head_dim = 192
qk_nope_head_dim = 128
qk_rope_head_dim = 64
rotary_dim = qk_rope_head_dim
num_heads = 22
q_lora_rank = 1536
hidden_size = 7168
B = 1
eps = 1e-6
def layernorm(x, weight, variance_epsilon=1e-6, residual=None):
orig_dtype = x.dtype
x = x.to(torch.float32)
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + variance_epsilon)
return (x * weight).to(orig_dtype)
def rotary_emb(q_pe, k_pe, pos, cos_sin_cache):
orig_dtype = q_pe.dtype
q_pe = q_pe.float()
k_pe = k_pe.float()
cos_sin_cache = cos_sin_cache.float()
query_rot = q_pe[..., :rotary_dim]
key_rot = k_pe[..., :rotary_dim]
cos_sin = cos_sin_cache[pos]
cos, sin = cos_sin.chunk(2, dim=-1)
query_rot = _apply_rotary_emb(query_rot, cos, sin, False)
key_rot = _apply_rotary_emb(key_rot, cos, sin, False)
return query_rot.to(orig_dtype), key_rot.to(orig_dtype)
def native_torch(
q_input,
hidden_states,
q_a_proj_weight,
norm_weight1,
q_b_proj_weight,
w_kc,
kv_a_proj_weight,
norm_weight2,
pos,
cos_sin_cache,
):
q = torch.matmul(hidden_states, q_a_proj_weight.t())
q = layernorm(q, norm_weight1)
q = torch.matmul(q, q_b_proj_weight.t()).view(-1, num_heads, qk_head_dim)
q_nope, q_pe = q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
q_nope_out = torch.bmm(q_nope.transpose(0, 1), w_kc)
q_input[..., :kv_lora_rank] = q_nope_out.transpose(0, 1)
latent_cache = torch.matmul(hidden_states, kv_a_proj_weight.t())
v_input = latent_cache[..., :kv_lora_rank]
v_input = layernorm(v_input.contiguous(), norm_weight2).unsqueeze(1)
k_input = latent_cache.unsqueeze(1)
k_input[..., :kv_lora_rank] = v_input
k_pe = k_input[..., kv_lora_rank:]
q_pe, k_pe = rotary_emb(q_pe, k_pe, pos, cos_sin_cache)
q_input[..., kv_lora_rank:] = q_pe
k_input[..., kv_lora_rank:] = k_pe
return q_input, k_input, v_input
def native_torch_int8(
q_input,
hidden_states,
w1_q,
w1_s,
norm_weight1,
w2_q,
w2_s,
w_kc,
w3_q,
w3_s,
norm_weight2,
pos,
cos_sin_cache,
):
a_q, a_s = per_token_quant_int8(hidden_states)
q = native_w8a8_per_token_matmul(a_q, w1_q, a_s, w1_s, None, torch.bfloat16)
q = layernorm(q, norm_weight1)
a_q, a_s = per_token_quant_int8(q)
q = native_w8a8_per_token_matmul(a_q, w2_q, a_s, w2_s, None, torch.bfloat16).view(
-1, num_heads, qk_head_dim
)
q_nope, q_pe = q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
q_nope_out = torch.bmm(q_nope.transpose(0, 1), w_kc)
q_input[..., :kv_lora_rank] = q_nope_out.transpose(0, 1)
a_q, a_s = per_token_quant_int8(hidden_states)
latent_cache = native_w8a8_per_token_matmul(
a_q, w3_q, a_s, w3_s, None, torch.bfloat16
)
v_input = latent_cache[..., :kv_lora_rank]
v_input = layernorm(v_input.contiguous(), norm_weight2).unsqueeze(1)
k_input = latent_cache.unsqueeze(1)
k_input[..., :kv_lora_rank] = v_input
k_pe = k_input[..., kv_lora_rank:]
q_pe, k_pe = rotary_emb(q_pe, k_pe, pos, cos_sin_cache)
q_input[..., kv_lora_rank:] = q_pe
k_input[..., kv_lora_rank:] = k_pe
return q_input, k_input, v_input
class TestQKVProjWithROPE(CustomTestCase):
def test_bf16_qkv_proj_with_rope(self):
dtype = torch.bfloat16
hidden_states = torch.randn(B, hidden_size, dtype=dtype) / hidden_size
q_input = torch.empty(
B, num_heads, kv_lora_rank + qk_rope_head_dim, dtype=dtype
)
q_a_proj_weight = torch.randn(q_lora_rank, hidden_size, dtype=dtype) * 0.1
norm_weight1 = torch.randn(q_lora_rank, dtype=dtype)
q_b_proj_weight = (
torch.randn(num_heads * qk_head_dim, q_lora_rank, dtype=dtype) * 0.1
)
w_kc = torch.randn(num_heads, kv_lora_rank, qk_nope_head_dim, dtype=dtype) * 0.1
kv_a_proj_weight = (
torch.randn(kv_lora_rank + qk_rope_head_dim, hidden_size, dtype=dtype) * 0.1
)
fused_weight = torch.cat([q_a_proj_weight, kv_a_proj_weight], dim=0)
norm_weight2 = torch.randn(kv_lora_rank, dtype=dtype)
pos = torch.randint(10, 100, (B,))
cos_sin_cache = torch.randn(100, rotary_dim, dtype=dtype)
q_ref, k_ref, v_ref = native_torch(
q_input,
hidden_states,
q_a_proj_weight,
norm_weight1,
q_b_proj_weight,
w_kc.transpose(1, 2),
kv_a_proj_weight,
norm_weight2,
pos,
cos_sin_cache,
)
qa_packed = convert_weight_packed(q_a_proj_weight)
qb_packed = convert_weight_packed(q_b_proj_weight)
kva_packed = convert_weight_packed(kv_a_proj_weight)
wkc_packed = convert_weight_packed(w_kc)
fused_weight_packed = convert_weight_packed(fused_weight)
q_out, k_out, v_out = qkv_proj_with_rope(
hidden_states,
qa_packed,
qb_packed,
kva_packed,
wkc_packed,
norm_weight1,
norm_weight2,
pos,
cos_sin_cache,
eps,
False,
False,
None,
None,
None,
True,
None,
)
fused_q_out, fused_k_out, fused_v_out = qkv_proj_with_rope_fused_weight(
hidden_states,
fused_weight_packed,
qb_packed,
wkc_packed,
norm_weight1,
norm_weight2,
pos,
cos_sin_cache,
eps,
False,
False,
None,
None,
True,
None,
q_lora_rank,
kv_lora_rank,
qk_rope_head_dim,
)
atol = rtol = precision[q_ref.dtype]
torch.testing.assert_close(q_ref, q_out, atol=atol, rtol=rtol)
torch.testing.assert_close(k_ref, k_out, atol=atol, rtol=rtol)
torch.testing.assert_close(v_ref, v_out, atol=atol, rtol=rtol)
torch.testing.assert_close(fused_q_out, q_out)
torch.testing.assert_close(fused_k_out, k_out)
torch.testing.assert_close(fused_v_out, v_out)
def test_int8_qkv_proj_with_rope(self):
dtype = torch.bfloat16
hidden_states = torch.randn(B, hidden_size, dtype=dtype) / hidden_size
q_input = torch.empty(
B, num_heads, kv_lora_rank + qk_rope_head_dim, dtype=dtype
)
q_a_proj_weight = torch.randn(q_lora_rank, hidden_size, dtype=dtype) * 0.1
norm_weight1 = torch.randn(q_lora_rank, dtype=dtype)
q_b_proj_weight = (
torch.randn(num_heads * qk_head_dim, q_lora_rank, dtype=dtype) * 0.1
)
w_kc = torch.randn(num_heads, kv_lora_rank, qk_nope_head_dim, dtype=dtype) * 0.1
kv_a_proj_weight = (
torch.randn(kv_lora_rank + qk_rope_head_dim, hidden_size, dtype=dtype) * 0.1
)
norm_weight2 = torch.randn(kv_lora_rank, dtype=dtype)
pos = torch.randint(10, 100, (B,))
cos_sin_cache = torch.randn(100, rotary_dim, dtype=dtype)
w1_q, w1_s = per_token_quant_int8(q_a_proj_weight)
w2_q, w2_s = per_token_quant_int8(q_b_proj_weight)
w3_q, w3_s = per_token_quant_int8(kv_a_proj_weight)
q_ref, k_ref, v_ref = native_torch_int8(
q_input,
hidden_states,
w1_q,
w1_s,
norm_weight1,
w2_q,
w2_s,
w_kc.transpose(1, 2),
w3_q,
w3_s,
norm_weight2,
pos,
cos_sin_cache,
)
w1_q_packed = convert_weight_packed(w1_q)
w2_q_packed = convert_weight_packed(w2_q)
w3_q_packed = convert_weight_packed(w3_q)
wkc_packed = convert_weight_packed(w_kc)
q_out, k_out, v_out = qkv_proj_with_rope(
hidden_states,
w1_q_packed,
w2_q_packed,
w3_q_packed,
wkc_packed,
norm_weight1,
norm_weight2,
pos,
cos_sin_cache,
eps,
True,
False,
w1_s,
w2_s,
w3_s,
True,
None,
)
fused_weight = torch.cat([w1_q, w3_q], dim=0)
fused_weight_s = torch.cat([w1_s, w3_s], dim=0)
w_fused_q_packed = convert_weight_packed(fused_weight)
fused_q_out, fused_k_out, fused_v_out = qkv_proj_with_rope_fused_weight(
hidden_states,
w_fused_q_packed,
w2_q_packed,
wkc_packed,
norm_weight1,
norm_weight2,
pos,
cos_sin_cache,
eps,
True,
False,
fused_weight_s,
w2_s,
True,
None,
q_lora_rank,
kv_lora_rank,
qk_rope_head_dim,
)
atol = rtol = precision[q_ref.dtype]
torch.testing.assert_close(q_ref, q_out, atol=atol, rtol=rtol)
torch.testing.assert_close(k_ref, k_out, atol=atol, rtol=rtol)
torch.testing.assert_close(v_ref, v_out, atol=atol, rtol=rtol)
torch.testing.assert_close(fused_q_out, q_out)
torch.testing.assert_close(fused_k_out, k_out)
torch.testing.assert_close(fused_v_out, v_out)
def test_fp8_qkv_proj_with_rope(self):
dtype = torch.bfloat16
hidden_states = torch.randn(B, hidden_size, dtype=dtype) / hidden_size
q_input = torch.empty(
B, num_heads, kv_lora_rank + qk_rope_head_dim, dtype=dtype
)
q_a_proj_weight = torch.randn(q_lora_rank, hidden_size, dtype=dtype) * 0.1
norm_weight1 = torch.randn(q_lora_rank, dtype=dtype)
q_b_proj_weight = (
torch.randn(num_heads * qk_head_dim, q_lora_rank, dtype=dtype) * 0.1
)
w_kc = torch.randn(num_heads, kv_lora_rank, qk_nope_head_dim, dtype=dtype) * 0.1
kv_a_proj_weight = (
torch.randn(kv_lora_rank + qk_rope_head_dim, hidden_size, dtype=dtype) * 0.1
)
norm_weight2 = torch.randn(kv_lora_rank, dtype=dtype)
pos = torch.randint(10, 100, (B,))
cos_sin_cache = torch.randn(100, rotary_dim, dtype=dtype)
scale_block_size_N = 128
scale_block_size_K = 128
fp8_q_a_proj_weight, q_a_proj_weight_scale_inv, q_a_proj_weight_dq = (
convert_weight(
q_a_proj_weight,
[scale_block_size_N, scale_block_size_K],
torch.bfloat16,
)
)
fp8_q_b_proj_weight, q_b_proj_weight_scale_inv, q_b_proj_weight_dq = (
convert_weight(
q_b_proj_weight,
[scale_block_size_N, scale_block_size_K],
torch.bfloat16,
)
)
(
fp8_kv_a_proj_with_mqa_weight,
kv_a_proj_with_mqa_weight_scale_inv,
kv_a_proj_with_mqa_weight_dq,
) = convert_weight(
kv_a_proj_weight, [scale_block_size_N, scale_block_size_K], torch.bfloat16
)
q_ref, k_ref, v_ref = native_torch(
q_input,
hidden_states,
q_a_proj_weight_dq,
norm_weight1,
q_b_proj_weight_dq,
w_kc.transpose(1, 2),
kv_a_proj_with_mqa_weight_dq,
norm_weight2,
pos,
cos_sin_cache,
)
fp8_q_a_proj_weight_packed = convert_weight_packed(fp8_q_a_proj_weight)
fp8_q_b_proj_weight_packed = convert_weight_packed(fp8_q_b_proj_weight)
fp8_kv_a_proj_with_mqa_weight_packed = convert_weight_packed(
fp8_kv_a_proj_with_mqa_weight
)
w_kc = convert_weight_packed(w_kc)
q_out, k_out, v_out = qkv_proj_with_rope(
hidden_states,
fp8_q_a_proj_weight_packed,
fp8_q_b_proj_weight_packed,
fp8_kv_a_proj_with_mqa_weight_packed,
w_kc,
norm_weight1,
norm_weight2,
pos,
cos_sin_cache,
eps,
False,
True,
q_a_proj_weight_scale_inv.float(),
q_b_proj_weight_scale_inv.float(),
kv_a_proj_with_mqa_weight_scale_inv.float(),
True,
[scale_block_size_N, scale_block_size_K],
)
fused_weight = torch.cat(
[fp8_q_a_proj_weight, fp8_kv_a_proj_with_mqa_weight], dim=0
)
fused_weight_s = torch.cat(
[q_a_proj_weight_scale_inv, kv_a_proj_with_mqa_weight_scale_inv], dim=0
)
fused_weight_packed = convert_weight_packed(fused_weight)
fused_q_out, fused_k_out, fused_v_out = qkv_proj_with_rope_fused_weight(
hidden_states,
fused_weight_packed,
fp8_q_b_proj_weight_packed,
w_kc,
norm_weight1,
norm_weight2,
pos,
cos_sin_cache,
eps,
False,
True,
fused_weight_s.float(),
q_b_proj_weight_scale_inv.float(),
True,
[scale_block_size_N, scale_block_size_K],
q_lora_rank,
kv_lora_rank,
qk_rope_head_dim,
)
atol = rtol = precision[q_ref.dtype]
# Due to the change in multiplication order, the error is amplified.
# In the model, with fewer layers, this doesn't cause issues, but in
# tests with more layers, we need to enlarge the tolerance to pass the tests.
torch.testing.assert_close(q_ref, q_out, atol=1e-1, rtol=1e-1)
torch.testing.assert_close(k_ref, k_out, atol=atol, rtol=rtol)
torch.testing.assert_close(v_ref, v_out, atol=atol, rtol=rtol)
torch.testing.assert_close(fused_q_out, q_out)
torch.testing.assert_close(fused_k_out, k_out)
torch.testing.assert_close(fused_v_out, v_out)
if __name__ == "__main__":
unittest.main()

178
test/srt/cpu/test_rope.py Normal file
View File

@@ -0,0 +1,178 @@
import unittest
import sgl_kernel
import torch
from utils import precision
from sglang.srt.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding,
RotaryEmbedding,
)
from sglang.test.test_utils import CustomTestCase
torch.manual_seed(1234)
class TestROPE(CustomTestCase):
def test_deepseek_v2_rope(self):
num_head = 16
seq_len = 1024
q_head_dim = 192
qk_nope_head_dim = 128
qk_rope_head_dim = 64
max_pos = 256
k_dim = 576
rotary_dim = 64
is_neox_style = False
# Create cos_sin_cache
freqs = torch.rand(max_pos, qk_rope_head_dim // 2)
cos = freqs.cos() * 0.7
sin = freqs.sin() * 0.7
cos_sin_cache = torch.cat((cos, sin), dim=-1).to(torch.bfloat16)
positions = torch.randint(0, max_pos, (seq_len,))
rope = DeepseekScalingRotaryEmbedding(
qk_rope_head_dim,
rotary_dim,
max_pos,
16, # not used since cos_sin_cache is provided
is_neox_style,
1.0,
torch.bfloat16,
device="cpu",
)
rope.register_buffer("cos_sin_cache", cos_sin_cache)
for dtype in [torch.bfloat16]:
enable_autocast = True
with torch.no_grad(), torch.amp.autocast("cpu", enabled=enable_autocast):
q = torch.randn(seq_len, num_head, q_head_dim, dtype=dtype)
q_clone = q.clone()
k = torch.randn(seq_len, 1, k_dim, dtype=dtype)
k_clone = k.clone()
_, q_pe = q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
_, q_pe_clone = q_clone.split(
[qk_nope_head_dim, qk_rope_head_dim], dim=-1
)
k_pe = k[:, :, k_dim - qk_rope_head_dim :]
k_pe_clone = k_clone[:, :, k_dim - qk_rope_head_dim :]
# ref kernel
q_pe, k_pe = rope.forward_native(
query=q_pe,
key=k_pe,
positions=positions,
)
# fused rope kernel
q_pe_clone, k_pe_clone = torch.ops.sgl_kernel.rotary_embedding_cpu(
positions,
q_pe_clone,
k_pe_clone,
rope.head_size,
cos_sin_cache,
False,
)
atol = rtol = precision[q_pe.dtype]
torch.testing.assert_close(q_pe, q_pe_clone, atol=atol, rtol=rtol)
torch.testing.assert_close(k_pe, k_pe_clone, atol=atol, rtol=rtol)
torch.testing.assert_close(k_pe, k_pe_clone)
def test_origin_rope(self):
def single_test(
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
device: str,
batch_size: int,
seq_len: int,
num_q_heads: int,
num_kv_heads: int,
):
torch.manual_seed(100)
rope_ref = RotaryEmbedding(
head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
).to(device)
pos_ids = torch.arange(seq_len, device=device).repeat(batch_size)
query = torch.randn(
batch_size * seq_len,
num_q_heads * head_size,
dtype=dtype,
device=device,
)
key = torch.randn(
batch_size * seq_len,
num_kv_heads * head_size,
dtype=dtype,
device=device,
)
query_ref, key_ref = query.clone(), key.clone()
query_cpu, key_cpu = query.clone(), key.clone()
query_ref_out, key_ref_out = rope_ref.forward_native(
pos_ids, query_ref, key_ref
)
query_cpu_out, key_cpu_out = torch.ops.sgl_kernel.rotary_embedding_cpu(
pos_ids,
query_cpu,
key_cpu,
rope_ref.head_size,
rope_ref.cos_sin_cache.to(query.dtype),
rope_ref.is_neox_style,
)
torch.testing.assert_close(
query_ref_out, query_cpu_out, atol=1e-2, rtol=1e-2
)
torch.testing.assert_close(key_ref_out, key_cpu_out, atol=1e-2, rtol=1e-2)
test_config = [
(64, 64, 32, 8000, True, torch.bfloat16, "cpu", 32, 32, 1, 1),
(256, 128, 4096, 10000, True, torch.bfloat16, "cpu", 2, 512, 32, 8),
(512, 128, 311, 10000, True, torch.bfloat16, "cpu", 3, 39, 4, 2),
(128, 128, 2048, 10000, False, torch.bfloat16, "cpu", 2, 512, 32, 8),
(128, 128, 2048, 10000, False, torch.bfloat16, "cpu", 2, 512, 16, 4),
(512, 128, 311, 10000, False, torch.bfloat16, "cpu", 3, 39, 4, 2),
]
for (
head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
device,
batch_size,
seq_len,
num_q_heads,
num_kv_heads,
) in test_config:
single_test(
head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
device,
batch_size,
seq_len,
num_q_heads,
num_kv_heads,
)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,223 @@
import itertools
import math
import unittest
# TODO: use interface in cpu.py
import sgl_kernel
import torch
import torch.nn as nn
from utils import (
BLOCK_K,
BLOCK_N,
SiluAndMul,
factor_for_scale,
fp8_max,
fp8_min,
per_token_quant_int8,
precision,
scaled_weight,
torch_naive_moe,
torch_w8a8_per_column_moe,
)
from sglang.test.test_utils import CustomTestCase
torch.manual_seed(1234)
class TestSharedExpert(CustomTestCase):
M = [2, 121]
N = [32, 32 * 4]
K = [32, 32 * 2]
routed_scaling_factor = [16]
M_fp8 = [2, 12]
N_fp8 = [512]
K_fp8 = [256]
def _bf16_shared_expert(self, m, n, k, routed_scaling_factor):
dtype = torch.bfloat16
prepack = True
hidden_states = torch.randn(m, k, dtype=dtype) / k
w1 = torch.randn(2 * n, k, dtype=dtype)
w2 = torch.randn(k, n, dtype=dtype)
fused_output = torch.randn(m, k, dtype=dtype) / k
# fused moe mutates content in hs
hidden_states2 = hidden_states.clone()
# bfloat16
ref = torch_naive_moe(
hidden_states.float(),
w1.float(),
w2.float(),
fused_output.float(),
routed_scaling_factor,
).to(dtype=dtype)
res = torch.ops.sgl_kernel.shared_expert_cpu(
hidden_states,
w1,
w2,
fused_output,
routed_scaling_factor,
True,
False,
False,
None,
None,
None,
None,
None,
False,
)
atol = rtol = precision[ref.dtype]
torch.testing.assert_close(ref, res, atol=atol, rtol=rtol)
def test_bf16_shared_expert(self):
for params in itertools.product(
self.M,
self.N,
self.K,
self.routed_scaling_factor,
):
with self.subTest(
m=params[0],
n=params[1],
k=params[2],
routed_scaling_factor=params[3],
):
self._bf16_shared_expert(*params)
def _int8_shared_expert(self, m, n, k, routed_scaling_factor):
dtype = torch.bfloat16
prepack = True
hidden_states = torch.randn(m, k, dtype=dtype) / k
w1 = torch.randn(2 * n, k, dtype=dtype)
w2 = torch.randn(k, n, dtype=dtype)
fused_output = torch.randn(m, k, dtype=dtype) / k
# fused moe mutates content in hs
hidden_states2 = hidden_states.clone()
w1_q, w1_s = per_token_quant_int8(w1)
w2_q, w2_s = per_token_quant_int8(w2)
ref2 = torch_w8a8_per_column_moe(
hidden_states2.float(),
w1_q,
w2_q,
w1_s,
w2_s,
fused_output.float(),
routed_scaling_factor,
).to(dtype=dtype)
res2 = torch.ops.sgl_kernel.shared_expert_cpu(
hidden_states2,
w1_q,
w2_q,
fused_output,
routed_scaling_factor,
True,
True,
False,
w1_s,
w2_s,
None,
None,
None,
False,
)
atol = rtol = precision[ref2.dtype]
torch.testing.assert_close(ref2, res2, atol=atol, rtol=rtol)
def test_int8_shared_expert(self):
for params in itertools.product(
self.M,
self.N,
self.K,
self.routed_scaling_factor,
):
with self.subTest(
m=params[0],
n=params[1],
k=params[2],
routed_scaling_factor=params[3],
):
self._int8_shared_expert(*params)
def _fp8_shared_expert(self, M, N, K, routed_scaling_factor):
dtype = torch.bfloat16
prepack = True
a = torch.randn(M, K, dtype=dtype) / math.sqrt(K)
w1_fp32 = torch.randn(1, 2 * N, K)
w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
w2_fp32 = torch.randn(1, K, N)
w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
w1s = torch.randn(1, 2 * N // BLOCK_N, K // BLOCK_K) * factor_for_scale
w2s = torch.randn(1, K // BLOCK_N, N // BLOCK_K) * factor_for_scale
w1_scaled = scaled_weight(w1, w1s).view(2 * N, K)
w2_scaled = scaled_weight(w2, w2s).view(K, N)
# change back to 2D
w1, w2 = w1.squeeze(0), w2.squeeze(0)
w1s, w2s = w1s.squeeze(0), w2s.squeeze(0)
w1_scaled, w2_scaled = w1_scaled.squeeze(0), w2_scaled.squeeze(0)
fused_out = torch.randn(M, K, dtype=dtype) / math.sqrt(K)
a2 = a.clone()
# ref
ic0 = torch.matmul(a.float(), w1_scaled.transpose(0, 1))
ic1 = SiluAndMul(ic0)
shared_out = torch.matmul(ic1, w2_scaled.transpose(0, 1))
ref_out = shared_out + fused_out.float() * routed_scaling_factor
ref_out = ref_out.to(dtype=dtype)
w1 = torch.ops.sgl_kernel.convert_weight_packed(w1) # [2N, K]
w2 = torch.ops.sgl_kernel.convert_weight_packed(w2) # [K, N]
out = torch.ops.sgl_kernel.shared_expert_cpu(
a2,
w1,
w2,
fused_out,
routed_scaling_factor,
True,
False,
True,
w1s,
w2s,
[BLOCK_N, BLOCK_K],
None,
None,
True,
)
atol = rtol = precision[ref_out.dtype]
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
def test_fp8_shared_expert(self):
for params in itertools.product(
self.M_fp8,
self.N_fp8,
self.K_fp8,
self.routed_scaling_factor,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
routed_scaling_factor=params[3],
):
self._fp8_shared_expert(*params)
if __name__ == "__main__":
unittest.main()

199
test/srt/cpu/test_topk.py Normal file
View File

@@ -0,0 +1,199 @@
import itertools
import unittest
import sgl_kernel
import torch
from utils import precision
from sglang.srt.layers.moe.topk import (
biased_grouped_topk_impl as native_biased_grouped_topk,
)
from sglang.srt.layers.moe.topk import fused_topk_torch_native as native_fused_topk
from sglang.srt.layers.moe.topk import grouped_topk_gpu as native_grouped_topk
from sglang.srt.models.llama4 import Llama4MoE
from sglang.test.test_utils import CustomTestCase
torch.manual_seed(1234)
# This is used by the Deepseek-V2 model
class TestGroupedTopK(CustomTestCase):
def _run_single_test(self, M, E, G, topk, topk_group, renormalize, dtype):
torch.manual_seed(1234)
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
hidden_states = torch.randn(M, 100, dtype=dtype)
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
ref_topk_weights, ref_topk_ids = native_grouped_topk(
hidden_states.float(),
gating_output.float(),
topk,
renormalize,
G,
topk_group,
)
# fused version
topk_weights, topk_ids = torch.ops.sgl_kernel.grouped_topk_cpu(
hidden_states,
gating_output,
topk,
renormalize,
G,
topk_group,
0,
None,
None,
)
res = torch.zeros(M, E, dtype=torch.float)
ref = torch.zeros(M, E, dtype=torch.float)
res.scatter_(1, topk_ids.long(), topk_weights)
ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights)
torch.testing.assert_close(res, ref)
def test_grouped_topk(self):
for renormalize in [True, False]:
self._run_single_test(123, 8, 2, 2, 1, renormalize, torch.bfloat16)
self._run_single_test(123, 16, 4, 3, 2, renormalize, torch.bfloat16)
self._run_single_test(123, 32, 4, 3, 2, renormalize, torch.bfloat16)
self._run_single_test(1123, 32, 4, 3, 2, renormalize, torch.bfloat16)
self._run_single_test(123, 64, 1, 6, 1, renormalize, torch.bfloat16)
self._run_single_test(123, 256, 8, 4, 8, renormalize, torch.bfloat16)
self._run_single_test(123, 160, 8, 6, 2, renormalize, torch.bfloat16)
# DeepSeek V2/V3/R1 uses biased_grouped_top
class TestBiasedGroupedTopK(CustomTestCase):
def _run_single_test(
self, M, E, G, topk, topk_group, renormalize, dtype, bias_dtype
):
torch.manual_seed(1234)
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
hidden_states = torch.randn(M, 100, dtype=dtype)
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
correction_bias = torch.randn(E, dtype=bias_dtype)
ref_topk_weights, ref_topk_ids = native_biased_grouped_topk(
hidden_states.float(),
gating_output.float(),
correction_bias.float(),
topk,
renormalize,
G,
topk_group,
)
# fused version
topk_weights, topk_ids = torch.ops.sgl_kernel.biased_grouped_topk_cpu(
hidden_states,
gating_output,
correction_bias,
topk,
renormalize,
G,
topk_group,
0,
None,
None,
)
res = torch.zeros(M, E, dtype=torch.float)
ref = torch.zeros(M, E, dtype=torch.float)
res.scatter_(1, topk_ids.long(), topk_weights)
ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights)
torch.testing.assert_close(res, ref)
def test_biased_grouped_topk(self):
for renormalize in [True, False]:
for bias_dtype in [torch.float32, torch.bfloat16]:
self._run_single_test(
122, 256, 8, 8, 2, renormalize, torch.bfloat16, bias_dtype
)
class TestTopK(CustomTestCase):
def _run_single_test(self, M, E, topk, renormalize, dtype):
torch.manual_seed(1998)
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
hidden_states = torch.randn(M, 100, dtype=dtype)
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
ref_topk_weights, ref_topk_ids = native_fused_topk(
hidden_states.float(),
gating_output.float(),
topk,
renormalize,
)
# fused version
topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu(
hidden_states, gating_output, topk, renormalize
)
res = torch.zeros(M, E, dtype=torch.float)
ref = torch.zeros(M, E, dtype=torch.float)
res.scatter_(1, topk_ids.long(), topk_weights)
ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights)
torch.testing.assert_close(res, ref)
def test_topk(self):
for renormalize in [True, False]:
self._run_single_test(123, 8, 2, renormalize, torch.bfloat16)
self._run_single_test(123, 16, 3, renormalize, torch.bfloat16)
self._run_single_test(123, 32, 3, renormalize, torch.bfloat16)
self._run_single_test(123, 32, 3, renormalize, torch.bfloat16)
self._run_single_test(123, 64, 6, renormalize, torch.bfloat16)
self._run_single_test(123, 256, 4, renormalize, torch.bfloat16)
self._run_single_test(123, 160, 6, renormalize, torch.bfloat16)
class TestCustomTopK(CustomTestCase):
def _run_single_test(
self, M, E, topk, renormalize, dtype, native_custom_f, fused_custom_f
):
torch.manual_seed(16)
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
hidden_states = torch.randn(M, 100, dtype=dtype)
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
ref_topk_weights, ref_topk_ids = native_custom_f(
hidden_states.float(),
gating_output.float(),
topk,
renormalize,
)
# fused version
topk_weights, topk_ids = fused_custom_f(
hidden_states, gating_output, topk, renormalize
)
res = torch.zeros(M, E, dtype=torch.float)
ref = torch.zeros(M, E, dtype=torch.float)
res.scatter_(1, topk_ids.long(), topk_weights)
ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights)
torch.testing.assert_close(res, ref)
def test_custom_topk(self):
test_custom_functions = [
(Llama4MoE.custom_routing_function, torch.ops.sgl_kernel.topk_sigmoid_cpu)
]
for native_custom_f, fused_custom_f in test_custom_functions:
self._run_single_test(
123, 8, 1, False, torch.bfloat16, native_custom_f, fused_custom_f
)
self._run_single_test(
123, 16, 1, False, torch.bfloat16, native_custom_f, fused_custom_f
)
self._run_single_test(
123, 32, 1, False, torch.bfloat16, native_custom_f, fused_custom_f
)
if __name__ == "__main__":
unittest.main()

274
test/srt/cpu/utils.py Normal file
View File

@@ -0,0 +1,274 @@
import math
import torch
import torch.nn.functional as F
precision = {
torch.bfloat16: 1e-2,
torch.float16: 1e-3,
torch.float32: 1e-5,
}
BLOCK_N, BLOCK_K = 64, 128
factor_for_scale = 1e-3
fp8_max, fp8_min = 400, -400
def SiluAndMul(x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
def GeluAndMul(x: torch.Tensor, approximate="tanh") -> torch.Tensor:
d = x.shape[-1] // 2
return F.gelu(x[..., :d], approximate=approximate) * x[..., d:]
def per_token_quant_int8(x):
x = x.float()
absmax = x.abs().max(dim=-1).values
absmax = absmax.clamp_min(1e-10).unsqueeze(-1)
scale_x = absmax / 127
x_q = x.mul(127 / absmax)
x_q = torch.round(x_q).to(torch.int8)
return x_q, scale_x
def convert_weight(weight, scale_block_size, A_dtype):
N, K = weight.size()
fp8_max = 448.0
scale_block_size_N, scale_block_size_K = scale_block_size # (128, 128)
pad_N = (scale_block_size_N - (N % scale_block_size_N)) % scale_block_size_N
pad_K = (scale_block_size_K - (K % scale_block_size_K)) % scale_block_size_K
if pad_N > 0 or pad_K > 0:
weight = torch.nn.functional.pad(weight, (0, pad_K, 0, pad_N))
weight_blocks = weight.view(
math.ceil(N / scale_block_size_N),
scale_block_size_N,
math.ceil(K / scale_block_size_K),
scale_block_size_K,
) # (8, 128, 8, 128)
weight_blocks = weight_blocks.permute(0, 2, 1, 3).contiguous() # (8, 8, 128, 128)
# Step 2: compute per-block max abs values → scale
abs_max = weight_blocks.abs().amax(dim=(-2, -1), keepdim=True) # (8, 8, 1, 1)
scales = abs_max / fp8_max
scales = torch.where(
scales == 0, torch.ones_like(scales), scales
) # avoid division by zero
q_fp8 = (weight_blocks / scales).to(torch.float8_e4m3fn)
q_fp8_reshape = q_fp8.permute(0, 2, 1, 3).contiguous()
if pad_N > 0 or pad_K > 0:
q_fp8_reshape = q_fp8_reshape.view(N + pad_N, K + pad_K)
q_fp8_reshape = q_fp8_reshape[:N, :K].contiguous()
else:
q_fp8_reshape = q_fp8_reshape.view(N, K)
dq_weight = q_fp8.float() * scales
dq_weight = dq_weight.permute(0, 2, 1, 3).contiguous() # (8, 128, 8, 128)
if pad_N > 0 or pad_K > 0:
w_dq = dq_weight.view(N + pad_N, K + pad_K).to(A_dtype)
w_dq = w_dq[:N, :K].contiguous()
else:
w_dq = dq_weight.view(N, K).to(A_dtype)
scales = scales.view(
math.ceil(N / scale_block_size_N), math.ceil(K / scale_block_size_K)
)
return q_fp8_reshape, scales, w_dq
def native_w8a8_per_token_matmul(A, B, As, Bs, bias, output_dtype=torch.bfloat16):
"""Matrix multiplication function that supports per-token input quantization and per-column weight quantization"""
A = A.to(torch.float32)
B = B.to(torch.float32)
assert A.shape[-1] == B.shape[-1], "Dimension mismatch"
assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor"
# Reshape input
M = A.numel() // A.shape[-1]
B = B.t() # Transpose weight matrix
N, K = B.shape
origin_C_shape = A.shape[:-1] + (K,)
A = A.reshape(M, N)
# As is per-token [M, 1], Bs is per-column [1, K]
C = torch.matmul(A, B) # [M, K]
C = As * C * Bs.view(1, -1) # Broadcast per-column scale
if bias is not None:
C.add_(bias.view(1, -1))
return C.reshape(origin_C_shape).to(output_dtype)
def torch_naive_moe(a, w1, w2, b, routed_scaling_factor):
ic1 = torch.matmul(a, w1.transpose(0, 1))
ic2 = SiluAndMul(ic1)
ic3 = torch.matmul(ic2, w2.transpose(0, 1))
return ic3 + b * routed_scaling_factor
def torch_w8a8_per_column_moe(a, w1_q, w2_q, w1_s, w2_s, b, routed_scaling_factor):
# Perform per-token quantization
a_q, a_s = per_token_quant_int8(a)
ic1 = native_w8a8_per_token_matmul(
a_q, w1_q, a_s, w1_s, bias=None, output_dtype=torch.float32
)
ic2 = SiluAndMul(ic1)
a1_q, a1_s = per_token_quant_int8(ic2)
ic3 = native_w8a8_per_token_matmul(
a1_q, w2_q, a1_s, w2_s, bias=None, output_dtype=torch.float32
)
return ic3 + b * routed_scaling_factor
def scaled_weight(weight, scales):
E, N, K = weight.shape
pad_N = (BLOCK_N - (N % BLOCK_N)) % BLOCK_N
pad_K = (BLOCK_K - (K % BLOCK_K)) % BLOCK_K
if pad_N > 0 or pad_K > 0:
weight = torch.nn.functional.pad(weight, (0, pad_K, 0, pad_N))
weight_block = (
weight.view(E, math.ceil(N / BLOCK_N), BLOCK_N, math.ceil(K / BLOCK_K), BLOCK_K)
.permute(0, 1, 3, 2, 4)
.float()
.contiguous()
)
weight_scaled = (
(
weight_block
* scales.view(E, math.ceil(N / BLOCK_N), math.ceil(K / BLOCK_K), 1, 1)
)
.permute(0, 1, 3, 2, 4)
.contiguous()
)
if pad_N > 0 or pad_K > 0:
weight_scaled = weight_scaled.view(E, N + pad_N, K + pad_K)
weight_scaled = weight_scaled[..., :N, :K].contiguous()
else:
weight_scaled = weight_scaled.view(E, N, K)
return weight_scaled
def torch_naive_fused_moe(a, w1, w2, score, topk, renormalize):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
if renormalize:
topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = SiluAndMul(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(
0, 1
)
return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)
def torch_w8a8_per_column_fused_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, topk):
"""This function performs fused moe with per-column int8 quantization using native torch."""
B, D = a.shape
# Perform per-token quantization
a_q, a_s = per_token_quant_int8(a)
# Repeat tokens to match topk
a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
# Also repeat the scale
a_s = a_s.view(B, -1, 1).repeat(1, topk, 1).reshape(-1, 1) # [B*topk, 1]
out = torch.zeros(B * topk, w2.shape[1], dtype=torch.float32, device=a.device)
# Calculate routing
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
# Process each expert
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
# First MLP layer: note that a_s is now per-token
inter_out = native_w8a8_per_token_matmul(
a_q[mask],
w1[i],
a_s[mask],
w1_s[i],
bias=None,
output_dtype=torch.float32,
)
# Activation function
act_out = SiluAndMul(inter_out)
# Quantize activation output with per-token
act_out_q, act_out_s = per_token_quant_int8(act_out)
# Second MLP layer
out[mask] = native_w8a8_per_token_matmul(
act_out_q,
w2[i],
act_out_s,
w2_s[i],
bias=None,
output_dtype=torch.float32,
)
# Apply routing weights and sum
return (
(out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype))
.sum(dim=1)
.to(a.dtype)
)
def native_fp8_fused_moe(a, w1, w2, topk_weight, topk_ids, topk):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D).float()
out = torch.zeros(B * topk, w2.shape[1], dtype=torch.float32, device=a.device)
# Calculate routing
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
ic0 = torch.matmul(a[mask], w1[i].transpose(0, 1))
ic1 = SiluAndMul(ic0)
out[mask] = torch.matmul(ic1, w2[i].transpose(0, 1))
return (
(out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype))
.sum(dim=1)
.to(a.dtype)
)
def make_non_contiguous(x: torch.Tensor) -> torch.Tensor:
"""
Make a tensor non-contiguous by slicing it via last dimension.
"""
last_dim = x.shape[-1]
return x[..., : last_dim // 2] if x.is_contiguous() else x