enable marlin fp8 blockwise (#8990)
This commit is contained in:
committed by
GitHub
parent
720cd308ba
commit
e483ab6d20
@@ -49,6 +49,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.fp8_utils import (
|
from sglang.srt.layers.quantization.fp8_utils import (
|
||||||
apply_fp8_linear,
|
apply_fp8_linear,
|
||||||
|
can_auto_enable_marlin_fp8,
|
||||||
cutlass_fp8_supported,
|
cutlass_fp8_supported,
|
||||||
dispatch_w8a8_block_fp8_linear,
|
dispatch_w8a8_block_fp8_linear,
|
||||||
input_to_float8,
|
input_to_float8,
|
||||||
@@ -209,17 +210,13 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
|
|
||||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||||
# kernel for fast weight-only FP8 quantization
|
# kernel for fast weight-only FP8 quantization
|
||||||
self.use_marlin = (
|
self.use_marlin = False
|
||||||
get_bool_env_var("SGLANG_FORCE_FP8_MARLIN") and MARLIN_FP8_AVAILABLE
|
if _is_cuda and MARLIN_FP8_AVAILABLE:
|
||||||
)
|
force_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN")
|
||||||
# Disable marlin for ROCm
|
auto_enable = can_auto_enable_marlin_fp8()
|
||||||
if _is_hip:
|
self.use_marlin = force_marlin or auto_enable
|
||||||
self.use_marlin = False
|
|
||||||
|
|
||||||
self.block_quant = self.quant_config.weight_block_size is not None
|
self.block_quant = self.quant_config.weight_block_size is not None
|
||||||
if self.block_quant:
|
|
||||||
# Marlin doesn't support block-wise fp8
|
|
||||||
self.use_marlin = False
|
|
||||||
|
|
||||||
self.w8a8_block_fp8_linear = dispatch_w8a8_block_fp8_linear()
|
self.w8a8_block_fp8_linear = dispatch_w8a8_block_fp8_linear()
|
||||||
|
|
||||||
@@ -332,7 +329,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
layer.register_parameter("input_scale", None)
|
layer.register_parameter("input_scale", None)
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
# Block quant doesn't need to process weights after loading
|
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||||
if _is_fp8_fnuz:
|
if _is_fp8_fnuz:
|
||||||
@@ -342,7 +338,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
weight_scale=layer.weight_scale_inv,
|
weight_scale=layer.weight_scale_inv,
|
||||||
input_scale=None,
|
input_scale=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
layer.input_scale = None
|
layer.input_scale = None
|
||||||
elif _is_cpu:
|
elif _is_cpu:
|
||||||
assert (
|
assert (
|
||||||
@@ -352,90 +347,94 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
|
weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
|
||||||
layer.weight = torch.nn.Parameter(weight, requires_grad=False)
|
layer.weight = Parameter(weight, requires_grad=False)
|
||||||
layer.weight_scale_inv = torch.nn.Parameter(
|
layer.weight_scale_inv = Parameter(weight_scale, requires_grad=False)
|
||||||
weight_scale, requires_grad=False
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
|
||||||
|
|
||||||
# If checkpoint not serialized fp8, quantize the weights.
|
|
||||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
|
||||||
if self.cutlass_fp8_supported or self.use_marlin:
|
|
||||||
# apply per-channel quantization default, as cutlass sgl-kernel and marlin only support per-channel scale
|
|
||||||
qweight, weight_scale = per_token_group_quant_fp8(
|
|
||||||
layer.weight, layer.weight.shape[-1]
|
|
||||||
)
|
|
||||||
weight_scale = weight_scale.t().contiguous()
|
|
||||||
else:
|
|
||||||
# per-tensor quantization
|
|
||||||
qweight, weight_scale = input_to_float8(layer.weight)
|
|
||||||
|
|
||||||
# Update the layer with the new values.
|
|
||||||
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
|
||||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
|
||||||
layer.input_scale = None
|
|
||||||
|
|
||||||
# If checkpoint is fp8, handle that there are N scales for N
|
|
||||||
# shards in a fused module
|
|
||||||
else:
|
else:
|
||||||
layer.weight_scale = torch.nn.Parameter(
|
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
||||||
layer.weight_scale.data, requires_grad=False
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
hasattr(self.quant_config, "activation_scheme")
|
|
||||||
and self.quant_config.activation_scheme == "static"
|
|
||||||
) or (
|
|
||||||
hasattr(self.quant_config, "linear_activation_scheme")
|
|
||||||
and self.quant_config.linear_activation_scheme == "static"
|
|
||||||
):
|
|
||||||
layer.input_scale = torch.nn.Parameter(
|
|
||||||
layer.input_scale.data, requires_grad=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# cutlass sgl-kernel and marlin only support per-channel scale
|
# If checkpoint not serialized fp8, quantize the weights.
|
||||||
if self.cutlass_fp8_supported or self.use_marlin:
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
weight = layer.weight
|
if self.cutlass_fp8_supported or self.use_marlin:
|
||||||
weight_scale = convert_to_channelwise(
|
# apply per-channel quantization default as
|
||||||
layer.weight_scale, layer.logical_widths
|
# cutlass sgl-kernel and marlin only support per-channel scale
|
||||||
)
|
qweight, weight_scale = per_token_group_quant_fp8(
|
||||||
|
layer.weight, layer.weight.shape[-1]
|
||||||
|
)
|
||||||
|
weight_scale = weight_scale.t().contiguous()
|
||||||
|
else:
|
||||||
|
# per-tensor quantization
|
||||||
|
qweight, weight_scale = input_to_float8(layer.weight)
|
||||||
|
|
||||||
|
# Update the layer with the new values.
|
||||||
|
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
||||||
|
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||||
|
layer.input_scale = None
|
||||||
|
|
||||||
|
# If checkpoint is fp8, handle that there are N scales for N
|
||||||
|
# shards in a fused module
|
||||||
else:
|
else:
|
||||||
# Dequant -> Quant with max scale so we can run per tensor.
|
layer.weight_scale = Parameter(
|
||||||
weight = layer.weight
|
layer.weight_scale.data, requires_grad=False
|
||||||
weight_scale = layer.weight_scale
|
)
|
||||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
if (
|
||||||
if _is_fp8_fnuz:
|
hasattr(self.quant_config, "activation_scheme")
|
||||||
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
and self.quant_config.activation_scheme == "static"
|
||||||
|
) or (
|
||||||
|
hasattr(self.quant_config, "linear_activation_scheme")
|
||||||
|
and self.quant_config.linear_activation_scheme == "static"
|
||||||
|
):
|
||||||
|
layer.input_scale = Parameter(
|
||||||
|
layer.input_scale.data, requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# cutlass sgl-kernel and marlin only support per-channel scale
|
||||||
|
if self.cutlass_fp8_supported or self.use_marlin:
|
||||||
|
weight = layer.weight
|
||||||
|
weight_scale = convert_to_channelwise(
|
||||||
|
layer.weight_scale, layer.logical_widths
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Dequant -> Quant with max scale so we can run per tensor.
|
||||||
|
weight = layer.weight
|
||||||
|
weight_scale = layer.weight_scale
|
||||||
|
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||||
|
if _is_fp8_fnuz:
|
||||||
|
weight, weight_scale, input_scale = (
|
||||||
|
normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
weight=weight,
|
||||||
|
weight_scale=weight_scale,
|
||||||
|
input_scale=layer.input_scale,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if input_scale is not None:
|
||||||
|
layer.input_scale = Parameter(
|
||||||
|
input_scale, requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
weight_scale, weight = requantize_with_max_scale(
|
||||||
weight=weight,
|
weight=weight,
|
||||||
weight_scale=weight_scale,
|
weight_scale=weight_scale,
|
||||||
input_scale=layer.input_scale,
|
logical_widths=layer.logical_widths,
|
||||||
)
|
)
|
||||||
if input_scale is not None:
|
|
||||||
layer.input_scale = Parameter(input_scale, requires_grad=False)
|
|
||||||
|
|
||||||
weight_scale, weight = requantize_with_max_scale(
|
# Update layer with new values.
|
||||||
weight=weight,
|
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||||
weight_scale=weight_scale,
|
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||||
logical_widths=layer.logical_widths,
|
if (
|
||||||
)
|
hasattr(self.quant_config, "activation_scheme")
|
||||||
|
and self.quant_config.activation_scheme == "static"
|
||||||
# Update layer with new values.
|
) or (
|
||||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
hasattr(self.quant_config, "linear_activation_scheme")
|
||||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
and self.quant_config.linear_activation_scheme == "static"
|
||||||
if (
|
):
|
||||||
hasattr(self.quant_config, "activation_scheme")
|
layer.input_scale = Parameter(
|
||||||
and self.quant_config.activation_scheme == "static"
|
layer.input_scale.max(), requires_grad=False
|
||||||
) or (
|
)
|
||||||
hasattr(self.quant_config, "linear_activation_scheme")
|
|
||||||
and self.quant_config.linear_activation_scheme == "static"
|
|
||||||
):
|
|
||||||
layer.input_scale = Parameter(
|
|
||||||
layer.input_scale.max(), requires_grad=False
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.use_marlin:
|
if self.use_marlin:
|
||||||
prepare_fp8_layer_for_marlin(layer)
|
if self.block_quant:
|
||||||
|
layer.weight_block_size = self.quant_config.weight_block_size
|
||||||
|
prepare_fp8_layer_for_marlin(layer, not self.block_quant)
|
||||||
# Activations not quantized for marlin.
|
# Activations not quantized for marlin.
|
||||||
del layer.input_scale
|
del layer.input_scale
|
||||||
|
|
||||||
|
|||||||
@@ -789,3 +789,12 @@ def apply_fp8_linear(
|
|||||||
bias,
|
bias,
|
||||||
input.dtype,
|
input.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def can_auto_enable_marlin_fp8() -> bool:
|
||||||
|
try:
|
||||||
|
major, minor = get_device_capability()
|
||||||
|
sm = major * 10 + minor
|
||||||
|
return 80 <= sm < 89
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|||||||
Reference in New Issue
Block a user