[perf] experimental enhance fp8 per-tensor quant (#5370)

This commit is contained in:
JieXin Liang
2025-04-15 03:35:43 +08:00
committed by GitHub
parent e9fc2ac7b6
commit bdde237562
4 changed files with 178 additions and 13 deletions

View File

@@ -839,3 +839,103 @@ def w8a8_block_fp8_matmul(
)
return C
@triton.jit
def _per_tensor_quant_mla_fp8_stage1(
x_ptr,
x_s_ptr,
head_size,
x_stride_h,
x_stride_s,
eps,
fp8_max,
BLOCK_SIZE: tl.constexpr,
):
seq_id = tl.program_id(0)
head_id = tl.program_id(1)
offset = tl.arange(0, BLOCK_SIZE)
mask = offset < head_size
x_ptr += head_id * x_stride_h + seq_id * x_stride_s
x = tl.load(x_ptr + offset, mask=mask, other=0.0).to(tl.float32)
_absmax = tl.maximum(tl.max(tl.abs(x)), eps)
tl.atomic_max(x_s_ptr, _absmax / fp8_max)
@triton.jit
def _per_tensor_quant_mla_fp8_stage2(
x_ptr,
x_s_ptr,
x_q_ptr,
num_seq,
head_size,
x_stride_h,
x_stride_s,
fp8_min,
fp8_max,
BLOCK_SIZE: tl.constexpr,
):
seq_id = tl.program_id(0)
head_id = tl.program_id(1)
offset = tl.arange(0, BLOCK_SIZE)
mask = offset < head_size
x_s = tl.load(x_s_ptr)
x_s_inv = 1.0 / x_s
x_ptr += head_id * x_stride_h + seq_id * x_stride_s
x_q_ptr += head_id * num_seq * head_size + seq_id * head_size
x = tl.load(x_ptr + offset, mask=mask, other=0.0).to(tl.float32)
x_q = tl.clamp(x * x_s_inv, fp8_min, fp8_max).to(x_q_ptr.dtype.element_ty)
tl.store(x_q_ptr + offset, x_q, mask=mask)
def per_tensor_quant_mla_fp8(
x: torch.Tensor, eps: float = 1e-12, dtype: torch.dtype = torch.float8_e4m3fn
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
This function quantizes input values to float8 values with tensor-wise quantization
and specialized for mla absorbed case.
"""
assert x.dim() == 3, "`x` is not a 3d-tensor"
finfo = torch.finfo(dtype)
fp8_max = finfo.max
if _is_hip:
dtype = torch.float8_e4m3fnuz
fp8_max = 224.0
x_q = x.new_empty(x.size(), dtype=dtype)
x_s = torch.zeros((1,), dtype=torch.float32, device=x.device)
num_head, num_seq, head_size = x.shape
BLOCK_SIZE = triton.next_power_of_2(head_size)
grid = (num_seq, num_head)
_per_tensor_quant_mla_fp8_stage1[grid](
x,
x_s,
head_size,
x.stride(0),
x.stride(1),
eps,
fp8_max,
BLOCK_SIZE,
)
_per_tensor_quant_mla_fp8_stage2[grid](
x,
x_s,
x_q,
num_seq,
head_size,
x.stride(0),
x.stride(1),
-fp8_max,
fp8_max,
BLOCK_SIZE,
)
return x_q, x_s

View File

@@ -168,13 +168,13 @@ def input_to_float8(
"""This function quantizes input values to float8 values with tensor-wise quantization."""
finfo = torch.finfo(dtype)
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
amax = torch.maximum(min_val.abs(), max_val.abs()).float().clamp(min=1e-12)
fp8_max = finfo.max
if _is_hip:
dtype = torch.float8_e4m3fnuz
fp8_max = 224.0
scale = fp8_max / amax
x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max)
x_scl_sat = (x.float() * scale).clamp(min=-fp8_max, max=fp8_max)
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
@@ -213,7 +213,11 @@ def block_quant_to_tensor_quant(
for j in range(n_tiles):
x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype)
x_q_tensor, scale = (
sgl_scaled_fp8_quant(x_dq_block)
if _is_cuda
else input_to_float8(x_dq_block, dtype=x_q_block.dtype)
)
return x_q_tensor, scale
@@ -222,7 +226,11 @@ def channel_quant_to_tensor_quant(
x_s: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
x_dq_channel = x_q_channel.to(torch.float32) * x_s
x_q_tensor, scale = input_to_float8(x_dq_channel, dtype=x_q_channel.dtype)
x_q_tensor, scale = (
sgl_scaled_fp8_quant(x_dq_channel)
if _is_cuda
else input_to_float8(x_dq_channel, dtype=x_q_channel.dtype)
)
return x_q_tensor, scale