use sgl_per_token_group_quant_fp8 kernel (#3493)
This commit is contained in:
@@ -27,6 +27,10 @@ from sglang.srt.utils import get_device_core_count, get_device_name, is_hip
|
||||
is_hip_ = is_hip()
|
||||
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
||||
|
||||
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
||||
if _is_cuda:
|
||||
from sgl_kernel import sgl_per_token_group_quant_fp8
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -135,6 +139,36 @@ def per_token_group_quant_fp8(
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
def sglang_per_token_group_quant_fp8(
|
||||
x: torch.Tensor,
|
||||
group_size: int,
|
||||
eps: float = 1e-10,
|
||||
dtype: torch.dtype = fp8_type_,
|
||||
):
|
||||
assert (
|
||||
x.shape[-1] % group_size == 0
|
||||
), "the last dimension of `x` cannot be divisible by `group_size`"
|
||||
assert x.is_contiguous(), "`x` is not contiguous"
|
||||
|
||||
finfo = torch.finfo(dtype)
|
||||
fp8_max = finfo.max
|
||||
|
||||
fp8_min = -fp8_max
|
||||
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
||||
M = x.numel() // group_size
|
||||
N = group_size
|
||||
x_s = torch.empty(
|
||||
x.shape[:-1] + (x.shape[-1] // group_size,),
|
||||
device=x.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _w8a8_block_fp8_matmul(
|
||||
# Pointers to inputs and output
|
||||
|
||||
Reference in New Issue
Block a user