Add dsv3 router gemm kernel (#7627)

This commit is contained in:
Baizhou Zhang
2025-06-29 23:31:55 -07:00
committed by GitHub
parent 22352d47a9
commit 7248272ccc
8 changed files with 398 additions and 0 deletions

View File

@@ -34,6 +34,7 @@ from sgl_kernel.gemm import (
bmm_fp8,
cutlass_scaled_fp4_mm,
dsv3_fused_a_gemm,
dsv3_router_gemm,
fp8_blockwise_scaled_mm,
fp8_scaled_mm,
int8_scaled_mm,

View File

@@ -259,6 +259,24 @@ def qserve_w4a8_per_group_gemm(
return out_feats
def dsv3_router_gemm(
hidden_states: torch.Tensor,
router_weights: torch.Tensor,
) -> torch.Tensor:
output = torch.empty(
hidden_states.shape[0],
router_weights.shape[0],
device=hidden_states.device,
dtype=torch.float32,
)
torch.ops.sgl_kernel.dsv3_router_gemm(
output,
hidden_states,
router_weights,
)
return output
def shuffle_rows(input_tensor, dst2src_map, output_tensor_shape):
output_tensor = torch.empty(
output_tensor_shape,