Add dsv3 router gemm kernel (#7627)

This commit is contained in:
Baizhou Zhang
2025-06-29 23:31:55 -07:00
committed by GitHub
parent 22352d47a9
commit 7248272ccc
8 changed files with 398 additions and 0 deletions

View File

@@ -259,6 +259,24 @@ def qserve_w4a8_per_group_gemm(
return out_feats
def dsv3_router_gemm(
hidden_states: torch.Tensor,
router_weights: torch.Tensor,
) -> torch.Tensor:
output = torch.empty(
hidden_states.shape[0],
router_weights.shape[0],
device=hidden_states.device,
dtype=torch.float32,
)
torch.ops.sgl_kernel.dsv3_router_gemm(
output,
hidden_states,
router_weights,
)
return output
def shuffle_rows(input_tensor, dst2src_map, output_tensor_shape):
output_tensor = torch.empty(
output_tensor_shape,