[sgl-kernel][1/N]Support Expert Specialization Grouped GEMM (#11432)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com> Co-authored-by: PGFLMG <1106310035@qq.com> Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
This commit is contained in:
@@ -244,6 +244,7 @@ from sgl_kernel.elementwise import (
|
||||
rmsnorm,
|
||||
silu_and_mul,
|
||||
)
|
||||
from sgl_kernel.expert_specilization import es_fp8_blockwise_scaled_grouped_mm
|
||||
from sgl_kernel.fused_moe import fused_marlin_moe
|
||||
from sgl_kernel.gemm import (
|
||||
awq_dequantize,
|
||||
|
||||
27
sgl-kernel/python/sgl_kernel/expert_specilization.py
Normal file
27
sgl-kernel/python/sgl_kernel/expert_specilization.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import torch
|
||||
|
||||
|
||||
def es_fp8_blockwise_scaled_grouped_mm(
|
||||
output,
|
||||
a,
|
||||
b,
|
||||
scales_a,
|
||||
scales_b,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_d,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
):
|
||||
torch.ops.sgl_kernel.es_fp8_blockwise_scaled_grouped_mm.default(
|
||||
output,
|
||||
a,
|
||||
b,
|
||||
scales_a,
|
||||
scales_b,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_d,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
)
|
||||
Reference in New Issue
Block a user