Support Online Quantization for W8A8 (#4485)
This commit is contained in:
@@ -9,9 +9,11 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
||||
from sglang.srt.layers.quantization.fp8_utils import (
|
||||
apply_fp8_linear,
|
||||
cutlass_fp8_supported,
|
||||
input_to_float8,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
)
|
||||
from sglang.srt.utils import is_hip
|
||||
@@ -22,12 +24,24 @@ _is_hip = is_hip()
|
||||
class W8A8Fp8Config(QuantizationConfig):
|
||||
"""Config class for W8A8 FP8 Quantization.
|
||||
|
||||
- Weight: static, per-channel, symmetric
|
||||
- Activation: dynamic, per-token, symmetric
|
||||
Weight Quantization:
|
||||
- Method: Static quantization
|
||||
- Granularity: Per-channel
|
||||
- Type: Symmetric
|
||||
|
||||
Activation Quantization:
|
||||
- Method: Dynamic quantization
|
||||
- Granularity: Per-token
|
||||
- Type: Symmetric
|
||||
|
||||
Note:
|
||||
- For models without offline quantization, weights will be quantized during model loading
|
||||
- If CUTLASS is supported: Per-channel weight quantization is used
|
||||
- If CUTLASS is not supported: Falls back to per-token weight quantization
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self, is_checkpoint_fp8_serialized: bool = False):
|
||||
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
@@ -47,7 +61,9 @@ class W8A8Fp8Config(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "W8A8Fp8Config":
|
||||
return cls()
|
||||
quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||
is_checkpoint_fp8_serialized = "compressed-tensors" in quant_method
|
||||
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized)
|
||||
|
||||
def get_quant_method(
|
||||
self,
|
||||
@@ -72,13 +88,35 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
weight = layer.weight
|
||||
weight_scale = layer.weight_scale.detach()
|
||||
if _is_hip:
|
||||
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight, weight_scale=weight_scale
|
||||
)
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
|
||||
if self.quantization_config.is_checkpoint_fp8_serialized:
|
||||
weight_scale = layer.weight_scale.detach()
|
||||
# If checkpoint offline quantized with w8a8_fp8, load the weight and weight_scale directly.
|
||||
if _is_hip:
|
||||
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight, weight_scale=weight_scale
|
||||
)
|
||||
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
else:
|
||||
# If checkpoint not offline quantized, quantize the weights with per-channel quantization.
|
||||
if self.cutlass_fp8_supported:
|
||||
# if cutlass supported, we use cutlass_scaled_mm
|
||||
# which requires per-channel quantization on weight
|
||||
qweight, weight_scale = per_token_group_quant_fp8(
|
||||
layer.weight, layer.weight.shape[-1]
|
||||
)
|
||||
weight_scale = weight_scale.t().contiguous()
|
||||
else:
|
||||
# if cutlass not supported, we fall back to use torch._scaled_mm
|
||||
# which requires per tensor quantization on weight
|
||||
qweight, weight_scale = input_to_float8(layer.weight)
|
||||
|
||||
# Update the layer with the new values.
|
||||
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
layer.input_scale = None
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -90,6 +128,11 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs
|
||||
):
|
||||
weight_dtype = (
|
||||
torch.float8_e4m3fn
|
||||
if self.quantization_config.is_checkpoint_fp8_serialized
|
||||
else params_dtype
|
||||
)
|
||||
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
self.logical_widths = output_partition_sizes
|
||||
@@ -98,7 +141,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
dtype=weight_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
@@ -106,12 +149,15 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
if self.quantization_config.is_checkpoint_fp8_serialized:
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
else:
|
||||
layer.weight_scale = None
|
||||
|
||||
def apply(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user