diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py index d4be7ae05..33ac80f4f 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py @@ -13,8 +13,7 @@ from triton_kernels.matmul_ogs import ( PrecisionConfig, matmul_ogs, ) -from triton_kernels.numerics import InFlexData, MicroscalingCtx -from triton_kernels.quantization import downcast_to_mxfp +from triton_kernels.numerics import InFlexData from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx from triton_kernels.swiglu import swiglu_fn @@ -26,30 +25,6 @@ if TYPE_CHECKING: def quantize(w, dtype, dev, **opt): if dtype == "bf16": return w.to(torch.bfloat16), InFlexData() - elif dtype == "fp8": - wq = w.to(torch.float8_e4m3fn).transpose(-1, -2).contiguous().transpose(-1, -2) - return ( - wq, - InFlexData(dtype=wq.dtype, scale=w.abs().max().unsqueeze(0)), - MicroscalingCtx(), - ) - else: - assert dtype == "mx4", f"{dtype=}" - swizzle_mx_scale = opt["swizzle_mx_scale"] - swizzle_axis = 2 if swizzle_mx_scale else None - w = w.to(torch.bfloat16) - w, mx_scales, weight_scale_shape = downcast_to_mxfp( - w, torch.uint8, axis=1, swizzle_axis=swizzle_axis - ) - return ( - w, - InFlexData(), - MicroscalingCtx( - weight_scale=mx_scales, - swizzle_mx=swizzle_mx_scale, - actual_weight_scale_shape=weight_scale_shape, - ), - ) def triton_kernel_moe_forward(