Apply sgl w8a8 fp8 kernel (#3148)
This commit is contained in:
@@ -7,6 +7,7 @@ from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
per_token_group_quant_fp8,
|
||||
static_quant_fp8,
|
||||
w8a8_block_fp8_matmul,
|
||||
)
|
||||
|
||||
@@ -63,7 +64,7 @@ class TestPerTokenGroupQuantFP8(unittest.TestCase):
|
||||
out, scale = per_token_group_quant_fp8(x, group_size)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.15)
|
||||
torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.20)
|
||||
)
|
||||
self.assertTrue(torch.allclose(scale, ref_scale))
|
||||
|
||||
@@ -85,6 +86,71 @@ class TestPerTokenGroupQuantFP8(unittest.TestCase):
|
||||
self._per_token_group_quant_fp8(*params)
|
||||
|
||||
|
||||
# For test
|
||||
def native_static_quant_fp8(x, x_s, dtype=torch.float8_e4m3fn):
|
||||
"""Function to perform static quantization on an input tensor `x` using native torch.
|
||||
|
||||
It converts the tensor values into float8 values and returns the
|
||||
quantized tensor along with the scaling factor used for quantization.
|
||||
"""
|
||||
assert x.is_contiguous(), "`x` is not contiguous"
|
||||
assert x_s.numel() == 1, "only supports per-tensor scale"
|
||||
|
||||
finfo = torch.finfo(dtype)
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
|
||||
x_ = x.reshape(x.numel() // x.shape[-1], x.shape[-1])
|
||||
x_s_inv = 1.0 / x_s
|
||||
x_q = (x_ * x_s_inv).clamp(min=fp8_min, max=fp8_max).to(dtype)
|
||||
x_q = x_q.reshape(x.shape)
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
class TestStaticQuantFP8(unittest.TestCase):
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float32]
|
||||
NUM_TOKENS = [7, 83, 2048]
|
||||
D = [512, 4096, 5120, 13824]
|
||||
SEEDS = [0]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if not torch.cuda.is_available():
|
||||
raise unittest.SkipTest("CUDA is not available")
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
def _static_quant_fp8(self, num_tokens, d, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
|
||||
x = torch.rand(num_tokens, d, dtype=dtype)
|
||||
fp8_max = torch.finfo(torch.float8_e4m3fn).max
|
||||
x_s = x.max() / fp8_max
|
||||
|
||||
with torch.inference_mode():
|
||||
ref_out, _ = native_static_quant_fp8(x, x_s)
|
||||
out, _ = static_quant_fp8(x, x_s, repeat_scale=True)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.50)
|
||||
)
|
||||
|
||||
def test_static_quant_fp8(self):
|
||||
for params in itertools.product(
|
||||
self.NUM_TOKENS,
|
||||
self.D,
|
||||
self.DTYPES,
|
||||
self.SEEDS,
|
||||
):
|
||||
with self.subTest(
|
||||
num_tokens=params[0],
|
||||
d=params[1],
|
||||
dtype=params[2],
|
||||
seed=params[3],
|
||||
):
|
||||
self._static_quant_fp8(*params)
|
||||
|
||||
|
||||
# For test
|
||||
def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
|
||||
"""This function performs matrix multiplication with block-wise quantization using native torch.
|
||||
|
||||
Reference in New Issue
Block a user