Partially unify triton per token group quant kernels (#9485)
This commit is contained in:
@@ -113,7 +113,7 @@ if supports_custom_op():
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _per_token_group_quant_fp8(
|
||||
def _per_token_group_quant_8bit(
|
||||
# Pointers to inputs and output
|
||||
y_ptr,
|
||||
y_q_ptr,
|
||||
@@ -125,8 +125,8 @@ def _per_token_group_quant_fp8(
|
||||
# Avoid to divide zero
|
||||
eps,
|
||||
# Information for float8
|
||||
fp8_min,
|
||||
fp8_max,
|
||||
bit8_min,
|
||||
bit8_max,
|
||||
# Meta-parameters
|
||||
BLOCK: tl.constexpr,
|
||||
):
|
||||
@@ -147,16 +147,16 @@ def _per_token_group_quant_fp8(
|
||||
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
||||
# Quant
|
||||
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
||||
y_s = _absmax / fp8_max
|
||||
y_s = _absmax / bit8_max
|
||||
y_s_inv = 1.0 / y_s
|
||||
y_q = tl.clamp(y * y_s_inv, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
||||
y_q = tl.clamp(y * y_s_inv, bit8_min, bit8_max).to(y_q_ptr.dtype.element_ty)
|
||||
|
||||
tl.store(y_q_ptr + cols, y_q, mask=mask)
|
||||
tl.store(y_s_ptr, y_s)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _per_token_group_quant_fp8_colmajor(
|
||||
def _per_token_group_quant_8bit_colmajor(
|
||||
# Pointers to inputs and output
|
||||
y_ptr,
|
||||
y_q_ptr,
|
||||
@@ -169,8 +169,8 @@ def _per_token_group_quant_fp8_colmajor(
|
||||
# Avoid to divide zero
|
||||
eps,
|
||||
# Information for float8
|
||||
fp8_min,
|
||||
fp8_max,
|
||||
bit8_min,
|
||||
bit8_max,
|
||||
# Meta-parameters
|
||||
BLOCK: tl.constexpr,
|
||||
SCALE_UE8M0: tl.constexpr,
|
||||
@@ -197,19 +197,20 @@ def _per_token_group_quant_fp8_colmajor(
|
||||
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
||||
# Quant
|
||||
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
||||
y_s = _absmax / fp8_max
|
||||
y_s = _absmax / bit8_max
|
||||
if SCALE_UE8M0:
|
||||
y_s = tl.exp2(tl.ceil(tl.log2(tl.abs(y_s))))
|
||||
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
||||
y_q = tl.clamp(y / y_s, bit8_min, bit8_max).to(y_q_ptr.dtype.element_ty)
|
||||
|
||||
tl.store(y_q_ptr + cols, y_q, mask=mask)
|
||||
tl.store(y_s_ptr, y_s)
|
||||
|
||||
|
||||
def per_token_group_quant_fp8(
|
||||
def _per_token_group_quant_8bit_raw(
|
||||
x: torch.Tensor,
|
||||
group_size: int,
|
||||
eps: float = 1e-10,
|
||||
dtype: torch.dtype = fp8_dtype,
|
||||
column_major_scales: bool = False,
|
||||
scale_tma_aligned: bool = False,
|
||||
scale_ue8m0: bool = False,
|
||||
@@ -223,6 +224,7 @@ def per_token_group_quant_fp8(
|
||||
x: The input tenosr with ndim >= 2.
|
||||
group_size: The group size used for quantization.
|
||||
eps: The minimum to avoid dividing zero.
|
||||
dtype: The dype of output tensor.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
|
||||
@@ -232,7 +234,21 @@ def per_token_group_quant_fp8(
|
||||
), "the last dimension of `x` cannot be divisible by `group_size`"
|
||||
assert x.is_contiguous(), "`x` is not contiguous"
|
||||
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
|
||||
if _is_hip:
|
||||
if dtype == torch.int8:
|
||||
bit8_max = 127.0
|
||||
else:
|
||||
bit8_max = 224.0
|
||||
bit8_min = -bit8_max # TODO incorrect for int8
|
||||
else:
|
||||
if dtype == torch.int8:
|
||||
info = torch.iinfo(dtype)
|
||||
else:
|
||||
info = torch.finfo(dtype)
|
||||
bit8_max = info.max
|
||||
bit8_min = info.min
|
||||
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
||||
x_s = create_per_token_group_quant_fp8_output_scale(
|
||||
x_shape=x.shape,
|
||||
device=x.device,
|
||||
@@ -250,7 +266,7 @@ def per_token_group_quant_fp8(
|
||||
num_warps = min(max(BLOCK // 256, 1), 8)
|
||||
num_stages = 1
|
||||
if column_major_scales:
|
||||
_per_token_group_quant_fp8_colmajor[(M,)](
|
||||
_per_token_group_quant_8bit_colmajor[(M,)](
|
||||
x,
|
||||
x_q,
|
||||
x_s,
|
||||
@@ -258,8 +274,8 @@ def per_token_group_quant_fp8(
|
||||
x.shape[1],
|
||||
x_s.stride(1),
|
||||
eps,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
bit8_min=bit8_min,
|
||||
bit8_max=bit8_max,
|
||||
BLOCK=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
@@ -267,15 +283,15 @@ def per_token_group_quant_fp8(
|
||||
)
|
||||
else:
|
||||
assert not scale_ue8m0
|
||||
_per_token_group_quant_fp8[(M,)](
|
||||
_per_token_group_quant_8bit[(M,)](
|
||||
x,
|
||||
x_q,
|
||||
x_s,
|
||||
group_size,
|
||||
N,
|
||||
eps,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
bit8_min=bit8_min,
|
||||
bit8_max=bit8_max,
|
||||
BLOCK=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
@@ -297,6 +313,117 @@ def per_token_group_quant_fp8(
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
# backward compatibility
|
||||
per_token_group_quant_fp8 = _per_token_group_quant_8bit_raw
|
||||
|
||||
|
||||
def _per_token_group_quant_8bit_fuse_silu_and_mul(
|
||||
x: torch.Tensor,
|
||||
group_size: int,
|
||||
dst_dtype: torch.dtype,
|
||||
column_major_scales: bool,
|
||||
scale_tma_aligned: bool,
|
||||
scale_ue8m0: bool,
|
||||
masked_m: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Another way to implement (can be used in e.g. comparison tests)
|
||||
# from sgl_kernel import silu_and_mul
|
||||
# x_after_silu_and_mul = silu_and_mul(x)
|
||||
# return per_token_group_quant_fp8(
|
||||
# x_after_silu_and_mul,
|
||||
# group_size=group_size,
|
||||
# eps=eps,
|
||||
# column_major_scales=column_major_scales,
|
||||
# scale_tma_aligned=scale_tma_aligned,
|
||||
# scale_ue8m0=scale_ue8m0,
|
||||
# )
|
||||
|
||||
from deep_gemm.utils.layout import transform_sf_into_required_layout
|
||||
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_masked_post_quant_fwd
|
||||
|
||||
assert column_major_scales
|
||||
assert scale_tma_aligned
|
||||
assert scale_ue8m0
|
||||
|
||||
needs_unsqueeze = x.dim() == 2
|
||||
if needs_unsqueeze:
|
||||
num_tokens, _ = x.shape
|
||||
x = x.unsqueeze(0)
|
||||
assert masked_m is None
|
||||
masked_m = torch.tensor([num_tokens], device=x.device, dtype=torch.int32)
|
||||
|
||||
# Use `zeros` for easier testing
|
||||
output = torch.zeros(
|
||||
(*x.shape[:-1], x.shape[-1] // 2),
|
||||
device=x.device,
|
||||
dtype=dst_dtype,
|
||||
)
|
||||
# Use `zeros` for easier testing
|
||||
output_scale_for_kernel = torch.zeros(
|
||||
(*x.shape[:-1], x.shape[-1] // 2 // group_size),
|
||||
device=x.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
silu_and_mul_masked_post_quant_fwd(
|
||||
input=x,
|
||||
output=output,
|
||||
output_scale=output_scale_for_kernel,
|
||||
quant_group_size=group_size,
|
||||
masked_m=masked_m,
|
||||
scale_ue8m0=scale_ue8m0,
|
||||
)
|
||||
|
||||
assert group_size == 128
|
||||
output_scale = transform_sf_into_required_layout(
|
||||
output_scale_for_kernel,
|
||||
num_groups=output.shape[0],
|
||||
mn=output.shape[-2],
|
||||
k=output.shape[-1],
|
||||
recipe=(1, group_size, group_size),
|
||||
is_sfa=True,
|
||||
)
|
||||
|
||||
if needs_unsqueeze:
|
||||
output = output.squeeze(0)
|
||||
output_scale = output_scale.squeeze(0)
|
||||
|
||||
return output, output_scale
|
||||
|
||||
|
||||
def per_token_group_quant_8bit(
|
||||
x: torch.Tensor,
|
||||
group_size: int,
|
||||
dst_dtype: torch.dtype,
|
||||
eps: float = 1e-10,
|
||||
column_major_scales: bool = False,
|
||||
scale_tma_aligned: bool = False,
|
||||
scale_ue8m0: bool = False,
|
||||
fuse_silu_and_mul: bool = False,
|
||||
masked_m: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if fuse_silu_and_mul:
|
||||
return _per_token_group_quant_8bit_fuse_silu_and_mul(
|
||||
x=x,
|
||||
group_size=group_size,
|
||||
dst_dtype=dst_dtype,
|
||||
column_major_scales=column_major_scales,
|
||||
scale_tma_aligned=scale_tma_aligned,
|
||||
scale_ue8m0=scale_ue8m0,
|
||||
masked_m=masked_m,
|
||||
)
|
||||
else:
|
||||
return _per_token_group_quant_8bit_raw(
|
||||
x=x,
|
||||
group_size=group_size,
|
||||
eps=eps,
|
||||
column_major_scales=column_major_scales,
|
||||
scale_tma_aligned=scale_tma_aligned,
|
||||
scale_ue8m0=scale_ue8m0,
|
||||
dtype=dst_dtype,
|
||||
)
|
||||
|
||||
|
||||
def create_per_token_group_quant_fp8_output_scale(
|
||||
x_shape,
|
||||
device,
|
||||
@@ -307,16 +434,16 @@ def create_per_token_group_quant_fp8_output_scale(
|
||||
):
|
||||
if scale_ue8m0:
|
||||
assert column_major_scales and scale_tma_aligned
|
||||
x_q_mn, x_q_k = x_shape
|
||||
*x_batch, x_q_mn, x_q_k = x_shape
|
||||
x_s_mn, x_s_k = x_q_mn, x_q_k // 128
|
||||
aligned_mn = align(x_s_mn, 4)
|
||||
aligned_k = align(x_s_k, 4)
|
||||
# TODO(FIXME): Fix cuda kernel and recover here to empty.
|
||||
return torch.zeros(
|
||||
(aligned_k // 4, aligned_mn),
|
||||
return torch.empty(
|
||||
(*x_batch, aligned_k // 4, aligned_mn),
|
||||
device=device,
|
||||
dtype=torch.int,
|
||||
).transpose(0, 1)[:x_s_mn, :]
|
||||
).transpose(-1, -2)[..., :x_s_mn, :]
|
||||
elif column_major_scales:
|
||||
if scale_tma_aligned:
|
||||
# TODO extract "align" function
|
||||
@@ -341,39 +468,6 @@ def create_per_token_group_quant_fp8_output_scale(
|
||||
)
|
||||
|
||||
|
||||
# TODO maybe unify int8 and fp8 code later
|
||||
def per_token_group_quant_8bit(
|
||||
x: torch.Tensor,
|
||||
group_size: int,
|
||||
dst_dtype: torch.dtype,
|
||||
eps: float = 1e-10,
|
||||
column_major_scales: bool = False,
|
||||
scale_tma_aligned: bool = False,
|
||||
scale_ue8m0: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
from sglang.srt.layers.quantization.int8_kernel import per_token_group_quant_int8
|
||||
|
||||
if dst_dtype == torch.int8:
|
||||
assert not column_major_scales
|
||||
assert not scale_tma_aligned
|
||||
assert not scale_ue8m0
|
||||
return per_token_group_quant_int8(
|
||||
x=x,
|
||||
group_size=group_size,
|
||||
eps=eps,
|
||||
dtype=dst_dtype,
|
||||
)
|
||||
|
||||
return per_token_group_quant_fp8(
|
||||
x=x,
|
||||
group_size=group_size,
|
||||
eps=eps,
|
||||
column_major_scales=column_major_scales,
|
||||
scale_tma_aligned=scale_tma_aligned,
|
||||
scale_ue8m0=scale_ue8m0,
|
||||
)
|
||||
|
||||
|
||||
def sglang_per_token_group_quant_fp8(
|
||||
x: torch.Tensor,
|
||||
group_size: int,
|
||||
@@ -381,15 +475,19 @@ def sglang_per_token_group_quant_fp8(
|
||||
column_major_scales: bool = False,
|
||||
scale_tma_aligned: bool = False,
|
||||
scale_ue8m0: bool = False,
|
||||
fuse_silu_and_mul: bool = False,
|
||||
masked_m: Optional[torch.Tensor] = None,
|
||||
):
|
||||
assert (
|
||||
x.shape[-1] % group_size == 0
|
||||
), "the last dimension of `x` cannot be divisible by `group_size`"
|
||||
assert x.is_contiguous(), "`x` is not contiguous"
|
||||
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
|
||||
out_shape = (*x.shape[:-1], x.shape[-1] // (2 if fuse_silu_and_mul else 1))
|
||||
|
||||
x_q = torch.empty(out_shape, device=x.device, dtype=fp8_dtype)
|
||||
x_s = create_per_token_group_quant_fp8_output_scale(
|
||||
x_shape=x.shape,
|
||||
x_shape=out_shape,
|
||||
device=x.device,
|
||||
group_size=group_size,
|
||||
column_major_scales=column_major_scales,
|
||||
@@ -414,6 +512,8 @@ def sglang_per_token_group_quant_8bit(
|
||||
column_major_scales: bool = False,
|
||||
scale_tma_aligned: bool = False,
|
||||
scale_ue8m0: bool = False,
|
||||
fuse_silu_and_mul: bool = False,
|
||||
masked_m: Optional[torch.Tensor] = None,
|
||||
):
|
||||
from sglang.srt.layers.quantization.int8_kernel import (
|
||||
sglang_per_token_group_quant_int8,
|
||||
@@ -422,6 +522,8 @@ def sglang_per_token_group_quant_8bit(
|
||||
if dst_dtype == torch.int8:
|
||||
assert not column_major_scales
|
||||
assert not scale_tma_aligned
|
||||
assert not fuse_silu_and_mul
|
||||
assert masked_m is None
|
||||
return sglang_per_token_group_quant_int8(
|
||||
x=x,
|
||||
group_size=group_size,
|
||||
@@ -436,6 +538,8 @@ def sglang_per_token_group_quant_8bit(
|
||||
column_major_scales=column_major_scales,
|
||||
scale_tma_aligned=scale_tma_aligned,
|
||||
scale_ue8m0=scale_ue8m0,
|
||||
fuse_silu_and_mul=fuse_silu_and_mul,
|
||||
masked_m=masked_m,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user