Support Llama4 fp8 inference (#5194)
Co-authored-by: laixinn <xielx@shanghaitech.edu.cn> Co-authored-by: sleepcoo <sleepcoo@gmail.com> Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
@@ -77,6 +77,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
sparsity_ignore_list: List[str],
|
||||
kv_cache_scheme: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
packed_modules_mapping: Dict[str, List[str]] = {},
|
||||
):
|
||||
super().__init__()
|
||||
self.ignore = ignore
|
||||
@@ -87,6 +88,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
self.sparsity_scheme_map = sparsity_scheme_map
|
||||
self.sparsity_ignore_list = sparsity_ignore_list
|
||||
self.config = config
|
||||
self.packed_modules_mapping = packed_modules_mapping
|
||||
|
||||
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
|
||||
return CompressedTensorsLinearMethod(self)
|
||||
@@ -136,6 +138,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
|
||||
config=config
|
||||
)
|
||||
packed_modules_mapping = config.get("packed_modules_mapping", {})
|
||||
|
||||
return cls(
|
||||
target_scheme_map=target_scheme_map,
|
||||
@@ -144,6 +147,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
sparsity_scheme_map=sparsity_scheme_map,
|
||||
sparsity_ignore_list=sparsity_ignore_list,
|
||||
config=config,
|
||||
packed_modules_mapping=packed_modules_mapping,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -103,16 +103,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
"input_activations"
|
||||
)
|
||||
|
||||
if not (
|
||||
self.weight_quant.strategy == QuantizationStrategy.TENSOR
|
||||
and self.input_quant.strategy == QuantizationStrategy.TENSOR
|
||||
):
|
||||
raise ValueError(
|
||||
"For FP8 Fused MoE layers, only per-tensor scales "
|
||||
"for weights and activations are supported. Found "
|
||||
f"{self.weight_quant}, {self.input_quant}"
|
||||
)
|
||||
|
||||
self.static_input_scales = not self.input_quant.dynamic
|
||||
|
||||
def create_weights(
|
||||
@@ -154,27 +144,50 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# WEIGHT_SCALES
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
# They will be combined to a single scale after weight loading.
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
# per-tensor quantization
|
||||
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
# They will be combined to a single scale after weight loading.
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
weight_quant_method = FusedMoeWeightScaleSupported.TENSOR.value
|
||||
elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL:
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
weight_quant_method = FusedMoeWeightScaleSupported.CHANNEL.value
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported weight quantization strategy: {self.weight_quant.strategy}"
|
||||
)
|
||||
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
# Add the quantization method used (per tensor/grouped/channel)
|
||||
# to ensure the weight scales are loaded in properly
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
||||
)
|
||||
extra_weight_attrs.update({"quant_method": weight_quant_method})
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
# INPUT_SCALES
|
||||
if self.static_input_scales:
|
||||
assert (
|
||||
self.input_quant.strategy == QuantizationStrategy.TENSOR
|
||||
), "Only per-tensor quantization is supported for static input scales"
|
||||
w13_input_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
@@ -241,31 +254,37 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
w2_input_scale, requires_grad=False
|
||||
)
|
||||
|
||||
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
||||
# We take the max then dequant and requant each expert.
|
||||
assert layer.w13_weight_scale is not None
|
||||
shard_size = layer.intermediate_size_per_partition
|
||||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
start = 0
|
||||
for shard_id in range(2):
|
||||
dq_weight = per_tensor_dequantize(
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
||||
layer.w13_weight_scale[expert_id][shard_id],
|
||||
)
|
||||
|
||||
if _is_cuda:
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
|
||||
sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
||||
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
|
||||
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
||||
# We take the max then dequant and requant each expert.
|
||||
assert layer.w13_weight_scale is not None
|
||||
shard_size = layer.intermediate_size_per_partition
|
||||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
start = 0
|
||||
for shard_id in range(2):
|
||||
dq_weight = per_tensor_dequantize(
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
||||
layer.w13_weight_scale[expert_id][shard_id],
|
||||
)
|
||||
else:
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
|
||||
vllm_ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
||||
)
|
||||
start += shard_size
|
||||
|
||||
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False)
|
||||
if _is_cuda:
|
||||
(
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
||||
_,
|
||||
) = sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
||||
else:
|
||||
(
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
||||
_,
|
||||
) = vllm_ops.scaled_fp8_quant(
|
||||
dq_weight, max_w13_scales[expert_id]
|
||||
)
|
||||
start += shard_size
|
||||
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
max_w13_scales, requires_grad=False
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@@ -311,6 +330,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
use_fp8_w8a8=True,
|
||||
per_channel_quant=self.weight_quant.strategy
|
||||
== QuantizationStrategy.CHANNEL,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
|
||||
@@ -217,6 +217,15 @@ def block_quant_to_tensor_quant(
|
||||
return x_q_tensor, scale
|
||||
|
||||
|
||||
def channel_quant_to_tensor_quant(
|
||||
x_q_channel: torch.Tensor,
|
||||
x_s: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
x_dq_channel = x_q_channel.to(torch.float32) * x_s
|
||||
x_q_tensor, scale = input_to_float8(x_dq_channel, dtype=x_q_channel.dtype)
|
||||
return x_q_tensor, scale
|
||||
|
||||
|
||||
def apply_fp8_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
@@ -16,7 +16,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
||||
input_to_float8,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
)
|
||||
from sglang.srt.utils import is_hip
|
||||
from sglang.srt.utils import is_hip, set_weight_attrs
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
@@ -62,7 +62,9 @@ class W8A8Fp8Config(QuantizationConfig):
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "W8A8Fp8Config":
|
||||
quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||
is_checkpoint_fp8_serialized = "compressed-tensors" in quant_method
|
||||
is_checkpoint_fp8_serialized = (
|
||||
"compressed-tensors" in quant_method or "w8a8_fp8" in quant_method
|
||||
)
|
||||
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized)
|
||||
|
||||
def get_quant_method(
|
||||
@@ -71,9 +73,12 @@ class W8A8Fp8Config(QuantizationConfig):
|
||||
prefix: str,
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
return W8A8Fp8LinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return W8A8FP8MoEMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
@@ -131,7 +136,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
weight_dtype = (
|
||||
torch.float8_e4m3fn
|
||||
@@ -177,3 +182,148 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
|
||||
bias=bias,
|
||||
cutlass_fp8_supported=self.cutlass_fp8_supported,
|
||||
)
|
||||
|
||||
|
||||
class W8A8FP8MoEMethod:
|
||||
"""MoE method for FP8.
|
||||
Supports loading FP8 checkpoints with static weight scale and
|
||||
dynamic/static activation scale.
|
||||
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
|
||||
activation scaling. The weight scaling factor will be initialized after
|
||||
the model weights are loaded.
|
||||
Args:
|
||||
quant_config: The quantization config.
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
||||
|
||||
if not hasattr(cls, "_initialized"):
|
||||
original_init = cls.__init__
|
||||
new_cls = type(
|
||||
cls.__name__,
|
||||
(FusedMoEMethodBase,),
|
||||
{
|
||||
"__init__": original_init,
|
||||
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
||||
},
|
||||
)
|
||||
obj = super(new_cls, new_cls).__new__(new_cls)
|
||||
obj.__init__(*args, **kwargs)
|
||||
return obj
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self, quant_config):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||
|
||||
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts, 2 * intermediate_size, hidden_size, dtype=fp8_dtype
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(num_experts, hidden_size, intermediate_size, dtype=fp8_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
|
||||
)
|
||||
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
w13_input_scale = None
|
||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||
|
||||
w2_input_scale = None
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
|
||||
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
|
||||
layer.w13_weight_scale = Parameter(
|
||||
layer.w13_weight_scale.data, requires_grad=False
|
||||
)
|
||||
layer.w2_weight_scale = Parameter(
|
||||
layer.w2_weight_scale.data, requires_grad=False
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
# Expert selection
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
)
|
||||
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
use_fp8_w8a8=True,
|
||||
per_channel_quant=True,
|
||||
w1_scale=(layer.w13_weight_scale),
|
||||
w2_scale=(layer.w2_weight_scale),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
no_combine=no_combine,
|
||||
)
|
||||
|
||||
@@ -260,6 +260,7 @@ class W8A8Int8MoEMethod:
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_int8_w8a8=True,
|
||||
per_channel_quant=True,
|
||||
w1_scale=(layer.w13_weight_scale),
|
||||
w2_scale=(layer.w2_weight_scale),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
|
||||
Reference in New Issue
Block a user