From 39c237f02cb31c206bbc5a52fe968f1b6bf2b543 Mon Sep 17 00:00:00 2001 From: ErvinXie Date: Fri, 24 Oct 2025 03:08:05 +0800 Subject: [PATCH] Add AWQ quantization support for NPU. (#10158) Co-authored-by: Alisehen <814073252@qq.com> Co-authored-by: Yaochen Han <48639761+Alisehen@users.noreply.github.com> Co-authored-by: Zhengda Qin --- python/sglang/srt/layers/linear.py | 1 + python/sglang/srt/layers/quantization/awq.py | 180 +++++++++++++++++- .../srt/layers/quantization/awq_triton.py | 29 +++ .../srt/layers/quantization/w8a8_int8.py | 34 +++- python/sglang/srt/model_loader/loader.py | 2 + python/sglang/srt/models/deepseek_v2.py | 6 +- python/sglang/srt/utils/common.py | 2 + 7 files changed, 243 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 3aaf301bb..23b2635d2 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -51,6 +51,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [ "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod", "AWQLinearMethod", + "AWQLinearAscendMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod", "BlockInt8LinearMethod", diff --git a/python/sglang/srt/layers/quantization/awq.py b/python/sglang/srt/layers/quantization/awq.py index 5c195516c..5b4a75367 100644 --- a/python/sglang/srt/layers/quantization/awq.py +++ b/python/sglang/srt/layers/quantization/awq.py @@ -31,6 +31,7 @@ from sglang.srt.layers.quantization.marlin_utils import ( ) from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.utils import get_scalar_types, replace_parameter +from sglang.srt.layers.quantization.w8a8_int8 import npu_fused_experts if TYPE_CHECKING: from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig @@ -39,11 +40,16 @@ if TYPE_CHECKING: CombineInput, ) -from sglang.srt.utils import is_cuda, is_hip, is_xpu +from sglang.srt.utils import is_cuda, is_hip, is_npu, is_xpu _is_cuda = is_cuda() _is_hip = is_hip() _is_xpu = is_xpu() +_is_npu = is_npu() + +if _is_npu: + import torch_npu + if _is_cuda: from sgl_kernel import ( awq_dequantize, @@ -117,12 +123,17 @@ class AWQConfig(QuantizationConfig): return "awq" def get_supported_act_dtypes(self) -> List[torch.dtype]: - return [torch.half] + return [torch.float16] if not _is_npu else [torch.float16, torch.bfloat16] @classmethod def get_min_capability(cls) -> int: # The AWQ kernel only supports Turing or newer GPUs. - return 75 + if _is_npu: + raise NotImplementedError( + 'NPU hardware does not support "get_min_capability" feature.' + ) + else: + return 75 @staticmethod def get_config_filenames() -> List[str]: @@ -146,6 +157,16 @@ class AWQConfig(QuantizationConfig): self, layer: torch.nn.Module, prefix: str ) -> Optional[LinearMethodBase]: from sglang.srt.layers.linear import LinearBase + from sglang.srt.layers.moe.fused_moe_triton import FusedMoE + + if _is_npu: + if isinstance(layer, LinearBase): + if is_layer_skipped_awq(prefix, self.modules_to_not_convert): + return UnquantizedLinearMethod() + return AWQLinearAscendMethod(self) + elif isinstance(layer, FusedMoE): + return AWQMoEAscendMethod(self) + return None if isinstance(layer, LinearBase): if is_layer_skipped_awq(prefix, self.modules_to_not_convert): @@ -575,6 +596,64 @@ class AWQMarlinLinearMethod(LinearMethodBase): ) +class AWQLinearAscendMethod(AWQLinearMethod): + """Linear method for AWQ on Ascend. + + Args: + quant_config: The AWQ quantization config. + """ + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False) + qweight_tmp = torch.zeros_like(layer.qweight.data) + qzeros_tmp = layer.qzeros.data + qzeros_list = [] + shifts = [0, 4, 1, 5, 2, 6, 3, 7] + + for i in range(0, self.quant_config.pack_factor): + shift_num = shifts[i] * 4 + qzeros_list.append((qzeros_tmp.reshape(-1, 1) >> shift_num) & 0xF) + qweight_tmp.bitwise_or_( + ((layer.qweight.data >> shift_num) * (2 ** (4 * i))) & (0xF << (4 * i)) + ) + + qweight_tmp.bitwise_xor_(0x88888888) + + qzeros_tmp = torch.cat(qzeros_list, dim=-1).reshape(qzeros_tmp.shape[0], -1) + qzeros_tmp = -(qzeros_tmp - 8) + qzeros_tmp = qzeros_tmp.to(layer.scales.data.dtype) + + layer.qzeros = torch.nn.Parameter(qzeros_tmp, requires_grad=False) + layer.qweight = torch.nn.Parameter(qweight_tmp, requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qweight = layer.qweight + scales = layer.scales + qzeros = layer.qzeros + pack_factor = self.quant_config.pack_factor + out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,) + reshaped_x = x.reshape(-1, x.shape[-1]) + + if bias is not None and bias.dtype == torch.bfloat16: + bias = bias.float() + + out = torch_npu.npu_weight_quant_batchmatmul( + reshaped_x, + qweight, + antiquant_scale=scales, + antiquant_offset=qzeros, + antiquant_group_size=self.quant_config.group_size, + bias=bias, + ) + + return out.reshape(out_shape) + + class AWQMoEMethod(FusedMoEMethodBase): def __init__(self, quant_config: AWQMarlinConfig): @@ -677,7 +756,8 @@ class AWQMoEMethod(FusedMoEMethodBase): set_weight_attrs(w2_qzeros, extra_weight_attrs) device = layer.w13_qweight.device - layer.workspace = marlin_make_workspace(device, 4) + if not _is_npu: + layer.workspace = marlin_make_workspace(device, 4) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: num_experts = layer.w13_qweight.shape[0] @@ -785,3 +865,95 @@ class AWQMoEMethod(FusedMoEMethodBase): num_bits=self.quant_config.weight_bits, ).to(orig_dtype) return StandardCombineInput(hidden_states=output) + + +class AWQMoEAscendMethod(AWQMoEMethod): + def __init__(self, quant_config: AWQConfig): + self.quant_config = quant_config + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + w13_qweight_tmp = torch.zeros_like(layer.w13_qweight.data) + w2_qweight_tmp = torch.zeros_like(layer.w2_qweight.data) + w13_qzeros_list = [] + w2_qzeros_list = [] + shifts = [0, 4, 1, 5, 2, 6, 3, 7] + for i in range(0, self.quant_config.pack_factor): + shift_num = shifts[i] * 4 + w13_qzeros_list.append( + (layer.w13_qzeros.data.reshape(-1, 1) >> shift_num) & 0xF + ) + w2_qzeros_list.append( + (layer.w2_qzeros.data.reshape(-1, 1) >> shift_num) & 0xF + ) + w13_qweight_tmp.bitwise_or_( + ((layer.w13_qweight.data >> shift_num) * (2 ** (4 * i))) + & (0xF << (4 * i)) + ) + w2_qweight_tmp.bitwise_or_( + ((layer.w2_qweight.data >> shift_num) * (2 ** (4 * i))) + & (0xF << (4 * i)) + ) + + w13_qweight_tmp.bitwise_xor_(0x88888888) + w2_qweight_tmp.bitwise_xor_(0x88888888) + + w13_qzeros_tmp = torch.cat(w13_qzeros_list, dim=-1).reshape( + layer.w13_qzeros.shape[0], layer.w13_qzeros.shape[1], -1 + ) + w13_qzeros_tmp = -(w13_qzeros_tmp - 8) + w13_qzeros_tmp = w13_qzeros_tmp.to(layer.w13_scales.data.dtype) + w2_qzeros_tmp = torch.cat(w2_qzeros_list, dim=-1).reshape( + layer.w2_qzeros.shape[0], layer.w2_qzeros.shape[1], -1 + ) + w2_qzeros_tmp = -(w2_qzeros_tmp - 8) + w2_qzeros_tmp = w2_qzeros_tmp.to(layer.w2_scales.data.dtype) + + layer.register_parameter( + "w13_qzeros", torch.nn.Parameter(w13_qzeros_tmp, requires_grad=False) + ) + layer.register_parameter( + "w13_qweight", torch.nn.Parameter(w13_qweight_tmp, requires_grad=False) + ) + layer.register_parameter( + "w2_qzeros", torch.nn.Parameter(w2_qzeros_tmp, requires_grad=False) + ) + layer.register_parameter( + "w2_qweight", torch.nn.Parameter(w2_qweight_tmp, requires_grad=False) + ) + + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + + def apply( + self, + layer: torch.nn.Module, + dispatch_output: StandardDispatchOutput, + ) -> torch.Tensor: + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + + assert ( + self.moe_runner_config.activation == "silu" + ), "Only SiLU activation is supported." + + x = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + + topk_weights, topk_ids, _ = topk_output + topk_ids = topk_ids.to(torch.int32) + topk_weights = topk_weights.to(x.dtype) + output = npu_fused_experts( + hidden_states=x, + w13=layer.w13_qweight, + w13_scale=layer.w13_scales, + w13_offset=layer.w13_qzeros, + w2=layer.w2_qweight, + w2_scale=layer.w2_scales, + w2_offset=layer.w2_qzeros, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=topk_ids.shape[1], + use_wna16=True, + ) + return StandardCombineInput(hidden_states=output) diff --git a/python/sglang/srt/layers/quantization/awq_triton.py b/python/sglang/srt/layers/quantization/awq_triton.py index 13352efdb..b83dd79fb 100644 --- a/python/sglang/srt/layers/quantization/awq_triton.py +++ b/python/sglang/srt/layers/quantization/awq_triton.py @@ -337,3 +337,32 @@ def awq_gemm_triton( result = result.sum(0) return result + + +def awq_dequantize_decomposition( + qweight: torch.Tensor, + scales: torch.Tensor, + zeros: torch.Tensor, +) -> torch.Tensor: + qweight_tmp = qweight + qzeros_tmp = zeros + qweight_list = [] + qzeros_list = [] + shifts = [0, 4, 1, 5, 2, 6, 3, 7] + for i in range(0, 8): + shift_num = shifts[i] * 4 + qzeros_list.append((qzeros_tmp.reshape(-1, 1) >> shift_num) & 0xF) + qweight_list.append((qweight_tmp.reshape(-1, 1) >> shift_num) & 0xF) + qzeros_tmp = ( + torch.cat(qzeros_list, dim=-1).reshape(qzeros_tmp.shape[0], -1).to(scales.dtype) + ) + qweight_tmp = ( + torch.cat(qweight_list, dim=-1) + .reshape(qweight_tmp.shape[0], -1) + .to(scales.dtype) + ) + res = ( + qweight_tmp.reshape(qzeros_tmp.shape[0], -1, qzeros_tmp.shape[1]) + - qzeros_tmp.unsqueeze(1) + ) * scales.unsqueeze(1) + return res.reshape(qweight_tmp.shape[0], -1) diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py index 77be31163..5ceba2f67 100644 --- a/python/sglang/srt/layers/quantization/w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -102,7 +102,12 @@ def npu_fused_experts( topk_weights: torch.Tensor, topk_ids: torch.Tensor, top_k: int, + **kwargs, ): + w13_offset = kwargs.get("w13_offset", None) + w2_offset = kwargs.get("w2_offset", None) + use_wna16 = kwargs.get("use_wna16", False) + original_shape = hidden_states.shape original_dtype = hidden_states.dtype scale_dtype = original_dtype if original_dtype == torch.bfloat16 else torch.float32 @@ -127,12 +132,22 @@ def npu_fused_experts( ) expert_tokens = expert_tokens.to(torch.int64) # gmm1: gate_up_proj - hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states) + if not use_wna16: + hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states) + scale_args13 = { + "scale": [w13_scale.to(scale_dtype)], + "per_token_scale": [pertoken_scale], + } + else: + scale_args13 = { + "antiquant_scale": [w13_scale], + "antiquant_offset": [w13_offset], + } + hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], weight=[w13], - scale=[w13_scale.to(scale_dtype)], - per_token_scale=[pertoken_scale], + **scale_args13, split_item=2, group_list_type=0, group_type=0, @@ -141,13 +156,20 @@ def npu_fused_experts( )[0] # act_fn: swiglu hidden_states = torch_npu.npu_swiglu(hidden_states) - hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states) + if not use_wna16: + hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states) + + scale_args2 = { + "scale": [w2_scale.to(scale_dtype)], + "per_token_scale": [pertoken_scale], + } + else: + scale_args2 = {"antiquant_scale": [w2_scale], "antiquant_offset": [w2_offset]} # gmm2: down_proj hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], weight=[w2], - scale=[w2_scale.to(scale_dtype)], - per_token_scale=[pertoken_scale], + **scale_args2, split_item=2, group_list_type=0, group_type=0, diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 06ecb5041..6463db307 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -612,6 +612,8 @@ class DefaultModelLoader(BaseModelLoader): # parameters onto device for processing and back off after. with device_loading_context(module, target_device): quant_method.process_weights_after_loading(module) + if _is_npu: + torch.npu.empty_cache() class LayeredModelLoader(DefaultModelLoader): diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 27b7bfbd3..779cd8853 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -189,6 +189,10 @@ elif _is_npu: import custom_ops # noqa: F401 import sgl_kernel_npu # noqa: F401 import torch_npu # noqa: F401 + + from sglang.srt.layers.quantization.awq_triton import ( + awq_dequantize_decomposition as awq_dequantize, + ) else: pass @@ -2965,7 +2969,7 @@ class DeepseekV2ForCausalLM(nn.Module): ) if hasattr(self_attn.kv_b_proj, "qweight"): # AWQ compatible - if _is_cuda or _is_hip: + if _is_cuda or _is_hip or _is_npu: w = awq_dequantize( self_attn.kv_b_proj.qweight, self_attn.kv_b_proj.scales, diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 9d09f187c..3855dc0f7 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -510,6 +510,8 @@ def get_available_gpu_memory( f"WARNING: current device is not {gpu_id}, but {torch.npu.current_device()}, ", "which may cause useless memory allocation for torch NPU context.", ) + if empty_cache: + torch.npu.empty_cache() free_gpu_memory, total_gpu_memory = torch.npu.mem_get_info() if distributed: