Apply sgl w8a8 fp8 kernel (#3148)
This commit is contained in:
@@ -16,9 +16,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
all_close_1d,
|
||||
apply_fp8_linear,
|
||||
convert_to_channelwise,
|
||||
cutlass_fp8_supported,
|
||||
per_tensor_dequantize,
|
||||
requantize_with_max_scale,
|
||||
)
|
||||
@@ -29,14 +27,21 @@ from sglang.srt.layers.linear import (
|
||||
LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||
from sglang.srt.layers.parameter import (
|
||||
BlockQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
||||
from sglang.srt.layers.quantization.fp8_utils import (
|
||||
BlockQuantScaleParameter,
|
||||
apply_fp8_linear,
|
||||
apply_w8a8_block_fp8_linear,
|
||||
cutlass_fp8_supported,
|
||||
input_to_float8,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
)
|
||||
from sglang.srt.utils import (
|
||||
@@ -305,15 +310,15 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
||||
# If checkpoint not serialized fp8, quantize the weights.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
|
||||
|
||||
# If using marlin (w8a16), kernel uses channelwise weights,
|
||||
# so extend the weight scales to be channelwise.
|
||||
if self.use_marlin:
|
||||
assert weight_scale.numel() == 1
|
||||
weight_scale = convert_to_channelwise(
|
||||
weight_scale.expand(len(layer.logical_widths)), layer.logical_widths
|
||||
if self.cutlass_fp8_supported or self.use_marlin:
|
||||
# apply per-channel quantization default, as cutlass sgl-kernel and marlin only support per-channel scale
|
||||
qweight, weight_scale = per_token_group_quant_fp8(
|
||||
layer.weight, layer.weight.shape[-1]
|
||||
)
|
||||
weight_scale = weight_scale.t().contiguous()
|
||||
else:
|
||||
# per-tensor quantization
|
||||
qweight, weight_scale = input_to_float8(layer.weight)
|
||||
|
||||
# Update the layer with the new values.
|
||||
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
||||
@@ -330,23 +335,19 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
layer.input_scale = torch.nn.Parameter(
|
||||
layer.input_scale.data, requires_grad=False
|
||||
)
|
||||
# If using marlin (w8a16), kernel uses channelwise weights,
|
||||
# so extend the weight scales to be channelwise.
|
||||
if self.use_marlin:
|
||||
|
||||
# cutlass sgl-kernel and marlin only support per-channel scale
|
||||
if self.cutlass_fp8_supported or self.use_marlin:
|
||||
weight = layer.weight
|
||||
weight_scale = convert_to_channelwise(
|
||||
layer.weight_scale, layer.logical_widths
|
||||
)
|
||||
|
||||
# If using w8a8, torch._scaled_mm needs per tensor, so
|
||||
# requantize the logical shards as a single weight.
|
||||
else:
|
||||
# Dequant -> Quant with max scale so we can run per tensor.
|
||||
weight = layer.weight
|
||||
weight_scale = layer.weight_scale
|
||||
|
||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||
if is_hip_:
|
||||
if is_hip():
|
||||
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
weight_scale=weight_scale,
|
||||
|
||||
Reference in New Issue
Block a user