[perf] experimental enhance fp8 per-tensor quant (#5370)
This commit is contained in:
@@ -839,3 +839,103 @@ def w8a8_block_fp8_matmul(
|
||||
)
|
||||
|
||||
return C
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _per_tensor_quant_mla_fp8_stage1(
|
||||
x_ptr,
|
||||
x_s_ptr,
|
||||
head_size,
|
||||
x_stride_h,
|
||||
x_stride_s,
|
||||
eps,
|
||||
fp8_max,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
seq_id = tl.program_id(0)
|
||||
head_id = tl.program_id(1)
|
||||
offset = tl.arange(0, BLOCK_SIZE)
|
||||
mask = offset < head_size
|
||||
|
||||
x_ptr += head_id * x_stride_h + seq_id * x_stride_s
|
||||
x = tl.load(x_ptr + offset, mask=mask, other=0.0).to(tl.float32)
|
||||
_absmax = tl.maximum(tl.max(tl.abs(x)), eps)
|
||||
|
||||
tl.atomic_max(x_s_ptr, _absmax / fp8_max)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _per_tensor_quant_mla_fp8_stage2(
|
||||
x_ptr,
|
||||
x_s_ptr,
|
||||
x_q_ptr,
|
||||
num_seq,
|
||||
head_size,
|
||||
x_stride_h,
|
||||
x_stride_s,
|
||||
fp8_min,
|
||||
fp8_max,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
seq_id = tl.program_id(0)
|
||||
head_id = tl.program_id(1)
|
||||
offset = tl.arange(0, BLOCK_SIZE)
|
||||
mask = offset < head_size
|
||||
|
||||
x_s = tl.load(x_s_ptr)
|
||||
x_s_inv = 1.0 / x_s
|
||||
|
||||
x_ptr += head_id * x_stride_h + seq_id * x_stride_s
|
||||
x_q_ptr += head_id * num_seq * head_size + seq_id * head_size
|
||||
|
||||
x = tl.load(x_ptr + offset, mask=mask, other=0.0).to(tl.float32)
|
||||
x_q = tl.clamp(x * x_s_inv, fp8_min, fp8_max).to(x_q_ptr.dtype.element_ty)
|
||||
tl.store(x_q_ptr + offset, x_q, mask=mask)
|
||||
|
||||
|
||||
def per_tensor_quant_mla_fp8(
|
||||
x: torch.Tensor, eps: float = 1e-12, dtype: torch.dtype = torch.float8_e4m3fn
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
This function quantizes input values to float8 values with tensor-wise quantization
|
||||
and specialized for mla absorbed case.
|
||||
"""
|
||||
assert x.dim() == 3, "`x` is not a 3d-tensor"
|
||||
|
||||
finfo = torch.finfo(dtype)
|
||||
fp8_max = finfo.max
|
||||
if _is_hip:
|
||||
dtype = torch.float8_e4m3fnuz
|
||||
fp8_max = 224.0
|
||||
|
||||
x_q = x.new_empty(x.size(), dtype=dtype)
|
||||
x_s = torch.zeros((1,), dtype=torch.float32, device=x.device)
|
||||
|
||||
num_head, num_seq, head_size = x.shape
|
||||
BLOCK_SIZE = triton.next_power_of_2(head_size)
|
||||
grid = (num_seq, num_head)
|
||||
|
||||
_per_tensor_quant_mla_fp8_stage1[grid](
|
||||
x,
|
||||
x_s,
|
||||
head_size,
|
||||
x.stride(0),
|
||||
x.stride(1),
|
||||
eps,
|
||||
fp8_max,
|
||||
BLOCK_SIZE,
|
||||
)
|
||||
_per_tensor_quant_mla_fp8_stage2[grid](
|
||||
x,
|
||||
x_s,
|
||||
x_q,
|
||||
num_seq,
|
||||
head_size,
|
||||
x.stride(0),
|
||||
x.stride(1),
|
||||
-fp8_max,
|
||||
fp8_max,
|
||||
BLOCK_SIZE,
|
||||
)
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
@@ -168,13 +168,13 @@ def input_to_float8(
|
||||
"""This function quantizes input values to float8 values with tensor-wise quantization."""
|
||||
finfo = torch.finfo(dtype)
|
||||
min_val, max_val = x.aminmax()
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs()).float().clamp(min=1e-12)
|
||||
fp8_max = finfo.max
|
||||
if _is_hip:
|
||||
dtype = torch.float8_e4m3fnuz
|
||||
fp8_max = 224.0
|
||||
scale = fp8_max / amax
|
||||
x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max)
|
||||
x_scl_sat = (x.float() * scale).clamp(min=-fp8_max, max=fp8_max)
|
||||
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
||||
|
||||
|
||||
@@ -213,7 +213,11 @@ def block_quant_to_tensor_quant(
|
||||
for j in range(n_tiles):
|
||||
x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
|
||||
|
||||
x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype)
|
||||
x_q_tensor, scale = (
|
||||
sgl_scaled_fp8_quant(x_dq_block)
|
||||
if _is_cuda
|
||||
else input_to_float8(x_dq_block, dtype=x_q_block.dtype)
|
||||
)
|
||||
return x_q_tensor, scale
|
||||
|
||||
|
||||
@@ -222,7 +226,11 @@ def channel_quant_to_tensor_quant(
|
||||
x_s: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
x_dq_channel = x_q_channel.to(torch.float32) * x_s
|
||||
x_q_tensor, scale = input_to_float8(x_dq_channel, dtype=x_q_channel.dtype)
|
||||
x_q_tensor, scale = (
|
||||
sgl_scaled_fp8_quant(x_dq_channel)
|
||||
if _is_cuda
|
||||
else input_to_float8(x_dq_channel, dtype=x_q_channel.dtype)
|
||||
)
|
||||
return x_q_tensor, scale
|
||||
|
||||
|
||||
|
||||
@@ -53,10 +53,10 @@ from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.fp8_kernel import per_tensor_quant_mla_fp8
|
||||
from sglang.srt.layers.quantization.fp8_utils import (
|
||||
block_quant_to_tensor_quant,
|
||||
channel_quant_to_tensor_quant,
|
||||
input_to_float8,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
)
|
||||
from sglang.srt.layers.quantization.int8_utils import (
|
||||
@@ -817,8 +817,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
||||
)
|
||||
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
||||
q_nope_val, q_nope_scale = input_to_float8(
|
||||
q_nope.transpose(0, 1), torch.float8_e4m3fn
|
||||
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
|
||||
q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
|
||||
)
|
||||
q_nope_out = bmm_fp8(
|
||||
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
|
||||
@@ -848,8 +848,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.w_vc.to(torch.bfloat16) * self.w_scale,
|
||||
)
|
||||
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
||||
attn_output_val, attn_output_scale = input_to_float8(
|
||||
attn_output.transpose(0, 1), torch.float8_e4m3fn
|
||||
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
|
||||
attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
|
||||
)
|
||||
attn_bmm_output = bmm_fp8(
|
||||
attn_output_val,
|
||||
@@ -895,8 +895,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
||||
)
|
||||
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
||||
q_nope_val, q_nope_scale = input_to_float8(
|
||||
q_nope.transpose(0, 1), torch.float8_e4m3fn
|
||||
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
|
||||
q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
|
||||
)
|
||||
q_nope_out = bmm_fp8(
|
||||
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
|
||||
@@ -991,8 +991,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.w_vc.to(torch.bfloat16) * self.w_scale,
|
||||
)
|
||||
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
||||
attn_output_val, attn_output_scale = input_to_float8(
|
||||
attn_output.transpose(0, 1), torch.float8_e4m3fn
|
||||
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
|
||||
attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
|
||||
)
|
||||
attn_bmm_output = bmm_fp8(
|
||||
attn_output_val,
|
||||
|
||||
@@ -7,10 +7,12 @@ import torch
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
per_tensor_quant_mla_fp8,
|
||||
per_token_group_quant_fp8,
|
||||
static_quant_fp8,
|
||||
w8a8_block_fp8_matmul,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_utils import input_to_float8
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
||||
@@ -155,6 +157,61 @@ class TestStaticQuantFP8(CustomTestCase):
|
||||
self._static_quant_fp8(*params)
|
||||
|
||||
|
||||
class TestPerTensorQuantMlaFP8(CustomTestCase):
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float32]
|
||||
NUM_TOKENS = [7, 83, 2048]
|
||||
D = [512, 4096, 5120, 13824]
|
||||
LAST_D_EXT = [1024, 0]
|
||||
LAST_D = [512]
|
||||
SEEDS = [0]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if not torch.cuda.is_available():
|
||||
raise unittest.SkipTest("CUDA is not available")
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
def _per_tensor_quant_mla_fp8(self, num_tokens, d, last_d_ext, last_d, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
|
||||
x = torch.rand(
|
||||
(num_tokens, d // last_d, last_d + last_d_ext),
|
||||
dtype=dtype,
|
||||
)
|
||||
x_sub, _ = x.split([last_d, last_d_ext], dim=-1)
|
||||
|
||||
with torch.inference_mode():
|
||||
ref_out, ref_s = input_to_float8(x_sub.transpose(0, 1))
|
||||
out, out_s = per_tensor_quant_mla_fp8(x_sub.transpose(0, 1))
|
||||
|
||||
self.assertTrue(out.is_contiguous())
|
||||
self.assertTrue(
|
||||
torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.50)
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(out_s.to(torch.float32), ref_s.to(torch.float32))
|
||||
)
|
||||
|
||||
def test_per_tensor_quant_mla_fp8(self):
|
||||
for params in itertools.product(
|
||||
self.NUM_TOKENS,
|
||||
self.D,
|
||||
self.LAST_D_EXT,
|
||||
self.LAST_D,
|
||||
self.DTYPES,
|
||||
self.SEEDS,
|
||||
):
|
||||
with self.subTest(
|
||||
num_tokens=params[0],
|
||||
d=params[1],
|
||||
last_d_ext=params[2],
|
||||
last_d=params[3],
|
||||
dtype=params[4],
|
||||
seed=params[5],
|
||||
):
|
||||
self._per_tensor_quant_mla_fp8(*params)
|
||||
|
||||
|
||||
# For test
|
||||
def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
|
||||
"""This function performs matrix multiplication with block-wise quantization using native torch.
|
||||
|
||||
Reference in New Issue
Block a user