Add dsv3 router gemm kernel (#7627)
This commit is contained in:
@@ -34,6 +34,7 @@ from sgl_kernel.gemm import (
|
||||
bmm_fp8,
|
||||
cutlass_scaled_fp4_mm,
|
||||
dsv3_fused_a_gemm,
|
||||
dsv3_router_gemm,
|
||||
fp8_blockwise_scaled_mm,
|
||||
fp8_scaled_mm,
|
||||
int8_scaled_mm,
|
||||
|
||||
@@ -259,6 +259,24 @@ def qserve_w4a8_per_group_gemm(
|
||||
return out_feats
|
||||
|
||||
|
||||
def dsv3_router_gemm(
|
||||
hidden_states: torch.Tensor,
|
||||
router_weights: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
output = torch.empty(
|
||||
hidden_states.shape[0],
|
||||
router_weights.shape[0],
|
||||
device=hidden_states.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
torch.ops.sgl_kernel.dsv3_router_gemm(
|
||||
output,
|
||||
hidden_states,
|
||||
router_weights,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def shuffle_rows(input_tensor, dst2src_map, output_tensor_shape):
|
||||
output_tensor = torch.empty(
|
||||
output_tensor_shape,
|
||||
|
||||
Reference in New Issue
Block a user