Apply sgl w8a8 fp8 kernel (#3148)

This commit is contained in:
HandH1998
2025-03-09 16:03:32 +08:00
committed by GitHub
parent 9fb48f951f
commit 0dd6cda288
13 changed files with 523 additions and 37 deletions

View File

@@ -16,9 +16,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d,
apply_fp8_linear,
convert_to_channelwise,
cutlass_fp8_supported,
per_tensor_dequantize,
requantize_with_max_scale,
)
@@ -29,14 +27,21 @@ from sglang.srt.layers.linear import (
LinearMethodBase,
UnquantizedLinearMethod,
)
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.parameter import (
BlockQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter,
)
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
from sglang.srt.layers.quantization.fp8_utils import (
BlockQuantScaleParameter,
apply_fp8_linear,
apply_w8a8_block_fp8_linear,
cutlass_fp8_supported,
input_to_float8,
normalize_e4m3fn_to_e4m3fnuz,
)
from sglang.srt.utils import (
@@ -305,15 +310,15 @@ class Fp8LinearMethod(LinearMethodBase):
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
# If checkpoint not serialized fp8, quantize the weights.
if not self.quant_config.is_checkpoint_fp8_serialized:
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
# If using marlin (w8a16), kernel uses channelwise weights,
# so extend the weight scales to be channelwise.
if self.use_marlin:
assert weight_scale.numel() == 1
weight_scale = convert_to_channelwise(
weight_scale.expand(len(layer.logical_widths)), layer.logical_widths
if self.cutlass_fp8_supported or self.use_marlin:
# apply per-channel quantization default, as cutlass sgl-kernel and marlin only support per-channel scale
qweight, weight_scale = per_token_group_quant_fp8(
layer.weight, layer.weight.shape[-1]
)
weight_scale = weight_scale.t().contiguous()
else:
# per-tensor quantization
qweight, weight_scale = input_to_float8(layer.weight)
# Update the layer with the new values.
layer.weight = Parameter(qweight.t(), requires_grad=False)
@@ -330,23 +335,19 @@ class Fp8LinearMethod(LinearMethodBase):
layer.input_scale = torch.nn.Parameter(
layer.input_scale.data, requires_grad=False
)
# If using marlin (w8a16), kernel uses channelwise weights,
# so extend the weight scales to be channelwise.
if self.use_marlin:
# cutlass sgl-kernel and marlin only support per-channel scale
if self.cutlass_fp8_supported or self.use_marlin:
weight = layer.weight
weight_scale = convert_to_channelwise(
layer.weight_scale, layer.logical_widths
)
# If using w8a8, torch._scaled_mm needs per tensor, so
# requantize the logical shards as a single weight.
else:
# Dequant -> Quant with max scale so we can run per tensor.
weight = layer.weight
weight_scale = layer.weight_scale
# If ROCm, normalize the weights and scales to e4m3fnuz
if is_hip_:
if is_hip():
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=weight_scale,