[perf] experimental enhance fp8 per-tensor quant (#5370)
This commit is contained in:
@@ -839,3 +839,103 @@ def w8a8_block_fp8_matmul(
|
||||
)
|
||||
|
||||
return C
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _per_tensor_quant_mla_fp8_stage1(
|
||||
x_ptr,
|
||||
x_s_ptr,
|
||||
head_size,
|
||||
x_stride_h,
|
||||
x_stride_s,
|
||||
eps,
|
||||
fp8_max,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
seq_id = tl.program_id(0)
|
||||
head_id = tl.program_id(1)
|
||||
offset = tl.arange(0, BLOCK_SIZE)
|
||||
mask = offset < head_size
|
||||
|
||||
x_ptr += head_id * x_stride_h + seq_id * x_stride_s
|
||||
x = tl.load(x_ptr + offset, mask=mask, other=0.0).to(tl.float32)
|
||||
_absmax = tl.maximum(tl.max(tl.abs(x)), eps)
|
||||
|
||||
tl.atomic_max(x_s_ptr, _absmax / fp8_max)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _per_tensor_quant_mla_fp8_stage2(
|
||||
x_ptr,
|
||||
x_s_ptr,
|
||||
x_q_ptr,
|
||||
num_seq,
|
||||
head_size,
|
||||
x_stride_h,
|
||||
x_stride_s,
|
||||
fp8_min,
|
||||
fp8_max,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
seq_id = tl.program_id(0)
|
||||
head_id = tl.program_id(1)
|
||||
offset = tl.arange(0, BLOCK_SIZE)
|
||||
mask = offset < head_size
|
||||
|
||||
x_s = tl.load(x_s_ptr)
|
||||
x_s_inv = 1.0 / x_s
|
||||
|
||||
x_ptr += head_id * x_stride_h + seq_id * x_stride_s
|
||||
x_q_ptr += head_id * num_seq * head_size + seq_id * head_size
|
||||
|
||||
x = tl.load(x_ptr + offset, mask=mask, other=0.0).to(tl.float32)
|
||||
x_q = tl.clamp(x * x_s_inv, fp8_min, fp8_max).to(x_q_ptr.dtype.element_ty)
|
||||
tl.store(x_q_ptr + offset, x_q, mask=mask)
|
||||
|
||||
|
||||
def per_tensor_quant_mla_fp8(
|
||||
x: torch.Tensor, eps: float = 1e-12, dtype: torch.dtype = torch.float8_e4m3fn
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
This function quantizes input values to float8 values with tensor-wise quantization
|
||||
and specialized for mla absorbed case.
|
||||
"""
|
||||
assert x.dim() == 3, "`x` is not a 3d-tensor"
|
||||
|
||||
finfo = torch.finfo(dtype)
|
||||
fp8_max = finfo.max
|
||||
if _is_hip:
|
||||
dtype = torch.float8_e4m3fnuz
|
||||
fp8_max = 224.0
|
||||
|
||||
x_q = x.new_empty(x.size(), dtype=dtype)
|
||||
x_s = torch.zeros((1,), dtype=torch.float32, device=x.device)
|
||||
|
||||
num_head, num_seq, head_size = x.shape
|
||||
BLOCK_SIZE = triton.next_power_of_2(head_size)
|
||||
grid = (num_seq, num_head)
|
||||
|
||||
_per_tensor_quant_mla_fp8_stage1[grid](
|
||||
x,
|
||||
x_s,
|
||||
head_size,
|
||||
x.stride(0),
|
||||
x.stride(1),
|
||||
eps,
|
||||
fp8_max,
|
||||
BLOCK_SIZE,
|
||||
)
|
||||
_per_tensor_quant_mla_fp8_stage2[grid](
|
||||
x,
|
||||
x_s,
|
||||
x_q,
|
||||
num_seq,
|
||||
head_size,
|
||||
x.stride(0),
|
||||
x.stride(1),
|
||||
-fp8_max,
|
||||
fp8_max,
|
||||
BLOCK_SIZE,
|
||||
)
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
@@ -168,13 +168,13 @@ def input_to_float8(
|
||||
"""This function quantizes input values to float8 values with tensor-wise quantization."""
|
||||
finfo = torch.finfo(dtype)
|
||||
min_val, max_val = x.aminmax()
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs()).float().clamp(min=1e-12)
|
||||
fp8_max = finfo.max
|
||||
if _is_hip:
|
||||
dtype = torch.float8_e4m3fnuz
|
||||
fp8_max = 224.0
|
||||
scale = fp8_max / amax
|
||||
x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max)
|
||||
x_scl_sat = (x.float() * scale).clamp(min=-fp8_max, max=fp8_max)
|
||||
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
||||
|
||||
|
||||
@@ -213,7 +213,11 @@ def block_quant_to_tensor_quant(
|
||||
for j in range(n_tiles):
|
||||
x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
|
||||
|
||||
x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype)
|
||||
x_q_tensor, scale = (
|
||||
sgl_scaled_fp8_quant(x_dq_block)
|
||||
if _is_cuda
|
||||
else input_to_float8(x_dq_block, dtype=x_q_block.dtype)
|
||||
)
|
||||
return x_q_tensor, scale
|
||||
|
||||
|
||||
@@ -222,7 +226,11 @@ def channel_quant_to_tensor_quant(
|
||||
x_s: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
x_dq_channel = x_q_channel.to(torch.float32) * x_s
|
||||
x_q_tensor, scale = input_to_float8(x_dq_channel, dtype=x_q_channel.dtype)
|
||||
x_q_tensor, scale = (
|
||||
sgl_scaled_fp8_quant(x_dq_channel)
|
||||
if _is_cuda
|
||||
else input_to_float8(x_dq_channel, dtype=x_q_channel.dtype)
|
||||
)
|
||||
return x_q_tensor, scale
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user