enable aiter gemm_a8w8_bpreshuffle for ptpc gemm (#8555)
This commit is contained in:
@@ -21,9 +21,15 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
)
|
||||
from sglang.srt.layers.quantization.utils import requantize_with_max_scale
|
||||
from sglang.srt.utils import get_bool_env_var, is_hip
|
||||
|
||||
__all__ = ["CompressedTensorsW8A8Fp8"]
|
||||
|
||||
_is_hip = is_hip()
|
||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
if _use_aiter:
|
||||
from aiter.ops.shuffle import shuffle_weight
|
||||
|
||||
|
||||
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
|
||||
@@ -76,7 +82,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
else:
|
||||
weight_scale = layer.weight_scale.data
|
||||
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
if _use_aiter:
|
||||
layer.weight = Parameter(
|
||||
shuffle_weight(weight, (16, 16)), requires_grad=False
|
||||
)
|
||||
else:
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
|
||||
if _use_aiter:
|
||||
import aiter
|
||||
from aiter import gemm_a8w8_blockscale, get_hip_quant
|
||||
from aiter import gemm_a8w8_blockscale, gemm_a8w8_bpreshuffle, get_hip_quant
|
||||
|
||||
aiter_per1x128_quant = get_hip_quant(aiter.QuantType.per_1x128)
|
||||
|
||||
@@ -642,25 +642,49 @@ def apply_fp8_linear(
|
||||
use_per_token_if_dynamic
|
||||
and not per_tensor_weights
|
||||
and not per_tensor_activations
|
||||
and USE_ROWWISE_TORCH_SCALED_MM
|
||||
and (USE_ROWWISE_TORCH_SCALED_MM or _use_aiter)
|
||||
):
|
||||
# For now validated on ROCm platform
|
||||
# fp8 rowwise scaling in torch._scaled_mm is introduced in
|
||||
# https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt
|
||||
# and ROCm 6.3, which only exists in torch 2.7 and above.
|
||||
# For CUDA platform please validate if the
|
||||
# torch._scaled_mm support rowwise scaled GEMM
|
||||
# Fused GEMM_DQ Rowwise GEMM
|
||||
output = torch._scaled_mm(
|
||||
qinput,
|
||||
weight,
|
||||
out_dtype=input.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale.t(),
|
||||
bias=bias,
|
||||
)
|
||||
return _process_scaled_mm_output(output, input_2d.shape, output_shape)
|
||||
|
||||
# into this sector means use dynamic per-token-per-channel quant
|
||||
# per-token scale quant for input matrix, every row(one token) have one scale factor
|
||||
# per-channel scale quant for weight matrix, every col(one channel) have one scale factor
|
||||
if _use_aiter:
|
||||
# gemm_a8w8_bpreshuffle(XQ, WQ, x_scale, w_scale, dtype)
|
||||
# XQ -> input tensor, shape = (m, k)
|
||||
# WQ -> weight tensor, shape = (n, k), with preshuffe get better perf
|
||||
# x_scale -> input scale tensor, shape = (m, 1)
|
||||
# w_scale -> weight scale tensor, shape = (n ,1)
|
||||
# dtype -> output dtype
|
||||
output = gemm_a8w8_bpreshuffle(
|
||||
XQ=qinput,
|
||||
WQ=weight,
|
||||
x_scale=x_scale,
|
||||
w_scale=weight_scale,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
if bias is not None:
|
||||
output += bias
|
||||
return _process_scaled_mm_output(
|
||||
output, input_2d.shape, [*input.shape[:-1], weight.shape[0]]
|
||||
)
|
||||
else:
|
||||
# For now validated on ROCm platform
|
||||
# fp8 rowwise scaling in torch._scaled_mm is introduced in
|
||||
# https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt
|
||||
# and ROCm 6.3, which only exists in torch 2.7 and above.
|
||||
# For CUDA platform please validate if the
|
||||
# torch._scaled_mm support rowwise scaled GEMM
|
||||
# Fused GEMM_DQ Rowwise GEMM
|
||||
output = torch._scaled_mm(
|
||||
qinput,
|
||||
weight,
|
||||
out_dtype=input.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale.t(),
|
||||
bias=bias,
|
||||
)
|
||||
return _process_scaled_mm_output(
|
||||
output, input_2d.shape, output_shape
|
||||
)
|
||||
else:
|
||||
# Fallback for channelwise case, where we use unfused DQ
|
||||
# due to limitations with scaled_mm
|
||||
|
||||
Reference in New Issue
Block a user