Use sgl-kernel sgl_per_token_group_quant_int8 (#4971)
This commit is contained in:
@@ -755,6 +755,9 @@ def invoke_fused_moe_kernel(
|
|||||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
sglang_per_token_group_quant_fp8,
|
sglang_per_token_group_quant_fp8,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.quantization.int8_kernel import (
|
||||||
|
sglang_per_token_group_quant_int8,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
||||||
|
|
||||||
@@ -794,7 +797,10 @@ def invoke_fused_moe_kernel(
|
|||||||
# activation block-wise int8 quantization
|
# activation block-wise int8 quantization
|
||||||
assert len(block_shape) == 2
|
assert len(block_shape) == 2
|
||||||
block_n, block_k = block_shape[0], block_shape[1]
|
block_n, block_k = block_shape[0], block_shape[1]
|
||||||
A, A_scale = per_token_group_quant_int8(A, block_k)
|
if _is_cuda:
|
||||||
|
A, A_scale = sglang_per_token_group_quant_int8(A, block_k)
|
||||||
|
else:
|
||||||
|
A, A_scale = per_token_group_quant_int8(A, block_k)
|
||||||
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
||||||
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
|
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
|
||||||
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
|
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
|
||||||
|
|||||||
@@ -8,7 +8,11 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.utils import get_device_name
|
from sglang.srt.utils import get_device_name, is_cuda
|
||||||
|
|
||||||
|
_is_cuda = is_cuda()
|
||||||
|
if _is_cuda:
|
||||||
|
from sgl_kernel import sgl_per_token_group_quant_int8
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -165,6 +169,33 @@ def per_token_group_quant_int8(
|
|||||||
return x_q, x_s
|
return x_q, x_s
|
||||||
|
|
||||||
|
|
||||||
|
def sglang_per_token_group_quant_int8(
|
||||||
|
x: torch.Tensor,
|
||||||
|
group_size: int,
|
||||||
|
eps: float = 1e-10,
|
||||||
|
dtype: torch.dtype = torch.int8,
|
||||||
|
):
|
||||||
|
assert (
|
||||||
|
x.shape[-1] % group_size == 0
|
||||||
|
), "the last dimension of `x` cannot be divisible by `group_size`"
|
||||||
|
assert x.is_contiguous(), "`x` is not contiguous"
|
||||||
|
|
||||||
|
iinfo = torch.iinfo(dtype)
|
||||||
|
int8_max = iinfo.max
|
||||||
|
int8_min = iinfo.min
|
||||||
|
|
||||||
|
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
||||||
|
x_s = torch.empty(
|
||||||
|
x.shape[:-1] + (x.shape[-1] // group_size,),
|
||||||
|
device=x.device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max)
|
||||||
|
|
||||||
|
return x_q, x_s
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _w8a8_block_int8_matmul(
|
def _w8a8_block_int8_matmul(
|
||||||
# Pointers to inputs and output
|
# Pointers to inputs and output
|
||||||
|
|||||||
Reference in New Issue
Block a user