diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index e9df65a15..77ab92aff 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -341,6 +341,39 @@ def create_per_token_group_quant_fp8_output_scale( ) +# TODO maybe unify int8 and fp8 code later +def per_token_group_quant_8bit( + x: torch.Tensor, + group_size: int, + dst_dtype: torch.dtype, + eps: float = 1e-10, + column_major_scales: bool = False, + scale_tma_aligned: bool = False, + scale_ue8m0: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + from sglang.srt.layers.quantization.int8_kernel import per_token_group_quant_int8 + + if dst_dtype == torch.int8: + assert not column_major_scales + assert not scale_tma_aligned + assert not scale_ue8m0 + return per_token_group_quant_int8( + x=x, + group_size=group_size, + eps=eps, + dtype=dst_dtype, + ) + + return per_token_group_quant_fp8( + x=x, + group_size=group_size, + eps=eps, + column_major_scales=column_major_scales, + scale_tma_aligned=scale_tma_aligned, + scale_ue8m0=scale_ue8m0, + ) + + def sglang_per_token_group_quant_fp8( x: torch.Tensor, group_size: int, @@ -372,6 +405,40 @@ def sglang_per_token_group_quant_fp8( return x_q, x_s +# TODO maybe unify int8 and fp8 code later +def sglang_per_token_group_quant_8bit( + x: torch.Tensor, + group_size: int, + dst_dtype: torch.dtype, + eps: float = 1e-10, + column_major_scales: bool = False, + scale_tma_aligned: bool = False, + scale_ue8m0: bool = False, +): + from sglang.srt.layers.quantization.int8_kernel import ( + sglang_per_token_group_quant_int8, + ) + + if dst_dtype == torch.int8: + assert not column_major_scales + assert not scale_tma_aligned + return sglang_per_token_group_quant_int8( + x=x, + group_size=group_size, + eps=eps, + dtype=dst_dtype, + ) + + return sglang_per_token_group_quant_fp8( + x=x, + group_size=group_size, + eps=eps, + column_major_scales=column_major_scales, + scale_tma_aligned=scale_tma_aligned, + scale_ue8m0=scale_ue8m0, + ) + + def sglang_per_token_quant_fp8( x: torch.Tensor, dtype: torch.dtype = fp8_dtype, diff --git a/python/sglang/srt/layers/quantization/utils.py b/python/sglang/srt/layers/quantization/utils.py index a7be39141..df434ae0a 100644 --- a/python/sglang/srt/layers/quantization/utils.py +++ b/python/sglang/srt/layers/quantization/utils.py @@ -176,6 +176,27 @@ def replace_parameter( mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False)) +def assert_fp8_all_close(a: torch.Tensor, b: torch.Tensor): + assert a.shape == b.shape + assert a.dtype == b.dtype == torch.float8_e4m3fn + + a_u8 = a.view(torch.uint8) + b_u8 = b.view(torch.uint8) + diff_u8 = (a_u8.to(torch.int16) - b_u8.to(torch.int16)).abs() + + numel = a.numel() + + count_diff_sign = ((a_u8 >= 0) & (b_u8 < 0)).sum().item() + count_tiny_diff = (diff_u8 >= 1).sum().item() + count_large_diff = (diff_u8 >= 2).sum().item() + + assert ( + (count_diff_sign == 0) + and (count_tiny_diff / numel < 0.005) + and (count_large_diff == 0) + ), f"{count_diff_sign=} {count_tiny_diff=} {count_large_diff=} {numel=}" + + # Match dynamic rules with module name (prefix) and override quantize # config if module (prefix) matches a rule def override_config(config: QuantizationConfig, prefix: str): diff --git a/sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py b/sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py index 5a9248982..3f37a3248 100644 --- a/sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py +++ b/sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py @@ -1,189 +1,68 @@ import itertools -from typing import Tuple +import time +from functools import partial +from pathlib import Path import torch import triton -import triton.language as tl -from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_group_quant_int8 +from sglang.srt.bench_utils import bench_kineto +from sglang.srt.layers.quantization.fp8_kernel import ( + create_per_token_group_quant_fp8_output_scale, +) +from sglang.srt.layers.quantization.fp8_kernel import ( + per_token_group_quant_8bit as triton_per_token_group_quant_8bit, +) +from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit from sglang.srt.utils import is_hip _is_hip = is_hip() fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn -@triton.jit -def _per_token_group_quant_8bit( - # Pointers to inputs and output - y_ptr, - y_q_ptr, - y_s_ptr, - # Stride of input - y_stride, - # Columns of input - N, - # Avoid to divide zero - eps, - # Information for 8bit data type (int8 or fp8_type_) - max_8bit, - min_8bit, - # Meta-parameters - BLOCK: tl.constexpr, -): - """A Triton-accelerated function to perform per-token-group quantization on a - tensor. - This function converts the tensor values into 8bit values. - """ - # Map the program id to the row of X and Y it should compute. - g_id = tl.program_id(0) - y_ptr += g_id * y_stride - y_q_ptr += g_id * y_stride - y_s_ptr += g_id - - cols = tl.arange(0, BLOCK) # N <= BLOCK - mask = cols < N - - y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) - # Quant - _absmax = tl.maximum(tl.max(tl.abs(y)), eps) - y_s = _absmax / max_8bit - y_q = tl.clamp(y / y_s, min_8bit, max_8bit).to(y_q_ptr.dtype.element_ty) - - tl.store(y_q_ptr + cols, y_q, mask=mask) - tl.store(y_s_ptr, y_s) - - -def triton_per_token_group_quant_8bit( - x: torch.Tensor, - group_size: int, - dst_dtype: torch.dtype, - eps: float = 1e-10, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Function to perform per-token-group quantization on an input tensor `x`. - It converts the tensor values into signed float8 values and returns the - quantized tensor along with the scaling factor used for quantization. - Args: - x: The input tenosr with ndim >= 2. - group_size: The group size used for quantization. - eps: The minimum to avoid dividing zero. - dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now. - Returns: - Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. - """ - assert ( - x.shape[-1] % group_size == 0 - ), "the last dimension of `x` cannot be divisible by `group_size`" - assert x.is_contiguous(), "`x` is not contiguous" - - if dst_dtype == torch.int8: - iinfo = torch.iinfo(dst_dtype) - max_8bit = iinfo.max - min_8bit = iinfo.min - else: - finfo = torch.finfo(dst_dtype) - max_8bit = finfo.max - min_8bit = finfo.min - - x_q = torch.empty_like(x, device=x.device, dtype=dst_dtype) - M = x.numel() // group_size - N = group_size - x_s = torch.empty( - x.shape[:-1] + (x.shape[-1] // group_size,), - device=x.device, - dtype=torch.float32, - ) - - BLOCK = triton.next_power_of_2(N) - # heuristics for number of warps - num_warps = min(max(BLOCK // 256, 1), 8) - num_stages = 1 - _per_token_group_quant_8bit[(M,)]( - x, - x_q, - x_s, - group_size, - N, - eps, - max_8bit, - min_8bit, - BLOCK=BLOCK, - num_warps=num_warps, - num_stages=num_stages, - ) - - return x_q, x_s - - -def sglang_per_token_group_quant_8bit( - x: torch.Tensor, - group_size: int, - dst_dtype: torch.dtype, - eps: float = 1e-10, -): - assert ( - x.shape[-1] % group_size == 0 - ), "the last dimension of `x` cannot be divisible by `group_size`" - assert x.is_contiguous(), "`x` is not contiguous" - - x_q = torch.empty_like(x, device=x.device, dtype=dst_dtype) - x_s = torch.empty( - x.shape[:-1] + (x.shape[-1] // group_size,), - device=x.device, - dtype=torch.float32, - ) - - if dst_dtype == torch.int8: - iinfo = torch.iinfo(dst_dtype) - int8_max = iinfo.max - int8_min = iinfo.min - sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max) - else: - f8_info = torch.finfo(dst_dtype) - fp8_max = f8_info.max - fp8_min = f8_info.min - sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max) - - return x_q, x_s - - -def calculate_diff(batch_size, seq_len, group_size, dst_dtype): - device = torch.device("cuda") - hidden_dim = 7168 - - x = torch.randn( - batch_size * seq_len, hidden_dim, device=device, dtype=torch.float16 - ) - - x_q_triton, x_s_triton = triton_per_token_group_quant_8bit( - x.clone(), group_size, dst_dtype - ) - x_q_sglang, x_s_sglang = sglang_per_token_group_quant_8bit( - x.clone(), group_size, dst_dtype - ) - - if torch.allclose( - x_q_triton.to(torch.float32), x_q_sglang.to(torch.float32), rtol=1e-3, atol=1e-5 - ) and torch.allclose(x_s_triton, x_s_sglang, rtol=1e-3, atol=1e-5): - print(f"✅ {dst_dtype} implementations match") - else: - print("❌ Implementations differ") - - -batch_size_range = [1, 2, 4, 8, 16, 32, 64] -seq_len_range = [64, 128, 256, 512, 1024, 2048] +num_tokens_range = [1, 4, 16, 64, 256, 768, 2048, 8192, 16384] +hidden_dim_range = [1536, 7168, 18432] # For DeepSeek V3/R1 group_size_range = [128] # For DeepSeek V3/R1 -dst_dtype_range = [torch.int8, fp8_type_] +# TODO test int8 +dst_dtype_range = [fp8_type_] +flags_range = [ + dict( + column_major_scales=False, + scale_tma_aligned=False, + scale_ue8m0=False, + ), + dict( + column_major_scales=True, + scale_tma_aligned=False, + scale_ue8m0=False, + ), + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=False, + ), + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=True, + ), +] + configs = list( itertools.product( - batch_size_range, seq_len_range, group_size_range, dst_dtype_range + num_tokens_range, + hidden_dim_range, + group_size_range, + dst_dtype_range, + flags_range, ) ) @triton.testing.perf_report( triton.testing.Benchmark( - x_names=["batch_size", "seq_len", "group_size", "dst_dtype"], + x_names=["num_tokens", "hidden_dim", "group_size", "dst_dtype", "flags"], x_vals=configs, line_arg="provider", line_vals=["triton", "sglang"], @@ -194,29 +73,26 @@ configs = list( args={}, ) ) -def benchmark(batch_size, seq_len, group_size, dst_dtype, provider): +def benchmark(num_tokens, hidden_dim, group_size, dst_dtype, flags, provider): + if flags["scale_ue8m0"] and group_size != 128: + return + device = torch.device("cuda") - hidden_dim = 7168 - x = torch.randn( - batch_size * seq_len, hidden_dim, device=device, dtype=torch.float16 - ) + x = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16) - quantiles = [0.5, 0.2, 0.8] + fn, kernel_names = { + "triton": (triton_per_token_group_quant_8bit, "_per_token_group_quant_fp8"), + "sglang": ( + sglang_per_token_group_quant_8bit, + "per_token_group_quant_8bit_kernel", + ), + }[provider] + bench_fn = lambda: fn(x=x, group_size=group_size, dst_dtype=dst_dtype, **flags) - if provider == "triton": - fn = lambda: triton_per_token_group_quant_8bit(x, group_size, dst_dtype) - elif provider == "sglang": - fn = lambda: sglang_per_token_group_quant_8bit(x, group_size, dst_dtype) - - ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles) - - return 1000 * ms, 1000 * max_ms, 1000 * min_ms + time_s = bench_kineto(bench_fn, kernel_names=kernel_names) + return time_s * 1e6 if __name__ == "__main__": - - calculate_diff(batch_size=4, seq_len=128, group_size=64, dst_dtype=torch.int8) - calculate_diff(batch_size=4, seq_len=128, group_size=64, dst_dtype=fp8_type_) - benchmark.run(print_data=True) diff --git a/sgl-kernel/tests/test_per_token_group_quant_8bit.py b/sgl-kernel/tests/test_per_token_group_quant_8bit.py index 31070d1cd..778d14d31 100644 --- a/sgl-kernel/tests/test_per_token_group_quant_8bit.py +++ b/sgl-kernel/tests/test_per_token_group_quant_8bit.py @@ -1,278 +1,51 @@ import itertools -from typing import Tuple import pytest import torch -import triton -import triton.language as tl -from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_group_quant_int8 +from sglang.srt.layers.quantization import deep_gemm_wrapper +from sglang.srt.layers.quantization.fp8_kernel import ( + per_token_group_quant_8bit as triton_per_token_group_quant_8bit, +) +from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit +from sglang.srt.layers.quantization.utils import assert_fp8_all_close from sglang.srt.utils import is_hip _is_hip = is_hip() fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn -@triton.jit -def _per_token_group_quant_fp8( - # Pointers to inputs and output - y_ptr, - y_q_ptr, - y_s_ptr, - # Stride of input - y_stride, - # Columns of input - N, - # Avoid to divide zero - eps, - # Information for float8 - fp8_min, - fp8_max, - # Meta-parameters - BLOCK: tl.constexpr, -): - """A Triton-accelerated function to perform per-token-group quantization on a - tensor. - - This function converts the tensor values into float8 values. - """ - # Map the program id to the row of X and Y it should compute. - g_id = tl.program_id(0) - y_ptr += g_id * y_stride - y_q_ptr += g_id * y_stride - y_s_ptr += g_id - - cols = tl.arange(0, BLOCK) # N <= BLOCK - mask = cols < N - - y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) - # Quant - _absmax = tl.maximum(tl.max(tl.abs(y)), eps) - y_s = _absmax / fp8_max - y_s_inv = 1.0 / y_s - y_q = tl.clamp(y * y_s_inv, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) - - tl.store(y_q_ptr + cols, y_q, mask=mask) - tl.store(y_s_ptr, y_s) - - -@triton.jit -def _per_token_group_quant_fp8_colmajor( - # Pointers to inputs and output - y_ptr, - y_q_ptr, - y_s_ptr, - group_size, - # Num columns of y - y_num_columns, - # Stride from one column to the next of y_s - y_s_col_stride, - # Avoid to divide zero - eps, - # Information for float8 - fp8_min, - fp8_max, - # Meta-parameters - BLOCK: tl.constexpr, -): - """A Triton-accelerated function to perform per-token-group - quantization on a tensor. - This function converts the tensor values into float8 values. - """ - # Map the program id to the row of X and Y it should compute. - g_id = tl.program_id(0) - y_ptr += g_id * group_size - y_q_ptr += g_id * group_size - - # Convert g_id the flattened block coordinate to 2D so we can index - # into the output y_scales matrix - blocks_per_row = y_num_columns // group_size - scale_col = g_id % blocks_per_row - scale_row = g_id // blocks_per_row - y_s_ptr += scale_col * y_s_col_stride + scale_row - - cols = tl.arange(0, BLOCK) # group_size <= BLOCK - mask = cols < group_size - - y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) - # Quant - _absmax = tl.maximum(tl.max(tl.abs(y)), eps) - y_s = _absmax / fp8_max - y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) - - tl.store(y_q_ptr + cols, y_q, mask=mask) - tl.store(y_s_ptr, y_s) - - -def triton_per_token_group_quant_8bit( - x: torch.Tensor, - group_size: int, - eps: float = 1e-10, - dtype: torch.dtype = fp8_type_, - column_major_scales: bool = False, - scale_tma_aligned: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Function to perform per-token-group quantization on an input tensor `x`. - - It converts the tensor values into signed float8 values and returns the - quantized tensor along with the scaling factor used for quantization. - - Args: - x: The input tenosr with ndim >= 2. - group_size: The group size used for quantization. - eps: The minimum to avoid dividing zero. - dtype: The dype of output tensor. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. - """ - assert ( - x.shape[-1] % group_size == 0 - ), "the last dimension of `x` cannot be divisible by `group_size`" - assert x.is_contiguous(), "`x` is not contiguous" - - if dtype == torch.int8: - finfo = torch.iinfo(dtype) - else: - finfo = torch.finfo(dtype) - - fp8_max = finfo.max - - if _is_hip: - if dtype == torch.int8: - fp8_max = 127.0 - else: - fp8_max = 224.0 - - fp8_min = -fp8_max - - x_q = torch.empty_like(x, device=x.device, dtype=dtype) - M = x.numel() // group_size - N = group_size - if column_major_scales: - if scale_tma_aligned: - # aligned to 4 * sizeof(float) - aligned_size = (x.shape[-2] + 3) // 4 * 4 - x_s = torch.empty( - x.shape[:-2] + (x.shape[-1] // group_size, aligned_size), - device=x.device, - dtype=torch.float32, - ).permute(-1, -2)[: x.shape[-2], :] - else: - x_s = torch.empty( - (x.shape[-1] // group_size,) + x.shape[:-1], - device=x.device, - dtype=torch.float32, - ).permute(-1, -2) - else: - x_s = torch.empty( - x.shape[:-1] + (x.shape[-1] // group_size,), - device=x.device, - dtype=torch.float32, - ) - - BLOCK = triton.next_power_of_2(N) - # heuristics for number of warps - num_warps = min(max(BLOCK // 256, 1), 8) - num_stages = 1 - if column_major_scales: - _per_token_group_quant_fp8_colmajor[(M,)]( - x, - x_q, - x_s, - group_size, - x.shape[1], - x_s.stride(1), - eps, - fp8_min=fp8_min, - fp8_max=fp8_max, - BLOCK=BLOCK, - num_warps=num_warps, - num_stages=num_stages, - ) - else: - _per_token_group_quant_fp8[(M,)]( - x, - x_q, - x_s, - group_size, - N, - eps, - fp8_min=fp8_min, - fp8_max=fp8_max, - BLOCK=BLOCK, - num_warps=num_warps, - num_stages=num_stages, - ) - - return x_q, x_s - - -def sglang_per_token_group_quant_8bit( - x: torch.Tensor, - group_size: int, - eps: float = 1e-10, - dtype: torch.dtype = fp8_type_, - column_major_scales: bool = False, - scale_tma_aligned: bool = False, -): - assert ( - x.shape[-1] % group_size == 0 - ), "the last dimension of `x` cannot be divisible by `group_size`" - assert x.is_contiguous(), "`x` is not contiguous" - - x_q = torch.empty_like(x, device=x.device, dtype=dtype) - M = x.numel() // group_size - N = group_size - if column_major_scales: - if scale_tma_aligned: - # aligned to 4 * sizeof(float) - aligned_size = (x.shape[-2] + 3) // 4 * 4 - x_s = torch.empty( - x.shape[:-2] + (x.shape[-1] // group_size, aligned_size), - device=x.device, - dtype=torch.float32, - ).permute(-1, -2)[: x.shape[-2], :] - else: - x_s = torch.empty( - (x.shape[-1] // group_size,) + x.shape[:-1], - device=x.device, - dtype=torch.float32, - ).permute(-1, -2) - else: - x_s = torch.empty( - x.shape[:-1] + (x.shape[-1] // group_size,), - device=x.device, - dtype=torch.float32, - ) - - if dtype == torch.int8: - iinfo = torch.iinfo(dtype) - int8_max = iinfo.max - int8_min = iinfo.min - sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max) - else: - f8_info = torch.finfo(dtype) - fp8_max = f8_info.max - fp8_min = f8_info.min - scale_ue8m0 = False # TODO also test true - sgl_per_token_group_quant_fp8( - x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0 - ) - - return x_q, x_s - - @pytest.mark.parametrize( - "num_tokens, hidden_dim, group_size, dst_dtype, column_major_scales, scale_tma_aligned", + "num_tokens, hidden_dim, group_size, dst_dtype, flags", list( itertools.product( [127, 128, 512, 1024, 4096, 8192], # num_tokens [256, 512, 1024, 2048, 4096], # hidden_dim [8, 16, 32, 64, 128], # group_size - [torch.int8, fp8_type_], # dtype - [False, True], # column_major_scales - [False, True], # scale_tma_aligned + # TODO test int8 + [fp8_type_], # dtype + [ + dict( + column_major_scales=False, + scale_tma_aligned=False, + scale_ue8m0=False, + ), + dict( + column_major_scales=True, + scale_tma_aligned=False, + scale_ue8m0=False, + ), + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=False, + ), + dict( + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=True, + ), + ], ) ), ) @@ -281,37 +54,42 @@ def test_per_token_group_quant_with_column_major( hidden_dim, group_size, dst_dtype, - column_major_scales, - scale_tma_aligned, + flags, ): - if not column_major_scales and scale_tma_aligned: + if flags["scale_ue8m0"] and ((group_size != 128) or (hidden_dim % 512 != 0)): + pytest.skip() + return + if flags["scale_ue8m0"] and not deep_gemm_wrapper.DEEPGEMM_BLACKWELL: + pytest.skip("scale_ue8m0 only supported on Blackwell") return - x = torch.randn(num_tokens, hidden_dim, device="cuda", dtype=torch.float16) + x = torch.randn(num_tokens, hidden_dim, device="cuda", dtype=torch.bfloat16) - x_q_triton, x_s_triton = triton_per_token_group_quant_8bit( - x, - group_size, + execute_kwargs = dict( + x=x, + group_size=group_size, eps=1e-10, - dtype=dst_dtype, - column_major_scales=column_major_scales, - scale_tma_aligned=scale_tma_aligned, + dst_dtype=dst_dtype, + **flags, ) - x_q_sglang, x_s_sglang = sglang_per_token_group_quant_8bit( - x, - group_size, - eps=1e-10, - dtype=dst_dtype, - column_major_scales=column_major_scales, - scale_tma_aligned=scale_tma_aligned, - ) + x_q_triton, x_s_triton = triton_per_token_group_quant_8bit(**execute_kwargs) + x_q_sglang, x_s_sglang = sglang_per_token_group_quant_8bit(**execute_kwargs) + # torch.set_printoptions(profile="full") + # print(f"{x_q_triton=}") + # print(f"{x_s_triton=}") + # print(f"{x_q_sglang=}") + # print(f"{x_s_sglang=}") + # torch.set_printoptions(profile="default") + + assert_fp8_all_close(x_q_triton, x_q_sglang) torch.testing.assert_close( - x_q_triton.to(torch.float32), x_q_sglang.to(torch.float32), rtol=1e-3, atol=1e-5 - ) - torch.testing.assert_close( - x_s_triton.contiguous(), x_s_sglang.contiguous(), rtol=1e-3, atol=1e-5 + x_s_triton.contiguous(), + x_s_sglang.contiguous(), + rtol=1e-3, + atol=1e-5, + msg=lambda message: message + f" {x_s_triton=} {x_s_sglang=}", )