Create col-major and tma-aligned x_scale for deep_gemm.gemm_fp8_fp8_bf16_nt (#4515)
Co-authored-by: Zhang Kaihong <zhangkaihong.zkh@alibaba-inc.com>
This commit is contained in:
@@ -168,6 +168,7 @@ def per_token_group_quant_fp8(
|
|||||||
eps: float = 1e-10,
|
eps: float = 1e-10,
|
||||||
dtype: torch.dtype = fp8_type_,
|
dtype: torch.dtype = fp8_type_,
|
||||||
column_major_scales: bool = False,
|
column_major_scales: bool = False,
|
||||||
|
scale_tma_aligned: bool = False,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Function to perform per-token-group quantization on an input tensor `x`.
|
"""Function to perform per-token-group quantization on an input tensor `x`.
|
||||||
|
|
||||||
@@ -200,6 +201,15 @@ def per_token_group_quant_fp8(
|
|||||||
M = x.numel() // group_size
|
M = x.numel() // group_size
|
||||||
N = group_size
|
N = group_size
|
||||||
if column_major_scales:
|
if column_major_scales:
|
||||||
|
if scale_tma_aligned:
|
||||||
|
# aligned to 4 * sizeof(float)
|
||||||
|
aligned_size = (x.shape[-2] + 3) // 4 * 4
|
||||||
|
x_s = torch.empty(
|
||||||
|
x.shape[:-2] + (x.shape[-1] // group_size, aligned_size),
|
||||||
|
device=x.device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
).permute(-1, -2)[: x.shape[-2], :]
|
||||||
|
else:
|
||||||
x_s = torch.empty(
|
x_s = torch.empty(
|
||||||
(x.shape[-1] // group_size,) + x.shape[:-1],
|
(x.shape[-1] // group_size,) + x.shape[:-1],
|
||||||
device=x.device,
|
device=x.device,
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from typing import List, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
|
_enable_jit_deepgemm,
|
||||||
per_token_group_quant_fp8,
|
per_token_group_quant_fp8,
|
||||||
static_quant_fp8,
|
static_quant_fp8,
|
||||||
w8a8_block_fp8_matmul,
|
w8a8_block_fp8_matmul,
|
||||||
@@ -128,6 +129,14 @@ def apply_w8a8_block_fp8_linear(
|
|||||||
device=q_input.device,
|
device=q_input.device,
|
||||||
)
|
)
|
||||||
gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
|
gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
|
||||||
|
else:
|
||||||
|
if _enable_jit_deepgemm:
|
||||||
|
q_input, x_scale = per_token_group_quant_fp8(
|
||||||
|
input_2d,
|
||||||
|
block_size[1],
|
||||||
|
column_major_scales=True,
|
||||||
|
scale_tma_aligned=True,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
q_input, x_scale = per_token_group_quant_fp8(
|
q_input, x_scale = per_token_group_quant_fp8(
|
||||||
input_2d, block_size[1], column_major_scales=False
|
input_2d, block_size[1], column_major_scales=False
|
||||||
|
|||||||
2
sgl-kernel/3rdparty/deepgemm
vendored
2
sgl-kernel/3rdparty/deepgemm
vendored
Submodule sgl-kernel/3rdparty/deepgemm updated: bd2a775528...3b3783d06c
Reference in New Issue
Block a user