Support Llama4 fp8 inference (#5194)
Co-authored-by: laixinn <xielx@shanghaitech.edu.cn> Co-authored-by: sleepcoo <sleepcoo@gmail.com> Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
@@ -77,6 +77,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
sparsity_ignore_list: List[str],
|
||||
kv_cache_scheme: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
packed_modules_mapping: Dict[str, List[str]] = {},
|
||||
):
|
||||
super().__init__()
|
||||
self.ignore = ignore
|
||||
@@ -87,6 +88,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
self.sparsity_scheme_map = sparsity_scheme_map
|
||||
self.sparsity_ignore_list = sparsity_ignore_list
|
||||
self.config = config
|
||||
self.packed_modules_mapping = packed_modules_mapping
|
||||
|
||||
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
|
||||
return CompressedTensorsLinearMethod(self)
|
||||
@@ -136,6 +138,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
|
||||
config=config
|
||||
)
|
||||
packed_modules_mapping = config.get("packed_modules_mapping", {})
|
||||
|
||||
return cls(
|
||||
target_scheme_map=target_scheme_map,
|
||||
@@ -144,6 +147,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
sparsity_scheme_map=sparsity_scheme_map,
|
||||
sparsity_ignore_list=sparsity_ignore_list,
|
||||
config=config,
|
||||
packed_modules_mapping=packed_modules_mapping,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -103,16 +103,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
"input_activations"
|
||||
)
|
||||
|
||||
if not (
|
||||
self.weight_quant.strategy == QuantizationStrategy.TENSOR
|
||||
and self.input_quant.strategy == QuantizationStrategy.TENSOR
|
||||
):
|
||||
raise ValueError(
|
||||
"For FP8 Fused MoE layers, only per-tensor scales "
|
||||
"for weights and activations are supported. Found "
|
||||
f"{self.weight_quant}, {self.input_quant}"
|
||||
)
|
||||
|
||||
self.static_input_scales = not self.input_quant.dynamic
|
||||
|
||||
def create_weights(
|
||||
@@ -154,27 +144,50 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# WEIGHT_SCALES
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
# They will be combined to a single scale after weight loading.
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
# per-tensor quantization
|
||||
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
# They will be combined to a single scale after weight loading.
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
weight_quant_method = FusedMoeWeightScaleSupported.TENSOR.value
|
||||
elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL:
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
weight_quant_method = FusedMoeWeightScaleSupported.CHANNEL.value
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported weight quantization strategy: {self.weight_quant.strategy}"
|
||||
)
|
||||
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
# Add the quantization method used (per tensor/grouped/channel)
|
||||
# to ensure the weight scales are loaded in properly
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
||||
)
|
||||
extra_weight_attrs.update({"quant_method": weight_quant_method})
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
# INPUT_SCALES
|
||||
if self.static_input_scales:
|
||||
assert (
|
||||
self.input_quant.strategy == QuantizationStrategy.TENSOR
|
||||
), "Only per-tensor quantization is supported for static input scales"
|
||||
w13_input_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
@@ -241,31 +254,37 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
w2_input_scale, requires_grad=False
|
||||
)
|
||||
|
||||
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
||||
# We take the max then dequant and requant each expert.
|
||||
assert layer.w13_weight_scale is not None
|
||||
shard_size = layer.intermediate_size_per_partition
|
||||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
start = 0
|
||||
for shard_id in range(2):
|
||||
dq_weight = per_tensor_dequantize(
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
||||
layer.w13_weight_scale[expert_id][shard_id],
|
||||
)
|
||||
|
||||
if _is_cuda:
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
|
||||
sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
||||
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
|
||||
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
||||
# We take the max then dequant and requant each expert.
|
||||
assert layer.w13_weight_scale is not None
|
||||
shard_size = layer.intermediate_size_per_partition
|
||||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
start = 0
|
||||
for shard_id in range(2):
|
||||
dq_weight = per_tensor_dequantize(
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
||||
layer.w13_weight_scale[expert_id][shard_id],
|
||||
)
|
||||
else:
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
|
||||
vllm_ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
||||
)
|
||||
start += shard_size
|
||||
|
||||
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False)
|
||||
if _is_cuda:
|
||||
(
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
||||
_,
|
||||
) = sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
||||
else:
|
||||
(
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
||||
_,
|
||||
) = vllm_ops.scaled_fp8_quant(
|
||||
dq_weight, max_w13_scales[expert_id]
|
||||
)
|
||||
start += shard_size
|
||||
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
max_w13_scales, requires_grad=False
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@@ -311,6 +330,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
use_fp8_w8a8=True,
|
||||
per_channel_quant=self.weight_quant.strategy
|
||||
== QuantizationStrategy.CHANNEL,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
|
||||
Reference in New Issue
Block a user