Add bf16 output option for dsv3_router_gemm kernel (#7999)

This commit is contained in:
Baizhou Zhang
2025-07-19 18:49:37 -07:00
committed by GitHub
parent 4540a4666a
commit 282eb59ff3
7 changed files with 465 additions and 104 deletions

View File

@@ -262,12 +262,13 @@ def qserve_w4a8_per_group_gemm(
def dsv3_router_gemm(
hidden_states: torch.Tensor,
router_weights: torch.Tensor,
out_dtype: torch.dtype = torch.bfloat16,
) -> torch.Tensor:
output = torch.empty(
hidden_states.shape[0],
router_weights.shape[0],
device=hidden_states.device,
dtype=torch.float32,
dtype=out_dtype,
)
torch.ops.sgl_kernel.dsv3_router_gemm(
output,