[CPU][sgl-kernel] biased_grouped_topk: fix correction_bias dtype to float32 (#8212)

Co-authored-by: jianan-gu <jianan.gu@intel.com>
Co-authored-by: YanbingJiang <yanbing.jiang@intel.com>
This commit is contained in:
Chunyuan WU
2025-08-05 09:28:31 +08:00
committed by GitHub
parent d4bf5a8524
commit 08f8f49016
4 changed files with 94 additions and 28 deletions

View File

@@ -66,13 +66,15 @@ class TestGroupedTopK(CustomTestCase):
# DeepSeek V2/V3/R1 uses biased_grouped_top
class TestBiasedGroupedTopK(CustomTestCase):
def _run_single_test(self, M, E, G, topk, topk_group, renormalize, dtype):
def _run_single_test(
self, M, E, G, topk, topk_group, renormalize, dtype, bias_dtype
):
torch.manual_seed(1234)
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
hidden_states = torch.randn(M, 100, dtype=dtype)
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
correction_bias = torch.randn(E, dtype=dtype)
correction_bias = torch.randn(E, dtype=bias_dtype)
ref_topk_weights, ref_topk_ids = native_biased_grouped_topk(
hidden_states.float(),
@@ -106,7 +108,10 @@ class TestBiasedGroupedTopK(CustomTestCase):
def test_biased_grouped_topk(self):
for renormalize in [True, False]:
self._run_single_test(122, 256, 8, 8, 2, renormalize, torch.bfloat16)
for bias_dtype in [torch.float32, torch.bfloat16]:
self._run_single_test(
122, 256, 8, 8, 2, renormalize, torch.bfloat16, bias_dtype
)
class TestTopK(CustomTestCase):