Files
sglang/python/sglang/srt/layers/quantization/w8a8_int8.py
ErvinXie 39c237f02c Add AWQ quantization support for NPU. (#10158)
Co-authored-by: Alisehen <814073252@qq.com>
Co-authored-by: Yaochen Han <48639761+Alisehen@users.noreply.github.com>
Co-authored-by: Zhengda Qin <zhengdqin@gmail.com>
2025-10-23 12:08:05 -07:00

1053 lines
37 KiB
Python

from __future__ import annotations
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast
import torch
from torch.nn.parameter import Parameter
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.parameter import (
ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter,
)
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
LinearMethodBase,
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.compressed_tensors.utils import should_ignore_layer
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.utils import (
apply_module_patch,
cpu_has_amx_support,
is_cpu,
is_cuda,
is_npu,
set_weight_attrs,
use_intel_amx_backend,
)
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
StandardDispatchOutput,
)
_is_cuda = is_cuda()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
if _is_cuda:
from sgl_kernel import int8_scaled_mm
_is_npu = is_npu()
if _is_npu:
import torch_npu
try:
from mindie_turbo import _ops as ops
from mindie_turbo.quantize.quant_utils import quant_per_tensor
except ImportError:
useMindIETurbo = False
else:
useMindIETurbo = True
# func refers to RMSNorm.__init__
def npu_wrapper_rmsnorm_init(func):
def init(self, hidden_size: int, **extra_args) -> None:
func(self, hidden_size, **extra_args)
self.ignore_anti = True
# The Ascend w8a8_int8 quantization requires adding a bias in rmsnorm
self.bias = torch.nn.Parameter(torch.zeros(hidden_size), requires_grad=False)
return init
# func refers to RMSNorm.forward_oot
def npu_wrapper_rmsnorm_forward(func):
def _rmsnorm_forward_oot(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if not x.is_contiguous():
x = x.contiguous()
if residual is not None:
out, _, residual_out = torch_npu.npu_add_rms_norm(
residual, x, self.weight.data, self.variance_epsilon
)
out = out + self.bias
return out.to(x.dtype), residual_out
out = torch_npu.npu_rms_norm(x, self.weight.data, self.variance_epsilon)[0]
out = out + self.bias
return out.to(x.dtype)
return _rmsnorm_forward_oot
def npu_fused_experts(
hidden_states: torch.Tensor,
w13: torch.Tensor,
w13_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
**kwargs,
):
w13_offset = kwargs.get("w13_offset", None)
w2_offset = kwargs.get("w2_offset", None)
use_wna16 = kwargs.get("use_wna16", False)
original_shape = hidden_states.shape
original_dtype = hidden_states.dtype
scale_dtype = original_dtype if original_dtype == torch.bfloat16 else torch.float32
if len(original_shape) == 3:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
num_tokens = hidden_states.shape[0]
num_experts = w13.shape[0]
row_idx_len = num_tokens * top_k
row_idx = (
torch.arange(0, row_idx_len, dtype=torch.int32, device=topk_weights.device)
.view(top_k, -1)
.permute(1, 0)
.contiguous()
)
hidden_states, expanded_row_idx, expanded_expert_idx = (
torch_npu.npu_moe_init_routing(
hidden_states, row_idx=row_idx, expert_idx=topk_ids, active_num=num_tokens
)
)
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
expanded_expert_idx, num_experts
)
expert_tokens = expert_tokens.to(torch.int64)
# gmm1: gate_up_proj
if not use_wna16:
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
scale_args13 = {
"scale": [w13_scale.to(scale_dtype)],
"per_token_scale": [pertoken_scale],
}
else:
scale_args13 = {
"antiquant_scale": [w13_scale],
"antiquant_offset": [w13_offset],
}
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w13],
**scale_args13,
split_item=2,
group_list_type=0,
group_type=0,
group_list=expert_tokens,
output_dtype=original_dtype,
)[0]
# act_fn: swiglu
hidden_states = torch_npu.npu_swiglu(hidden_states)
if not use_wna16:
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
scale_args2 = {
"scale": [w2_scale.to(scale_dtype)],
"per_token_scale": [pertoken_scale],
}
else:
scale_args2 = {"antiquant_scale": [w2_scale], "antiquant_offset": [w2_offset]}
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w2],
**scale_args2,
split_item=2,
group_list_type=0,
group_type=0,
group_list=expert_tokens,
output_dtype=original_dtype,
)[0]
final_hidden_states = torch_npu.npu_moe_finalize_routing(
hidden_states,
skip1=None,
skip2=None,
bias=None,
scales=topk_weights,
expanded_src_to_dst_row=expanded_row_idx,
export_for_source_row=topk_ids,
)
if len(original_shape) == 3:
final_hidden_states = final_hidden_states.view(original_shape)
return final_hidden_states
class W8A8Int8Config(QuantizationConfig):
"""Config class for W8A8 Int8 Quantization.
- Weight: static, per-channel, symmetric
- Activation: dynamic, per-token, symmetric
"""
def __init__(self, quant_config: Dict[str, Any] = {}):
super().__init__()
self.quant_description = quant_config
self.is_dynamic = quant_config.get("is_dynamic", False)
ignore = cast(List[str], quant_config.get("ignore", []))
self.ignore = ignore if ignore is not None else []
packed_modules_mapping = quant_config.get("packed_modules_mapping", {})
self.packed_modules_mapping = (
packed_modules_mapping if packed_modules_mapping is not None else {}
)
if _is_npu:
# Ascend w8a8_int8 quantization with bias, use wrappers to isolate the effects between models
for name in self.quant_description.keys():
if "norm.bias" in name:
apply_module_patch(
"sglang.srt.layers.layernorm.RMSNorm",
"__init__",
[npu_wrapper_rmsnorm_init],
)
apply_module_patch(
"sglang.srt.layers.layernorm.RMSNorm",
"forward_npu",
[npu_wrapper_rmsnorm_forward],
)
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return (
[torch.float16, torch.bfloat16]
if not _is_npu
else [torch.int8, torch.float16, torch.bfloat16]
)
@classmethod
def get_min_capability(cls) -> int:
if _is_npu:
raise NotImplementedError(
'NPU hardware does not support "get_min_capability" feature.'
)
else:
return 75
@classmethod
def get_name(self) -> str:
return "w8a8_int8"
@classmethod
def get_config_filenames(cls) -> List[str]:
filenames = []
if _is_npu:
filenames.append("quant_model_description.json")
return filenames
@classmethod
def from_config(cls, config: Dict[str, Any]) -> W8A8Int8Config:
return cls(config)
def get_quant_method(
self,
layer: torch.nn.Module,
prefix: str,
) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if _is_npu:
if isinstance(layer, LinearBase):
key = "model"
if "vision_model" in prefix:
key = "vision_model"
elif "visual" in prefix:
key = "visual"
packed_modules_mapping_subset = self.packed_modules_mapping.get(key, {})
prefix_in_quant_config = prefix
proj_name = prefix.split(".")[-1]
if proj_name in packed_modules_mapping_subset:
prefix_in_quant_config = prefix.replace(
proj_name, packed_modules_mapping_subset[proj_name][0]
)
self.is_dynamic = (
self.quant_description[prefix_in_quant_config + ".weight"]
== "W8A8_DYNAMIC"
)
if self.is_layer_skipped(prefix, packed_modules_mapping_subset):
return UnquantizedLinearMethod()
return (
NPU_W8A8DynamicLinearMethod(self)
if self.is_dynamic
else NPU_W8A8LinearMethod(self)
)
elif isinstance(layer, FusedMoE):
return NPU_W8A8MoEMethod(self)
return None
if should_ignore_layer(
prefix, ignore=self.ignore, fused_mapping=self.packed_modules_mapping
):
return UnquantizedLinearMethod()
if isinstance(layer, LinearBase):
return W8A8Int8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return W8A8Int8MoEMethod(self)
return None
def is_layer_skipped(
self, prefix: str, fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
):
# adapted from vllm.model_executor.layers.quantization.utils.quant_utils.is_layer_skipped
proj_name = prefix.split(".")[-1]
if proj_name in fused_mapping:
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in fused_mapping[proj_name]
]
is_skipped = None
for shard_prefix in shard_prefixes:
is_shard_skipped = (
self.quant_description[shard_prefix + ".weight"] == "FLOAT"
)
if is_skipped is None:
is_skipped = is_shard_skipped
elif is_shard_skipped != is_skipped:
raise ValueError(
f"Detected some but not all shards of {prefix} "
"are quantized. All shards of fused layers "
"to have the same precision."
)
else:
is_skipped = self.quant_description[prefix + ".weight"] == "FLOAT"
assert is_skipped is not None
return is_skipped
def get_scaled_act_names(self) -> List[str]:
return []
class W8A8Int8LinearMethod(LinearMethodBase):
def __init__(self, quantization_config: W8A8Int8Config):
self.quantization_config = quantization_config
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if _is_cpu:
assert (
_is_cpu_amx_available
), "W8A8Int8LinearMethod on CPU requires that CPU has AMX support"
_amx_process_weight_after_loading(layer, ["weight"])
else:
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
weight_loader = extra_weight_attrs.get("weight_loader")
self.logical_widths = output_partition_sizes
weight = ModelWeightParameter(
data=torch.empty(
sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
):
if use_intel_amx_backend(layer):
return torch.ops.sgl_kernel.int8_scaled_mm_with_quant(
x,
layer.weight,
layer.weight_scale,
bias,
x.dtype,
True, # is_vnni
)
x_q, x_scale = per_token_quant_int8(x)
x_q_2d = x_q.view(-1, x_q.shape[-1])
x_scale_2d = x_scale.view(-1, x_scale.shape[-1])
output_shape = [*x_q.shape[:-1], layer.weight.shape[1]]
output = int8_scaled_mm(
x_q_2d,
layer.weight,
x_scale_2d,
layer.weight_scale,
out_dtype=x.dtype,
bias=bias,
)
return output.view(output_shape)
class W8A8Int8MoEMethod(FusedMoEMethodBase):
"""MoE method for INT8.
Supports loading INT8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Args:
quant_config: The quantization config.
"""
def __init__(self, quant_config: W8A8Int8Config):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
tp_size = get_tensor_model_parallel_world_size()
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=torch.int8,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=torch.int8,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
w13_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
w13_input_scale = None
layer.register_parameter("w13_input_scale", w13_input_scale)
w2_input_scale = None
layer.register_parameter("w2_input_scale", w2_input_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if _is_cpu:
assert (
_is_cpu_amx_available
), "W8A8Int8MoEMethod on CPU requires that CPU has AMX support"
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
else:
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
layer.w13_weight_scale = Parameter(
layer.w13_weight_scale.data, requires_grad=False
)
layer.w2_weight_scale = Parameter(
layer.w2_weight_scale.data, requires_grad=False
)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
def apply(
self,
layer: torch.nn.Module,
dispatch_output: StandardDispatchOutput,
) -> torch.Tensor:
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
if use_intel_amx_backend(layer):
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
topk_weights, topk_ids, _ = topk_output
x, topk_weights = apply_topk_weights_cpu(
self.moe_runner_config.apply_router_weight_on_input, topk_weights, x
)
output = torch.ops.sgl_kernel.fused_experts_cpu(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
False, # inplace See [Note] inplace should be False in fused_experts.
True, # use_int8_w8a8
False, # use_fp8_w8a16
layer.w13_weight_scale, # w1_scale
layer.w2_weight_scale, # w2_scale
None, # block_size
layer.w13_input_scale, # a1_scale
layer.w2_input_scale, # a2_scale
True, # is_vnni
)
return StandardCombineInput(hidden_states=output)
quant_info = TritonMoeQuantInfo(
w13_weight=layer.w13_weight,
w2_weight=layer.w2_weight,
use_int8_w8a8=True,
per_channel_quant=True,
w13_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
return self.runner.run(dispatch_output, quant_info)
class NPU_W8A8LinearMethodImpl:
"""Linear method for NPU W8A8."""
def __init__(self) -> None:
# aclnn quant matmul requires to transpose matrix B, set to true by default.
self.transpose_weight = True
@staticmethod
def get_weight(
input_size: int,
output_size: int,
params_dtype: torch.dtype = torch.bfloat16,
) -> Dict[str, Any]:
params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)}
return params_dict
@staticmethod
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
params_dict = {}
params_dict["input_scale"] = torch.empty(1, dtype=params_dtype)
params_dict["input_offset"] = torch.empty(1, dtype=params_dtype)
return params_dict
@staticmethod
def get_perchannel_param(
output_size: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
params_dict = {}
params_dict["quant_bias"] = torch.empty(output_size, dtype=torch.int32)
if params_dtype == torch.bfloat16:
params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.float32)
elif params_dtype == torch.float16:
params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.int64)
params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype)
params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype)
return params_dict
@staticmethod
def apply(
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# To prevent import loops
from sglang.srt.layers.linear import RowParallelLinear
original_dtype = x.dtype
if original_dtype != torch.int8:
x = torch_npu.npu_quantize(
x,
layer.aclnn_input_scale_reciprocal,
layer.aclnn_input_offset,
torch.qint8,
-1,
False,
)
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in Attention TP>1 case)
if isinstance(layer, RowParallelLinear) and layer.tp_rank > 0:
quant_bias = None
else:
quant_bias = layer.quant_bias
return torch_npu.npu_quant_matmul(
x,
layer.weight,
layer.deq_scale,
bias=quant_bias,
output_dtype=original_dtype,
)
def process_weights_after_loading(self, layer):
expanding_factor = layer.weight.data.shape[1]
layer.aclnn_input_scale = torch.nn.Parameter(
layer.input_scale.data.repeat(expanding_factor).to(device="npu"),
requires_grad=False,
)
layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter(
layer.input_scale.data.repeat(expanding_factor).to(device="npu"),
requires_grad=False,
)
layer.aclnn_input_offset = torch.nn.Parameter(
layer.input_offset.data.repeat(expanding_factor).to(device="npu"),
requires_grad=False,
)
if self.transpose_weight:
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
class NPU_W8A8LinearMethodMTImpl:
"""Linear method for NPU W8A8."""
def __init__(self) -> None:
self.transpose_weight = True
@staticmethod
def get_weight(
input_size: int,
output_size: int,
params_dtype: torch.dtype = torch.bfloat16,
) -> Dict[str, Any]:
params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)}
return params_dict
@staticmethod
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
params_dict = {}
params_dict["input_scale"] = torch.empty(1, dtype=params_dtype)
params_dict["input_offset"] = torch.empty(1, dtype=torch.int8)
return params_dict
@staticmethod
def get_perchannel_param(
output_size: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
params_dict = {}
params_dict["quant_bias"] = torch.empty(output_size, dtype=torch.int32)
if params_dtype == torch.bfloat16:
params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.float32)
elif params_dtype == torch.float16:
params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.int64)
params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype)
params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype)
return params_dict
@staticmethod
def apply(
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# To prevent import loops
from sglang.srt.layers.linear import RowParallelLinear
original_dtype = x.dtype
if original_dtype != torch.int8:
x = quant_per_tensor(x, layer.input_scale, layer.input_offset)
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in Attention TP>1 case)
if isinstance(layer, RowParallelLinear) and layer.tp_rank > 0:
quant_bias = None
else:
quant_bias = layer.quant_bias
return ops.quant_matmul(
x=x, weight=layer.weight, deq_scale=layer.deq_scale, deq_bias=quant_bias
)
def process_weights_after_loading(self, layer):
layer.aclnn_deq_scale = torch.nn.Parameter(
torch_npu.npu_trans_quant_param(layer.deq_scale.npu()).to(device="npu"),
requires_grad=False,
)
class NPU_W8A8LinearMethod(LinearMethodBase):
"""Linear method for NPU quantization.
This class search for specific quantization
implementation supported on NPU hardware for linear methods.
Args:
quant_config: The NPU quantization config.
"""
def __init__(self, quantization_config: W8A8Int8Config) -> None:
self.quantization_config = quantization_config
self.quant_method = (
NPU_W8A8LinearMethodMTImpl()
if useMindIETurbo
else NPU_W8A8LinearMethodImpl()
)
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
weight_dict = self.quant_method.get_weight(
input_size_per_partition, output_size_per_partition, params_dtype
)
for weight_name, weight_param in weight_dict.items():
param = torch.nn.Parameter(weight_param, requires_grad=False)
set_weight_attrs(param, {"input_dim": 1, "output_dim": 0})
layer.register_parameter(weight_name, param)
set_weight_attrs(param, extra_weight_attrs)
pertensor_dict = self.quant_method.get_pertensor_param(params_dtype)
for pertensor_name, pertensor_param in pertensor_dict.items():
param = PerTensorScaleParameter(
data=pertensor_param, weight_loader=weight_loader
)
# disable warning
param.ignore_warning = True
layer.register_parameter(pertensor_name, param)
perchannel_dict = self.quant_method.get_perchannel_param(
output_size_per_partition, params_dtype
)
for perchannel_name, perchannel_param in perchannel_dict.items():
param = torch.nn.Parameter(perchannel_param, requires_grad=False)
set_weight_attrs(param, {"output_dim": 0})
layer.register_parameter(perchannel_name, param)
set_weight_attrs(param, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method, "process_weights_after_loading"):
self.quant_method.process_weights_after_loading(layer)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.quant_method.apply(layer, x, bias)
class NPU_W8A8DynamicLinearMethodImpl:
"""Linear method for NPU W8A8_DYNAMIC."""
def __init__(self):
self.transpose_weight = True
@staticmethod
def get_weight(
input_size: int, output_size: int, params_dtype: torch.dtype
) -> Dict[str, Any]:
params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)}
return params_dict
@staticmethod
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
return {}
@staticmethod
def get_perchannel_param(
output_size: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
params_dict = {}
params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype)
params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype)
return params_dict
@staticmethod
def apply(
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = 0,
) -> torch.Tensor:
original_dtype = x.dtype
quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x)
return torch_npu.npu_quant_matmul(
quant_out,
layer.weight,
layer.weight_scale,
pertoken_scale=dynamic_scale,
bias=bias,
output_dtype=original_dtype,
)
def process_weights_after_loading(self, layer):
if self.transpose_weight:
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight_scale.data = layer.weight_scale.data.flatten()
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
layer.weight_offset.data = layer.weight_offset.data.flatten()
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
"""Linear method for NPU quantization.
This class search for specific quantization
implementations supported on NPU hardware for linear methods.
Args:
quant_config: The NPU quantization config.
"""
def __init__(self, quantization_config: W8A8Int8Config) -> None:
self.quantization_config = quantization_config
self.quant_method = NPU_W8A8DynamicLinearMethodImpl()
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
weight_dict = self.quant_method.get_weight(
input_size_per_partition, output_size_per_partition, params_dtype
)
for weight_name, weight_param in weight_dict.items():
param = torch.nn.Parameter(weight_param, requires_grad=False)
set_weight_attrs(param, {"input_dim": 1, "output_dim": 0})
layer.register_parameter(weight_name, param)
set_weight_attrs(param, extra_weight_attrs)
pertensor_dict = self.quant_method.get_pertensor_param(params_dtype)
for pertensor_name, pertensor_param in pertensor_dict.items():
param = PerTensorScaleParameter(
data=pertensor_param, weight_loader=weight_loader
)
# disable warning
param.ignore_warning = True
layer.register_parameter(pertensor_name, param)
perchannel_dict = self.quant_method.get_perchannel_param(
output_size_per_partition, params_dtype
)
for perchannel_name, perchannel_param in perchannel_dict.items():
param = torch.nn.Parameter(perchannel_param, requires_grad=False)
set_weight_attrs(param, {"output_dim": 0})
layer.register_parameter(perchannel_name, param)
set_weight_attrs(param, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method, "process_weights_after_loading"):
self.quant_method.process_weights_after_loading(layer)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.quant_method.apply(layer, x, bias)
class NPU_W8A8MoEMethod(FusedMoEMethodBase):
"""MoE method for NPU quantization.
This class search for specific quantization
implementations supported on NPU hardware for moe methods.
Args:
quant_config: The NPU quantization config.
"""
def __init__(self, quantization_config: W8A8Int8Config) -> None:
self.quantization_config = quantization_config
self.quant_method = self
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
self.num_experts = num_experts
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
)
# weight
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=torch.int8,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=torch.int8,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# scale
w13_weight_scale = torch.nn.Parameter(
torch.empty(
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
w2_weight_scale = torch.nn.Parameter(
torch.empty(num_experts, hidden_size, 1, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# offset
w13_weight_offset = torch.nn.Parameter(
torch.empty(
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
),
requires_grad=False,
)
layer.register_parameter("w13_weight_offset", w13_weight_offset)
set_weight_attrs(w13_weight_offset, extra_weight_attrs)
w2_weight_offset = torch.nn.Parameter(
torch.empty(num_experts, hidden_size, 1, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w2_weight_offset", w2_weight_offset)
set_weight_attrs(w2_weight_offset, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.w13_weight = Parameter(
layer.w13_weight.data.transpose(1, 2).contiguous(), requires_grad=False
)
layer.w2_weight = Parameter(
layer.w2_weight.data.transpose(1, 2).contiguous(), requires_grad=False
)
layer.w13_weight_scale = Parameter(
layer.w13_weight_scale.data.squeeze(-1).contiguous(), requires_grad=False
)
layer.w2_weight_scale = Parameter(
layer.w2_weight_scale.data.squeeze(-1).contiguous(), requires_grad=False
)
layer.w13_weight_offset = Parameter(
layer.w13_weight_offset.data.squeeze(-1).contiguous(), requires_grad=False
)
layer.w2_weight_offset = Parameter(
layer.w2_weight_offset.data.squeeze(-1).contiguous(), requires_grad=False
)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
def apply(
self,
layer,
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
topk_weights, topk_ids, _ = topk_output
topk_ids = topk_ids.to(torch.int32)
topk_weights = topk_weights.to(x.dtype)
output = npu_fused_experts(
hidden_states=x,
w13=layer.w13_weight,
w13_scale=layer.w13_weight_scale,
w2=layer.w2_weight,
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=topk_ids.shape[1],
)
return StandardCombineInput(hidden_states=output)