Add bf16 output option for dsv3_router_gemm kernel (#7999)
This commit is contained in:
@@ -262,12 +262,13 @@ def qserve_w4a8_per_group_gemm(
|
||||
def dsv3_router_gemm(
|
||||
hidden_states: torch.Tensor,
|
||||
router_weights: torch.Tensor,
|
||||
out_dtype: torch.dtype = torch.bfloat16,
|
||||
) -> torch.Tensor:
|
||||
output = torch.empty(
|
||||
hidden_states.shape[0],
|
||||
router_weights.shape[0],
|
||||
device=hidden_states.device,
|
||||
dtype=torch.float32,
|
||||
dtype=out_dtype,
|
||||
)
|
||||
torch.ops.sgl_kernel.dsv3_router_gemm(
|
||||
output,
|
||||
|
||||
Reference in New Issue
Block a user