Fix dtype error in CI (#8197)

This commit is contained in:
Ke Bao
2025-07-21 00:27:55 +08:00
committed by GitHub
parent 750838adc4
commit 465968b2e3

View File

@@ -524,7 +524,7 @@ def biased_grouped_topk_gpu(
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
aiter_biased_grouped_topk(
gating_output,
gating_output.to(dtype=torch.float32),
correction_bias,
topk_weights,
topk_ids,