Fix quant kernel test errors and benchmark wrong output speeds (#7604)
This commit is contained in:
@@ -341,6 +341,39 @@ def create_per_token_group_quant_fp8_output_scale(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO maybe unify int8 and fp8 code later
|
||||||
|
def per_token_group_quant_8bit(
|
||||||
|
x: torch.Tensor,
|
||||||
|
group_size: int,
|
||||||
|
dst_dtype: torch.dtype,
|
||||||
|
eps: float = 1e-10,
|
||||||
|
column_major_scales: bool = False,
|
||||||
|
scale_tma_aligned: bool = False,
|
||||||
|
scale_ue8m0: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
from sglang.srt.layers.quantization.int8_kernel import per_token_group_quant_int8
|
||||||
|
|
||||||
|
if dst_dtype == torch.int8:
|
||||||
|
assert not column_major_scales
|
||||||
|
assert not scale_tma_aligned
|
||||||
|
assert not scale_ue8m0
|
||||||
|
return per_token_group_quant_int8(
|
||||||
|
x=x,
|
||||||
|
group_size=group_size,
|
||||||
|
eps=eps,
|
||||||
|
dtype=dst_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
return per_token_group_quant_fp8(
|
||||||
|
x=x,
|
||||||
|
group_size=group_size,
|
||||||
|
eps=eps,
|
||||||
|
column_major_scales=column_major_scales,
|
||||||
|
scale_tma_aligned=scale_tma_aligned,
|
||||||
|
scale_ue8m0=scale_ue8m0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def sglang_per_token_group_quant_fp8(
|
def sglang_per_token_group_quant_fp8(
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
group_size: int,
|
group_size: int,
|
||||||
@@ -372,6 +405,40 @@ def sglang_per_token_group_quant_fp8(
|
|||||||
return x_q, x_s
|
return x_q, x_s
|
||||||
|
|
||||||
|
|
||||||
|
# TODO maybe unify int8 and fp8 code later
|
||||||
|
def sglang_per_token_group_quant_8bit(
|
||||||
|
x: torch.Tensor,
|
||||||
|
group_size: int,
|
||||||
|
dst_dtype: torch.dtype,
|
||||||
|
eps: float = 1e-10,
|
||||||
|
column_major_scales: bool = False,
|
||||||
|
scale_tma_aligned: bool = False,
|
||||||
|
scale_ue8m0: bool = False,
|
||||||
|
):
|
||||||
|
from sglang.srt.layers.quantization.int8_kernel import (
|
||||||
|
sglang_per_token_group_quant_int8,
|
||||||
|
)
|
||||||
|
|
||||||
|
if dst_dtype == torch.int8:
|
||||||
|
assert not column_major_scales
|
||||||
|
assert not scale_tma_aligned
|
||||||
|
return sglang_per_token_group_quant_int8(
|
||||||
|
x=x,
|
||||||
|
group_size=group_size,
|
||||||
|
eps=eps,
|
||||||
|
dtype=dst_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
return sglang_per_token_group_quant_fp8(
|
||||||
|
x=x,
|
||||||
|
group_size=group_size,
|
||||||
|
eps=eps,
|
||||||
|
column_major_scales=column_major_scales,
|
||||||
|
scale_tma_aligned=scale_tma_aligned,
|
||||||
|
scale_ue8m0=scale_ue8m0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def sglang_per_token_quant_fp8(
|
def sglang_per_token_quant_fp8(
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
dtype: torch.dtype = fp8_dtype,
|
dtype: torch.dtype = fp8_dtype,
|
||||||
|
|||||||
@@ -176,6 +176,27 @@ def replace_parameter(
|
|||||||
mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False))
|
mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False))
|
||||||
|
|
||||||
|
|
||||||
|
def assert_fp8_all_close(a: torch.Tensor, b: torch.Tensor):
|
||||||
|
assert a.shape == b.shape
|
||||||
|
assert a.dtype == b.dtype == torch.float8_e4m3fn
|
||||||
|
|
||||||
|
a_u8 = a.view(torch.uint8)
|
||||||
|
b_u8 = b.view(torch.uint8)
|
||||||
|
diff_u8 = (a_u8.to(torch.int16) - b_u8.to(torch.int16)).abs()
|
||||||
|
|
||||||
|
numel = a.numel()
|
||||||
|
|
||||||
|
count_diff_sign = ((a_u8 >= 0) & (b_u8 < 0)).sum().item()
|
||||||
|
count_tiny_diff = (diff_u8 >= 1).sum().item()
|
||||||
|
count_large_diff = (diff_u8 >= 2).sum().item()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
(count_diff_sign == 0)
|
||||||
|
and (count_tiny_diff / numel < 0.005)
|
||||||
|
and (count_large_diff == 0)
|
||||||
|
), f"{count_diff_sign=} {count_tiny_diff=} {count_large_diff=} {numel=}"
|
||||||
|
|
||||||
|
|
||||||
# Match dynamic rules with module name (prefix) and override quantize
|
# Match dynamic rules with module name (prefix) and override quantize
|
||||||
# config if module (prefix) matches a rule
|
# config if module (prefix) matches a rule
|
||||||
def override_config(config: QuantizationConfig, prefix: str):
|
def override_config(config: QuantizationConfig, prefix: str):
|
||||||
|
|||||||
@@ -1,189 +1,68 @@
|
|||||||
import itertools
|
import itertools
|
||||||
from typing import Tuple
|
import time
|
||||||
|
from functools import partial
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
|
||||||
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_group_quant_int8
|
|
||||||
|
|
||||||
|
from sglang.srt.bench_utils import bench_kineto
|
||||||
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
|
create_per_token_group_quant_fp8_output_scale,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
|
per_token_group_quant_8bit as triton_per_token_group_quant_8bit,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit
|
||||||
from sglang.srt.utils import is_hip
|
from sglang.srt.utils import is_hip
|
||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
num_tokens_range = [1, 4, 16, 64, 256, 768, 2048, 8192, 16384]
|
||||||
def _per_token_group_quant_8bit(
|
hidden_dim_range = [1536, 7168, 18432] # For DeepSeek V3/R1
|
||||||
# Pointers to inputs and output
|
|
||||||
y_ptr,
|
|
||||||
y_q_ptr,
|
|
||||||
y_s_ptr,
|
|
||||||
# Stride of input
|
|
||||||
y_stride,
|
|
||||||
# Columns of input
|
|
||||||
N,
|
|
||||||
# Avoid to divide zero
|
|
||||||
eps,
|
|
||||||
# Information for 8bit data type (int8 or fp8_type_)
|
|
||||||
max_8bit,
|
|
||||||
min_8bit,
|
|
||||||
# Meta-parameters
|
|
||||||
BLOCK: tl.constexpr,
|
|
||||||
):
|
|
||||||
"""A Triton-accelerated function to perform per-token-group quantization on a
|
|
||||||
tensor.
|
|
||||||
This function converts the tensor values into 8bit values.
|
|
||||||
"""
|
|
||||||
# Map the program id to the row of X and Y it should compute.
|
|
||||||
g_id = tl.program_id(0)
|
|
||||||
y_ptr += g_id * y_stride
|
|
||||||
y_q_ptr += g_id * y_stride
|
|
||||||
y_s_ptr += g_id
|
|
||||||
|
|
||||||
cols = tl.arange(0, BLOCK) # N <= BLOCK
|
|
||||||
mask = cols < N
|
|
||||||
|
|
||||||
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
|
||||||
# Quant
|
|
||||||
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
|
||||||
y_s = _absmax / max_8bit
|
|
||||||
y_q = tl.clamp(y / y_s, min_8bit, max_8bit).to(y_q_ptr.dtype.element_ty)
|
|
||||||
|
|
||||||
tl.store(y_q_ptr + cols, y_q, mask=mask)
|
|
||||||
tl.store(y_s_ptr, y_s)
|
|
||||||
|
|
||||||
|
|
||||||
def triton_per_token_group_quant_8bit(
|
|
||||||
x: torch.Tensor,
|
|
||||||
group_size: int,
|
|
||||||
dst_dtype: torch.dtype,
|
|
||||||
eps: float = 1e-10,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""Function to perform per-token-group quantization on an input tensor `x`.
|
|
||||||
It converts the tensor values into signed float8 values and returns the
|
|
||||||
quantized tensor along with the scaling factor used for quantization.
|
|
||||||
Args:
|
|
||||||
x: The input tenosr with ndim >= 2.
|
|
||||||
group_size: The group size used for quantization.
|
|
||||||
eps: The minimum to avoid dividing zero.
|
|
||||||
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now.
|
|
||||||
Returns:
|
|
||||||
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
|
|
||||||
"""
|
|
||||||
assert (
|
|
||||||
x.shape[-1] % group_size == 0
|
|
||||||
), "the last dimension of `x` cannot be divisible by `group_size`"
|
|
||||||
assert x.is_contiguous(), "`x` is not contiguous"
|
|
||||||
|
|
||||||
if dst_dtype == torch.int8:
|
|
||||||
iinfo = torch.iinfo(dst_dtype)
|
|
||||||
max_8bit = iinfo.max
|
|
||||||
min_8bit = iinfo.min
|
|
||||||
else:
|
|
||||||
finfo = torch.finfo(dst_dtype)
|
|
||||||
max_8bit = finfo.max
|
|
||||||
min_8bit = finfo.min
|
|
||||||
|
|
||||||
x_q = torch.empty_like(x, device=x.device, dtype=dst_dtype)
|
|
||||||
M = x.numel() // group_size
|
|
||||||
N = group_size
|
|
||||||
x_s = torch.empty(
|
|
||||||
x.shape[:-1] + (x.shape[-1] // group_size,),
|
|
||||||
device=x.device,
|
|
||||||
dtype=torch.float32,
|
|
||||||
)
|
|
||||||
|
|
||||||
BLOCK = triton.next_power_of_2(N)
|
|
||||||
# heuristics for number of warps
|
|
||||||
num_warps = min(max(BLOCK // 256, 1), 8)
|
|
||||||
num_stages = 1
|
|
||||||
_per_token_group_quant_8bit[(M,)](
|
|
||||||
x,
|
|
||||||
x_q,
|
|
||||||
x_s,
|
|
||||||
group_size,
|
|
||||||
N,
|
|
||||||
eps,
|
|
||||||
max_8bit,
|
|
||||||
min_8bit,
|
|
||||||
BLOCK=BLOCK,
|
|
||||||
num_warps=num_warps,
|
|
||||||
num_stages=num_stages,
|
|
||||||
)
|
|
||||||
|
|
||||||
return x_q, x_s
|
|
||||||
|
|
||||||
|
|
||||||
def sglang_per_token_group_quant_8bit(
|
|
||||||
x: torch.Tensor,
|
|
||||||
group_size: int,
|
|
||||||
dst_dtype: torch.dtype,
|
|
||||||
eps: float = 1e-10,
|
|
||||||
):
|
|
||||||
assert (
|
|
||||||
x.shape[-1] % group_size == 0
|
|
||||||
), "the last dimension of `x` cannot be divisible by `group_size`"
|
|
||||||
assert x.is_contiguous(), "`x` is not contiguous"
|
|
||||||
|
|
||||||
x_q = torch.empty_like(x, device=x.device, dtype=dst_dtype)
|
|
||||||
x_s = torch.empty(
|
|
||||||
x.shape[:-1] + (x.shape[-1] // group_size,),
|
|
||||||
device=x.device,
|
|
||||||
dtype=torch.float32,
|
|
||||||
)
|
|
||||||
|
|
||||||
if dst_dtype == torch.int8:
|
|
||||||
iinfo = torch.iinfo(dst_dtype)
|
|
||||||
int8_max = iinfo.max
|
|
||||||
int8_min = iinfo.min
|
|
||||||
sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max)
|
|
||||||
else:
|
|
||||||
f8_info = torch.finfo(dst_dtype)
|
|
||||||
fp8_max = f8_info.max
|
|
||||||
fp8_min = f8_info.min
|
|
||||||
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
|
|
||||||
|
|
||||||
return x_q, x_s
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_diff(batch_size, seq_len, group_size, dst_dtype):
|
|
||||||
device = torch.device("cuda")
|
|
||||||
hidden_dim = 7168
|
|
||||||
|
|
||||||
x = torch.randn(
|
|
||||||
batch_size * seq_len, hidden_dim, device=device, dtype=torch.float16
|
|
||||||
)
|
|
||||||
|
|
||||||
x_q_triton, x_s_triton = triton_per_token_group_quant_8bit(
|
|
||||||
x.clone(), group_size, dst_dtype
|
|
||||||
)
|
|
||||||
x_q_sglang, x_s_sglang = sglang_per_token_group_quant_8bit(
|
|
||||||
x.clone(), group_size, dst_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
if torch.allclose(
|
|
||||||
x_q_triton.to(torch.float32), x_q_sglang.to(torch.float32), rtol=1e-3, atol=1e-5
|
|
||||||
) and torch.allclose(x_s_triton, x_s_sglang, rtol=1e-3, atol=1e-5):
|
|
||||||
print(f"✅ {dst_dtype} implementations match")
|
|
||||||
else:
|
|
||||||
print("❌ Implementations differ")
|
|
||||||
|
|
||||||
|
|
||||||
batch_size_range = [1, 2, 4, 8, 16, 32, 64]
|
|
||||||
seq_len_range = [64, 128, 256, 512, 1024, 2048]
|
|
||||||
group_size_range = [128] # For DeepSeek V3/R1
|
group_size_range = [128] # For DeepSeek V3/R1
|
||||||
dst_dtype_range = [torch.int8, fp8_type_]
|
# TODO test int8
|
||||||
|
dst_dtype_range = [fp8_type_]
|
||||||
|
flags_range = [
|
||||||
|
dict(
|
||||||
|
column_major_scales=False,
|
||||||
|
scale_tma_aligned=False,
|
||||||
|
scale_ue8m0=False,
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
column_major_scales=True,
|
||||||
|
scale_tma_aligned=False,
|
||||||
|
scale_ue8m0=False,
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
column_major_scales=True,
|
||||||
|
scale_tma_aligned=True,
|
||||||
|
scale_ue8m0=False,
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
column_major_scales=True,
|
||||||
|
scale_tma_aligned=True,
|
||||||
|
scale_ue8m0=True,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
configs = list(
|
configs = list(
|
||||||
itertools.product(
|
itertools.product(
|
||||||
batch_size_range, seq_len_range, group_size_range, dst_dtype_range
|
num_tokens_range,
|
||||||
|
hidden_dim_range,
|
||||||
|
group_size_range,
|
||||||
|
dst_dtype_range,
|
||||||
|
flags_range,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@triton.testing.perf_report(
|
@triton.testing.perf_report(
|
||||||
triton.testing.Benchmark(
|
triton.testing.Benchmark(
|
||||||
x_names=["batch_size", "seq_len", "group_size", "dst_dtype"],
|
x_names=["num_tokens", "hidden_dim", "group_size", "dst_dtype", "flags"],
|
||||||
x_vals=configs,
|
x_vals=configs,
|
||||||
line_arg="provider",
|
line_arg="provider",
|
||||||
line_vals=["triton", "sglang"],
|
line_vals=["triton", "sglang"],
|
||||||
@@ -194,29 +73,26 @@ configs = list(
|
|||||||
args={},
|
args={},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
def benchmark(batch_size, seq_len, group_size, dst_dtype, provider):
|
def benchmark(num_tokens, hidden_dim, group_size, dst_dtype, flags, provider):
|
||||||
|
if flags["scale_ue8m0"] and group_size != 128:
|
||||||
|
return
|
||||||
|
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
hidden_dim = 7168
|
|
||||||
|
|
||||||
x = torch.randn(
|
x = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16)
|
||||||
batch_size * seq_len, hidden_dim, device=device, dtype=torch.float16
|
|
||||||
)
|
|
||||||
|
|
||||||
quantiles = [0.5, 0.2, 0.8]
|
fn, kernel_names = {
|
||||||
|
"triton": (triton_per_token_group_quant_8bit, "_per_token_group_quant_fp8"),
|
||||||
|
"sglang": (
|
||||||
|
sglang_per_token_group_quant_8bit,
|
||||||
|
"per_token_group_quant_8bit_kernel",
|
||||||
|
),
|
||||||
|
}[provider]
|
||||||
|
bench_fn = lambda: fn(x=x, group_size=group_size, dst_dtype=dst_dtype, **flags)
|
||||||
|
|
||||||
if provider == "triton":
|
time_s = bench_kineto(bench_fn, kernel_names=kernel_names)
|
||||||
fn = lambda: triton_per_token_group_quant_8bit(x, group_size, dst_dtype)
|
return time_s * 1e6
|
||||||
elif provider == "sglang":
|
|
||||||
fn = lambda: sglang_per_token_group_quant_8bit(x, group_size, dst_dtype)
|
|
||||||
|
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
|
|
||||||
|
|
||||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
calculate_diff(batch_size=4, seq_len=128, group_size=64, dst_dtype=torch.int8)
|
|
||||||
calculate_diff(batch_size=4, seq_len=128, group_size=64, dst_dtype=fp8_type_)
|
|
||||||
|
|
||||||
benchmark.run(print_data=True)
|
benchmark.run(print_data=True)
|
||||||
|
|||||||
@@ -1,278 +1,51 @@
|
|||||||
import itertools
|
import itertools
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_group_quant_int8
|
|
||||||
|
|
||||||
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||||
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
|
per_token_group_quant_8bit as triton_per_token_group_quant_8bit,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit
|
||||||
|
from sglang.srt.layers.quantization.utils import assert_fp8_all_close
|
||||||
from sglang.srt.utils import is_hip
|
from sglang.srt.utils import is_hip
|
||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def _per_token_group_quant_fp8(
|
|
||||||
# Pointers to inputs and output
|
|
||||||
y_ptr,
|
|
||||||
y_q_ptr,
|
|
||||||
y_s_ptr,
|
|
||||||
# Stride of input
|
|
||||||
y_stride,
|
|
||||||
# Columns of input
|
|
||||||
N,
|
|
||||||
# Avoid to divide zero
|
|
||||||
eps,
|
|
||||||
# Information for float8
|
|
||||||
fp8_min,
|
|
||||||
fp8_max,
|
|
||||||
# Meta-parameters
|
|
||||||
BLOCK: tl.constexpr,
|
|
||||||
):
|
|
||||||
"""A Triton-accelerated function to perform per-token-group quantization on a
|
|
||||||
tensor.
|
|
||||||
|
|
||||||
This function converts the tensor values into float8 values.
|
|
||||||
"""
|
|
||||||
# Map the program id to the row of X and Y it should compute.
|
|
||||||
g_id = tl.program_id(0)
|
|
||||||
y_ptr += g_id * y_stride
|
|
||||||
y_q_ptr += g_id * y_stride
|
|
||||||
y_s_ptr += g_id
|
|
||||||
|
|
||||||
cols = tl.arange(0, BLOCK) # N <= BLOCK
|
|
||||||
mask = cols < N
|
|
||||||
|
|
||||||
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
|
||||||
# Quant
|
|
||||||
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
|
||||||
y_s = _absmax / fp8_max
|
|
||||||
y_s_inv = 1.0 / y_s
|
|
||||||
y_q = tl.clamp(y * y_s_inv, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
|
||||||
|
|
||||||
tl.store(y_q_ptr + cols, y_q, mask=mask)
|
|
||||||
tl.store(y_s_ptr, y_s)
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def _per_token_group_quant_fp8_colmajor(
|
|
||||||
# Pointers to inputs and output
|
|
||||||
y_ptr,
|
|
||||||
y_q_ptr,
|
|
||||||
y_s_ptr,
|
|
||||||
group_size,
|
|
||||||
# Num columns of y
|
|
||||||
y_num_columns,
|
|
||||||
# Stride from one column to the next of y_s
|
|
||||||
y_s_col_stride,
|
|
||||||
# Avoid to divide zero
|
|
||||||
eps,
|
|
||||||
# Information for float8
|
|
||||||
fp8_min,
|
|
||||||
fp8_max,
|
|
||||||
# Meta-parameters
|
|
||||||
BLOCK: tl.constexpr,
|
|
||||||
):
|
|
||||||
"""A Triton-accelerated function to perform per-token-group
|
|
||||||
quantization on a tensor.
|
|
||||||
This function converts the tensor values into float8 values.
|
|
||||||
"""
|
|
||||||
# Map the program id to the row of X and Y it should compute.
|
|
||||||
g_id = tl.program_id(0)
|
|
||||||
y_ptr += g_id * group_size
|
|
||||||
y_q_ptr += g_id * group_size
|
|
||||||
|
|
||||||
# Convert g_id the flattened block coordinate to 2D so we can index
|
|
||||||
# into the output y_scales matrix
|
|
||||||
blocks_per_row = y_num_columns // group_size
|
|
||||||
scale_col = g_id % blocks_per_row
|
|
||||||
scale_row = g_id // blocks_per_row
|
|
||||||
y_s_ptr += scale_col * y_s_col_stride + scale_row
|
|
||||||
|
|
||||||
cols = tl.arange(0, BLOCK) # group_size <= BLOCK
|
|
||||||
mask = cols < group_size
|
|
||||||
|
|
||||||
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
|
||||||
# Quant
|
|
||||||
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
|
||||||
y_s = _absmax / fp8_max
|
|
||||||
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
|
||||||
|
|
||||||
tl.store(y_q_ptr + cols, y_q, mask=mask)
|
|
||||||
tl.store(y_s_ptr, y_s)
|
|
||||||
|
|
||||||
|
|
||||||
def triton_per_token_group_quant_8bit(
|
|
||||||
x: torch.Tensor,
|
|
||||||
group_size: int,
|
|
||||||
eps: float = 1e-10,
|
|
||||||
dtype: torch.dtype = fp8_type_,
|
|
||||||
column_major_scales: bool = False,
|
|
||||||
scale_tma_aligned: bool = False,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""Function to perform per-token-group quantization on an input tensor `x`.
|
|
||||||
|
|
||||||
It converts the tensor values into signed float8 values and returns the
|
|
||||||
quantized tensor along with the scaling factor used for quantization.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x: The input tenosr with ndim >= 2.
|
|
||||||
group_size: The group size used for quantization.
|
|
||||||
eps: The minimum to avoid dividing zero.
|
|
||||||
dtype: The dype of output tensor.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
|
|
||||||
"""
|
|
||||||
assert (
|
|
||||||
x.shape[-1] % group_size == 0
|
|
||||||
), "the last dimension of `x` cannot be divisible by `group_size`"
|
|
||||||
assert x.is_contiguous(), "`x` is not contiguous"
|
|
||||||
|
|
||||||
if dtype == torch.int8:
|
|
||||||
finfo = torch.iinfo(dtype)
|
|
||||||
else:
|
|
||||||
finfo = torch.finfo(dtype)
|
|
||||||
|
|
||||||
fp8_max = finfo.max
|
|
||||||
|
|
||||||
if _is_hip:
|
|
||||||
if dtype == torch.int8:
|
|
||||||
fp8_max = 127.0
|
|
||||||
else:
|
|
||||||
fp8_max = 224.0
|
|
||||||
|
|
||||||
fp8_min = -fp8_max
|
|
||||||
|
|
||||||
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
|
||||||
M = x.numel() // group_size
|
|
||||||
N = group_size
|
|
||||||
if column_major_scales:
|
|
||||||
if scale_tma_aligned:
|
|
||||||
# aligned to 4 * sizeof(float)
|
|
||||||
aligned_size = (x.shape[-2] + 3) // 4 * 4
|
|
||||||
x_s = torch.empty(
|
|
||||||
x.shape[:-2] + (x.shape[-1] // group_size, aligned_size),
|
|
||||||
device=x.device,
|
|
||||||
dtype=torch.float32,
|
|
||||||
).permute(-1, -2)[: x.shape[-2], :]
|
|
||||||
else:
|
|
||||||
x_s = torch.empty(
|
|
||||||
(x.shape[-1] // group_size,) + x.shape[:-1],
|
|
||||||
device=x.device,
|
|
||||||
dtype=torch.float32,
|
|
||||||
).permute(-1, -2)
|
|
||||||
else:
|
|
||||||
x_s = torch.empty(
|
|
||||||
x.shape[:-1] + (x.shape[-1] // group_size,),
|
|
||||||
device=x.device,
|
|
||||||
dtype=torch.float32,
|
|
||||||
)
|
|
||||||
|
|
||||||
BLOCK = triton.next_power_of_2(N)
|
|
||||||
# heuristics for number of warps
|
|
||||||
num_warps = min(max(BLOCK // 256, 1), 8)
|
|
||||||
num_stages = 1
|
|
||||||
if column_major_scales:
|
|
||||||
_per_token_group_quant_fp8_colmajor[(M,)](
|
|
||||||
x,
|
|
||||||
x_q,
|
|
||||||
x_s,
|
|
||||||
group_size,
|
|
||||||
x.shape[1],
|
|
||||||
x_s.stride(1),
|
|
||||||
eps,
|
|
||||||
fp8_min=fp8_min,
|
|
||||||
fp8_max=fp8_max,
|
|
||||||
BLOCK=BLOCK,
|
|
||||||
num_warps=num_warps,
|
|
||||||
num_stages=num_stages,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
_per_token_group_quant_fp8[(M,)](
|
|
||||||
x,
|
|
||||||
x_q,
|
|
||||||
x_s,
|
|
||||||
group_size,
|
|
||||||
N,
|
|
||||||
eps,
|
|
||||||
fp8_min=fp8_min,
|
|
||||||
fp8_max=fp8_max,
|
|
||||||
BLOCK=BLOCK,
|
|
||||||
num_warps=num_warps,
|
|
||||||
num_stages=num_stages,
|
|
||||||
)
|
|
||||||
|
|
||||||
return x_q, x_s
|
|
||||||
|
|
||||||
|
|
||||||
def sglang_per_token_group_quant_8bit(
|
|
||||||
x: torch.Tensor,
|
|
||||||
group_size: int,
|
|
||||||
eps: float = 1e-10,
|
|
||||||
dtype: torch.dtype = fp8_type_,
|
|
||||||
column_major_scales: bool = False,
|
|
||||||
scale_tma_aligned: bool = False,
|
|
||||||
):
|
|
||||||
assert (
|
|
||||||
x.shape[-1] % group_size == 0
|
|
||||||
), "the last dimension of `x` cannot be divisible by `group_size`"
|
|
||||||
assert x.is_contiguous(), "`x` is not contiguous"
|
|
||||||
|
|
||||||
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
|
||||||
M = x.numel() // group_size
|
|
||||||
N = group_size
|
|
||||||
if column_major_scales:
|
|
||||||
if scale_tma_aligned:
|
|
||||||
# aligned to 4 * sizeof(float)
|
|
||||||
aligned_size = (x.shape[-2] + 3) // 4 * 4
|
|
||||||
x_s = torch.empty(
|
|
||||||
x.shape[:-2] + (x.shape[-1] // group_size, aligned_size),
|
|
||||||
device=x.device,
|
|
||||||
dtype=torch.float32,
|
|
||||||
).permute(-1, -2)[: x.shape[-2], :]
|
|
||||||
else:
|
|
||||||
x_s = torch.empty(
|
|
||||||
(x.shape[-1] // group_size,) + x.shape[:-1],
|
|
||||||
device=x.device,
|
|
||||||
dtype=torch.float32,
|
|
||||||
).permute(-1, -2)
|
|
||||||
else:
|
|
||||||
x_s = torch.empty(
|
|
||||||
x.shape[:-1] + (x.shape[-1] // group_size,),
|
|
||||||
device=x.device,
|
|
||||||
dtype=torch.float32,
|
|
||||||
)
|
|
||||||
|
|
||||||
if dtype == torch.int8:
|
|
||||||
iinfo = torch.iinfo(dtype)
|
|
||||||
int8_max = iinfo.max
|
|
||||||
int8_min = iinfo.min
|
|
||||||
sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max)
|
|
||||||
else:
|
|
||||||
f8_info = torch.finfo(dtype)
|
|
||||||
fp8_max = f8_info.max
|
|
||||||
fp8_min = f8_info.min
|
|
||||||
scale_ue8m0 = False # TODO also test true
|
|
||||||
sgl_per_token_group_quant_fp8(
|
|
||||||
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
|
|
||||||
)
|
|
||||||
|
|
||||||
return x_q, x_s
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"num_tokens, hidden_dim, group_size, dst_dtype, column_major_scales, scale_tma_aligned",
|
"num_tokens, hidden_dim, group_size, dst_dtype, flags",
|
||||||
list(
|
list(
|
||||||
itertools.product(
|
itertools.product(
|
||||||
[127, 128, 512, 1024, 4096, 8192], # num_tokens
|
[127, 128, 512, 1024, 4096, 8192], # num_tokens
|
||||||
[256, 512, 1024, 2048, 4096], # hidden_dim
|
[256, 512, 1024, 2048, 4096], # hidden_dim
|
||||||
[8, 16, 32, 64, 128], # group_size
|
[8, 16, 32, 64, 128], # group_size
|
||||||
[torch.int8, fp8_type_], # dtype
|
# TODO test int8
|
||||||
[False, True], # column_major_scales
|
[fp8_type_], # dtype
|
||||||
[False, True], # scale_tma_aligned
|
[
|
||||||
|
dict(
|
||||||
|
column_major_scales=False,
|
||||||
|
scale_tma_aligned=False,
|
||||||
|
scale_ue8m0=False,
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
column_major_scales=True,
|
||||||
|
scale_tma_aligned=False,
|
||||||
|
scale_ue8m0=False,
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
column_major_scales=True,
|
||||||
|
scale_tma_aligned=True,
|
||||||
|
scale_ue8m0=False,
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
column_major_scales=True,
|
||||||
|
scale_tma_aligned=True,
|
||||||
|
scale_ue8m0=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -281,37 +54,42 @@ def test_per_token_group_quant_with_column_major(
|
|||||||
hidden_dim,
|
hidden_dim,
|
||||||
group_size,
|
group_size,
|
||||||
dst_dtype,
|
dst_dtype,
|
||||||
column_major_scales,
|
flags,
|
||||||
scale_tma_aligned,
|
|
||||||
):
|
):
|
||||||
if not column_major_scales and scale_tma_aligned:
|
if flags["scale_ue8m0"] and ((group_size != 128) or (hidden_dim % 512 != 0)):
|
||||||
|
pytest.skip()
|
||||||
|
return
|
||||||
|
if flags["scale_ue8m0"] and not deep_gemm_wrapper.DEEPGEMM_BLACKWELL:
|
||||||
|
pytest.skip("scale_ue8m0 only supported on Blackwell")
|
||||||
return
|
return
|
||||||
|
|
||||||
x = torch.randn(num_tokens, hidden_dim, device="cuda", dtype=torch.float16)
|
x = torch.randn(num_tokens, hidden_dim, device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
x_q_triton, x_s_triton = triton_per_token_group_quant_8bit(
|
execute_kwargs = dict(
|
||||||
x,
|
x=x,
|
||||||
group_size,
|
group_size=group_size,
|
||||||
eps=1e-10,
|
eps=1e-10,
|
||||||
dtype=dst_dtype,
|
dst_dtype=dst_dtype,
|
||||||
column_major_scales=column_major_scales,
|
**flags,
|
||||||
scale_tma_aligned=scale_tma_aligned,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
x_q_sglang, x_s_sglang = sglang_per_token_group_quant_8bit(
|
x_q_triton, x_s_triton = triton_per_token_group_quant_8bit(**execute_kwargs)
|
||||||
x,
|
x_q_sglang, x_s_sglang = sglang_per_token_group_quant_8bit(**execute_kwargs)
|
||||||
group_size,
|
|
||||||
eps=1e-10,
|
|
||||||
dtype=dst_dtype,
|
|
||||||
column_major_scales=column_major_scales,
|
|
||||||
scale_tma_aligned=scale_tma_aligned,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# torch.set_printoptions(profile="full")
|
||||||
|
# print(f"{x_q_triton=}")
|
||||||
|
# print(f"{x_s_triton=}")
|
||||||
|
# print(f"{x_q_sglang=}")
|
||||||
|
# print(f"{x_s_sglang=}")
|
||||||
|
# torch.set_printoptions(profile="default")
|
||||||
|
|
||||||
|
assert_fp8_all_close(x_q_triton, x_q_sglang)
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
x_q_triton.to(torch.float32), x_q_sglang.to(torch.float32), rtol=1e-3, atol=1e-5
|
x_s_triton.contiguous(),
|
||||||
)
|
x_s_sglang.contiguous(),
|
||||||
torch.testing.assert_close(
|
rtol=1e-3,
|
||||||
x_s_triton.contiguous(), x_s_sglang.contiguous(), rtol=1e-3, atol=1e-5
|
atol=1e-5,
|
||||||
|
msg=lambda message: message + f" {x_s_triton=} {x_s_sglang=}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user