[Quantization][w8a8_int8] Fix weight loading issue for w8a8_int8 path with "ignore" layer list in quantization config (#7820)

This commit is contained in:
jianan-gu
2025-07-18 13:03:56 +08:00
committed by GitHub
parent 48c1fa7bb6
commit 7891bac16b
2 changed files with 22 additions and 16 deletions

View File

@@ -347,8 +347,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
if use_intel_amx_backend(layer): if use_intel_amx_backend(layer):
from sglang.srt.layers.moe.topk import ( from sglang.srt.layers.moe.topk import (
select_experts,
apply_topk_weights_cpu, apply_topk_weights_cpu,
select_experts,
) )
topk_weights, topk_ids = select_experts( topk_weights, topk_ids = select_experts(

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import importlib import importlib
import sys import sys
from types import MappingProxyType from types import MappingProxyType
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union, cast
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
@@ -24,6 +24,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.compressed_tensors.utils import should_ignore_layer
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.utils import ( from sglang.srt.utils import (
@@ -178,17 +179,18 @@ class W8A8Int8Config(QuantizationConfig):
- Activation: dynamic, per-token, symmetric - Activation: dynamic, per-token, symmetric
""" """
def __init__(self, quant_config: Dict[str, Any]): def __init__(self, quant_config: Dict[str, Any] = {}):
super().__init__() super().__init__()
self.quant_description = quant_config self.quant_description = quant_config
self.is_dynamic = quant_config.get("is_dynamic", False) self.is_dynamic = quant_config.get("is_dynamic", False)
if _is_npu: ignore = cast(List[str], quant_config.get("ignore", []))
if ( self.ignore = ignore if ignore is not None else []
"packed_modules_mapping" in quant_config packed_modules_mapping = quant_config.get("packed_modules_mapping", {})
and quant_config["packed_modules_mapping"] is not None self.packed_modules_mapping = (
): packed_modules_mapping if packed_modules_mapping is not None else {}
self.packed_modules_mapping = quant_config["packed_modules_mapping"] )
if _is_npu:
# Ascend w8a8_int8 quantization with bias, use wrappers to isolate the effects between models # Ascend w8a8_int8 quantization with bias, use wrappers to isolate the effects between models
for name in self.quant_description.keys(): for name in self.quant_description.keys():
if "norm.bias" in name: if "norm.bias" in name:
@@ -237,7 +239,7 @@ class W8A8Int8Config(QuantizationConfig):
layer: torch.nn.Module, layer: torch.nn.Module,
prefix: str, prefix: str,
) -> Optional[QuantizeMethodBase]: ) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if _is_npu: if _is_npu:
@@ -262,12 +264,16 @@ class W8A8Int8Config(QuantizationConfig):
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return NPU_W8A8MoEMethod(self) return NPU_W8A8MoEMethod(self)
return None return None
else:
if isinstance(layer, LinearBase): if should_ignore_layer(
return W8A8Int8LinearMethod(self) prefix, ignore=self.ignore, fused_mapping=self.packed_modules_mapping
elif isinstance(layer, FusedMoE): ):
return W8A8Int8MoEMethod(self) return UnquantizedLinearMethod()
return None if isinstance(layer, LinearBase):
return W8A8Int8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return W8A8Int8MoEMethod(self)
return None
def is_layer_skipped( def is_layer_skipped(
self, prefix: str, fused_mapping: Mapping[str, List[str]] = MappingProxyType({}) self, prefix: str, fused_mapping: Mapping[str, List[str]] = MappingProxyType({})