Support Llama4 fp8 inference (#5194)

Co-authored-by: laixinn <xielx@shanghaitech.edu.cn>
Co-authored-by: sleepcoo <sleepcoo@gmail.com>
Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
HandH1998
2025-04-09 20:14:34 +08:00
committed by GitHub
parent 86a876d883
commit 4065248214
14 changed files with 537 additions and 106 deletions

View File

@@ -77,6 +77,7 @@ class CompressedTensorsConfig(QuantizationConfig):
sparsity_ignore_list: List[str],
kv_cache_scheme: Optional[Dict[str, Any]] = None,
config: Optional[Dict[str, Any]] = None,
packed_modules_mapping: Dict[str, List[str]] = {},
):
super().__init__()
self.ignore = ignore
@@ -87,6 +88,7 @@ class CompressedTensorsConfig(QuantizationConfig):
self.sparsity_scheme_map = sparsity_scheme_map
self.sparsity_ignore_list = sparsity_ignore_list
self.config = config
self.packed_modules_mapping = packed_modules_mapping
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
return CompressedTensorsLinearMethod(self)
@@ -136,6 +138,7 @@ class CompressedTensorsConfig(QuantizationConfig):
sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
config=config
)
packed_modules_mapping = config.get("packed_modules_mapping", {})
return cls(
target_scheme_map=target_scheme_map,
@@ -144,6 +147,7 @@ class CompressedTensorsConfig(QuantizationConfig):
sparsity_scheme_map=sparsity_scheme_map,
sparsity_ignore_list=sparsity_ignore_list,
config=config,
packed_modules_mapping=packed_modules_mapping,
)
@classmethod

View File

@@ -103,16 +103,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
"input_activations"
)
if not (
self.weight_quant.strategy == QuantizationStrategy.TENSOR
and self.input_quant.strategy == QuantizationStrategy.TENSOR
):
raise ValueError(
"For FP8 Fused MoE layers, only per-tensor scales "
"for weights and activations are supported. Found "
f"{self.weight_quant}, {self.input_quant}"
)
self.static_input_scales = not self.input_quant.dynamic
def create_weights(
@@ -154,27 +144,50 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
# per-tensor quantization
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
weight_quant_method = FusedMoeWeightScaleSupported.TENSOR.value
elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL:
w13_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=torch.float32,
),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
requires_grad=False,
)
weight_quant_method = FusedMoeWeightScaleSupported.CHANNEL.value
else:
raise ValueError(
f"Unsupported weight quantization strategy: {self.weight_quant.strategy}"
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
extra_weight_attrs.update({"quant_method": weight_quant_method})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# INPUT_SCALES
if self.static_input_scales:
assert (
self.input_quant.strategy == QuantizationStrategy.TENSOR
), "Only per-tensor quantization is supported for static input scales"
w13_input_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
@@ -241,31 +254,37 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_input_scale = torch.nn.Parameter(
w2_input_scale, requires_grad=False
)
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.local_num_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][start : start + shard_size, :],
layer.w13_weight_scale[expert_id][shard_id],
)
if _is_cuda:
layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.local_num_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][start : start + shard_size, :],
layer.w13_weight_scale[expert_id][shard_id],
)
else:
layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
vllm_ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
)
start += shard_size
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False)
if _is_cuda:
(
layer.w13_weight[expert_id][start : start + shard_size, :],
_,
) = sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
else:
(
layer.w13_weight[expert_id][start : start + shard_size, :],
_,
) = vllm_ops.scaled_fp8_quant(
dq_weight, max_w13_scales[expert_id]
)
start += shard_size
layer.w13_weight_scale = torch.nn.Parameter(
max_w13_scales, requires_grad=False
)
def apply(
self,
@@ -311,6 +330,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
inplace=inplace,
activation=activation,
use_fp8_w8a8=True,
per_channel_quant=self.weight_quant.strategy
== QuantizationStrategy.CHANNEL,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,