[CPU][sgl-kernel] biased_grouped_topk: fix correction_bias dtype to float32 (#8212)
Co-authored-by: jianan-gu <jianan.gu@intel.com> Co-authored-by: YanbingJiang <yanbing.jiang@intel.com>
This commit is contained in:
@@ -66,13 +66,15 @@ class TestGroupedTopK(CustomTestCase):
|
||||
|
||||
# DeepSeek V2/V3/R1 uses biased_grouped_top
|
||||
class TestBiasedGroupedTopK(CustomTestCase):
|
||||
def _run_single_test(self, M, E, G, topk, topk_group, renormalize, dtype):
|
||||
def _run_single_test(
|
||||
self, M, E, G, topk, topk_group, renormalize, dtype, bias_dtype
|
||||
):
|
||||
torch.manual_seed(1234)
|
||||
|
||||
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
|
||||
hidden_states = torch.randn(M, 100, dtype=dtype)
|
||||
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
|
||||
correction_bias = torch.randn(E, dtype=dtype)
|
||||
correction_bias = torch.randn(E, dtype=bias_dtype)
|
||||
|
||||
ref_topk_weights, ref_topk_ids = native_biased_grouped_topk(
|
||||
hidden_states.float(),
|
||||
@@ -106,7 +108,10 @@ class TestBiasedGroupedTopK(CustomTestCase):
|
||||
|
||||
def test_biased_grouped_topk(self):
|
||||
for renormalize in [True, False]:
|
||||
self._run_single_test(122, 256, 8, 8, 2, renormalize, torch.bfloat16)
|
||||
for bias_dtype in [torch.float32, torch.bfloat16]:
|
||||
self._run_single_test(
|
||||
122, 256, 8, 8, 2, renormalize, torch.bfloat16, bias_dtype
|
||||
)
|
||||
|
||||
|
||||
class TestTopK(CustomTestCase):
|
||||
|
||||
Reference in New Issue
Block a user