diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 1b5514920..eca92b151 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -755,6 +755,9 @@ def invoke_fused_moe_kernel( from sglang.srt.layers.quantization.fp8_kernel import ( sglang_per_token_group_quant_fp8, ) + from sglang.srt.layers.quantization.int8_kernel import ( + sglang_per_token_group_quant_int8, + ) else: from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 @@ -794,7 +797,10 @@ def invoke_fused_moe_kernel( # activation block-wise int8 quantization assert len(block_shape) == 2 block_n, block_k = block_shape[0], block_shape[1] - A, A_scale = per_token_group_quant_int8(A, block_k) + if _is_cuda: + A, A_scale = sglang_per_token_group_quant_int8(A, block_k) + else: + A, A_scale = per_token_group_quant_int8(A, block_k) assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] diff --git a/python/sglang/srt/layers/quantization/int8_kernel.py b/python/sglang/srt/layers/quantization/int8_kernel.py index 79b03f64a..22f1c5069 100644 --- a/python/sglang/srt/layers/quantization/int8_kernel.py +++ b/python/sglang/srt/layers/quantization/int8_kernel.py @@ -8,7 +8,11 @@ import torch import triton import triton.language as tl -from sglang.srt.utils import get_device_name +from sglang.srt.utils import get_device_name, is_cuda + +_is_cuda = is_cuda() +if _is_cuda: + from sgl_kernel import sgl_per_token_group_quant_int8 logger = logging.getLogger(__name__) @@ -165,6 +169,33 @@ def per_token_group_quant_int8( return x_q, x_s +def sglang_per_token_group_quant_int8( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + dtype: torch.dtype = torch.int8, +): + assert ( + x.shape[-1] % group_size == 0 + ), "the last dimension of `x` cannot be divisible by `group_size`" + assert x.is_contiguous(), "`x` is not contiguous" + + iinfo = torch.iinfo(dtype) + int8_max = iinfo.max + int8_min = iinfo.min + + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + x_s = torch.empty( + x.shape[:-1] + (x.shape[-1] // group_size,), + device=x.device, + dtype=torch.float32, + ) + + sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max) + + return x_q, x_s + + @triton.jit def _w8a8_block_int8_matmul( # Pointers to inputs and output