use sglang_per_token_group_quant_fp8 from sgl-kernel instead of trion kernel (#5473)
Co-authored-by: Zhang Kaihong <zhangkaihong.zkh@alibaba-inc.com>
This commit is contained in:
@@ -275,6 +275,8 @@ def sglang_per_token_group_quant_fp8(
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
group_size: int,
|
group_size: int,
|
||||||
eps: float = 1e-10,
|
eps: float = 1e-10,
|
||||||
|
column_major_scales: bool = False,
|
||||||
|
scale_tma_aligned: bool = False,
|
||||||
):
|
):
|
||||||
assert (
|
assert (
|
||||||
x.shape[-1] % group_size == 0
|
x.shape[-1] % group_size == 0
|
||||||
@@ -282,11 +284,28 @@ def sglang_per_token_group_quant_fp8(
|
|||||||
assert x.is_contiguous(), "`x` is not contiguous"
|
assert x.is_contiguous(), "`x` is not contiguous"
|
||||||
|
|
||||||
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
|
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
|
||||||
x_s = torch.empty(
|
if column_major_scales:
|
||||||
x.shape[:-1] + (x.shape[-1] // group_size,),
|
if scale_tma_aligned:
|
||||||
device=x.device,
|
# aligned to 4 * sizeof(float)
|
||||||
dtype=torch.float32,
|
aligned_size = (x.shape[-2] + 3) // 4 * 4
|
||||||
)
|
x_s = torch.empty(
|
||||||
|
x.shape[:-2] + (x.shape[-1] // group_size, aligned_size),
|
||||||
|
device=x.device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
).permute(-1, -2)[: x.shape[-2], :]
|
||||||
|
else:
|
||||||
|
x_s = torch.empty(
|
||||||
|
(x.shape[-1] // group_size,) + x.shape[:-1],
|
||||||
|
device=x.device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
).permute(-1, -2)
|
||||||
|
else:
|
||||||
|
x_s = torch.empty(
|
||||||
|
x.shape[:-1] + (x.shape[-1] // group_size,),
|
||||||
|
device=x.device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
|
||||||
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
|
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
|
||||||
|
|
||||||
return x_q, x_s
|
return x_q, x_s
|
||||||
|
|||||||
@@ -141,7 +141,7 @@ def apply_w8a8_block_fp8_linear(
|
|||||||
gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
|
gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
|
||||||
else:
|
else:
|
||||||
if _enable_jit_deepgemm:
|
if _enable_jit_deepgemm:
|
||||||
q_input, x_scale = per_token_group_quant_fp8(
|
q_input, x_scale = sglang_per_token_group_quant_fp8(
|
||||||
input_2d,
|
input_2d,
|
||||||
block_size[1],
|
block_size[1],
|
||||||
column_major_scales=True,
|
column_major_scales=True,
|
||||||
|
|||||||
Reference in New Issue
Block a user