Create col-major and tma-aligned x_scale for deep_gemm.gemm_fp8_fp8_bf16_nt (#4515)
Co-authored-by: Zhang Kaihong <zhangkaihong.zkh@alibaba-inc.com>
This commit is contained in:
@@ -168,6 +168,7 @@ def per_token_group_quant_fp8(
|
||||
eps: float = 1e-10,
|
||||
dtype: torch.dtype = fp8_type_,
|
||||
column_major_scales: bool = False,
|
||||
scale_tma_aligned: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Function to perform per-token-group quantization on an input tensor `x`.
|
||||
|
||||
@@ -200,11 +201,20 @@ def per_token_group_quant_fp8(
|
||||
M = x.numel() // group_size
|
||||
N = group_size
|
||||
if column_major_scales:
|
||||
x_s = torch.empty(
|
||||
(x.shape[-1] // group_size,) + x.shape[:-1],
|
||||
device=x.device,
|
||||
dtype=torch.float32,
|
||||
).permute(-1, -2)
|
||||
if scale_tma_aligned:
|
||||
# aligned to 4 * sizeof(float)
|
||||
aligned_size = (x.shape[-2] + 3) // 4 * 4
|
||||
x_s = torch.empty(
|
||||
x.shape[:-2] + (x.shape[-1] // group_size, aligned_size),
|
||||
device=x.device,
|
||||
dtype=torch.float32,
|
||||
).permute(-1, -2)[: x.shape[-2], :]
|
||||
else:
|
||||
x_s = torch.empty(
|
||||
(x.shape[-1] // group_size,) + x.shape[:-1],
|
||||
device=x.device,
|
||||
dtype=torch.float32,
|
||||
).permute(-1, -2)
|
||||
else:
|
||||
x_s = torch.empty(
|
||||
x.shape[:-1] + (x.shape[-1] // group_size,),
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import List, Optional, Tuple
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
_enable_jit_deepgemm,
|
||||
per_token_group_quant_fp8,
|
||||
static_quant_fp8,
|
||||
w8a8_block_fp8_matmul,
|
||||
@@ -129,9 +130,17 @@ def apply_w8a8_block_fp8_linear(
|
||||
)
|
||||
gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
|
||||
else:
|
||||
q_input, x_scale = per_token_group_quant_fp8(
|
||||
input_2d, block_size[1], column_major_scales=False
|
||||
)
|
||||
if _enable_jit_deepgemm:
|
||||
q_input, x_scale = per_token_group_quant_fp8(
|
||||
input_2d,
|
||||
block_size[1],
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=True,
|
||||
)
|
||||
else:
|
||||
q_input, x_scale = per_token_group_quant_fp8(
|
||||
input_2d, block_size[1], column_major_scales=False
|
||||
)
|
||||
output = w8a8_block_fp8_matmul(
|
||||
q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user