Use sgl-kernel sgl_per_token_group_quant_int8 (#4971)
This commit is contained in:
@@ -755,6 +755,9 @@ def invoke_fused_moe_kernel(
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
sglang_per_token_group_quant_fp8,
|
||||
)
|
||||
from sglang.srt.layers.quantization.int8_kernel import (
|
||||
sglang_per_token_group_quant_int8,
|
||||
)
|
||||
else:
|
||||
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
||||
|
||||
@@ -794,7 +797,10 @@ def invoke_fused_moe_kernel(
|
||||
# activation block-wise int8 quantization
|
||||
assert len(block_shape) == 2
|
||||
block_n, block_k = block_shape[0], block_shape[1]
|
||||
A, A_scale = per_token_group_quant_int8(A, block_k)
|
||||
if _is_cuda:
|
||||
A, A_scale = sglang_per_token_group_quant_int8(A, block_k)
|
||||
else:
|
||||
A, A_scale = per_token_group_quant_int8(A, block_k)
|
||||
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
||||
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
|
||||
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
|
||||
|
||||
@@ -8,7 +8,11 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.utils import get_device_name
|
||||
from sglang.srt.utils import get_device_name, is_cuda
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
if _is_cuda:
|
||||
from sgl_kernel import sgl_per_token_group_quant_int8
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -165,6 +169,33 @@ def per_token_group_quant_int8(
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
def sglang_per_token_group_quant_int8(
|
||||
x: torch.Tensor,
|
||||
group_size: int,
|
||||
eps: float = 1e-10,
|
||||
dtype: torch.dtype = torch.int8,
|
||||
):
|
||||
assert (
|
||||
x.shape[-1] % group_size == 0
|
||||
), "the last dimension of `x` cannot be divisible by `group_size`"
|
||||
assert x.is_contiguous(), "`x` is not contiguous"
|
||||
|
||||
iinfo = torch.iinfo(dtype)
|
||||
int8_max = iinfo.max
|
||||
int8_min = iinfo.min
|
||||
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
||||
x_s = torch.empty(
|
||||
x.shape[:-1] + (x.shape[-1] // group_size,),
|
||||
device=x.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max)
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _w8a8_block_int8_matmul(
|
||||
# Pointers to inputs and output
|
||||
|
||||
Reference in New Issue
Block a user