[Quantization][w8a8_int8] Fix weight loading issue for w8a8_int8 path with "ignore" layer list in quantization config (#7820)

This commit is contained in:
jianan-gu
2025-07-18 13:03:56 +08:00
committed by GitHub
parent 48c1fa7bb6
commit 7891bac16b
2 changed files with 22 additions and 16 deletions

View File

@@ -347,8 +347,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
if use_intel_amx_backend(layer):
from sglang.srt.layers.moe.topk import (
select_experts,
apply_topk_weights_cpu,
select_experts,
)
topk_weights, topk_ids = select_experts(

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import importlib
import sys
from types import MappingProxyType
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union, cast
import torch
from torch.nn.parameter import Parameter
@@ -24,6 +24,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.compressed_tensors.utils import should_ignore_layer
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.utils import (
@@ -178,17 +179,18 @@ class W8A8Int8Config(QuantizationConfig):
- Activation: dynamic, per-token, symmetric
"""
def __init__(self, quant_config: Dict[str, Any]):
def __init__(self, quant_config: Dict[str, Any] = {}):
super().__init__()
self.quant_description = quant_config
self.is_dynamic = quant_config.get("is_dynamic", False)
if _is_npu:
if (
"packed_modules_mapping" in quant_config
and quant_config["packed_modules_mapping"] is not None
):
self.packed_modules_mapping = quant_config["packed_modules_mapping"]
ignore = cast(List[str], quant_config.get("ignore", []))
self.ignore = ignore if ignore is not None else []
packed_modules_mapping = quant_config.get("packed_modules_mapping", {})
self.packed_modules_mapping = (
packed_modules_mapping if packed_modules_mapping is not None else {}
)
if _is_npu:
# Ascend w8a8_int8 quantization with bias, use wrappers to isolate the effects between models
for name in self.quant_description.keys():
if "norm.bias" in name:
@@ -237,7 +239,7 @@ class W8A8Int8Config(QuantizationConfig):
layer: torch.nn.Module,
prefix: str,
) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if _is_npu:
@@ -262,12 +264,16 @@ class W8A8Int8Config(QuantizationConfig):
elif isinstance(layer, FusedMoE):
return NPU_W8A8MoEMethod(self)
return None
else:
if isinstance(layer, LinearBase):
return W8A8Int8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return W8A8Int8MoEMethod(self)
return None
if should_ignore_layer(
prefix, ignore=self.ignore, fused_mapping=self.packed_modules_mapping
):
return UnquantizedLinearMethod()
if isinstance(layer, LinearBase):
return W8A8Int8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return W8A8Int8MoEMethod(self)
return None
def is_layer_skipped(
self, prefix: str, fused_mapping: Mapping[str, List[str]] = MappingProxyType({})