Files
sglang/python/sglang/srt/layers/quantization/w8a8_fp8.py

332 lines
11 KiB
Python

from typing import Any, Callable, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from sglang.srt.layers.linear import LinearMethodBase
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8_kernel import (
fp8_dtype,
is_fp8_fnuz,
per_token_group_quant_fp8,
)
from sglang.srt.layers.quantization.fp8_utils import (
apply_fp8_linear,
cutlass_fp8_supported,
input_to_float8,
normalize_e4m3fn_to_e4m3fnuz,
)
from sglang.srt.utils import set_weight_attrs
_is_fp8_fnuz = is_fp8_fnuz()
class W8A8Fp8Config(QuantizationConfig):
"""Config class for W8A8 FP8 Quantization.
Weight Quantization:
- Method: Static quantization
- Granularity: Per-channel
- Type: Symmetric
Activation Quantization:
- Method: Dynamic quantization
- Granularity: Per-token
- Type: Symmetric
Note:
- For models without offline quantization, weights will be quantized during model loading
- If CUTLASS is supported: Per-channel weight quantization is used
- If CUTLASS is not supported: Falls back to per-tensor weight quantization
"""
def __init__(self, is_checkpoint_fp8_serialized: bool = False):
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 89
@classmethod
def get_name(self) -> str:
return "w8a8_fp8"
@classmethod
def get_config_filenames(cls) -> List[str]:
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "W8A8Fp8Config":
quant_method = cls.get_from_keys(config, ["quant_method"])
is_checkpoint_fp8_serialized = (
"compressed-tensors" in quant_method or "w8a8_fp8" in quant_method
)
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized)
def get_quant_method(
self,
layer: torch.nn.Module,
prefix: str,
) -> Optional["QuantizeMethodBase"]:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if isinstance(layer, LinearBase):
return W8A8Fp8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return W8A8FP8MoEMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class W8A8Fp8LinearMethod(LinearMethodBase):
def __init__(self, quantization_config: W8A8Fp8Config):
self.cutlass_fp8_supported = cutlass_fp8_supported()
self.quantization_config = quantization_config
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
weight = layer.weight
if self.quantization_config.is_checkpoint_fp8_serialized:
weight_scale = layer.weight_scale.detach()
# If checkpoint offline quantized with w8a8_fp8, load the weight and weight_scale directly.
if _is_fp8_fnuz:
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight_scale=weight_scale
)
layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
else:
# If checkpoint not offline quantized, quantize the weights with per-channel quantization.
if self.cutlass_fp8_supported:
# if cutlass supported, we use cutlass_scaled_mm
# which requires per-channel quantization on weight
qweight, weight_scale = per_token_group_quant_fp8(
layer.weight, layer.weight.shape[-1]
)
weight_scale = weight_scale.t().contiguous()
else:
# if cutlass not supported, we fall back to use torch._scaled_mm
# which requires per tensor quantization on weight
qweight, weight_scale = input_to_float8(layer.weight, dtype=fp8_dtype)
# Update the layer with the new values.
layer.weight = Parameter(qweight.t(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.input_scale = None
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
weight_dtype = (
torch.float8_e4m3fn
if self.quantization_config.is_checkpoint_fp8_serialized
else params_dtype
)
weight_loader = extra_weight_attrs.get("weight_loader")
self.logical_widths = output_partition_sizes
weight = ModelWeightParameter(
data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=weight_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
if self.quantization_config.is_checkpoint_fp8_serialized:
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
else:
layer.weight_scale = None
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
):
return apply_fp8_linear(
x,
layer.weight,
layer.weight_scale,
bias=bias,
cutlass_fp8_supported=self.cutlass_fp8_supported,
)
class W8A8FP8MoEMethod:
"""MoE method for FP8.
Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Args:
quant_config: The quantization config.
"""
def __new__(cls, *args, **kwargs):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
def __init__(self, quant_config):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts, 2 * intermediate_size, hidden_size, dtype=fp8_dtype
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(num_experts, hidden_size, intermediate_size, dtype=fp8_dtype),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
w13_input_scale = None
layer.register_parameter("w13_input_scale", w13_input_scale)
w2_input_scale = None
layer.register_parameter("w2_input_scale", w2_input_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
layer.w13_weight_scale = Parameter(
layer.w13_weight_scale.data, requires_grad=False
)
layer.w2_weight_scale = Parameter(
layer.w2_weight_scale.data, requires_grad=False
)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts
# Expert selection
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace,
activation=activation,
use_fp8_w8a8=True,
per_channel_quant=True,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
no_combine=no_combine,
)