diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 210a24f69..a157ebc3e 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -21,9 +21,15 @@ from sglang.srt.layers.quantization.fp8_utils import ( normalize_e4m3fn_to_e4m3fnuz, ) from sglang.srt.layers.quantization.utils import requantize_with_max_scale +from sglang.srt.utils import get_bool_env_var, is_hip __all__ = ["CompressedTensorsW8A8Fp8"] +_is_hip = is_hip() +_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip +if _use_aiter: + from aiter.ops.shuffle import shuffle_weight + class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): @@ -76,7 +82,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): else: weight_scale = layer.weight_scale.data - layer.weight = Parameter(weight.t(), requires_grad=False) + if _use_aiter: + layer.weight = Parameter( + shuffle_weight(weight, (16, 16)), requires_grad=False + ) + else: + layer.weight = Parameter(weight.t(), requires_grad=False) + # required by torch.compile to be torch.nn.Parameter layer.weight_scale = Parameter(weight_scale, requires_grad=False) diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index e4bcbe23c..998423b86 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -45,7 +45,7 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip if _use_aiter: import aiter - from aiter import gemm_a8w8_blockscale, get_hip_quant + from aiter import gemm_a8w8_blockscale, gemm_a8w8_bpreshuffle, get_hip_quant aiter_per1x128_quant = get_hip_quant(aiter.QuantType.per_1x128) @@ -642,25 +642,49 @@ def apply_fp8_linear( use_per_token_if_dynamic and not per_tensor_weights and not per_tensor_activations - and USE_ROWWISE_TORCH_SCALED_MM + and (USE_ROWWISE_TORCH_SCALED_MM or _use_aiter) ): - # For now validated on ROCm platform - # fp8 rowwise scaling in torch._scaled_mm is introduced in - # https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt - # and ROCm 6.3, which only exists in torch 2.7 and above. - # For CUDA platform please validate if the - # torch._scaled_mm support rowwise scaled GEMM - # Fused GEMM_DQ Rowwise GEMM - output = torch._scaled_mm( - qinput, - weight, - out_dtype=input.dtype, - scale_a=x_scale, - scale_b=weight_scale.t(), - bias=bias, - ) - return _process_scaled_mm_output(output, input_2d.shape, output_shape) - + # into this sector means use dynamic per-token-per-channel quant + # per-token scale quant for input matrix, every row(one token) have one scale factor + # per-channel scale quant for weight matrix, every col(one channel) have one scale factor + if _use_aiter: + # gemm_a8w8_bpreshuffle(XQ, WQ, x_scale, w_scale, dtype) + # XQ -> input tensor, shape = (m, k) + # WQ -> weight tensor, shape = (n, k), with preshuffe get better perf + # x_scale -> input scale tensor, shape = (m, 1) + # w_scale -> weight scale tensor, shape = (n ,1) + # dtype -> output dtype + output = gemm_a8w8_bpreshuffle( + XQ=qinput, + WQ=weight, + x_scale=x_scale, + w_scale=weight_scale, + dtype=input.dtype, + ) + if bias is not None: + output += bias + return _process_scaled_mm_output( + output, input_2d.shape, [*input.shape[:-1], weight.shape[0]] + ) + else: + # For now validated on ROCm platform + # fp8 rowwise scaling in torch._scaled_mm is introduced in + # https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt + # and ROCm 6.3, which only exists in torch 2.7 and above. + # For CUDA platform please validate if the + # torch._scaled_mm support rowwise scaled GEMM + # Fused GEMM_DQ Rowwise GEMM + output = torch._scaled_mm( + qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale.t(), + bias=bias, + ) + return _process_scaled_mm_output( + output, input_2d.shape, output_shape + ) else: # Fallback for channelwise case, where we use unfused DQ # due to limitations with scaled_mm