[Quantization][w8a8_int8] Fix weight loading issue for w8a8_int8 path with "ignore" layer list in quantization config (#7820)
This commit is contained in:
@@ -347,8 +347,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
if use_intel_amx_backend(layer):
|
if use_intel_amx_backend(layer):
|
||||||
|
|
||||||
from sglang.srt.layers.moe.topk import (
|
from sglang.srt.layers.moe.topk import (
|
||||||
select_experts,
|
|
||||||
apply_topk_weights_cpu,
|
apply_topk_weights_cpu,
|
||||||
|
select_experts,
|
||||||
)
|
)
|
||||||
|
|
||||||
topk_weights, topk_ids = select_experts(
|
topk_weights, topk_ids = select_experts(
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
import importlib
|
import importlib
|
||||||
import sys
|
import sys
|
||||||
from types import MappingProxyType
|
from types import MappingProxyType
|
||||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
@@ -24,6 +24,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.quantization.compressed_tensors.utils import should_ignore_layer
|
||||||
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
||||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
@@ -178,17 +179,18 @@ class W8A8Int8Config(QuantizationConfig):
|
|||||||
- Activation: dynamic, per-token, symmetric
|
- Activation: dynamic, per-token, symmetric
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, quant_config: Dict[str, Any]):
|
def __init__(self, quant_config: Dict[str, Any] = {}):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.quant_description = quant_config
|
self.quant_description = quant_config
|
||||||
self.is_dynamic = quant_config.get("is_dynamic", False)
|
self.is_dynamic = quant_config.get("is_dynamic", False)
|
||||||
if _is_npu:
|
ignore = cast(List[str], quant_config.get("ignore", []))
|
||||||
if (
|
self.ignore = ignore if ignore is not None else []
|
||||||
"packed_modules_mapping" in quant_config
|
packed_modules_mapping = quant_config.get("packed_modules_mapping", {})
|
||||||
and quant_config["packed_modules_mapping"] is not None
|
self.packed_modules_mapping = (
|
||||||
):
|
packed_modules_mapping if packed_modules_mapping is not None else {}
|
||||||
self.packed_modules_mapping = quant_config["packed_modules_mapping"]
|
)
|
||||||
|
|
||||||
|
if _is_npu:
|
||||||
# Ascend w8a8_int8 quantization with bias, use wrappers to isolate the effects between models
|
# Ascend w8a8_int8 quantization with bias, use wrappers to isolate the effects between models
|
||||||
for name in self.quant_description.keys():
|
for name in self.quant_description.keys():
|
||||||
if "norm.bias" in name:
|
if "norm.bias" in name:
|
||||||
@@ -237,7 +239,7 @@ class W8A8Int8Config(QuantizationConfig):
|
|||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
prefix: str,
|
prefix: str,
|
||||||
) -> Optional[QuantizeMethodBase]:
|
) -> Optional[QuantizeMethodBase]:
|
||||||
from sglang.srt.layers.linear import LinearBase
|
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
|
|
||||||
if _is_npu:
|
if _is_npu:
|
||||||
@@ -262,12 +264,16 @@ class W8A8Int8Config(QuantizationConfig):
|
|||||||
elif isinstance(layer, FusedMoE):
|
elif isinstance(layer, FusedMoE):
|
||||||
return NPU_W8A8MoEMethod(self)
|
return NPU_W8A8MoEMethod(self)
|
||||||
return None
|
return None
|
||||||
else:
|
|
||||||
if isinstance(layer, LinearBase):
|
if should_ignore_layer(
|
||||||
return W8A8Int8LinearMethod(self)
|
prefix, ignore=self.ignore, fused_mapping=self.packed_modules_mapping
|
||||||
elif isinstance(layer, FusedMoE):
|
):
|
||||||
return W8A8Int8MoEMethod(self)
|
return UnquantizedLinearMethod()
|
||||||
return None
|
if isinstance(layer, LinearBase):
|
||||||
|
return W8A8Int8LinearMethod(self)
|
||||||
|
elif isinstance(layer, FusedMoE):
|
||||||
|
return W8A8Int8MoEMethod(self)
|
||||||
|
return None
|
||||||
|
|
||||||
def is_layer_skipped(
|
def is_layer_skipped(
|
||||||
self, prefix: str, fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
|
self, prefix: str, fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
|
||||||
|
|||||||
Reference in New Issue
Block a user