Support ue8m0 for triton quant kernel (#7603)
This commit is contained in:
@@ -173,6 +173,7 @@ def _per_token_group_quant_fp8_colmajor(
|
||||
fp8_max,
|
||||
# Meta-parameters
|
||||
BLOCK: tl.constexpr,
|
||||
SCALE_UE8M0: tl.constexpr,
|
||||
):
|
||||
"""A Triton-accelerated function to perform per-token-group
|
||||
quantization on a tensor.
|
||||
@@ -197,6 +198,8 @@ def _per_token_group_quant_fp8_colmajor(
|
||||
# Quant
|
||||
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
||||
y_s = _absmax / fp8_max
|
||||
if SCALE_UE8M0:
|
||||
y_s = tl.exp2(tl.ceil(tl.log2(tl.abs(y_s))))
|
||||
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
||||
|
||||
tl.store(y_q_ptr + cols, y_q, mask=mask)
|
||||
@@ -209,6 +212,7 @@ def per_token_group_quant_fp8(
|
||||
eps: float = 1e-10,
|
||||
column_major_scales: bool = False,
|
||||
scale_tma_aligned: bool = False,
|
||||
scale_ue8m0: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Function to perform per-token-group quantization on an input tensor `x`.
|
||||
|
||||
@@ -229,29 +233,17 @@ def per_token_group_quant_fp8(
|
||||
assert x.is_contiguous(), "`x` is not contiguous"
|
||||
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
|
||||
x_s = create_per_token_group_quant_fp8_output_scale(
|
||||
x_shape=x.shape,
|
||||
device=x.device,
|
||||
group_size=group_size,
|
||||
column_major_scales=column_major_scales,
|
||||
scale_tma_aligned=scale_tma_aligned,
|
||||
scale_ue8m0=False,
|
||||
)
|
||||
|
||||
M = x.numel() // group_size
|
||||
N = group_size
|
||||
if column_major_scales:
|
||||
if scale_tma_aligned:
|
||||
# aligned to 4 * sizeof(float)
|
||||
aligned_size = (x.shape[-2] + 3) // 4 * 4
|
||||
x_s = torch.empty(
|
||||
x.shape[:-2] + (x.shape[-1] // group_size, aligned_size),
|
||||
device=x.device,
|
||||
dtype=torch.float32,
|
||||
).permute(-1, -2)[: x.shape[-2], :]
|
||||
else:
|
||||
x_s = torch.empty(
|
||||
(x.shape[-1] // group_size,) + x.shape[:-1],
|
||||
device=x.device,
|
||||
dtype=torch.float32,
|
||||
).permute(-1, -2)
|
||||
else:
|
||||
x_s = torch.empty(
|
||||
x.shape[:-1] + (x.shape[-1] // group_size,),
|
||||
device=x.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
BLOCK = triton.next_power_of_2(N)
|
||||
# heuristics for number of warps
|
||||
@@ -271,8 +263,10 @@ def per_token_group_quant_fp8(
|
||||
BLOCK=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
SCALE_UE8M0=scale_ue8m0,
|
||||
)
|
||||
else:
|
||||
assert not scale_ue8m0
|
||||
_per_token_group_quant_fp8[(M,)](
|
||||
x,
|
||||
x_q,
|
||||
@@ -287,9 +281,66 @@ def per_token_group_quant_fp8(
|
||||
num_stages=num_stages,
|
||||
)
|
||||
|
||||
if scale_ue8m0:
|
||||
from deep_gemm.utils.layout import transform_sf_into_required_layout
|
||||
|
||||
assert group_size == 128
|
||||
x_s = transform_sf_into_required_layout(
|
||||
x_s,
|
||||
num_groups=None,
|
||||
mn=x_q.shape[0],
|
||||
k=x_q.shape[1],
|
||||
recipe=(1, group_size, group_size),
|
||||
is_sfa=True,
|
||||
)
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
def create_per_token_group_quant_fp8_output_scale(
|
||||
x_shape,
|
||||
device,
|
||||
group_size,
|
||||
column_major_scales: bool,
|
||||
scale_tma_aligned: bool,
|
||||
scale_ue8m0: bool,
|
||||
):
|
||||
if scale_ue8m0:
|
||||
assert column_major_scales and scale_tma_aligned
|
||||
x_q_mn, x_q_k = x_shape
|
||||
x_s_mn, x_s_k = x_q_mn, x_q_k // 128
|
||||
aligned_mn = align(x_s_mn, 4)
|
||||
aligned_k = align(x_s_k, 4)
|
||||
# TODO(FIXME): Fix cuda kernel and recover here to empty.
|
||||
return torch.zeros(
|
||||
(aligned_k // 4, aligned_mn),
|
||||
device=device,
|
||||
dtype=torch.int,
|
||||
).transpose(0, 1)[:x_s_mn, :]
|
||||
elif column_major_scales:
|
||||
if scale_tma_aligned:
|
||||
# TODO extract "align" function
|
||||
# aligned to 4 * sizeof(float)
|
||||
aligned_size = (x_shape[-2] + 3) // 4 * 4
|
||||
return torch.empty(
|
||||
x_shape[:-2] + (x_shape[-1] // group_size, aligned_size),
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
).permute(-1, -2)[: x_shape[-2], :]
|
||||
else:
|
||||
return torch.empty(
|
||||
(x_shape[-1] // group_size,) + x_shape[:-1],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
).permute(-1, -2)
|
||||
else:
|
||||
return torch.empty(
|
||||
x_shape[:-1] + (x_shape[-1] // group_size,),
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
|
||||
def sglang_per_token_group_quant_fp8(
|
||||
x: torch.Tensor,
|
||||
group_size: int,
|
||||
@@ -303,41 +354,20 @@ def sglang_per_token_group_quant_fp8(
|
||||
), "the last dimension of `x` cannot be divisible by `group_size`"
|
||||
assert x.is_contiguous(), "`x` is not contiguous"
|
||||
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
|
||||
if scale_ue8m0:
|
||||
assert column_major_scales and scale_tma_aligned
|
||||
x_q_mn, x_q_k = x.shape
|
||||
x_s_mn, x_s_k = x_q_mn, x_q_k // 128
|
||||
aligned_mn = align(x_s_mn, 4)
|
||||
aligned_k = align(x_s_k, 4)
|
||||
# TODO(FIXME): Fix cuda kernel and recover here to empty.
|
||||
x_s = torch.zeros(
|
||||
(aligned_k // 4, aligned_mn),
|
||||
device=x.device,
|
||||
dtype=torch.int,
|
||||
).transpose(0, 1)[:x_s_mn, :]
|
||||
elif column_major_scales:
|
||||
if scale_tma_aligned:
|
||||
# TODO extract "align" function
|
||||
# aligned to 4 * sizeof(float)
|
||||
aligned_size = (x.shape[-2] + 3) // 4 * 4
|
||||
x_s = torch.empty(
|
||||
x.shape[:-2] + (x.shape[-1] // group_size, aligned_size),
|
||||
device=x.device,
|
||||
dtype=torch.float32,
|
||||
).permute(-1, -2)[: x.shape[-2], :]
|
||||
else:
|
||||
x_s = torch.empty(
|
||||
(x.shape[-1] // group_size,) + x.shape[:-1],
|
||||
device=x.device,
|
||||
dtype=torch.float32,
|
||||
).permute(-1, -2)
|
||||
else:
|
||||
x_s = torch.empty(
|
||||
x.shape[:-1] + (x.shape[-1] // group_size,),
|
||||
device=x.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
# TODO: handle this case by fixing the (token=4, dim=256, group_size=128) UT case
|
||||
assert x.shape[-1] % (group_size * 4) == 0
|
||||
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
|
||||
x_s = create_per_token_group_quant_fp8_output_scale(
|
||||
x_shape=x.shape,
|
||||
device=x.device,
|
||||
group_size=group_size,
|
||||
column_major_scales=column_major_scales,
|
||||
scale_tma_aligned=scale_tma_aligned,
|
||||
scale_ue8m0=scale_ue8m0,
|
||||
)
|
||||
|
||||
if x.shape[0] > 0:
|
||||
sgl_per_token_group_quant_fp8(
|
||||
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
|
||||
|
||||
Reference in New Issue
Block a user