Support new DeepGEMM format in per token group quant (part 2: srt) (#7155)
This commit is contained in:
@@ -49,7 +49,7 @@ runtime_common = [
|
|||||||
|
|
||||||
srt = [
|
srt = [
|
||||||
"sglang[runtime_common]",
|
"sglang[runtime_common]",
|
||||||
"sgl-kernel==0.1.7",
|
"sgl-kernel==0.1.8.post1",
|
||||||
"flashinfer_python==0.2.6.post1",
|
"flashinfer_python==0.2.6.post1",
|
||||||
"torch==2.7.1",
|
"torch==2.7.1",
|
||||||
"torchaudio==2.7.1",
|
"torchaudio==2.7.1",
|
||||||
|
|||||||
@@ -605,7 +605,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
assert_pkg_version(
|
assert_pkg_version(
|
||||||
"sgl-kernel",
|
"sgl-kernel",
|
||||||
"0.1.7",
|
"0.1.8.post1",
|
||||||
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -280,6 +280,7 @@ def sglang_per_token_group_quant_fp8(
|
|||||||
eps: float = 1e-10,
|
eps: float = 1e-10,
|
||||||
column_major_scales: bool = False,
|
column_major_scales: bool = False,
|
||||||
scale_tma_aligned: bool = False,
|
scale_tma_aligned: bool = False,
|
||||||
|
scale_ue8m0: bool = False,
|
||||||
):
|
):
|
||||||
assert (
|
assert (
|
||||||
x.shape[-1] % group_size == 0
|
x.shape[-1] % group_size == 0
|
||||||
@@ -287,8 +288,20 @@ def sglang_per_token_group_quant_fp8(
|
|||||||
assert x.is_contiguous(), "`x` is not contiguous"
|
assert x.is_contiguous(), "`x` is not contiguous"
|
||||||
|
|
||||||
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
|
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
|
||||||
if column_major_scales:
|
if scale_ue8m0:
|
||||||
|
assert column_major_scales and scale_tma_aligned
|
||||||
|
x_q_mn, x_q_k = x.shape
|
||||||
|
x_s_mn, x_s_k = x_q_mn, x_q_k // 128
|
||||||
|
aligned_mn = align(x_s_mn, 4)
|
||||||
|
aligned_k = align(x_s_k, 4)
|
||||||
|
x_s = torch.empty(
|
||||||
|
(aligned_k // 4, aligned_mn),
|
||||||
|
device=x.device,
|
||||||
|
dtype=torch.int,
|
||||||
|
).permute(-1, -2)[:x_s_mn, :]
|
||||||
|
elif column_major_scales:
|
||||||
if scale_tma_aligned:
|
if scale_tma_aligned:
|
||||||
|
# TODO extract "align" function
|
||||||
# aligned to 4 * sizeof(float)
|
# aligned to 4 * sizeof(float)
|
||||||
aligned_size = (x.shape[-2] + 3) // 4 * 4
|
aligned_size = (x.shape[-2] + 3) // 4 * 4
|
||||||
x_s = torch.empty(
|
x_s = torch.empty(
|
||||||
@@ -309,7 +322,9 @@ def sglang_per_token_group_quant_fp8(
|
|||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
if x.shape[0] > 0:
|
if x.shape[0] > 0:
|
||||||
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
|
sgl_per_token_group_quant_fp8(
|
||||||
|
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
|
||||||
|
)
|
||||||
|
|
||||||
return x_q, x_s
|
return x_q, x_s
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user