sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct
This commit is contained in:
55
test/srt/cpu/test_activation.py
Normal file
55
test/srt/cpu/test_activation.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import itertools
|
||||
import unittest
|
||||
|
||||
import sgl_kernel
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from utils import GeluAndMul, SiluAndMul, precision
|
||||
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
torch.manual_seed(1234)
|
||||
|
||||
|
||||
class TestActivation(CustomTestCase):
|
||||
M = [128, 129, 257]
|
||||
N = [22016, 22018]
|
||||
dtype = [torch.float16, torch.bfloat16]
|
||||
|
||||
def _silu_and_mul_test(self, m, n, dtype):
|
||||
x = torch.randn([m, n], dtype=dtype)
|
||||
|
||||
out = torch.ops.sgl_kernel.silu_and_mul_cpu(x)
|
||||
ref_out = SiluAndMul(x)
|
||||
|
||||
atol = rtol = precision[ref_out.dtype]
|
||||
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
|
||||
|
||||
def _gelu_and_mul_test(self, m, n, dtype):
|
||||
x = torch.randn([m, n], dtype=dtype)
|
||||
|
||||
out = torch.ops.sgl_kernel.gelu_and_mul_cpu(x)
|
||||
ref_out = GeluAndMul(x, approximate="none")
|
||||
|
||||
atol = rtol = precision[ref_out.dtype]
|
||||
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
|
||||
|
||||
def _gelu_tanh_and_mul_test(self, m, n, dtype):
|
||||
x = torch.randn([m, n], dtype=dtype)
|
||||
|
||||
out = torch.ops.sgl_kernel.gelu_tanh_and_mul_cpu(x)
|
||||
ref_out = GeluAndMul(x, approximate="tanh")
|
||||
|
||||
atol = rtol = precision[ref_out.dtype]
|
||||
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
|
||||
|
||||
def test_activation(self):
|
||||
for params in itertools.product(self.M, self.N, self.dtype):
|
||||
with self.subTest(m=params[0], n=params[1], dtype=params[2]):
|
||||
self._silu_and_mul_test(*params)
|
||||
self._gelu_and_mul_test(*params)
|
||||
self._gelu_tanh_and_mul_test(*params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
28
test/srt/cpu/test_binding.py
Normal file
28
test/srt/cpu/test_binding.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import re
|
||||
import unittest
|
||||
|
||||
import sgl_kernel
|
||||
import torch
|
||||
|
||||
kernel = torch.ops.sgl_kernel
|
||||
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
|
||||
class TestGemm(CustomTestCase):
|
||||
def test_binding(self):
|
||||
start_id = 1
|
||||
n_cpu = 6
|
||||
|
||||
expected_cores = list(map(str, range(start_id, start_id + n_cpu)))
|
||||
cpu_ids = ",".join(expected_cores)
|
||||
output = kernel.init_cpu_threads_env(cpu_ids)
|
||||
|
||||
bindings = re.findall(r"OMP tid: \d+, core (\d+)", output)
|
||||
self.assertEqual(len(bindings), n_cpu)
|
||||
|
||||
self.assertEqual(bindings, expected_cores)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
170
test/srt/cpu/test_decode.py
Normal file
170
test/srt/cpu/test_decode.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import unittest
|
||||
|
||||
import sgl_kernel
|
||||
import torch
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
torch.manual_seed(1234)
|
||||
|
||||
|
||||
class TestDecodeAttention(CustomTestCase):
|
||||
def _run_sdpa_forward_decode(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
req_to_token: torch.Tensor,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
scaling=None,
|
||||
enable_gqa=False,
|
||||
causal=False,
|
||||
):
|
||||
# [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
|
||||
query = query.movedim(0, query.dim() - 2)
|
||||
|
||||
start_q, start_kv = 0, 0
|
||||
for seq_idx in range(seq_lens.shape[0]):
|
||||
seq_len_q = 1
|
||||
seq_len_kv = seq_lens[seq_idx]
|
||||
end_q = start_q + seq_len_q
|
||||
end_kv = start_kv + seq_len_kv
|
||||
|
||||
per_req_query = query[:, start_q:end_q, :]
|
||||
|
||||
# get key and value from cache. per_req_tokens contains the kv cache
|
||||
# index for each token in the sequence.
|
||||
req_pool_idx = req_pool_indices[seq_idx]
|
||||
per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
|
||||
per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
||||
per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
||||
|
||||
per_req_out = (
|
||||
scaled_dot_product_attention(
|
||||
per_req_query.unsqueeze(0),
|
||||
per_req_key.unsqueeze(0),
|
||||
per_req_value.unsqueeze(0),
|
||||
enable_gqa=enable_gqa,
|
||||
scale=scaling,
|
||||
is_causal=causal,
|
||||
)
|
||||
.squeeze(0)
|
||||
.movedim(query.dim() - 2, 0)
|
||||
)
|
||||
output[start_q:end_q, :, :] = per_req_out
|
||||
start_q, start_kv = end_q, end_kv
|
||||
|
||||
return output
|
||||
|
||||
def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V, device):
|
||||
dtype = torch.bfloat16
|
||||
# This represents the number of tokens already in the sequence
|
||||
seq_len = 1024
|
||||
total_tokens = B * seq_len
|
||||
sm_scale = 1.0 / (D**0.5)
|
||||
logit_cap = 0.0
|
||||
num_kv_splits = 8
|
||||
enable_gqa = H_Q != H_KV
|
||||
|
||||
# q represents the new token being generated, one per batch
|
||||
q = torch.randn(B, H_Q, D, dtype=dtype, device=device)
|
||||
|
||||
# k_buffer and v_buffer represent all previous tokens
|
||||
k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device=device)
|
||||
v_buffer = torch.randn(total_tokens, H_KV, D_V, dtype=dtype, device=device)
|
||||
|
||||
key = torch.randn(B, H_KV, D, dtype=dtype)
|
||||
value = torch.randn(B, H_KV, D_V, dtype=dtype)
|
||||
loc = torch.randint(0, 10, (B,)).to(torch.int64)
|
||||
|
||||
# set kv cache
|
||||
k_buffer[loc] = key
|
||||
v_buffer[loc] = value
|
||||
|
||||
# o will have the same shape as q
|
||||
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device=device)
|
||||
o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype, device=device)
|
||||
|
||||
req_to_token = (
|
||||
torch.arange(total_tokens, device=device)
|
||||
.reshape(B, seq_len)
|
||||
.to(torch.int32)
|
||||
)
|
||||
b_req_idx = torch.arange(B, device=device).to(torch.int64)
|
||||
b_seq_len = torch.full((B,), seq_len, device=device).to(torch.int64)
|
||||
|
||||
attn_logits = torch.empty(
|
||||
(B, H_Q, num_kv_splits, D_V + 1),
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# k_buffer, v_buffer, query, key and value supports non-contiguous tensors
|
||||
k_buffer = k_buffer.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
v_buffer = v_buffer.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
q = q.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
key = key.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
value = value.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
torch.ops.sgl_kernel.decode_attention_cpu(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
key,
|
||||
value,
|
||||
loc,
|
||||
attn_logits,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_seq_len,
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
)
|
||||
|
||||
self._run_sdpa_forward_decode(
|
||||
q,
|
||||
o_grouped,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_seq_len,
|
||||
scaling=sm_scale,
|
||||
enable_gqa=enable_gqa,
|
||||
)
|
||||
|
||||
cos_sim = torch.nn.functional.cosine_similarity(
|
||||
o.flatten(), o_grouped.flatten(), dim=0
|
||||
)
|
||||
self.assertGreater(cos_sim.item(), 0.99)
|
||||
torch.testing.assert_close(o, o_grouped, atol=3e-2, rtol=1e-6)
|
||||
|
||||
def _test_grouped_decode_attention(self, device="cuda"):
|
||||
configs = [
|
||||
(2, 16, 16, 64, 64),
|
||||
(2, 16, 1, 16, 16),
|
||||
(2, 32, 8, 33, 55),
|
||||
(2, 16, 1, 64, 64),
|
||||
(2, 64, 1, 13, 13),
|
||||
(2, 128, 1, 80, 80),
|
||||
(2, 128, 2, 512, 512),
|
||||
(1, 16, 1, 576, 512),
|
||||
(1, 16, 16, 576, 512),
|
||||
(1, 22, 1, 576, 512),
|
||||
(1, 40, 8, 128, 128),
|
||||
]
|
||||
|
||||
for B, H_Q, H_KV, D, D_V in configs:
|
||||
self._test_grouped_decode_attention_once(
|
||||
B, H_Q, H_KV, D, D_V, device=device
|
||||
)
|
||||
|
||||
def test_grouped_decode_attention(self):
|
||||
self._test_grouped_decode_attention("cpu")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
190
test/srt/cpu/test_extend.py
Normal file
190
test/srt/cpu/test_extend.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import unittest
|
||||
|
||||
import sgl_kernel
|
||||
import torch
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
torch.manual_seed(1234)
|
||||
|
||||
|
||||
class TestExtendAttention(CustomTestCase):
|
||||
|
||||
def _run_sdpa_forward_extend(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
req_to_token: torch.Tensor,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
extend_prefix_lens: torch.Tensor,
|
||||
extend_seq_lens: torch.Tensor,
|
||||
scaling=None,
|
||||
enable_gqa=False,
|
||||
causal=False,
|
||||
):
|
||||
|
||||
assert seq_lens.shape[0] == extend_prefix_lens.shape[0]
|
||||
assert seq_lens.shape[0] == extend_seq_lens.shape[0]
|
||||
|
||||
# [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
|
||||
query = query.movedim(0, query.dim() - 2)
|
||||
|
||||
start_q, start_kv = 0, 0
|
||||
for seq_idx in range(seq_lens.shape[0]):
|
||||
|
||||
extend_seq_len_q = extend_seq_lens[seq_idx]
|
||||
prefill_seq_len_q = extend_prefix_lens[seq_idx]
|
||||
|
||||
seq_len_kv = seq_lens[seq_idx]
|
||||
end_q = start_q + extend_seq_len_q
|
||||
end_kv = start_kv + seq_len_kv
|
||||
|
||||
per_req_query = query[:, start_q:end_q, :]
|
||||
per_req_query_redudant = torch.empty(
|
||||
(per_req_query.shape[0], seq_len_kv, per_req_query.shape[2]),
|
||||
dtype=per_req_query.dtype,
|
||||
device=per_req_query.device,
|
||||
)
|
||||
|
||||
per_req_query_redudant[:, prefill_seq_len_q:, :] = per_req_query
|
||||
|
||||
# get key and value from cache. per_req_tokens contains the kv cache
|
||||
# index for each token in the sequence.
|
||||
req_pool_idx = req_pool_indices[seq_idx]
|
||||
per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
|
||||
per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
||||
per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
||||
|
||||
per_req_out_redudant = (
|
||||
scaled_dot_product_attention(
|
||||
per_req_query_redudant.unsqueeze(0),
|
||||
per_req_key.unsqueeze(0),
|
||||
per_req_value.unsqueeze(0),
|
||||
enable_gqa=enable_gqa,
|
||||
scale=scaling,
|
||||
is_causal=causal,
|
||||
)
|
||||
.squeeze(0)
|
||||
.movedim(query.dim() - 2, 0)
|
||||
)
|
||||
output[start_q:end_q, :, :] = per_req_out_redudant[prefill_seq_len_q:, :, :]
|
||||
start_q, start_kv = end_q, end_kv
|
||||
return output
|
||||
|
||||
def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D, DV, mla=False):
|
||||
dtype = torch.bfloat16
|
||||
|
||||
b_seq_len_prefix = torch.randint(1, N_CTX // 2, (B,), dtype=torch.int32)
|
||||
if mla:
|
||||
b_seq_len_prefix.zero_()
|
||||
b_seq_len_extend = torch.randint(1, N_CTX // 2, (B,), dtype=torch.int32)
|
||||
b_seq_len = b_seq_len_prefix + b_seq_len_extend
|
||||
max_len_in_batch = torch.max(b_seq_len, 0)[0].item()
|
||||
|
||||
b_req_idx = torch.arange(B, dtype=torch.int32)
|
||||
req_to_tokens = torch.empty((B, max_len_in_batch), dtype=torch.int32)
|
||||
b_start_loc = torch.zeros((B,), dtype=torch.int32)
|
||||
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
|
||||
b_start_loc_extend = torch.zeros((B,), dtype=torch.int32)
|
||||
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
|
||||
|
||||
for i in range(B):
|
||||
req_to_tokens[i, : b_seq_len[i]] = torch.arange(
|
||||
b_start_loc[i], b_start_loc[i] + b_seq_len[i]
|
||||
)
|
||||
|
||||
total_token_num = torch.sum(b_seq_len).item()
|
||||
extend_token_num = torch.sum(b_seq_len_extend).item()
|
||||
|
||||
H_BUF = 1 if mla else H_KV
|
||||
k_buffer = torch.randn((total_token_num, H_BUF, D), dtype=dtype)
|
||||
v_buffer = torch.randn((total_token_num, H_BUF, DV), dtype=dtype)
|
||||
|
||||
k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype)
|
||||
v_extend = torch.empty((extend_token_num, H_KV, DV), dtype=dtype)
|
||||
q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype)
|
||||
|
||||
for i in range(B):
|
||||
extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
|
||||
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
|
||||
extend_start = b_start_loc_extend[i]
|
||||
extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
|
||||
k_extend[extend_start:extend_end] = k_buffer[
|
||||
extend_start_in_buffer:extend_end_in_buffer
|
||||
]
|
||||
v_extend[extend_start:extend_end] = v_buffer[
|
||||
extend_start_in_buffer:extend_end_in_buffer
|
||||
]
|
||||
q_extend[extend_start:extend_end] = torch.randn(
|
||||
(b_seq_len_extend[i], H_Q, D), dtype=dtype
|
||||
)
|
||||
|
||||
# q_extend, k_extend, v_extend, k_buffer and v_buffer supports non-contiguous tensors
|
||||
q_extend = q_extend.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
k_extend = k_extend.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
v_extend = v_extend.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
k_buffer = k_buffer.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
v_buffer = v_buffer.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
|
||||
b_seq_len_extend = b_seq_len - b_seq_len_prefix
|
||||
b_start_loc_extend = torch.zeros_like(b_seq_len)
|
||||
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
|
||||
max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
|
||||
|
||||
sm_scale = 1.0 / (D**0.5)
|
||||
logit_cap = 0.0
|
||||
|
||||
# handle index type
|
||||
b_req_idx = b_req_idx.to(torch.int64)
|
||||
b_seq_len = b_seq_len.to(torch.int64)
|
||||
|
||||
enable_gqa = H_Q != H_KV
|
||||
o_ref = torch.empty((extend_token_num, H_Q, DV), dtype=dtype)
|
||||
self._run_sdpa_forward_extend(
|
||||
q_extend,
|
||||
o_ref,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_tokens,
|
||||
b_req_idx,
|
||||
b_seq_len,
|
||||
b_seq_len_prefix,
|
||||
b_seq_len_extend,
|
||||
scaling=sm_scale,
|
||||
enable_gqa=enable_gqa,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
o_extend = torch.empty((extend_token_num, H_Q, DV), dtype=dtype)
|
||||
torch.ops.sgl_kernel.extend_attention_cpu(
|
||||
q_extend,
|
||||
k_extend,
|
||||
v_extend,
|
||||
o_extend,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_tokens,
|
||||
b_req_idx,
|
||||
b_seq_len,
|
||||
b_seq_len_extend,
|
||||
b_start_loc_extend,
|
||||
max_len_extend,
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(o_ref, o_extend, atol=1e-2, rtol=1e-2)
|
||||
|
||||
def test_extend_attention(self):
|
||||
for is_mla in [True, False]:
|
||||
self._test_extend_attention_once(1, 123, 1, 1, 128, 96, is_mla)
|
||||
self._test_extend_attention_once(1, 123, 16, 1, 128, 96, is_mla)
|
||||
self._test_extend_attention_once(4, 1230, 16, 4, 128, 96, is_mla)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
189
test/srt/cpu/test_gemm.py
Normal file
189
test/srt/cpu/test_gemm.py
Normal file
@@ -0,0 +1,189 @@
|
||||
import itertools
|
||||
import unittest
|
||||
|
||||
# TODO: use interface in cpu.py
|
||||
import sgl_kernel
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from utils import (
|
||||
convert_weight,
|
||||
native_w8a8_per_token_matmul,
|
||||
per_token_quant_int8,
|
||||
precision,
|
||||
)
|
||||
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
torch.manual_seed(1234)
|
||||
|
||||
|
||||
class Mod(nn.Module):
|
||||
def __init__(self, input_channel, output_channel, has_bias):
|
||||
super(Mod, self).__init__()
|
||||
self.linear = torch.nn.Linear(input_channel, output_channel, has_bias)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
|
||||
class TestGemm(CustomTestCase):
|
||||
M = [1, 101]
|
||||
N = [16, 32 * 13]
|
||||
K = [32 * 16]
|
||||
has_bias = [False, True]
|
||||
|
||||
M_int8 = [2, 128]
|
||||
N_int8 = [32 * 12]
|
||||
K_int8 = [32 * 17]
|
||||
|
||||
M_fp8 = [1, 11]
|
||||
N_fp8 = [128, 224]
|
||||
K_fp8 = [512, 576]
|
||||
|
||||
def _bf16_gemm(self, M, N, K, has_bias):
|
||||
|
||||
mat1 = torch.randn(M, K, dtype=torch.bfloat16)
|
||||
mat2 = torch.randn(N, K, dtype=torch.bfloat16)
|
||||
|
||||
ref = torch.matmul(mat1.float(), mat2.float().t())
|
||||
if has_bias:
|
||||
bias = torch.randn(N, dtype=torch.float32)
|
||||
ref.add_(bias.bfloat16())
|
||||
|
||||
ref = ref.bfloat16()
|
||||
|
||||
out = torch.ops.sgl_kernel.weight_packed_linear(
|
||||
mat1, mat2, bias if has_bias else None, False
|
||||
)
|
||||
|
||||
packed_mat2 = torch.ops.sgl_kernel.convert_weight_packed(mat2)
|
||||
out2 = torch.ops.sgl_kernel.weight_packed_linear(
|
||||
mat1, packed_mat2, bias if has_bias else None, True
|
||||
)
|
||||
|
||||
atol = rtol = precision[ref.dtype]
|
||||
torch.testing.assert_close(ref, out, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(ref, out2, atol=atol, rtol=rtol)
|
||||
|
||||
def test_bf16_gemm(self):
|
||||
for params in itertools.product(
|
||||
self.M,
|
||||
self.N,
|
||||
self.K,
|
||||
self.has_bias,
|
||||
):
|
||||
with self.subTest(
|
||||
M=params[0],
|
||||
N=params[1],
|
||||
K=params[2],
|
||||
has_bias=params[3],
|
||||
):
|
||||
self._bf16_gemm(*params)
|
||||
|
||||
def _int8_gemm(self, M, N, K, has_bias):
|
||||
dtype = torch.bfloat16
|
||||
A = torch.randn((M, K), dtype=dtype) / 10
|
||||
Aq, As = per_token_quant_int8(A)
|
||||
|
||||
factor_for_scale = 1e-2
|
||||
int8_max = 127
|
||||
int8_min = -128
|
||||
|
||||
B = (torch.rand((N, K), dtype=torch.float32) - 0.5) * 2
|
||||
Bq = (B * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)
|
||||
Bs = torch.rand(N) * factor_for_scale
|
||||
|
||||
bias = torch.randn(N) if has_bias else None
|
||||
ref_out = native_w8a8_per_token_matmul(Aq, Bq, As, Bs, bias, dtype)
|
||||
|
||||
atol = rtol = precision[ref_out.dtype]
|
||||
|
||||
Aq2, As2 = torch.ops.sgl_kernel.per_token_quant_int8_cpu(A)
|
||||
out = torch.ops.sgl_kernel.int8_scaled_mm_cpu(
|
||||
Aq2, Bq, As2, Bs, bias if has_bias else None, torch.bfloat16, False
|
||||
)
|
||||
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
|
||||
|
||||
# test the fused version
|
||||
fused_out = torch.ops.sgl_kernel.int8_scaled_mm_with_quant(
|
||||
A, Bq, Bs, bias if has_bias else None, torch.bfloat16, False
|
||||
)
|
||||
torch.testing.assert_close(ref_out, fused_out, atol=atol, rtol=rtol)
|
||||
|
||||
def test_int8_gemm(self):
|
||||
for params in itertools.product(
|
||||
self.M_int8,
|
||||
self.N_int8,
|
||||
self.K_int8,
|
||||
self.has_bias,
|
||||
):
|
||||
with self.subTest(
|
||||
M=params[0],
|
||||
N=params[1],
|
||||
K=params[2],
|
||||
has_bias=params[3],
|
||||
):
|
||||
self._int8_gemm(*params)
|
||||
|
||||
def _fp8_gemm(self, M, N, K, has_bias):
|
||||
prepack = True
|
||||
chunk = False
|
||||
scale_block_size_N = 64
|
||||
scale_block_size_K = 128
|
||||
assert scale_block_size_N <= N
|
||||
assert scale_block_size_K <= K
|
||||
A_dtype = torch.bfloat16
|
||||
|
||||
model = Mod(K, N, has_bias).eval()
|
||||
if chunk:
|
||||
data = torch.randn(M, K + 6, dtype=A_dtype).narrow(1, 0, K)
|
||||
else:
|
||||
data = torch.randn(M, K, dtype=A_dtype)
|
||||
|
||||
weight = model.linear.weight # (N, K)
|
||||
|
||||
if has_bias:
|
||||
bias = model.linear.bias
|
||||
|
||||
fp8_weight, scales, dq_weight = convert_weight(
|
||||
weight, [scale_block_size_N, scale_block_size_K], A_dtype
|
||||
)
|
||||
|
||||
if has_bias:
|
||||
ref = torch.matmul(data.to(A_dtype), dq_weight.T) + bias.to(A_dtype)
|
||||
else:
|
||||
ref = torch.matmul(data.to(A_dtype), dq_weight.T)
|
||||
|
||||
if prepack:
|
||||
fp8_weight = torch.ops.sgl_kernel.convert_weight_packed(fp8_weight)
|
||||
|
||||
opt = torch.ops.sgl_kernel.fp8_scaled_mm_cpu(
|
||||
data,
|
||||
fp8_weight,
|
||||
scales,
|
||||
[scale_block_size_N, scale_block_size_K],
|
||||
bias if has_bias else None,
|
||||
data.dtype,
|
||||
prepack,
|
||||
)
|
||||
atol = rtol = precision[ref.dtype]
|
||||
torch.testing.assert_close(ref, opt, atol=atol, rtol=rtol)
|
||||
|
||||
def test_fp8_gemm(self):
|
||||
for params in itertools.product(
|
||||
self.M_fp8,
|
||||
self.N_fp8,
|
||||
self.K_fp8,
|
||||
self.has_bias,
|
||||
):
|
||||
with self.subTest(
|
||||
M=params[0],
|
||||
N=params[1],
|
||||
K=params[2],
|
||||
has_bias=params[3],
|
||||
):
|
||||
self._fp8_gemm(*params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
157
test/srt/cpu/test_mla.py
Normal file
157
test/srt/cpu/test_mla.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import itertools
|
||||
import unittest
|
||||
|
||||
import sgl_kernel
|
||||
import torch
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
from utils import precision
|
||||
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
torch.manual_seed(1234)
|
||||
|
||||
|
||||
class TestMLA(CustomTestCase):
|
||||
def _run_sdpa_forward_decode(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
loc: torch.Tensor,
|
||||
req_to_token: torch.Tensor,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
scaling=None,
|
||||
enable_gqa=False,
|
||||
causal=False,
|
||||
):
|
||||
# set kv cache
|
||||
k_cache[loc] = key
|
||||
|
||||
# [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
|
||||
query = query.movedim(0, query.dim() - 2)
|
||||
|
||||
start_q, start_kv = 0, 0
|
||||
for seq_idx in range(seq_lens.shape[0]):
|
||||
seq_len_q = 1
|
||||
seq_len_kv = seq_lens[seq_idx]
|
||||
end_q = start_q + seq_len_q
|
||||
end_kv = start_kv + seq_len_kv
|
||||
|
||||
per_req_query = query[:, start_q:end_q, :]
|
||||
|
||||
# get key and value from cache. per_req_tokens contains the kv cache
|
||||
# index for each token in the sequence.
|
||||
req_pool_idx = req_pool_indices[seq_idx]
|
||||
per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
|
||||
per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
||||
per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
||||
|
||||
per_req_out = (
|
||||
scaled_dot_product_attention(
|
||||
per_req_query.unsqueeze(0),
|
||||
per_req_key.unsqueeze(0),
|
||||
per_req_value.unsqueeze(0),
|
||||
enable_gqa=enable_gqa,
|
||||
scale=scaling,
|
||||
is_causal=causal,
|
||||
)
|
||||
.squeeze(0)
|
||||
.movedim(query.dim() - 2, 0)
|
||||
)
|
||||
output[start_q:end_q, :, :] = per_req_out
|
||||
start_q, start_kv = end_q, end_kv
|
||||
|
||||
return output
|
||||
|
||||
def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V, seq_len):
|
||||
dtype = torch.bfloat16
|
||||
|
||||
total_tokens = B * seq_len
|
||||
sm_scale = 1.0 / (D**0.5)
|
||||
logit_cap = 0.0
|
||||
num_kv_splits = 8
|
||||
enable_gqa = H_Q != H_KV
|
||||
|
||||
# q represents the new token being generated, one per batch
|
||||
q = torch.randn(B, H_Q, D, dtype=dtype)
|
||||
|
||||
# k_buffer and v_buffer represent all previous tokens
|
||||
k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype)
|
||||
v_buffer = k_buffer.narrow(2, 0, D_V)
|
||||
|
||||
key = torch.randn(B, H_KV, D, dtype=dtype)
|
||||
value = key.narrow(2, 0, D_V)
|
||||
# make sure no duplicates in loc
|
||||
loc = torch.randperm(total_tokens)[:B].to(torch.int64)
|
||||
|
||||
k_buffer2 = k_buffer.clone()
|
||||
v_buffer2 = k_buffer2.narrow(2, 0, D_V)
|
||||
|
||||
# o will have the same shape as q
|
||||
o = torch.zeros(B, H_Q, D_V, dtype=dtype)
|
||||
o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype)
|
||||
|
||||
req_to_token = torch.arange(total_tokens).reshape(B, seq_len).to(torch.int32)
|
||||
b_req_idx = torch.arange(B).to(torch.int64)
|
||||
b_seq_len = torch.full((B,), seq_len).to(torch.int64)
|
||||
|
||||
attn_logits = torch.empty(
|
||||
(B, H_Q, num_kv_splits, D_V + 1),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
torch.ops.sgl_kernel.decode_attention_cpu(
|
||||
q,
|
||||
k_buffer2,
|
||||
v_buffer2,
|
||||
o,
|
||||
key,
|
||||
value,
|
||||
loc,
|
||||
attn_logits,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_seq_len,
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
)
|
||||
|
||||
self._run_sdpa_forward_decode(
|
||||
q,
|
||||
o_grouped,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
key,
|
||||
loc,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_seq_len,
|
||||
scaling=sm_scale,
|
||||
enable_gqa=enable_gqa,
|
||||
)
|
||||
|
||||
cos_sim = torch.nn.functional.cosine_similarity(
|
||||
o.flatten(), o_grouped.flatten(), dim=0
|
||||
)
|
||||
atol = rtol = precision[q.dtype]
|
||||
self.assertGreater(cos_sim.item(), 0.99)
|
||||
torch.testing.assert_close(o, o_grouped, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(k_buffer, k_buffer2, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(v_buffer, v_buffer2, atol=atol, rtol=rtol)
|
||||
|
||||
def test_grouped_decode_attention(self):
|
||||
configs = [
|
||||
(1, 22, 1, 576, 512, 8 * 111),
|
||||
(4, 22, 1, 576, 512, 8 * 128),
|
||||
(40, 22, 1, 576, 512, 8 * 133),
|
||||
]
|
||||
|
||||
for B, H_Q, H_KV, D, D_V, seqlen in configs:
|
||||
self._test_grouped_decode_attention_once(B, H_Q, H_KV, D, D_V, seqlen)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
265
test/srt/cpu/test_moe.py
Normal file
265
test/srt/cpu/test_moe.py
Normal file
@@ -0,0 +1,265 @@
|
||||
import itertools
|
||||
import math
|
||||
import unittest
|
||||
|
||||
# TODO: use interface in cpu.py
|
||||
import sgl_kernel
|
||||
import torch
|
||||
|
||||
kernel = torch.ops.sgl_kernel
|
||||
|
||||
torch.manual_seed(1234)
|
||||
|
||||
from utils import (
|
||||
BLOCK_K,
|
||||
BLOCK_N,
|
||||
factor_for_scale,
|
||||
fp8_max,
|
||||
fp8_min,
|
||||
native_fp8_fused_moe,
|
||||
precision,
|
||||
scaled_weight,
|
||||
torch_naive_fused_moe,
|
||||
torch_w8a8_per_column_fused_moe,
|
||||
)
|
||||
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
|
||||
def fused_moe(a, w1, w2, score, topk, renormalize, prepack):
|
||||
|
||||
G = 1
|
||||
topk_group = 1
|
||||
|
||||
B, D = a.shape
|
||||
topk_weights = torch.empty(B, topk, dtype=torch.float32)
|
||||
topk_ids = torch.empty(B, topk, dtype=torch.int32)
|
||||
topk_weights, topk_ids = kernel.grouped_topk_cpu(
|
||||
a, score, topk, renormalize, G, topk_group, 0, None, None
|
||||
)
|
||||
|
||||
packed_w1 = kernel.convert_weight_packed(w1) if prepack else w1
|
||||
packed_w2 = kernel.convert_weight_packed(w2) if prepack else w2
|
||||
|
||||
inplace = True
|
||||
return kernel.fused_experts_cpu(
|
||||
a,
|
||||
packed_w1,
|
||||
packed_w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace,
|
||||
False,
|
||||
False,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
prepack,
|
||||
)
|
||||
|
||||
|
||||
class TestFusedExperts(CustomTestCase):
|
||||
M = [2, 114]
|
||||
N = [32]
|
||||
K = [32]
|
||||
E = [4]
|
||||
topk = [2]
|
||||
renormalize = [False, True]
|
||||
|
||||
M_int8 = [1, 39]
|
||||
N_int8 = [128]
|
||||
K_int8 = [256]
|
||||
E_int8 = [8]
|
||||
topk_int8 = [3]
|
||||
|
||||
M_fp8 = [2, 121]
|
||||
N_fp8 = [352, 512]
|
||||
K_fp8 = [256, 320]
|
||||
E_fp8 = [8]
|
||||
topk_fp8 = [4]
|
||||
|
||||
def _bf16_moe(self, m, n, k, e, topk, renormalize):
|
||||
dtype = torch.bfloat16
|
||||
prepack = True
|
||||
|
||||
a = torch.randn((m, k), device="cpu", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cpu", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cpu", dtype=dtype) / 10
|
||||
score = torch.randn((m, e), device="cpu", dtype=dtype)
|
||||
|
||||
torch_output = torch_naive_fused_moe(a, w1, w2, score, topk, renormalize)
|
||||
fused_output = fused_moe(a, w1, w2, score, topk, renormalize, prepack)
|
||||
|
||||
atol = rtol = precision[torch_output.dtype]
|
||||
torch.testing.assert_close(torch_output, fused_output, atol=atol, rtol=rtol)
|
||||
|
||||
def test_bf16_moe(self):
|
||||
for params in itertools.product(
|
||||
self.M,
|
||||
self.N,
|
||||
self.K,
|
||||
self.E,
|
||||
self.topk,
|
||||
self.renormalize,
|
||||
):
|
||||
with self.subTest(
|
||||
m=params[0],
|
||||
n=params[1],
|
||||
k=params[2],
|
||||
e=params[3],
|
||||
topk=params[4],
|
||||
renormalize=params[5],
|
||||
):
|
||||
self._bf16_moe(*params)
|
||||
|
||||
def _int8_moe(self, M, N, K, E, topk):
|
||||
dtype = torch.bfloat16
|
||||
prepack = True
|
||||
|
||||
# Initialize int8 quantization parameters
|
||||
int8_factor_for_scale = 1e-2
|
||||
int8_max = 127
|
||||
int8_min = -128
|
||||
|
||||
# Input tensor
|
||||
# M * K
|
||||
a = torch.randn((M, K), dtype=dtype) / math.sqrt(K)
|
||||
|
||||
# Generate int8 weights
|
||||
w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2
|
||||
w1 = (w1_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)
|
||||
|
||||
w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2
|
||||
w2 = (w2_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)
|
||||
|
||||
# Generate scale for each column (per-column quantization)
|
||||
w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * int8_factor_for_scale
|
||||
w2_s = torch.rand(E, K, device=w2_fp32.device) * int8_factor_for_scale
|
||||
|
||||
# Calculate routing
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
|
||||
ref_out = torch_w8a8_per_column_fused_moe(
|
||||
a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, topk
|
||||
)
|
||||
|
||||
inplace = True
|
||||
packed_w1 = kernel.convert_weight_packed(w1) if prepack else w1
|
||||
packed_w2 = kernel.convert_weight_packed(w2) if prepack else w2
|
||||
out = kernel.fused_experts_cpu(
|
||||
a,
|
||||
packed_w1,
|
||||
packed_w2,
|
||||
topk_weight,
|
||||
topk_ids.to(torch.int32),
|
||||
inplace,
|
||||
True,
|
||||
False,
|
||||
w1_s,
|
||||
w2_s,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
prepack,
|
||||
)
|
||||
|
||||
atol = rtol = precision[ref_out.dtype]
|
||||
# Increase the tolerance for large input shapes
|
||||
if M > 35:
|
||||
atol = rtol = 0.02
|
||||
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
|
||||
|
||||
def test_int8_moe(self):
|
||||
for params in itertools.product(
|
||||
self.M_int8,
|
||||
self.N_int8,
|
||||
self.K_int8,
|
||||
self.E_int8,
|
||||
self.topk_int8,
|
||||
):
|
||||
with self.subTest(
|
||||
M=params[0],
|
||||
N=params[1],
|
||||
K=params[2],
|
||||
E=params[3],
|
||||
topk=params[4],
|
||||
):
|
||||
self._int8_moe(*params)
|
||||
|
||||
def _fp8_moe(self, M, N, K, E, topk):
|
||||
dtype = torch.bfloat16
|
||||
|
||||
a = torch.randn(M, K, dtype=dtype) / math.sqrt(K)
|
||||
|
||||
w1_fp32 = torch.randn(E, 2 * N, K)
|
||||
w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
w2_fp32 = torch.randn(E, K, N)
|
||||
w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
w1s = (
|
||||
torch.randn(E, math.ceil(2 * N / BLOCK_N), math.ceil(K / BLOCK_K))
|
||||
* factor_for_scale
|
||||
)
|
||||
w2s = (
|
||||
torch.randn(E, math.ceil(K / BLOCK_N), math.ceil(N / BLOCK_K))
|
||||
* factor_for_scale
|
||||
)
|
||||
|
||||
w1_scaled = scaled_weight(w1, w1s)
|
||||
w2_scaled = scaled_weight(w2, w2s)
|
||||
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
|
||||
w1 = kernel.convert_weight_packed(w1)
|
||||
w2 = kernel.convert_weight_packed(w2)
|
||||
|
||||
ref_out = native_fp8_fused_moe(
|
||||
a, w1_scaled, w2_scaled, topk_weight, topk_ids, topk
|
||||
)
|
||||
out = kernel.fused_experts_cpu(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids.to(torch.int32),
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
w1s,
|
||||
w2s,
|
||||
[BLOCK_N, BLOCK_K],
|
||||
None,
|
||||
None,
|
||||
True,
|
||||
)
|
||||
|
||||
atol = rtol = precision[dtype]
|
||||
torch.testing.assert_close(ref_out.bfloat16(), out, atol=atol, rtol=rtol)
|
||||
|
||||
def test_fp8_moe(self):
|
||||
for params in itertools.product(
|
||||
self.M_fp8,
|
||||
self.N_fp8,
|
||||
self.K_fp8,
|
||||
self.E_fp8,
|
||||
self.topk_fp8,
|
||||
):
|
||||
with self.subTest(
|
||||
M=params[0],
|
||||
N=params[1],
|
||||
K=params[2],
|
||||
E=params[3],
|
||||
topk=params[4],
|
||||
):
|
||||
self._fp8_moe(*params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
90
test/srt/cpu/test_norm.py
Normal file
90
test/srt/cpu/test_norm.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import itertools
|
||||
import unittest
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import sgl_kernel
|
||||
import torch
|
||||
from utils import make_non_contiguous, precision
|
||||
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
torch.manual_seed(1234)
|
||||
|
||||
|
||||
class TestNorm(CustomTestCase):
|
||||
M = [4096, 1024]
|
||||
N = [4096, 4096 + 13]
|
||||
dtype = [torch.float16, torch.bfloat16]
|
||||
|
||||
def _forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
variance_epsilon: float = 1e-6,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
orig_dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
if residual is not None:
|
||||
x = x + residual.to(torch.float32)
|
||||
residual = x.to(orig_dtype)
|
||||
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + variance_epsilon)
|
||||
x = x.to(orig_dtype) * weight
|
||||
if residual is None:
|
||||
return x
|
||||
else:
|
||||
return x, residual
|
||||
|
||||
def _norm_test(self, m, n, dtype):
|
||||
|
||||
x = torch.randn([m, n], dtype=dtype)
|
||||
x = make_non_contiguous(x)
|
||||
hidden_size = x.size(-1)
|
||||
weight = torch.randn(hidden_size, dtype=dtype)
|
||||
variance_epsilon = 1e-6
|
||||
|
||||
out = torch.ops.sgl_kernel.rmsnorm_cpu(x, weight, variance_epsilon)
|
||||
ref_out = self._forward_native(x, weight, variance_epsilon)
|
||||
|
||||
atol = rtol = precision[ref_out.dtype]
|
||||
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
|
||||
|
||||
ref_x = x.clone()
|
||||
residual = torch.randn([m, hidden_size], dtype=dtype)
|
||||
ref_residual = residual.clone()
|
||||
|
||||
torch.ops.sgl_kernel.fused_add_rmsnorm_cpu(
|
||||
x, residual, weight, variance_epsilon
|
||||
)
|
||||
|
||||
ref_x, ref_residual = self._forward_native(
|
||||
ref_x, weight, variance_epsilon, ref_residual
|
||||
)
|
||||
|
||||
torch.testing.assert_close(x, ref_x, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(residual, ref_residual, atol=atol, rtol=rtol)
|
||||
|
||||
def _l2norm_test(self, m, n, dtype):
|
||||
|
||||
x = torch.randn([m, n], dtype=dtype)
|
||||
hidden_size = x.size(-1)
|
||||
fake_ones_weight = torch.ones(hidden_size, dtype=dtype)
|
||||
variance_epsilon = 1e-6
|
||||
|
||||
out = torch.ops.sgl_kernel.l2norm_cpu(x, variance_epsilon)
|
||||
ref_out = self._forward_native(x, fake_ones_weight, variance_epsilon)
|
||||
|
||||
atol = rtol = precision[ref_out.dtype]
|
||||
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
|
||||
|
||||
def test_norm(self):
|
||||
for params in itertools.product(self.M, self.N, self.dtype):
|
||||
with self.subTest(m=params[0], n=params[1], dtype=params[2]):
|
||||
self._norm_test(*params)
|
||||
self._l2norm_test(*params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
432
test/srt/cpu/test_qkv_proj_with_rope.py
Normal file
432
test/srt/cpu/test_qkv_proj_with_rope.py
Normal file
@@ -0,0 +1,432 @@
|
||||
import unittest
|
||||
|
||||
import sgl_kernel
|
||||
import torch
|
||||
from utils import (
|
||||
convert_weight,
|
||||
native_w8a8_per_token_matmul,
|
||||
per_token_quant_int8,
|
||||
precision,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.rotary_embedding import _apply_rotary_emb
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
convert_weight_packed = torch.ops.sgl_kernel.convert_weight_packed
|
||||
qkv_proj_with_rope = torch.ops.sgl_kernel.qkv_proj_with_rope
|
||||
qkv_proj_with_rope_fused_weight = torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight
|
||||
torch.manual_seed(1234)
|
||||
# constants
|
||||
kv_lora_rank = 512
|
||||
qk_head_dim = 192
|
||||
qk_nope_head_dim = 128
|
||||
qk_rope_head_dim = 64
|
||||
rotary_dim = qk_rope_head_dim
|
||||
num_heads = 22
|
||||
q_lora_rank = 1536
|
||||
hidden_size = 7168
|
||||
B = 1
|
||||
eps = 1e-6
|
||||
|
||||
|
||||
def layernorm(x, weight, variance_epsilon=1e-6, residual=None):
|
||||
orig_dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + variance_epsilon)
|
||||
return (x * weight).to(orig_dtype)
|
||||
|
||||
|
||||
def rotary_emb(q_pe, k_pe, pos, cos_sin_cache):
|
||||
orig_dtype = q_pe.dtype
|
||||
q_pe = q_pe.float()
|
||||
k_pe = k_pe.float()
|
||||
cos_sin_cache = cos_sin_cache.float()
|
||||
|
||||
query_rot = q_pe[..., :rotary_dim]
|
||||
key_rot = k_pe[..., :rotary_dim]
|
||||
cos_sin = cos_sin_cache[pos]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
query_rot = _apply_rotary_emb(query_rot, cos, sin, False)
|
||||
key_rot = _apply_rotary_emb(key_rot, cos, sin, False)
|
||||
return query_rot.to(orig_dtype), key_rot.to(orig_dtype)
|
||||
|
||||
|
||||
def native_torch(
|
||||
q_input,
|
||||
hidden_states,
|
||||
q_a_proj_weight,
|
||||
norm_weight1,
|
||||
q_b_proj_weight,
|
||||
w_kc,
|
||||
kv_a_proj_weight,
|
||||
norm_weight2,
|
||||
pos,
|
||||
cos_sin_cache,
|
||||
):
|
||||
|
||||
q = torch.matmul(hidden_states, q_a_proj_weight.t())
|
||||
q = layernorm(q, norm_weight1)
|
||||
q = torch.matmul(q, q_b_proj_weight.t()).view(-1, num_heads, qk_head_dim)
|
||||
|
||||
q_nope, q_pe = q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
|
||||
q_nope_out = torch.bmm(q_nope.transpose(0, 1), w_kc)
|
||||
|
||||
q_input[..., :kv_lora_rank] = q_nope_out.transpose(0, 1)
|
||||
latent_cache = torch.matmul(hidden_states, kv_a_proj_weight.t())
|
||||
v_input = latent_cache[..., :kv_lora_rank]
|
||||
v_input = layernorm(v_input.contiguous(), norm_weight2).unsqueeze(1)
|
||||
k_input = latent_cache.unsqueeze(1)
|
||||
k_input[..., :kv_lora_rank] = v_input
|
||||
k_pe = k_input[..., kv_lora_rank:]
|
||||
|
||||
q_pe, k_pe = rotary_emb(q_pe, k_pe, pos, cos_sin_cache)
|
||||
q_input[..., kv_lora_rank:] = q_pe
|
||||
k_input[..., kv_lora_rank:] = k_pe
|
||||
|
||||
return q_input, k_input, v_input
|
||||
|
||||
|
||||
def native_torch_int8(
|
||||
q_input,
|
||||
hidden_states,
|
||||
w1_q,
|
||||
w1_s,
|
||||
norm_weight1,
|
||||
w2_q,
|
||||
w2_s,
|
||||
w_kc,
|
||||
w3_q,
|
||||
w3_s,
|
||||
norm_weight2,
|
||||
pos,
|
||||
cos_sin_cache,
|
||||
):
|
||||
|
||||
a_q, a_s = per_token_quant_int8(hidden_states)
|
||||
q = native_w8a8_per_token_matmul(a_q, w1_q, a_s, w1_s, None, torch.bfloat16)
|
||||
q = layernorm(q, norm_weight1)
|
||||
|
||||
a_q, a_s = per_token_quant_int8(q)
|
||||
q = native_w8a8_per_token_matmul(a_q, w2_q, a_s, w2_s, None, torch.bfloat16).view(
|
||||
-1, num_heads, qk_head_dim
|
||||
)
|
||||
|
||||
q_nope, q_pe = q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
|
||||
q_nope_out = torch.bmm(q_nope.transpose(0, 1), w_kc)
|
||||
|
||||
q_input[..., :kv_lora_rank] = q_nope_out.transpose(0, 1)
|
||||
a_q, a_s = per_token_quant_int8(hidden_states)
|
||||
latent_cache = native_w8a8_per_token_matmul(
|
||||
a_q, w3_q, a_s, w3_s, None, torch.bfloat16
|
||||
)
|
||||
v_input = latent_cache[..., :kv_lora_rank]
|
||||
v_input = layernorm(v_input.contiguous(), norm_weight2).unsqueeze(1)
|
||||
k_input = latent_cache.unsqueeze(1)
|
||||
k_input[..., :kv_lora_rank] = v_input
|
||||
k_pe = k_input[..., kv_lora_rank:]
|
||||
|
||||
q_pe, k_pe = rotary_emb(q_pe, k_pe, pos, cos_sin_cache)
|
||||
q_input[..., kv_lora_rank:] = q_pe
|
||||
k_input[..., kv_lora_rank:] = k_pe
|
||||
|
||||
return q_input, k_input, v_input
|
||||
|
||||
|
||||
class TestQKVProjWithROPE(CustomTestCase):
|
||||
def test_bf16_qkv_proj_with_rope(self):
|
||||
dtype = torch.bfloat16
|
||||
hidden_states = torch.randn(B, hidden_size, dtype=dtype) / hidden_size
|
||||
q_input = torch.empty(
|
||||
B, num_heads, kv_lora_rank + qk_rope_head_dim, dtype=dtype
|
||||
)
|
||||
q_a_proj_weight = torch.randn(q_lora_rank, hidden_size, dtype=dtype) * 0.1
|
||||
norm_weight1 = torch.randn(q_lora_rank, dtype=dtype)
|
||||
q_b_proj_weight = (
|
||||
torch.randn(num_heads * qk_head_dim, q_lora_rank, dtype=dtype) * 0.1
|
||||
)
|
||||
w_kc = torch.randn(num_heads, kv_lora_rank, qk_nope_head_dim, dtype=dtype) * 0.1
|
||||
kv_a_proj_weight = (
|
||||
torch.randn(kv_lora_rank + qk_rope_head_dim, hidden_size, dtype=dtype) * 0.1
|
||||
)
|
||||
fused_weight = torch.cat([q_a_proj_weight, kv_a_proj_weight], dim=0)
|
||||
norm_weight2 = torch.randn(kv_lora_rank, dtype=dtype)
|
||||
pos = torch.randint(10, 100, (B,))
|
||||
cos_sin_cache = torch.randn(100, rotary_dim, dtype=dtype)
|
||||
q_ref, k_ref, v_ref = native_torch(
|
||||
q_input,
|
||||
hidden_states,
|
||||
q_a_proj_weight,
|
||||
norm_weight1,
|
||||
q_b_proj_weight,
|
||||
w_kc.transpose(1, 2),
|
||||
kv_a_proj_weight,
|
||||
norm_weight2,
|
||||
pos,
|
||||
cos_sin_cache,
|
||||
)
|
||||
qa_packed = convert_weight_packed(q_a_proj_weight)
|
||||
qb_packed = convert_weight_packed(q_b_proj_weight)
|
||||
kva_packed = convert_weight_packed(kv_a_proj_weight)
|
||||
wkc_packed = convert_weight_packed(w_kc)
|
||||
fused_weight_packed = convert_weight_packed(fused_weight)
|
||||
|
||||
q_out, k_out, v_out = qkv_proj_with_rope(
|
||||
hidden_states,
|
||||
qa_packed,
|
||||
qb_packed,
|
||||
kva_packed,
|
||||
wkc_packed,
|
||||
norm_weight1,
|
||||
norm_weight2,
|
||||
pos,
|
||||
cos_sin_cache,
|
||||
eps,
|
||||
False,
|
||||
False,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
True,
|
||||
None,
|
||||
)
|
||||
fused_q_out, fused_k_out, fused_v_out = qkv_proj_with_rope_fused_weight(
|
||||
hidden_states,
|
||||
fused_weight_packed,
|
||||
qb_packed,
|
||||
wkc_packed,
|
||||
norm_weight1,
|
||||
norm_weight2,
|
||||
pos,
|
||||
cos_sin_cache,
|
||||
eps,
|
||||
False,
|
||||
False,
|
||||
None,
|
||||
None,
|
||||
True,
|
||||
None,
|
||||
q_lora_rank,
|
||||
kv_lora_rank,
|
||||
qk_rope_head_dim,
|
||||
)
|
||||
atol = rtol = precision[q_ref.dtype]
|
||||
torch.testing.assert_close(q_ref, q_out, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(k_ref, k_out, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(v_ref, v_out, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(fused_q_out, q_out)
|
||||
torch.testing.assert_close(fused_k_out, k_out)
|
||||
torch.testing.assert_close(fused_v_out, v_out)
|
||||
|
||||
def test_int8_qkv_proj_with_rope(self):
|
||||
dtype = torch.bfloat16
|
||||
hidden_states = torch.randn(B, hidden_size, dtype=dtype) / hidden_size
|
||||
q_input = torch.empty(
|
||||
B, num_heads, kv_lora_rank + qk_rope_head_dim, dtype=dtype
|
||||
)
|
||||
q_a_proj_weight = torch.randn(q_lora_rank, hidden_size, dtype=dtype) * 0.1
|
||||
norm_weight1 = torch.randn(q_lora_rank, dtype=dtype)
|
||||
q_b_proj_weight = (
|
||||
torch.randn(num_heads * qk_head_dim, q_lora_rank, dtype=dtype) * 0.1
|
||||
)
|
||||
w_kc = torch.randn(num_heads, kv_lora_rank, qk_nope_head_dim, dtype=dtype) * 0.1
|
||||
kv_a_proj_weight = (
|
||||
torch.randn(kv_lora_rank + qk_rope_head_dim, hidden_size, dtype=dtype) * 0.1
|
||||
)
|
||||
norm_weight2 = torch.randn(kv_lora_rank, dtype=dtype)
|
||||
pos = torch.randint(10, 100, (B,))
|
||||
cos_sin_cache = torch.randn(100, rotary_dim, dtype=dtype)
|
||||
|
||||
w1_q, w1_s = per_token_quant_int8(q_a_proj_weight)
|
||||
w2_q, w2_s = per_token_quant_int8(q_b_proj_weight)
|
||||
w3_q, w3_s = per_token_quant_int8(kv_a_proj_weight)
|
||||
q_ref, k_ref, v_ref = native_torch_int8(
|
||||
q_input,
|
||||
hidden_states,
|
||||
w1_q,
|
||||
w1_s,
|
||||
norm_weight1,
|
||||
w2_q,
|
||||
w2_s,
|
||||
w_kc.transpose(1, 2),
|
||||
w3_q,
|
||||
w3_s,
|
||||
norm_weight2,
|
||||
pos,
|
||||
cos_sin_cache,
|
||||
)
|
||||
w1_q_packed = convert_weight_packed(w1_q)
|
||||
w2_q_packed = convert_weight_packed(w2_q)
|
||||
w3_q_packed = convert_weight_packed(w3_q)
|
||||
wkc_packed = convert_weight_packed(w_kc)
|
||||
q_out, k_out, v_out = qkv_proj_with_rope(
|
||||
hidden_states,
|
||||
w1_q_packed,
|
||||
w2_q_packed,
|
||||
w3_q_packed,
|
||||
wkc_packed,
|
||||
norm_weight1,
|
||||
norm_weight2,
|
||||
pos,
|
||||
cos_sin_cache,
|
||||
eps,
|
||||
True,
|
||||
False,
|
||||
w1_s,
|
||||
w2_s,
|
||||
w3_s,
|
||||
True,
|
||||
None,
|
||||
)
|
||||
fused_weight = torch.cat([w1_q, w3_q], dim=0)
|
||||
fused_weight_s = torch.cat([w1_s, w3_s], dim=0)
|
||||
w_fused_q_packed = convert_weight_packed(fused_weight)
|
||||
fused_q_out, fused_k_out, fused_v_out = qkv_proj_with_rope_fused_weight(
|
||||
hidden_states,
|
||||
w_fused_q_packed,
|
||||
w2_q_packed,
|
||||
wkc_packed,
|
||||
norm_weight1,
|
||||
norm_weight2,
|
||||
pos,
|
||||
cos_sin_cache,
|
||||
eps,
|
||||
True,
|
||||
False,
|
||||
fused_weight_s,
|
||||
w2_s,
|
||||
True,
|
||||
None,
|
||||
q_lora_rank,
|
||||
kv_lora_rank,
|
||||
qk_rope_head_dim,
|
||||
)
|
||||
atol = rtol = precision[q_ref.dtype]
|
||||
torch.testing.assert_close(q_ref, q_out, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(k_ref, k_out, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(v_ref, v_out, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(fused_q_out, q_out)
|
||||
torch.testing.assert_close(fused_k_out, k_out)
|
||||
torch.testing.assert_close(fused_v_out, v_out)
|
||||
|
||||
def test_fp8_qkv_proj_with_rope(self):
|
||||
dtype = torch.bfloat16
|
||||
hidden_states = torch.randn(B, hidden_size, dtype=dtype) / hidden_size
|
||||
q_input = torch.empty(
|
||||
B, num_heads, kv_lora_rank + qk_rope_head_dim, dtype=dtype
|
||||
)
|
||||
q_a_proj_weight = torch.randn(q_lora_rank, hidden_size, dtype=dtype) * 0.1
|
||||
norm_weight1 = torch.randn(q_lora_rank, dtype=dtype)
|
||||
q_b_proj_weight = (
|
||||
torch.randn(num_heads * qk_head_dim, q_lora_rank, dtype=dtype) * 0.1
|
||||
)
|
||||
w_kc = torch.randn(num_heads, kv_lora_rank, qk_nope_head_dim, dtype=dtype) * 0.1
|
||||
kv_a_proj_weight = (
|
||||
torch.randn(kv_lora_rank + qk_rope_head_dim, hidden_size, dtype=dtype) * 0.1
|
||||
)
|
||||
norm_weight2 = torch.randn(kv_lora_rank, dtype=dtype)
|
||||
pos = torch.randint(10, 100, (B,))
|
||||
cos_sin_cache = torch.randn(100, rotary_dim, dtype=dtype)
|
||||
|
||||
scale_block_size_N = 128
|
||||
scale_block_size_K = 128
|
||||
fp8_q_a_proj_weight, q_a_proj_weight_scale_inv, q_a_proj_weight_dq = (
|
||||
convert_weight(
|
||||
q_a_proj_weight,
|
||||
[scale_block_size_N, scale_block_size_K],
|
||||
torch.bfloat16,
|
||||
)
|
||||
)
|
||||
fp8_q_b_proj_weight, q_b_proj_weight_scale_inv, q_b_proj_weight_dq = (
|
||||
convert_weight(
|
||||
q_b_proj_weight,
|
||||
[scale_block_size_N, scale_block_size_K],
|
||||
torch.bfloat16,
|
||||
)
|
||||
)
|
||||
(
|
||||
fp8_kv_a_proj_with_mqa_weight,
|
||||
kv_a_proj_with_mqa_weight_scale_inv,
|
||||
kv_a_proj_with_mqa_weight_dq,
|
||||
) = convert_weight(
|
||||
kv_a_proj_weight, [scale_block_size_N, scale_block_size_K], torch.bfloat16
|
||||
)
|
||||
q_ref, k_ref, v_ref = native_torch(
|
||||
q_input,
|
||||
hidden_states,
|
||||
q_a_proj_weight_dq,
|
||||
norm_weight1,
|
||||
q_b_proj_weight_dq,
|
||||
w_kc.transpose(1, 2),
|
||||
kv_a_proj_with_mqa_weight_dq,
|
||||
norm_weight2,
|
||||
pos,
|
||||
cos_sin_cache,
|
||||
)
|
||||
fp8_q_a_proj_weight_packed = convert_weight_packed(fp8_q_a_proj_weight)
|
||||
fp8_q_b_proj_weight_packed = convert_weight_packed(fp8_q_b_proj_weight)
|
||||
fp8_kv_a_proj_with_mqa_weight_packed = convert_weight_packed(
|
||||
fp8_kv_a_proj_with_mqa_weight
|
||||
)
|
||||
w_kc = convert_weight_packed(w_kc)
|
||||
q_out, k_out, v_out = qkv_proj_with_rope(
|
||||
hidden_states,
|
||||
fp8_q_a_proj_weight_packed,
|
||||
fp8_q_b_proj_weight_packed,
|
||||
fp8_kv_a_proj_with_mqa_weight_packed,
|
||||
w_kc,
|
||||
norm_weight1,
|
||||
norm_weight2,
|
||||
pos,
|
||||
cos_sin_cache,
|
||||
eps,
|
||||
False,
|
||||
True,
|
||||
q_a_proj_weight_scale_inv.float(),
|
||||
q_b_proj_weight_scale_inv.float(),
|
||||
kv_a_proj_with_mqa_weight_scale_inv.float(),
|
||||
True,
|
||||
[scale_block_size_N, scale_block_size_K],
|
||||
)
|
||||
|
||||
fused_weight = torch.cat(
|
||||
[fp8_q_a_proj_weight, fp8_kv_a_proj_with_mqa_weight], dim=0
|
||||
)
|
||||
fused_weight_s = torch.cat(
|
||||
[q_a_proj_weight_scale_inv, kv_a_proj_with_mqa_weight_scale_inv], dim=0
|
||||
)
|
||||
fused_weight_packed = convert_weight_packed(fused_weight)
|
||||
fused_q_out, fused_k_out, fused_v_out = qkv_proj_with_rope_fused_weight(
|
||||
hidden_states,
|
||||
fused_weight_packed,
|
||||
fp8_q_b_proj_weight_packed,
|
||||
w_kc,
|
||||
norm_weight1,
|
||||
norm_weight2,
|
||||
pos,
|
||||
cos_sin_cache,
|
||||
eps,
|
||||
False,
|
||||
True,
|
||||
fused_weight_s.float(),
|
||||
q_b_proj_weight_scale_inv.float(),
|
||||
True,
|
||||
[scale_block_size_N, scale_block_size_K],
|
||||
q_lora_rank,
|
||||
kv_lora_rank,
|
||||
qk_rope_head_dim,
|
||||
)
|
||||
atol = rtol = precision[q_ref.dtype]
|
||||
# Due to the change in multiplication order, the error is amplified.
|
||||
# In the model, with fewer layers, this doesn't cause issues, but in
|
||||
# tests with more layers, we need to enlarge the tolerance to pass the tests.
|
||||
torch.testing.assert_close(q_ref, q_out, atol=1e-1, rtol=1e-1)
|
||||
torch.testing.assert_close(k_ref, k_out, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(v_ref, v_out, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(fused_q_out, q_out)
|
||||
torch.testing.assert_close(fused_k_out, k_out)
|
||||
torch.testing.assert_close(fused_v_out, v_out)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
178
test/srt/cpu/test_rope.py
Normal file
178
test/srt/cpu/test_rope.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import unittest
|
||||
|
||||
import sgl_kernel
|
||||
import torch
|
||||
from utils import precision
|
||||
|
||||
from sglang.srt.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding,
|
||||
RotaryEmbedding,
|
||||
)
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
torch.manual_seed(1234)
|
||||
|
||||
|
||||
class TestROPE(CustomTestCase):
|
||||
def test_deepseek_v2_rope(self):
|
||||
num_head = 16
|
||||
seq_len = 1024
|
||||
q_head_dim = 192
|
||||
qk_nope_head_dim = 128
|
||||
qk_rope_head_dim = 64
|
||||
max_pos = 256
|
||||
k_dim = 576
|
||||
rotary_dim = 64
|
||||
is_neox_style = False
|
||||
|
||||
# Create cos_sin_cache
|
||||
freqs = torch.rand(max_pos, qk_rope_head_dim // 2)
|
||||
cos = freqs.cos() * 0.7
|
||||
sin = freqs.sin() * 0.7
|
||||
cos_sin_cache = torch.cat((cos, sin), dim=-1).to(torch.bfloat16)
|
||||
positions = torch.randint(0, max_pos, (seq_len,))
|
||||
|
||||
rope = DeepseekScalingRotaryEmbedding(
|
||||
qk_rope_head_dim,
|
||||
rotary_dim,
|
||||
max_pos,
|
||||
16, # not used since cos_sin_cache is provided
|
||||
is_neox_style,
|
||||
1.0,
|
||||
torch.bfloat16,
|
||||
device="cpu",
|
||||
)
|
||||
rope.register_buffer("cos_sin_cache", cos_sin_cache)
|
||||
|
||||
for dtype in [torch.bfloat16]:
|
||||
enable_autocast = True
|
||||
|
||||
with torch.no_grad(), torch.amp.autocast("cpu", enabled=enable_autocast):
|
||||
q = torch.randn(seq_len, num_head, q_head_dim, dtype=dtype)
|
||||
q_clone = q.clone()
|
||||
k = torch.randn(seq_len, 1, k_dim, dtype=dtype)
|
||||
k_clone = k.clone()
|
||||
_, q_pe = q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
|
||||
_, q_pe_clone = q_clone.split(
|
||||
[qk_nope_head_dim, qk_rope_head_dim], dim=-1
|
||||
)
|
||||
k_pe = k[:, :, k_dim - qk_rope_head_dim :]
|
||||
k_pe_clone = k_clone[:, :, k_dim - qk_rope_head_dim :]
|
||||
|
||||
# ref kernel
|
||||
q_pe, k_pe = rope.forward_native(
|
||||
query=q_pe,
|
||||
key=k_pe,
|
||||
positions=positions,
|
||||
)
|
||||
|
||||
# fused rope kernel
|
||||
q_pe_clone, k_pe_clone = torch.ops.sgl_kernel.rotary_embedding_cpu(
|
||||
positions,
|
||||
q_pe_clone,
|
||||
k_pe_clone,
|
||||
rope.head_size,
|
||||
cos_sin_cache,
|
||||
False,
|
||||
)
|
||||
|
||||
atol = rtol = precision[q_pe.dtype]
|
||||
torch.testing.assert_close(q_pe, q_pe_clone, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(k_pe, k_pe_clone, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(k_pe, k_pe_clone)
|
||||
|
||||
def test_origin_rope(self):
|
||||
def single_test(
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int,
|
||||
):
|
||||
torch.manual_seed(100)
|
||||
rope_ref = RotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
).to(device)
|
||||
pos_ids = torch.arange(seq_len, device=device).repeat(batch_size)
|
||||
query = torch.randn(
|
||||
batch_size * seq_len,
|
||||
num_q_heads * head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
key = torch.randn(
|
||||
batch_size * seq_len,
|
||||
num_kv_heads * head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
query_ref, key_ref = query.clone(), key.clone()
|
||||
query_cpu, key_cpu = query.clone(), key.clone()
|
||||
|
||||
query_ref_out, key_ref_out = rope_ref.forward_native(
|
||||
pos_ids, query_ref, key_ref
|
||||
)
|
||||
query_cpu_out, key_cpu_out = torch.ops.sgl_kernel.rotary_embedding_cpu(
|
||||
pos_ids,
|
||||
query_cpu,
|
||||
key_cpu,
|
||||
rope_ref.head_size,
|
||||
rope_ref.cos_sin_cache.to(query.dtype),
|
||||
rope_ref.is_neox_style,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
query_ref_out, query_cpu_out, atol=1e-2, rtol=1e-2
|
||||
)
|
||||
torch.testing.assert_close(key_ref_out, key_cpu_out, atol=1e-2, rtol=1e-2)
|
||||
|
||||
test_config = [
|
||||
(64, 64, 32, 8000, True, torch.bfloat16, "cpu", 32, 32, 1, 1),
|
||||
(256, 128, 4096, 10000, True, torch.bfloat16, "cpu", 2, 512, 32, 8),
|
||||
(512, 128, 311, 10000, True, torch.bfloat16, "cpu", 3, 39, 4, 2),
|
||||
(128, 128, 2048, 10000, False, torch.bfloat16, "cpu", 2, 512, 32, 8),
|
||||
(128, 128, 2048, 10000, False, torch.bfloat16, "cpu", 2, 512, 16, 4),
|
||||
(512, 128, 311, 10000, False, torch.bfloat16, "cpu", 3, 39, 4, 2),
|
||||
]
|
||||
|
||||
for (
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
device,
|
||||
batch_size,
|
||||
seq_len,
|
||||
num_q_heads,
|
||||
num_kv_heads,
|
||||
) in test_config:
|
||||
single_test(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
device,
|
||||
batch_size,
|
||||
seq_len,
|
||||
num_q_heads,
|
||||
num_kv_heads,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
223
test/srt/cpu/test_shared_expert.py
Normal file
223
test/srt/cpu/test_shared_expert.py
Normal file
@@ -0,0 +1,223 @@
|
||||
import itertools
|
||||
import math
|
||||
import unittest
|
||||
|
||||
# TODO: use interface in cpu.py
|
||||
import sgl_kernel
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from utils import (
|
||||
BLOCK_K,
|
||||
BLOCK_N,
|
||||
SiluAndMul,
|
||||
factor_for_scale,
|
||||
fp8_max,
|
||||
fp8_min,
|
||||
per_token_quant_int8,
|
||||
precision,
|
||||
scaled_weight,
|
||||
torch_naive_moe,
|
||||
torch_w8a8_per_column_moe,
|
||||
)
|
||||
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
torch.manual_seed(1234)
|
||||
|
||||
|
||||
class TestSharedExpert(CustomTestCase):
|
||||
M = [2, 121]
|
||||
N = [32, 32 * 4]
|
||||
K = [32, 32 * 2]
|
||||
routed_scaling_factor = [16]
|
||||
|
||||
M_fp8 = [2, 12]
|
||||
N_fp8 = [512]
|
||||
K_fp8 = [256]
|
||||
|
||||
def _bf16_shared_expert(self, m, n, k, routed_scaling_factor):
|
||||
dtype = torch.bfloat16
|
||||
prepack = True
|
||||
|
||||
hidden_states = torch.randn(m, k, dtype=dtype) / k
|
||||
w1 = torch.randn(2 * n, k, dtype=dtype)
|
||||
w2 = torch.randn(k, n, dtype=dtype)
|
||||
fused_output = torch.randn(m, k, dtype=dtype) / k
|
||||
|
||||
# fused moe mutates content in hs
|
||||
hidden_states2 = hidden_states.clone()
|
||||
|
||||
# bfloat16
|
||||
ref = torch_naive_moe(
|
||||
hidden_states.float(),
|
||||
w1.float(),
|
||||
w2.float(),
|
||||
fused_output.float(),
|
||||
routed_scaling_factor,
|
||||
).to(dtype=dtype)
|
||||
res = torch.ops.sgl_kernel.shared_expert_cpu(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
fused_output,
|
||||
routed_scaling_factor,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
False,
|
||||
)
|
||||
|
||||
atol = rtol = precision[ref.dtype]
|
||||
torch.testing.assert_close(ref, res, atol=atol, rtol=rtol)
|
||||
|
||||
def test_bf16_shared_expert(self):
|
||||
for params in itertools.product(
|
||||
self.M,
|
||||
self.N,
|
||||
self.K,
|
||||
self.routed_scaling_factor,
|
||||
):
|
||||
with self.subTest(
|
||||
m=params[0],
|
||||
n=params[1],
|
||||
k=params[2],
|
||||
routed_scaling_factor=params[3],
|
||||
):
|
||||
self._bf16_shared_expert(*params)
|
||||
|
||||
def _int8_shared_expert(self, m, n, k, routed_scaling_factor):
|
||||
dtype = torch.bfloat16
|
||||
prepack = True
|
||||
|
||||
hidden_states = torch.randn(m, k, dtype=dtype) / k
|
||||
w1 = torch.randn(2 * n, k, dtype=dtype)
|
||||
w2 = torch.randn(k, n, dtype=dtype)
|
||||
fused_output = torch.randn(m, k, dtype=dtype) / k
|
||||
|
||||
# fused moe mutates content in hs
|
||||
hidden_states2 = hidden_states.clone()
|
||||
|
||||
w1_q, w1_s = per_token_quant_int8(w1)
|
||||
w2_q, w2_s = per_token_quant_int8(w2)
|
||||
ref2 = torch_w8a8_per_column_moe(
|
||||
hidden_states2.float(),
|
||||
w1_q,
|
||||
w2_q,
|
||||
w1_s,
|
||||
w2_s,
|
||||
fused_output.float(),
|
||||
routed_scaling_factor,
|
||||
).to(dtype=dtype)
|
||||
res2 = torch.ops.sgl_kernel.shared_expert_cpu(
|
||||
hidden_states2,
|
||||
w1_q,
|
||||
w2_q,
|
||||
fused_output,
|
||||
routed_scaling_factor,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
w1_s,
|
||||
w2_s,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
False,
|
||||
)
|
||||
|
||||
atol = rtol = precision[ref2.dtype]
|
||||
torch.testing.assert_close(ref2, res2, atol=atol, rtol=rtol)
|
||||
|
||||
def test_int8_shared_expert(self):
|
||||
for params in itertools.product(
|
||||
self.M,
|
||||
self.N,
|
||||
self.K,
|
||||
self.routed_scaling_factor,
|
||||
):
|
||||
with self.subTest(
|
||||
m=params[0],
|
||||
n=params[1],
|
||||
k=params[2],
|
||||
routed_scaling_factor=params[3],
|
||||
):
|
||||
self._int8_shared_expert(*params)
|
||||
|
||||
def _fp8_shared_expert(self, M, N, K, routed_scaling_factor):
|
||||
dtype = torch.bfloat16
|
||||
prepack = True
|
||||
|
||||
a = torch.randn(M, K, dtype=dtype) / math.sqrt(K)
|
||||
|
||||
w1_fp32 = torch.randn(1, 2 * N, K)
|
||||
w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
w2_fp32 = torch.randn(1, K, N)
|
||||
w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
w1s = torch.randn(1, 2 * N // BLOCK_N, K // BLOCK_K) * factor_for_scale
|
||||
w2s = torch.randn(1, K // BLOCK_N, N // BLOCK_K) * factor_for_scale
|
||||
|
||||
w1_scaled = scaled_weight(w1, w1s).view(2 * N, K)
|
||||
w2_scaled = scaled_weight(w2, w2s).view(K, N)
|
||||
|
||||
# change back to 2D
|
||||
w1, w2 = w1.squeeze(0), w2.squeeze(0)
|
||||
w1s, w2s = w1s.squeeze(0), w2s.squeeze(0)
|
||||
w1_scaled, w2_scaled = w1_scaled.squeeze(0), w2_scaled.squeeze(0)
|
||||
|
||||
fused_out = torch.randn(M, K, dtype=dtype) / math.sqrt(K)
|
||||
a2 = a.clone()
|
||||
|
||||
# ref
|
||||
ic0 = torch.matmul(a.float(), w1_scaled.transpose(0, 1))
|
||||
ic1 = SiluAndMul(ic0)
|
||||
shared_out = torch.matmul(ic1, w2_scaled.transpose(0, 1))
|
||||
ref_out = shared_out + fused_out.float() * routed_scaling_factor
|
||||
ref_out = ref_out.to(dtype=dtype)
|
||||
|
||||
w1 = torch.ops.sgl_kernel.convert_weight_packed(w1) # [2N, K]
|
||||
w2 = torch.ops.sgl_kernel.convert_weight_packed(w2) # [K, N]
|
||||
out = torch.ops.sgl_kernel.shared_expert_cpu(
|
||||
a2,
|
||||
w1,
|
||||
w2,
|
||||
fused_out,
|
||||
routed_scaling_factor,
|
||||
True,
|
||||
False,
|
||||
True,
|
||||
w1s,
|
||||
w2s,
|
||||
[BLOCK_N, BLOCK_K],
|
||||
None,
|
||||
None,
|
||||
True,
|
||||
)
|
||||
|
||||
atol = rtol = precision[ref_out.dtype]
|
||||
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
|
||||
|
||||
def test_fp8_shared_expert(self):
|
||||
for params in itertools.product(
|
||||
self.M_fp8,
|
||||
self.N_fp8,
|
||||
self.K_fp8,
|
||||
self.routed_scaling_factor,
|
||||
):
|
||||
with self.subTest(
|
||||
M=params[0],
|
||||
N=params[1],
|
||||
K=params[2],
|
||||
routed_scaling_factor=params[3],
|
||||
):
|
||||
self._fp8_shared_expert(*params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
199
test/srt/cpu/test_topk.py
Normal file
199
test/srt/cpu/test_topk.py
Normal file
@@ -0,0 +1,199 @@
|
||||
import itertools
|
||||
import unittest
|
||||
|
||||
import sgl_kernel
|
||||
import torch
|
||||
from utils import precision
|
||||
|
||||
from sglang.srt.layers.moe.topk import (
|
||||
biased_grouped_topk_impl as native_biased_grouped_topk,
|
||||
)
|
||||
from sglang.srt.layers.moe.topk import fused_topk_torch_native as native_fused_topk
|
||||
from sglang.srt.layers.moe.topk import grouped_topk_gpu as native_grouped_topk
|
||||
from sglang.srt.models.llama4 import Llama4MoE
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
torch.manual_seed(1234)
|
||||
|
||||
|
||||
# This is used by the Deepseek-V2 model
|
||||
class TestGroupedTopK(CustomTestCase):
|
||||
def _run_single_test(self, M, E, G, topk, topk_group, renormalize, dtype):
|
||||
torch.manual_seed(1234)
|
||||
|
||||
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
|
||||
hidden_states = torch.randn(M, 100, dtype=dtype)
|
||||
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
|
||||
|
||||
ref_topk_weights, ref_topk_ids = native_grouped_topk(
|
||||
hidden_states.float(),
|
||||
gating_output.float(),
|
||||
topk,
|
||||
renormalize,
|
||||
G,
|
||||
topk_group,
|
||||
)
|
||||
|
||||
# fused version
|
||||
topk_weights, topk_ids = torch.ops.sgl_kernel.grouped_topk_cpu(
|
||||
hidden_states,
|
||||
gating_output,
|
||||
topk,
|
||||
renormalize,
|
||||
G,
|
||||
topk_group,
|
||||
0,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
res = torch.zeros(M, E, dtype=torch.float)
|
||||
ref = torch.zeros(M, E, dtype=torch.float)
|
||||
res.scatter_(1, topk_ids.long(), topk_weights)
|
||||
ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights)
|
||||
torch.testing.assert_close(res, ref)
|
||||
|
||||
def test_grouped_topk(self):
|
||||
for renormalize in [True, False]:
|
||||
self._run_single_test(123, 8, 2, 2, 1, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 16, 4, 3, 2, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 32, 4, 3, 2, renormalize, torch.bfloat16)
|
||||
self._run_single_test(1123, 32, 4, 3, 2, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 64, 1, 6, 1, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 256, 8, 4, 8, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 160, 8, 6, 2, renormalize, torch.bfloat16)
|
||||
|
||||
|
||||
# DeepSeek V2/V3/R1 uses biased_grouped_top
|
||||
class TestBiasedGroupedTopK(CustomTestCase):
|
||||
def _run_single_test(
|
||||
self, M, E, G, topk, topk_group, renormalize, dtype, bias_dtype
|
||||
):
|
||||
torch.manual_seed(1234)
|
||||
|
||||
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
|
||||
hidden_states = torch.randn(M, 100, dtype=dtype)
|
||||
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
|
||||
correction_bias = torch.randn(E, dtype=bias_dtype)
|
||||
|
||||
ref_topk_weights, ref_topk_ids = native_biased_grouped_topk(
|
||||
hidden_states.float(),
|
||||
gating_output.float(),
|
||||
correction_bias.float(),
|
||||
topk,
|
||||
renormalize,
|
||||
G,
|
||||
topk_group,
|
||||
)
|
||||
|
||||
# fused version
|
||||
topk_weights, topk_ids = torch.ops.sgl_kernel.biased_grouped_topk_cpu(
|
||||
hidden_states,
|
||||
gating_output,
|
||||
correction_bias,
|
||||
topk,
|
||||
renormalize,
|
||||
G,
|
||||
topk_group,
|
||||
0,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
res = torch.zeros(M, E, dtype=torch.float)
|
||||
ref = torch.zeros(M, E, dtype=torch.float)
|
||||
res.scatter_(1, topk_ids.long(), topk_weights)
|
||||
ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights)
|
||||
torch.testing.assert_close(res, ref)
|
||||
|
||||
def test_biased_grouped_topk(self):
|
||||
for renormalize in [True, False]:
|
||||
for bias_dtype in [torch.float32, torch.bfloat16]:
|
||||
self._run_single_test(
|
||||
122, 256, 8, 8, 2, renormalize, torch.bfloat16, bias_dtype
|
||||
)
|
||||
|
||||
|
||||
class TestTopK(CustomTestCase):
|
||||
def _run_single_test(self, M, E, topk, renormalize, dtype):
|
||||
torch.manual_seed(1998)
|
||||
|
||||
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
|
||||
hidden_states = torch.randn(M, 100, dtype=dtype)
|
||||
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
|
||||
|
||||
ref_topk_weights, ref_topk_ids = native_fused_topk(
|
||||
hidden_states.float(),
|
||||
gating_output.float(),
|
||||
topk,
|
||||
renormalize,
|
||||
)
|
||||
|
||||
# fused version
|
||||
topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu(
|
||||
hidden_states, gating_output, topk, renormalize
|
||||
)
|
||||
|
||||
res = torch.zeros(M, E, dtype=torch.float)
|
||||
ref = torch.zeros(M, E, dtype=torch.float)
|
||||
res.scatter_(1, topk_ids.long(), topk_weights)
|
||||
ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights)
|
||||
torch.testing.assert_close(res, ref)
|
||||
|
||||
def test_topk(self):
|
||||
for renormalize in [True, False]:
|
||||
self._run_single_test(123, 8, 2, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 16, 3, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 32, 3, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 32, 3, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 64, 6, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 256, 4, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 160, 6, renormalize, torch.bfloat16)
|
||||
|
||||
|
||||
class TestCustomTopK(CustomTestCase):
|
||||
def _run_single_test(
|
||||
self, M, E, topk, renormalize, dtype, native_custom_f, fused_custom_f
|
||||
):
|
||||
torch.manual_seed(16)
|
||||
|
||||
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
|
||||
hidden_states = torch.randn(M, 100, dtype=dtype)
|
||||
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
|
||||
|
||||
ref_topk_weights, ref_topk_ids = native_custom_f(
|
||||
hidden_states.float(),
|
||||
gating_output.float(),
|
||||
topk,
|
||||
renormalize,
|
||||
)
|
||||
|
||||
# fused version
|
||||
topk_weights, topk_ids = fused_custom_f(
|
||||
hidden_states, gating_output, topk, renormalize
|
||||
)
|
||||
|
||||
res = torch.zeros(M, E, dtype=torch.float)
|
||||
ref = torch.zeros(M, E, dtype=torch.float)
|
||||
res.scatter_(1, topk_ids.long(), topk_weights)
|
||||
ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights)
|
||||
torch.testing.assert_close(res, ref)
|
||||
|
||||
def test_custom_topk(self):
|
||||
test_custom_functions = [
|
||||
(Llama4MoE.custom_routing_function, torch.ops.sgl_kernel.topk_sigmoid_cpu)
|
||||
]
|
||||
for native_custom_f, fused_custom_f in test_custom_functions:
|
||||
self._run_single_test(
|
||||
123, 8, 1, False, torch.bfloat16, native_custom_f, fused_custom_f
|
||||
)
|
||||
self._run_single_test(
|
||||
123, 16, 1, False, torch.bfloat16, native_custom_f, fused_custom_f
|
||||
)
|
||||
self._run_single_test(
|
||||
123, 32, 1, False, torch.bfloat16, native_custom_f, fused_custom_f
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
274
test/srt/cpu/utils.py
Normal file
274
test/srt/cpu/utils.py
Normal file
@@ -0,0 +1,274 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
precision = {
|
||||
torch.bfloat16: 1e-2,
|
||||
torch.float16: 1e-3,
|
||||
torch.float32: 1e-5,
|
||||
}
|
||||
|
||||
|
||||
BLOCK_N, BLOCK_K = 64, 128
|
||||
factor_for_scale = 1e-3
|
||||
fp8_max, fp8_min = 400, -400
|
||||
|
||||
|
||||
def SiluAndMul(x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
return F.silu(x[..., :d]) * x[..., d:]
|
||||
|
||||
|
||||
def GeluAndMul(x: torch.Tensor, approximate="tanh") -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
return F.gelu(x[..., :d], approximate=approximate) * x[..., d:]
|
||||
|
||||
|
||||
def per_token_quant_int8(x):
|
||||
x = x.float()
|
||||
absmax = x.abs().max(dim=-1).values
|
||||
absmax = absmax.clamp_min(1e-10).unsqueeze(-1)
|
||||
scale_x = absmax / 127
|
||||
x_q = x.mul(127 / absmax)
|
||||
x_q = torch.round(x_q).to(torch.int8)
|
||||
|
||||
return x_q, scale_x
|
||||
|
||||
|
||||
def convert_weight(weight, scale_block_size, A_dtype):
|
||||
N, K = weight.size()
|
||||
fp8_max = 448.0
|
||||
scale_block_size_N, scale_block_size_K = scale_block_size # (128, 128)
|
||||
|
||||
pad_N = (scale_block_size_N - (N % scale_block_size_N)) % scale_block_size_N
|
||||
pad_K = (scale_block_size_K - (K % scale_block_size_K)) % scale_block_size_K
|
||||
|
||||
if pad_N > 0 or pad_K > 0:
|
||||
weight = torch.nn.functional.pad(weight, (0, pad_K, 0, pad_N))
|
||||
|
||||
weight_blocks = weight.view(
|
||||
math.ceil(N / scale_block_size_N),
|
||||
scale_block_size_N,
|
||||
math.ceil(K / scale_block_size_K),
|
||||
scale_block_size_K,
|
||||
) # (8, 128, 8, 128)
|
||||
weight_blocks = weight_blocks.permute(0, 2, 1, 3).contiguous() # (8, 8, 128, 128)
|
||||
|
||||
# Step 2: compute per-block max abs values → scale
|
||||
abs_max = weight_blocks.abs().amax(dim=(-2, -1), keepdim=True) # (8, 8, 1, 1)
|
||||
scales = abs_max / fp8_max
|
||||
scales = torch.where(
|
||||
scales == 0, torch.ones_like(scales), scales
|
||||
) # avoid division by zero
|
||||
|
||||
q_fp8 = (weight_blocks / scales).to(torch.float8_e4m3fn)
|
||||
q_fp8_reshape = q_fp8.permute(0, 2, 1, 3).contiguous()
|
||||
|
||||
if pad_N > 0 or pad_K > 0:
|
||||
q_fp8_reshape = q_fp8_reshape.view(N + pad_N, K + pad_K)
|
||||
q_fp8_reshape = q_fp8_reshape[:N, :K].contiguous()
|
||||
else:
|
||||
q_fp8_reshape = q_fp8_reshape.view(N, K)
|
||||
|
||||
dq_weight = q_fp8.float() * scales
|
||||
dq_weight = dq_weight.permute(0, 2, 1, 3).contiguous() # (8, 128, 8, 128)
|
||||
|
||||
if pad_N > 0 or pad_K > 0:
|
||||
w_dq = dq_weight.view(N + pad_N, K + pad_K).to(A_dtype)
|
||||
w_dq = w_dq[:N, :K].contiguous()
|
||||
else:
|
||||
w_dq = dq_weight.view(N, K).to(A_dtype)
|
||||
|
||||
scales = scales.view(
|
||||
math.ceil(N / scale_block_size_N), math.ceil(K / scale_block_size_K)
|
||||
)
|
||||
|
||||
return q_fp8_reshape, scales, w_dq
|
||||
|
||||
|
||||
def native_w8a8_per_token_matmul(A, B, As, Bs, bias, output_dtype=torch.bfloat16):
|
||||
"""Matrix multiplication function that supports per-token input quantization and per-column weight quantization"""
|
||||
A = A.to(torch.float32)
|
||||
B = B.to(torch.float32)
|
||||
|
||||
assert A.shape[-1] == B.shape[-1], "Dimension mismatch"
|
||||
assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor"
|
||||
|
||||
# Reshape input
|
||||
M = A.numel() // A.shape[-1]
|
||||
B = B.t() # Transpose weight matrix
|
||||
N, K = B.shape
|
||||
origin_C_shape = A.shape[:-1] + (K,)
|
||||
A = A.reshape(M, N)
|
||||
|
||||
# As is per-token [M, 1], Bs is per-column [1, K]
|
||||
C = torch.matmul(A, B) # [M, K]
|
||||
C = As * C * Bs.view(1, -1) # Broadcast per-column scale
|
||||
|
||||
if bias is not None:
|
||||
C.add_(bias.view(1, -1))
|
||||
|
||||
return C.reshape(origin_C_shape).to(output_dtype)
|
||||
|
||||
|
||||
def torch_naive_moe(a, w1, w2, b, routed_scaling_factor):
|
||||
|
||||
ic1 = torch.matmul(a, w1.transpose(0, 1))
|
||||
ic2 = SiluAndMul(ic1)
|
||||
ic3 = torch.matmul(ic2, w2.transpose(0, 1))
|
||||
|
||||
return ic3 + b * routed_scaling_factor
|
||||
|
||||
|
||||
def torch_w8a8_per_column_moe(a, w1_q, w2_q, w1_s, w2_s, b, routed_scaling_factor):
|
||||
|
||||
# Perform per-token quantization
|
||||
a_q, a_s = per_token_quant_int8(a)
|
||||
|
||||
ic1 = native_w8a8_per_token_matmul(
|
||||
a_q, w1_q, a_s, w1_s, bias=None, output_dtype=torch.float32
|
||||
)
|
||||
ic2 = SiluAndMul(ic1)
|
||||
|
||||
a1_q, a1_s = per_token_quant_int8(ic2)
|
||||
ic3 = native_w8a8_per_token_matmul(
|
||||
a1_q, w2_q, a1_s, w2_s, bias=None, output_dtype=torch.float32
|
||||
)
|
||||
|
||||
return ic3 + b * routed_scaling_factor
|
||||
|
||||
|
||||
def scaled_weight(weight, scales):
|
||||
E, N, K = weight.shape
|
||||
pad_N = (BLOCK_N - (N % BLOCK_N)) % BLOCK_N
|
||||
pad_K = (BLOCK_K - (K % BLOCK_K)) % BLOCK_K
|
||||
|
||||
if pad_N > 0 or pad_K > 0:
|
||||
weight = torch.nn.functional.pad(weight, (0, pad_K, 0, pad_N))
|
||||
|
||||
weight_block = (
|
||||
weight.view(E, math.ceil(N / BLOCK_N), BLOCK_N, math.ceil(K / BLOCK_K), BLOCK_K)
|
||||
.permute(0, 1, 3, 2, 4)
|
||||
.float()
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
weight_scaled = (
|
||||
(
|
||||
weight_block
|
||||
* scales.view(E, math.ceil(N / BLOCK_N), math.ceil(K / BLOCK_K), 1, 1)
|
||||
)
|
||||
.permute(0, 1, 3, 2, 4)
|
||||
.contiguous()
|
||||
)
|
||||
if pad_N > 0 or pad_K > 0:
|
||||
weight_scaled = weight_scaled.view(E, N + pad_N, K + pad_K)
|
||||
weight_scaled = weight_scaled[..., :N, :K].contiguous()
|
||||
else:
|
||||
weight_scaled = weight_scaled.view(E, N, K)
|
||||
return weight_scaled
|
||||
|
||||
|
||||
def torch_naive_fused_moe(a, w1, w2, score, topk, renormalize):
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
|
||||
if renormalize:
|
||||
topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True)
|
||||
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
out[mask] = SiluAndMul(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(
|
||||
0, 1
|
||||
)
|
||||
return (
|
||||
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
def torch_w8a8_per_column_fused_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, topk):
|
||||
"""This function performs fused moe with per-column int8 quantization using native torch."""
|
||||
|
||||
B, D = a.shape
|
||||
# Perform per-token quantization
|
||||
a_q, a_s = per_token_quant_int8(a)
|
||||
# Repeat tokens to match topk
|
||||
a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
# Also repeat the scale
|
||||
a_s = a_s.view(B, -1, 1).repeat(1, topk, 1).reshape(-1, 1) # [B*topk, 1]
|
||||
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=torch.float32, device=a.device)
|
||||
|
||||
# Calculate routing
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
# Process each expert
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
# First MLP layer: note that a_s is now per-token
|
||||
inter_out = native_w8a8_per_token_matmul(
|
||||
a_q[mask],
|
||||
w1[i],
|
||||
a_s[mask],
|
||||
w1_s[i],
|
||||
bias=None,
|
||||
output_dtype=torch.float32,
|
||||
)
|
||||
# Activation function
|
||||
act_out = SiluAndMul(inter_out)
|
||||
# Quantize activation output with per-token
|
||||
act_out_q, act_out_s = per_token_quant_int8(act_out)
|
||||
# Second MLP layer
|
||||
out[mask] = native_w8a8_per_token_matmul(
|
||||
act_out_q,
|
||||
w2[i],
|
||||
act_out_s,
|
||||
w2_s[i],
|
||||
bias=None,
|
||||
output_dtype=torch.float32,
|
||||
)
|
||||
# Apply routing weights and sum
|
||||
return (
|
||||
(out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype))
|
||||
.sum(dim=1)
|
||||
.to(a.dtype)
|
||||
)
|
||||
|
||||
|
||||
def native_fp8_fused_moe(a, w1, w2, topk_weight, topk_ids, topk):
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D).float()
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=torch.float32, device=a.device)
|
||||
|
||||
# Calculate routing
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
ic0 = torch.matmul(a[mask], w1[i].transpose(0, 1))
|
||||
ic1 = SiluAndMul(ic0)
|
||||
out[mask] = torch.matmul(ic1, w2[i].transpose(0, 1))
|
||||
|
||||
return (
|
||||
(out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype))
|
||||
.sum(dim=1)
|
||||
.to(a.dtype)
|
||||
)
|
||||
|
||||
|
||||
def make_non_contiguous(x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Make a tensor non-contiguous by slicing it via last dimension.
|
||||
"""
|
||||
last_dim = x.shape[-1]
|
||||
return x[..., : last_dim // 2] if x.is_contiguous() else x
|
||||
Reference in New Issue
Block a user