enable marlin fp8 blockwise (#8990)
This commit is contained in:
committed by
GitHub
parent
720cd308ba
commit
e483ab6d20
@@ -49,6 +49,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_utils import (
|
||||
apply_fp8_linear,
|
||||
can_auto_enable_marlin_fp8,
|
||||
cutlass_fp8_supported,
|
||||
dispatch_w8a8_block_fp8_linear,
|
||||
input_to_float8,
|
||||
@@ -209,17 +210,13 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
self.use_marlin = (
|
||||
get_bool_env_var("SGLANG_FORCE_FP8_MARLIN") and MARLIN_FP8_AVAILABLE
|
||||
)
|
||||
# Disable marlin for ROCm
|
||||
if _is_hip:
|
||||
self.use_marlin = False
|
||||
self.use_marlin = False
|
||||
if _is_cuda and MARLIN_FP8_AVAILABLE:
|
||||
force_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN")
|
||||
auto_enable = can_auto_enable_marlin_fp8()
|
||||
self.use_marlin = force_marlin or auto_enable
|
||||
|
||||
self.block_quant = self.quant_config.weight_block_size is not None
|
||||
if self.block_quant:
|
||||
# Marlin doesn't support block-wise fp8
|
||||
self.use_marlin = False
|
||||
|
||||
self.w8a8_block_fp8_linear = dispatch_w8a8_block_fp8_linear()
|
||||
|
||||
@@ -332,7 +329,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
layer.register_parameter("input_scale", None)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# Block quant doesn't need to process weights after loading
|
||||
if self.block_quant:
|
||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||
if _is_fp8_fnuz:
|
||||
@@ -342,7 +338,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
weight_scale=layer.weight_scale_inv,
|
||||
input_scale=None,
|
||||
)
|
||||
|
||||
layer.input_scale = None
|
||||
elif _is_cpu:
|
||||
assert (
|
||||
@@ -352,90 +347,94 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
return
|
||||
else:
|
||||
weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
|
||||
layer.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
layer.weight_scale_inv = torch.nn.Parameter(
|
||||
weight_scale, requires_grad=False
|
||||
)
|
||||
return
|
||||
|
||||
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
||||
|
||||
# If checkpoint not serialized fp8, quantize the weights.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
if self.cutlass_fp8_supported or self.use_marlin:
|
||||
# apply per-channel quantization default, as cutlass sgl-kernel and marlin only support per-channel scale
|
||||
qweight, weight_scale = per_token_group_quant_fp8(
|
||||
layer.weight, layer.weight.shape[-1]
|
||||
)
|
||||
weight_scale = weight_scale.t().contiguous()
|
||||
else:
|
||||
# per-tensor quantization
|
||||
qweight, weight_scale = input_to_float8(layer.weight)
|
||||
|
||||
# Update the layer with the new values.
|
||||
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
layer.input_scale = None
|
||||
|
||||
# If checkpoint is fp8, handle that there are N scales for N
|
||||
# shards in a fused module
|
||||
layer.weight = Parameter(weight, requires_grad=False)
|
||||
layer.weight_scale_inv = Parameter(weight_scale, requires_grad=False)
|
||||
else:
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
layer.weight_scale.data, requires_grad=False
|
||||
)
|
||||
if (
|
||||
hasattr(self.quant_config, "activation_scheme")
|
||||
and self.quant_config.activation_scheme == "static"
|
||||
) or (
|
||||
hasattr(self.quant_config, "linear_activation_scheme")
|
||||
and self.quant_config.linear_activation_scheme == "static"
|
||||
):
|
||||
layer.input_scale = torch.nn.Parameter(
|
||||
layer.input_scale.data, requires_grad=False
|
||||
)
|
||||
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
||||
|
||||
# cutlass sgl-kernel and marlin only support per-channel scale
|
||||
if self.cutlass_fp8_supported or self.use_marlin:
|
||||
weight = layer.weight
|
||||
weight_scale = convert_to_channelwise(
|
||||
layer.weight_scale, layer.logical_widths
|
||||
)
|
||||
# If checkpoint not serialized fp8, quantize the weights.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
if self.cutlass_fp8_supported or self.use_marlin:
|
||||
# apply per-channel quantization default as
|
||||
# cutlass sgl-kernel and marlin only support per-channel scale
|
||||
qweight, weight_scale = per_token_group_quant_fp8(
|
||||
layer.weight, layer.weight.shape[-1]
|
||||
)
|
||||
weight_scale = weight_scale.t().contiguous()
|
||||
else:
|
||||
# per-tensor quantization
|
||||
qweight, weight_scale = input_to_float8(layer.weight)
|
||||
|
||||
# Update the layer with the new values.
|
||||
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
layer.input_scale = None
|
||||
|
||||
# If checkpoint is fp8, handle that there are N scales for N
|
||||
# shards in a fused module
|
||||
else:
|
||||
# Dequant -> Quant with max scale so we can run per tensor.
|
||||
weight = layer.weight
|
||||
weight_scale = layer.weight_scale
|
||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||
if _is_fp8_fnuz:
|
||||
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
layer.weight_scale = Parameter(
|
||||
layer.weight_scale.data, requires_grad=False
|
||||
)
|
||||
if (
|
||||
hasattr(self.quant_config, "activation_scheme")
|
||||
and self.quant_config.activation_scheme == "static"
|
||||
) or (
|
||||
hasattr(self.quant_config, "linear_activation_scheme")
|
||||
and self.quant_config.linear_activation_scheme == "static"
|
||||
):
|
||||
layer.input_scale = Parameter(
|
||||
layer.input_scale.data, requires_grad=False
|
||||
)
|
||||
|
||||
# cutlass sgl-kernel and marlin only support per-channel scale
|
||||
if self.cutlass_fp8_supported or self.use_marlin:
|
||||
weight = layer.weight
|
||||
weight_scale = convert_to_channelwise(
|
||||
layer.weight_scale, layer.logical_widths
|
||||
)
|
||||
else:
|
||||
# Dequant -> Quant with max scale so we can run per tensor.
|
||||
weight = layer.weight
|
||||
weight_scale = layer.weight_scale
|
||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||
if _is_fp8_fnuz:
|
||||
weight, weight_scale, input_scale = (
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
weight_scale=weight_scale,
|
||||
input_scale=layer.input_scale,
|
||||
)
|
||||
)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(
|
||||
input_scale, requires_grad=False
|
||||
)
|
||||
|
||||
weight_scale, weight = requantize_with_max_scale(
|
||||
weight=weight,
|
||||
weight_scale=weight_scale,
|
||||
input_scale=layer.input_scale,
|
||||
logical_widths=layer.logical_widths,
|
||||
)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(input_scale, requires_grad=False)
|
||||
|
||||
weight_scale, weight = requantize_with_max_scale(
|
||||
weight=weight,
|
||||
weight_scale=weight_scale,
|
||||
logical_widths=layer.logical_widths,
|
||||
)
|
||||
|
||||
# Update layer with new values.
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
if (
|
||||
hasattr(self.quant_config, "activation_scheme")
|
||||
and self.quant_config.activation_scheme == "static"
|
||||
) or (
|
||||
hasattr(self.quant_config, "linear_activation_scheme")
|
||||
and self.quant_config.linear_activation_scheme == "static"
|
||||
):
|
||||
layer.input_scale = Parameter(
|
||||
layer.input_scale.max(), requires_grad=False
|
||||
)
|
||||
# Update layer with new values.
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
if (
|
||||
hasattr(self.quant_config, "activation_scheme")
|
||||
and self.quant_config.activation_scheme == "static"
|
||||
) or (
|
||||
hasattr(self.quant_config, "linear_activation_scheme")
|
||||
and self.quant_config.linear_activation_scheme == "static"
|
||||
):
|
||||
layer.input_scale = Parameter(
|
||||
layer.input_scale.max(), requires_grad=False
|
||||
)
|
||||
|
||||
if self.use_marlin:
|
||||
prepare_fp8_layer_for_marlin(layer)
|
||||
if self.block_quant:
|
||||
layer.weight_block_size = self.quant_config.weight_block_size
|
||||
prepare_fp8_layer_for_marlin(layer, not self.block_quant)
|
||||
# Activations not quantized for marlin.
|
||||
del layer.input_scale
|
||||
|
||||
|
||||
@@ -789,3 +789,12 @@ def apply_fp8_linear(
|
||||
bias,
|
||||
input.dtype,
|
||||
)
|
||||
|
||||
|
||||
def can_auto_enable_marlin_fp8() -> bool:
|
||||
try:
|
||||
major, minor = get_device_capability()
|
||||
sm = major * 10 + minor
|
||||
return 80 <= sm < 89
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user