Support ue8m0 for triton quant kernel (#7603)
This commit is contained in:
@@ -173,6 +173,7 @@ def _per_token_group_quant_fp8_colmajor(
|
|||||||
fp8_max,
|
fp8_max,
|
||||||
# Meta-parameters
|
# Meta-parameters
|
||||||
BLOCK: tl.constexpr,
|
BLOCK: tl.constexpr,
|
||||||
|
SCALE_UE8M0: tl.constexpr,
|
||||||
):
|
):
|
||||||
"""A Triton-accelerated function to perform per-token-group
|
"""A Triton-accelerated function to perform per-token-group
|
||||||
quantization on a tensor.
|
quantization on a tensor.
|
||||||
@@ -197,6 +198,8 @@ def _per_token_group_quant_fp8_colmajor(
|
|||||||
# Quant
|
# Quant
|
||||||
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
||||||
y_s = _absmax / fp8_max
|
y_s = _absmax / fp8_max
|
||||||
|
if SCALE_UE8M0:
|
||||||
|
y_s = tl.exp2(tl.ceil(tl.log2(tl.abs(y_s))))
|
||||||
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
||||||
|
|
||||||
tl.store(y_q_ptr + cols, y_q, mask=mask)
|
tl.store(y_q_ptr + cols, y_q, mask=mask)
|
||||||
@@ -209,6 +212,7 @@ def per_token_group_quant_fp8(
|
|||||||
eps: float = 1e-10,
|
eps: float = 1e-10,
|
||||||
column_major_scales: bool = False,
|
column_major_scales: bool = False,
|
||||||
scale_tma_aligned: bool = False,
|
scale_tma_aligned: bool = False,
|
||||||
|
scale_ue8m0: bool = False,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Function to perform per-token-group quantization on an input tensor `x`.
|
"""Function to perform per-token-group quantization on an input tensor `x`.
|
||||||
|
|
||||||
@@ -229,29 +233,17 @@ def per_token_group_quant_fp8(
|
|||||||
assert x.is_contiguous(), "`x` is not contiguous"
|
assert x.is_contiguous(), "`x` is not contiguous"
|
||||||
|
|
||||||
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
|
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
|
||||||
|
x_s = create_per_token_group_quant_fp8_output_scale(
|
||||||
|
x_shape=x.shape,
|
||||||
|
device=x.device,
|
||||||
|
group_size=group_size,
|
||||||
|
column_major_scales=column_major_scales,
|
||||||
|
scale_tma_aligned=scale_tma_aligned,
|
||||||
|
scale_ue8m0=False,
|
||||||
|
)
|
||||||
|
|
||||||
M = x.numel() // group_size
|
M = x.numel() // group_size
|
||||||
N = group_size
|
N = group_size
|
||||||
if column_major_scales:
|
|
||||||
if scale_tma_aligned:
|
|
||||||
# aligned to 4 * sizeof(float)
|
|
||||||
aligned_size = (x.shape[-2] + 3) // 4 * 4
|
|
||||||
x_s = torch.empty(
|
|
||||||
x.shape[:-2] + (x.shape[-1] // group_size, aligned_size),
|
|
||||||
device=x.device,
|
|
||||||
dtype=torch.float32,
|
|
||||||
).permute(-1, -2)[: x.shape[-2], :]
|
|
||||||
else:
|
|
||||||
x_s = torch.empty(
|
|
||||||
(x.shape[-1] // group_size,) + x.shape[:-1],
|
|
||||||
device=x.device,
|
|
||||||
dtype=torch.float32,
|
|
||||||
).permute(-1, -2)
|
|
||||||
else:
|
|
||||||
x_s = torch.empty(
|
|
||||||
x.shape[:-1] + (x.shape[-1] // group_size,),
|
|
||||||
device=x.device,
|
|
||||||
dtype=torch.float32,
|
|
||||||
)
|
|
||||||
|
|
||||||
BLOCK = triton.next_power_of_2(N)
|
BLOCK = triton.next_power_of_2(N)
|
||||||
# heuristics for number of warps
|
# heuristics for number of warps
|
||||||
@@ -271,8 +263,10 @@ def per_token_group_quant_fp8(
|
|||||||
BLOCK=BLOCK,
|
BLOCK=BLOCK,
|
||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
num_stages=num_stages,
|
num_stages=num_stages,
|
||||||
|
SCALE_UE8M0=scale_ue8m0,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
assert not scale_ue8m0
|
||||||
_per_token_group_quant_fp8[(M,)](
|
_per_token_group_quant_fp8[(M,)](
|
||||||
x,
|
x,
|
||||||
x_q,
|
x_q,
|
||||||
@@ -287,9 +281,66 @@ def per_token_group_quant_fp8(
|
|||||||
num_stages=num_stages,
|
num_stages=num_stages,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if scale_ue8m0:
|
||||||
|
from deep_gemm.utils.layout import transform_sf_into_required_layout
|
||||||
|
|
||||||
|
assert group_size == 128
|
||||||
|
x_s = transform_sf_into_required_layout(
|
||||||
|
x_s,
|
||||||
|
num_groups=None,
|
||||||
|
mn=x_q.shape[0],
|
||||||
|
k=x_q.shape[1],
|
||||||
|
recipe=(1, group_size, group_size),
|
||||||
|
is_sfa=True,
|
||||||
|
)
|
||||||
|
|
||||||
return x_q, x_s
|
return x_q, x_s
|
||||||
|
|
||||||
|
|
||||||
|
def create_per_token_group_quant_fp8_output_scale(
|
||||||
|
x_shape,
|
||||||
|
device,
|
||||||
|
group_size,
|
||||||
|
column_major_scales: bool,
|
||||||
|
scale_tma_aligned: bool,
|
||||||
|
scale_ue8m0: bool,
|
||||||
|
):
|
||||||
|
if scale_ue8m0:
|
||||||
|
assert column_major_scales and scale_tma_aligned
|
||||||
|
x_q_mn, x_q_k = x_shape
|
||||||
|
x_s_mn, x_s_k = x_q_mn, x_q_k // 128
|
||||||
|
aligned_mn = align(x_s_mn, 4)
|
||||||
|
aligned_k = align(x_s_k, 4)
|
||||||
|
# TODO(FIXME): Fix cuda kernel and recover here to empty.
|
||||||
|
return torch.zeros(
|
||||||
|
(aligned_k // 4, aligned_mn),
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int,
|
||||||
|
).transpose(0, 1)[:x_s_mn, :]
|
||||||
|
elif column_major_scales:
|
||||||
|
if scale_tma_aligned:
|
||||||
|
# TODO extract "align" function
|
||||||
|
# aligned to 4 * sizeof(float)
|
||||||
|
aligned_size = (x_shape[-2] + 3) // 4 * 4
|
||||||
|
return torch.empty(
|
||||||
|
x_shape[:-2] + (x_shape[-1] // group_size, aligned_size),
|
||||||
|
device=device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
).permute(-1, -2)[: x_shape[-2], :]
|
||||||
|
else:
|
||||||
|
return torch.empty(
|
||||||
|
(x_shape[-1] // group_size,) + x_shape[:-1],
|
||||||
|
device=device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
).permute(-1, -2)
|
||||||
|
else:
|
||||||
|
return torch.empty(
|
||||||
|
x_shape[:-1] + (x_shape[-1] // group_size,),
|
||||||
|
device=device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def sglang_per_token_group_quant_fp8(
|
def sglang_per_token_group_quant_fp8(
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
group_size: int,
|
group_size: int,
|
||||||
@@ -303,41 +354,20 @@ def sglang_per_token_group_quant_fp8(
|
|||||||
), "the last dimension of `x` cannot be divisible by `group_size`"
|
), "the last dimension of `x` cannot be divisible by `group_size`"
|
||||||
assert x.is_contiguous(), "`x` is not contiguous"
|
assert x.is_contiguous(), "`x` is not contiguous"
|
||||||
|
|
||||||
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
|
|
||||||
if scale_ue8m0:
|
if scale_ue8m0:
|
||||||
assert column_major_scales and scale_tma_aligned
|
# TODO: handle this case by fixing the (token=4, dim=256, group_size=128) UT case
|
||||||
x_q_mn, x_q_k = x.shape
|
assert x.shape[-1] % (group_size * 4) == 0
|
||||||
x_s_mn, x_s_k = x_q_mn, x_q_k // 128
|
|
||||||
aligned_mn = align(x_s_mn, 4)
|
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
|
||||||
aligned_k = align(x_s_k, 4)
|
x_s = create_per_token_group_quant_fp8_output_scale(
|
||||||
# TODO(FIXME): Fix cuda kernel and recover here to empty.
|
x_shape=x.shape,
|
||||||
x_s = torch.zeros(
|
device=x.device,
|
||||||
(aligned_k // 4, aligned_mn),
|
group_size=group_size,
|
||||||
device=x.device,
|
column_major_scales=column_major_scales,
|
||||||
dtype=torch.int,
|
scale_tma_aligned=scale_tma_aligned,
|
||||||
).transpose(0, 1)[:x_s_mn, :]
|
scale_ue8m0=scale_ue8m0,
|
||||||
elif column_major_scales:
|
)
|
||||||
if scale_tma_aligned:
|
|
||||||
# TODO extract "align" function
|
|
||||||
# aligned to 4 * sizeof(float)
|
|
||||||
aligned_size = (x.shape[-2] + 3) // 4 * 4
|
|
||||||
x_s = torch.empty(
|
|
||||||
x.shape[:-2] + (x.shape[-1] // group_size, aligned_size),
|
|
||||||
device=x.device,
|
|
||||||
dtype=torch.float32,
|
|
||||||
).permute(-1, -2)[: x.shape[-2], :]
|
|
||||||
else:
|
|
||||||
x_s = torch.empty(
|
|
||||||
(x.shape[-1] // group_size,) + x.shape[:-1],
|
|
||||||
device=x.device,
|
|
||||||
dtype=torch.float32,
|
|
||||||
).permute(-1, -2)
|
|
||||||
else:
|
|
||||||
x_s = torch.empty(
|
|
||||||
x.shape[:-1] + (x.shape[-1] // group_size,),
|
|
||||||
device=x.device,
|
|
||||||
dtype=torch.float32,
|
|
||||||
)
|
|
||||||
if x.shape[0] > 0:
|
if x.shape[0] > 0:
|
||||||
sgl_per_token_group_quant_fp8(
|
sgl_per_token_group_quant_fp8(
|
||||||
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
|
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
|
||||||
|
|||||||
Reference in New Issue
Block a user