[Quantization][w8a8_int8] Fix weight loading issue for w8a8_int8 path with "ignore" layer list in quantization config (#7820)
This commit is contained in:
@@ -347,8 +347,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
if use_intel_amx_backend(layer):
|
||||
|
||||
from sglang.srt.layers.moe.topk import (
|
||||
select_experts,
|
||||
apply_topk_weights_cpu,
|
||||
select_experts,
|
||||
)
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import importlib
|
||||
import sys
|
||||
from types import MappingProxyType
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union, cast
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
@@ -24,6 +24,7 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.compressed_tensors.utils import should_ignore_layer
|
||||
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||
from sglang.srt.utils import (
|
||||
@@ -178,17 +179,18 @@ class W8A8Int8Config(QuantizationConfig):
|
||||
- Activation: dynamic, per-token, symmetric
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: Dict[str, Any]):
|
||||
def __init__(self, quant_config: Dict[str, Any] = {}):
|
||||
super().__init__()
|
||||
self.quant_description = quant_config
|
||||
self.is_dynamic = quant_config.get("is_dynamic", False)
|
||||
if _is_npu:
|
||||
if (
|
||||
"packed_modules_mapping" in quant_config
|
||||
and quant_config["packed_modules_mapping"] is not None
|
||||
):
|
||||
self.packed_modules_mapping = quant_config["packed_modules_mapping"]
|
||||
ignore = cast(List[str], quant_config.get("ignore", []))
|
||||
self.ignore = ignore if ignore is not None else []
|
||||
packed_modules_mapping = quant_config.get("packed_modules_mapping", {})
|
||||
self.packed_modules_mapping = (
|
||||
packed_modules_mapping if packed_modules_mapping is not None else {}
|
||||
)
|
||||
|
||||
if _is_npu:
|
||||
# Ascend w8a8_int8 quantization with bias, use wrappers to isolate the effects between models
|
||||
for name in self.quant_description.keys():
|
||||
if "norm.bias" in name:
|
||||
@@ -237,7 +239,7 @@ class W8A8Int8Config(QuantizationConfig):
|
||||
layer: torch.nn.Module,
|
||||
prefix: str,
|
||||
) -> Optional[QuantizeMethodBase]:
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
|
||||
if _is_npu:
|
||||
@@ -262,12 +264,16 @@ class W8A8Int8Config(QuantizationConfig):
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return NPU_W8A8MoEMethod(self)
|
||||
return None
|
||||
else:
|
||||
if isinstance(layer, LinearBase):
|
||||
return W8A8Int8LinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return W8A8Int8MoEMethod(self)
|
||||
return None
|
||||
|
||||
if should_ignore_layer(
|
||||
prefix, ignore=self.ignore, fused_mapping=self.packed_modules_mapping
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
if isinstance(layer, LinearBase):
|
||||
return W8A8Int8LinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return W8A8Int8MoEMethod(self)
|
||||
return None
|
||||
|
||||
def is_layer_skipped(
|
||||
self, prefix: str, fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
|
||||
|
||||
Reference in New Issue
Block a user