Fix dtype error in CI (#8197)
This commit is contained in:
@@ -524,7 +524,7 @@ def biased_grouped_topk_gpu(
|
|||||||
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
|
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
|
||||||
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
|
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
|
||||||
aiter_biased_grouped_topk(
|
aiter_biased_grouped_topk(
|
||||||
gating_output,
|
gating_output.to(dtype=torch.float32),
|
||||||
correction_bias,
|
correction_bias,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
|
|||||||
Reference in New Issue
Block a user