99 lines
3.5 KiB
Python
99 lines
3.5 KiB
Python
import itertools
|
|
import unittest
|
|
|
|
import sgl_kernel
|
|
import torch
|
|
from utils import precision
|
|
|
|
from sglang.srt.layers.moe.topk import (
|
|
biased_grouped_topk_impl as native_biased_grouped_topk,
|
|
)
|
|
from sglang.srt.layers.moe.topk import grouped_topk as native_grouped_topk
|
|
from sglang.test.test_utils import CustomTestCase
|
|
|
|
|
|
# This is used by the Deepseek-V2 model
|
|
class TestGroupedTopK(CustomTestCase):
|
|
def _run_single_test(self, M, E, G, topk, topk_group, renormalize, dtype):
|
|
torch.manual_seed(1234)
|
|
|
|
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
|
|
hidden_states = torch.randn(M, 100, dtype=dtype)
|
|
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
|
|
|
|
ref_topk_weights, ref_topk_ids = native_grouped_topk(
|
|
hidden_states.float(),
|
|
gating_output.float(),
|
|
topk,
|
|
renormalize,
|
|
G,
|
|
topk_group,
|
|
)
|
|
|
|
# fused version
|
|
topk_weights, topk_ids = torch.ops.sgl_kernel.grouped_topk_cpu(
|
|
hidden_states, gating_output, topk, renormalize, G, topk_group
|
|
)
|
|
|
|
res = torch.zeros(M, E, dtype=torch.float)
|
|
ref = torch.zeros(M, E, dtype=torch.float)
|
|
res.scatter_(1, topk_ids.long(), topk_weights)
|
|
ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights)
|
|
torch.testing.assert_close(res, ref)
|
|
|
|
def test_grouped_topk(self):
|
|
for renormalize in [True, False]:
|
|
self._run_single_test(123, 8, 2, 2, 1, renormalize, torch.bfloat16)
|
|
self._run_single_test(123, 16, 4, 3, 2, renormalize, torch.bfloat16)
|
|
self._run_single_test(123, 32, 4, 3, 2, renormalize, torch.bfloat16)
|
|
self._run_single_test(1123, 32, 4, 3, 2, renormalize, torch.bfloat16)
|
|
self._run_single_test(123, 64, 1, 6, 1, renormalize, torch.bfloat16)
|
|
self._run_single_test(123, 256, 8, 4, 8, renormalize, torch.bfloat16)
|
|
self._run_single_test(123, 160, 8, 6, 2, renormalize, torch.bfloat16)
|
|
|
|
|
|
# DeepSeek V2/V3/R1 uses biased_grouped_top
|
|
class TestBiasedGroupedTopK(CustomTestCase):
|
|
def _run_single_test(self, M, E, G, topk, topk_group, renormalize, dtype):
|
|
torch.manual_seed(1234)
|
|
|
|
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
|
|
hidden_states = torch.randn(M, 100, dtype=dtype)
|
|
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
|
|
correction_bias = torch.randn(E, dtype=dtype)
|
|
|
|
ref_topk_weights, ref_topk_ids = native_biased_grouped_topk(
|
|
hidden_states.float(),
|
|
gating_output.float(),
|
|
correction_bias.float(),
|
|
topk,
|
|
renormalize,
|
|
G,
|
|
topk_group,
|
|
)
|
|
|
|
# fused version
|
|
topk_weights, topk_ids = torch.ops.sgl_kernel.biased_grouped_topk_cpu(
|
|
hidden_states,
|
|
gating_output,
|
|
correction_bias,
|
|
topk,
|
|
renormalize,
|
|
G,
|
|
topk_group,
|
|
)
|
|
|
|
res = torch.zeros(M, E, dtype=torch.float)
|
|
ref = torch.zeros(M, E, dtype=torch.float)
|
|
res.scatter_(1, topk_ids.long(), topk_weights)
|
|
ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights)
|
|
torch.testing.assert_close(res, ref)
|
|
|
|
def test_biased_grouped_topk(self):
|
|
for renormalize in [True, False]:
|
|
self._run_single_test(122, 256, 8, 8, 2, renormalize, torch.bfloat16)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|