From f9c53cbb42444e68e9b32d8b9c98bbac5d8091d2 Mon Sep 17 00:00:00 2001 From: strgrb Date: Wed, 19 Mar 2025 15:02:43 +0800 Subject: [PATCH] Create col-major and tma-aligned x_scale for deep_gemm.gemm_fp8_fp8_bf16_nt (#4515) Co-authored-by: Zhang Kaihong --- .../srt/layers/quantization/fp8_kernel.py | 20 ++++++++++++++----- .../srt/layers/quantization/fp8_utils.py | 15 +++++++++++--- sgl-kernel/3rdparty/deepgemm | 2 +- 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 3c60baf0a..e25c7c333 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -168,6 +168,7 @@ def per_token_group_quant_fp8( eps: float = 1e-10, dtype: torch.dtype = fp8_type_, column_major_scales: bool = False, + scale_tma_aligned: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """Function to perform per-token-group quantization on an input tensor `x`. @@ -200,11 +201,20 @@ def per_token_group_quant_fp8( M = x.numel() // group_size N = group_size if column_major_scales: - x_s = torch.empty( - (x.shape[-1] // group_size,) + x.shape[:-1], - device=x.device, - dtype=torch.float32, - ).permute(-1, -2) + if scale_tma_aligned: + # aligned to 4 * sizeof(float) + aligned_size = (x.shape[-2] + 3) // 4 * 4 + x_s = torch.empty( + x.shape[:-2] + (x.shape[-1] // group_size, aligned_size), + device=x.device, + dtype=torch.float32, + ).permute(-1, -2)[: x.shape[-2], :] + else: + x_s = torch.empty( + (x.shape[-1] // group_size,) + x.shape[:-1], + device=x.device, + dtype=torch.float32, + ).permute(-1, -2) else: x_s = torch.empty( x.shape[:-1] + (x.shape[-1] // group_size,), diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index fa69a1ffa..4b790d3ba 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -3,6 +3,7 @@ from typing import List, Optional, Tuple import torch from sglang.srt.layers.quantization.fp8_kernel import ( + _enable_jit_deepgemm, per_token_group_quant_fp8, static_quant_fp8, w8a8_block_fp8_matmul, @@ -129,9 +130,17 @@ def apply_w8a8_block_fp8_linear( ) gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output) else: - q_input, x_scale = per_token_group_quant_fp8( - input_2d, block_size[1], column_major_scales=False - ) + if _enable_jit_deepgemm: + q_input, x_scale = per_token_group_quant_fp8( + input_2d, + block_size[1], + column_major_scales=True, + scale_tma_aligned=True, + ) + else: + q_input, x_scale = per_token_group_quant_fp8( + input_2d, block_size[1], column_major_scales=False + ) output = w8a8_block_fp8_matmul( q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype ) diff --git a/sgl-kernel/3rdparty/deepgemm b/sgl-kernel/3rdparty/deepgemm index bd2a77552..3b3783d06 160000 --- a/sgl-kernel/3rdparty/deepgemm +++ b/sgl-kernel/3rdparty/deepgemm @@ -1 +1 @@ -Subproject commit bd2a77552886b98c205af12f8d7d2d61247c4b27 +Subproject commit 3b3783d06cd4d06ac4ba048633e604151d1ee535