Support Llama4 fp8 inference (#5194)

Co-authored-by: laixinn <xielx@shanghaitech.edu.cn>
Co-authored-by: sleepcoo <sleepcoo@gmail.com>
Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
HandH1998
2025-04-09 20:14:34 +08:00
committed by GitHub
parent 86a876d883
commit 4065248214
14 changed files with 537 additions and 106 deletions

View File

@@ -77,6 +77,7 @@ class CompressedTensorsConfig(QuantizationConfig):
sparsity_ignore_list: List[str],
kv_cache_scheme: Optional[Dict[str, Any]] = None,
config: Optional[Dict[str, Any]] = None,
packed_modules_mapping: Dict[str, List[str]] = {},
):
super().__init__()
self.ignore = ignore
@@ -87,6 +88,7 @@ class CompressedTensorsConfig(QuantizationConfig):
self.sparsity_scheme_map = sparsity_scheme_map
self.sparsity_ignore_list = sparsity_ignore_list
self.config = config
self.packed_modules_mapping = packed_modules_mapping
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
return CompressedTensorsLinearMethod(self)
@@ -136,6 +138,7 @@ class CompressedTensorsConfig(QuantizationConfig):
sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
config=config
)
packed_modules_mapping = config.get("packed_modules_mapping", {})
return cls(
target_scheme_map=target_scheme_map,
@@ -144,6 +147,7 @@ class CompressedTensorsConfig(QuantizationConfig):
sparsity_scheme_map=sparsity_scheme_map,
sparsity_ignore_list=sparsity_ignore_list,
config=config,
packed_modules_mapping=packed_modules_mapping,
)
@classmethod

View File

@@ -103,16 +103,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
"input_activations"
)
if not (
self.weight_quant.strategy == QuantizationStrategy.TENSOR
and self.input_quant.strategy == QuantizationStrategy.TENSOR
):
raise ValueError(
"For FP8 Fused MoE layers, only per-tensor scales "
"for weights and activations are supported. Found "
f"{self.weight_quant}, {self.input_quant}"
)
self.static_input_scales = not self.input_quant.dynamic
def create_weights(
@@ -154,27 +144,50 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
# per-tensor quantization
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
weight_quant_method = FusedMoeWeightScaleSupported.TENSOR.value
elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL:
w13_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=torch.float32,
),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
requires_grad=False,
)
weight_quant_method = FusedMoeWeightScaleSupported.CHANNEL.value
else:
raise ValueError(
f"Unsupported weight quantization strategy: {self.weight_quant.strategy}"
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
extra_weight_attrs.update({"quant_method": weight_quant_method})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# INPUT_SCALES
if self.static_input_scales:
assert (
self.input_quant.strategy == QuantizationStrategy.TENSOR
), "Only per-tensor quantization is supported for static input scales"
w13_input_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
@@ -241,31 +254,37 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_input_scale = torch.nn.Parameter(
w2_input_scale, requires_grad=False
)
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.local_num_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][start : start + shard_size, :],
layer.w13_weight_scale[expert_id][shard_id],
)
if _is_cuda:
layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.local_num_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][start : start + shard_size, :],
layer.w13_weight_scale[expert_id][shard_id],
)
else:
layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
vllm_ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
)
start += shard_size
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False)
if _is_cuda:
(
layer.w13_weight[expert_id][start : start + shard_size, :],
_,
) = sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
else:
(
layer.w13_weight[expert_id][start : start + shard_size, :],
_,
) = vllm_ops.scaled_fp8_quant(
dq_weight, max_w13_scales[expert_id]
)
start += shard_size
layer.w13_weight_scale = torch.nn.Parameter(
max_w13_scales, requires_grad=False
)
def apply(
self,
@@ -311,6 +330,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
inplace=inplace,
activation=activation,
use_fp8_w8a8=True,
per_channel_quant=self.weight_quant.strategy
== QuantizationStrategy.CHANNEL,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,

View File

@@ -217,6 +217,15 @@ def block_quant_to_tensor_quant(
return x_q_tensor, scale
def channel_quant_to_tensor_quant(
x_q_channel: torch.Tensor,
x_s: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
x_dq_channel = x_q_channel.to(torch.float32) * x_s
x_q_tensor, scale = input_to_float8(x_dq_channel, dtype=x_q_channel.dtype)
return x_q_tensor, scale
def apply_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
@@ -16,7 +16,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
input_to_float8,
normalize_e4m3fn_to_e4m3fnuz,
)
from sglang.srt.utils import is_hip
from sglang.srt.utils import is_hip, set_weight_attrs
_is_hip = is_hip()
@@ -62,7 +62,9 @@ class W8A8Fp8Config(QuantizationConfig):
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "W8A8Fp8Config":
quant_method = cls.get_from_keys(config, ["quant_method"])
is_checkpoint_fp8_serialized = "compressed-tensors" in quant_method
is_checkpoint_fp8_serialized = (
"compressed-tensors" in quant_method or "w8a8_fp8" in quant_method
)
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized)
def get_quant_method(
@@ -71,9 +73,12 @@ class W8A8Fp8Config(QuantizationConfig):
prefix: str,
) -> Optional["QuantizeMethodBase"]:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if isinstance(layer, LinearBase):
return W8A8Fp8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return W8A8FP8MoEMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
@@ -131,7 +136,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs
**extra_weight_attrs,
):
weight_dtype = (
torch.float8_e4m3fn
@@ -177,3 +182,148 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
bias=bias,
cutlass_fp8_supported=self.cutlass_fp8_supported,
)
class W8A8FP8MoEMethod:
"""MoE method for FP8.
Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Args:
quant_config: The quantization config.
"""
def __new__(cls, *args, **kwargs):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
def __init__(self, quant_config):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts, 2 * intermediate_size, hidden_size, dtype=fp8_dtype
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(num_experts, hidden_size, intermediate_size, dtype=fp8_dtype),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
w13_input_scale = None
layer.register_parameter("w13_input_scale", w13_input_scale)
w2_input_scale = None
layer.register_parameter("w2_input_scale", w2_input_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
layer.w13_weight_scale = Parameter(
layer.w13_weight_scale.data, requires_grad=False
)
layer.w2_weight_scale = Parameter(
layer.w2_weight_scale.data, requires_grad=False
)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
inplace: bool = True,
no_combine: bool = False,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts
# Expert selection
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
)
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace,
activation=activation,
use_fp8_w8a8=True,
per_channel_quant=True,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
no_combine=no_combine,
)

View File

@@ -260,6 +260,7 @@ class W8A8Int8MoEMethod:
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_int8_w8a8=True,
per_channel_quant=True,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,