diff --git a/python/sglang/srt/layers/quantization/unquant.py b/python/sglang/srt/layers/quantization/unquant.py index 821b1cb85..06afcb70b 100644 --- a/python/sglang/srt/layers/quantization/unquant.py +++ b/python/sglang/srt/layers/quantization/unquant.py @@ -347,8 +347,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): if use_intel_amx_backend(layer): from sglang.srt.layers.moe.topk import ( - select_experts, apply_topk_weights_cpu, + select_experts, ) topk_weights, topk_ids = select_experts( diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py index 56ac26c57..c9af7ae29 100644 --- a/python/sglang/srt/layers/quantization/w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -3,7 +3,7 @@ from __future__ import annotations import importlib import sys from types import MappingProxyType -from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union, cast import torch from torch.nn.parameter import Parameter @@ -24,6 +24,7 @@ from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) +from sglang.srt.layers.quantization.compressed_tensors.utils import should_ignore_layer from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.utils import ( @@ -178,17 +179,18 @@ class W8A8Int8Config(QuantizationConfig): - Activation: dynamic, per-token, symmetric """ - def __init__(self, quant_config: Dict[str, Any]): + def __init__(self, quant_config: Dict[str, Any] = {}): super().__init__() self.quant_description = quant_config self.is_dynamic = quant_config.get("is_dynamic", False) - if _is_npu: - if ( - "packed_modules_mapping" in quant_config - and quant_config["packed_modules_mapping"] is not None - ): - self.packed_modules_mapping = quant_config["packed_modules_mapping"] + ignore = cast(List[str], quant_config.get("ignore", [])) + self.ignore = ignore if ignore is not None else [] + packed_modules_mapping = quant_config.get("packed_modules_mapping", {}) + self.packed_modules_mapping = ( + packed_modules_mapping if packed_modules_mapping is not None else {} + ) + if _is_npu: # Ascend w8a8_int8 quantization with bias, use wrappers to isolate the effects between models for name in self.quant_description.keys(): if "norm.bias" in name: @@ -237,7 +239,7 @@ class W8A8Int8Config(QuantizationConfig): layer: torch.nn.Module, prefix: str, ) -> Optional[QuantizeMethodBase]: - from sglang.srt.layers.linear import LinearBase + from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod from sglang.srt.layers.moe.fused_moe_triton import FusedMoE if _is_npu: @@ -262,12 +264,16 @@ class W8A8Int8Config(QuantizationConfig): elif isinstance(layer, FusedMoE): return NPU_W8A8MoEMethod(self) return None - else: - if isinstance(layer, LinearBase): - return W8A8Int8LinearMethod(self) - elif isinstance(layer, FusedMoE): - return W8A8Int8MoEMethod(self) - return None + + if should_ignore_layer( + prefix, ignore=self.ignore, fused_mapping=self.packed_modules_mapping + ): + return UnquantizedLinearMethod() + if isinstance(layer, LinearBase): + return W8A8Int8LinearMethod(self) + elif isinstance(layer, FusedMoE): + return W8A8Int8MoEMethod(self) + return None def is_layer_skipped( self, prefix: str, fused_mapping: Mapping[str, List[str]] = MappingProxyType({})