Create col-major and tma-aligned x_scale for deep_gemm.gemm_fp8_fp8_bf16_nt (#4515)

Co-authored-by: Zhang Kaihong <zhangkaihong.zkh@alibaba-inc.com>
This commit is contained in:
strgrb
2025-03-19 15:02:43 +08:00
committed by GitHub
parent 90532b7627
commit f9c53cbb42
3 changed files with 28 additions and 9 deletions

View File

@@ -168,6 +168,7 @@ def per_token_group_quant_fp8(
eps: float = 1e-10,
dtype: torch.dtype = fp8_type_,
column_major_scales: bool = False,
scale_tma_aligned: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
@@ -200,11 +201,20 @@ def per_token_group_quant_fp8(
M = x.numel() // group_size
N = group_size
if column_major_scales:
x_s = torch.empty(
(x.shape[-1] // group_size,) + x.shape[:-1],
device=x.device,
dtype=torch.float32,
).permute(-1, -2)
if scale_tma_aligned:
# aligned to 4 * sizeof(float)
aligned_size = (x.shape[-2] + 3) // 4 * 4
x_s = torch.empty(
x.shape[:-2] + (x.shape[-1] // group_size, aligned_size),
device=x.device,
dtype=torch.float32,
).permute(-1, -2)[: x.shape[-2], :]
else:
x_s = torch.empty(
(x.shape[-1] // group_size,) + x.shape[:-1],
device=x.device,
dtype=torch.float32,
).permute(-1, -2)
else:
x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size,),

View File

@@ -3,6 +3,7 @@ from typing import List, Optional, Tuple
import torch
from sglang.srt.layers.quantization.fp8_kernel import (
_enable_jit_deepgemm,
per_token_group_quant_fp8,
static_quant_fp8,
w8a8_block_fp8_matmul,
@@ -129,9 +130,17 @@ def apply_w8a8_block_fp8_linear(
)
gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
else:
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False
)
if _enable_jit_deepgemm:
q_input, x_scale = per_token_group_quant_fp8(
input_2d,
block_size[1],
column_major_scales=True,
scale_tma_aligned=True,
)
else:
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False
)
output = w8a8_block_fp8_matmul(
q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype
)