From 766392c6bda2558b61ce6d1c1bfd8081a549e1f1 Mon Sep 17 00:00:00 2001 From: ronnie_zheng Date: Thu, 10 Jul 2025 19:17:37 +0300 Subject: [PATCH] [feature]Ascend quantization support (#7791) Co-authored-by: ichernob Co-authored-by: liupeng --- python/sglang/srt/configs/model_config.py | 4 +- python/sglang/srt/layers/linear.py | 10 + python/sglang/srt/layers/moe/ep_moe/layer.py | 3 +- .../srt/layers/moe/fused_moe_triton/layer.py | 2 +- python/sglang/srt/layers/moe/topk.py | 2 +- .../srt/layers/quantization/moe_wna16.py | 3 +- .../srt/layers/quantization/w8a8_int8.py | 752 +++++++++++++++++- python/sglang/srt/mem_cache/memory_pool.py | 6 +- python/sglang/srt/model_loader/loader.py | 33 +- python/sglang/srt/models/llama.py | 2 + python/sglang/srt/models/mixtral_quant.py | 4 + python/sglang/srt/models/qwen2.py | 2 + python/sglang/srt/utils.py | 100 ++- 13 files changed, 889 insertions(+), 34 deletions(-) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 85ec5bd80..a6f563cf0 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -413,7 +413,9 @@ class ModelConfig: quant_cfg = self._parse_quant_hf_config() if quant_cfg is not None: - quant_method = quant_cfg.get("quant_method", "").lower() + quant_method = quant_cfg.get( + "quant_method", "" if not self.quantization else self.quantization + ).lower() # Detect which checkpoint is it for _, method in QUANTIZATION_METHODS.items(): diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 3fa012ce8..2ce049359 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -34,6 +34,7 @@ from sglang.srt.layers.quantization.base_config import ( from sglang.srt.utils import ( cpu_has_amx_support, is_cpu, + is_npu, set_weight_attrs, use_intel_amx_backend, ) @@ -60,6 +61,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [ _is_cpu_amx_available = cpu_has_amx_support() _is_cpu = is_cpu() +_is_npu = is_npu() def adjust_marlin_shard(param, shard_size, shard_offset): @@ -297,6 +299,14 @@ class ReplicatedLinear(LinearBase): if len(loaded_weight.shape) == 0: loaded_weight = loaded_weight.reshape(1) + # The per-tensor quant-scale must be 1 dimension + if _is_npu: + if param.size() != loaded_weight.size() and param.size(0) == 1: + if torch.allclose(loaded_weight, loaded_weight[0]): + loaded_weight = loaded_weight[:1] + else: + raise ValueError(f"{loaded_weight} are not all equal") + assert param.size() == loaded_weight.size() param.data.copy_(loaded_weight) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 568337fe9..353f131c9 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -12,7 +12,6 @@ from sglang.srt.distributed import ( ) from sglang.srt.eplb.expert_location import get_global_expert_location_metadata from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo -from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe from sglang.srt.layers.moe.ep_moe.kernels import ( ep_gather, ep_scatter, @@ -65,6 +64,8 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip if not _is_npu: from sgl_kernel import silu_and_mul + from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe + if _is_hip: from vllm._custom_ops import scaled_fp8_quant diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index c460b2850..6122e0ded 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -850,7 +850,7 @@ class FusedMoE(torch.nn.Module): return # Case weight scales and zero_points - if "scale" in weight_name or "zero" in weight_name: + if "scale" in weight_name or "zero" in weight_name or "offset" in weight_name: # load the weight scales and zp based on the quantization scheme # supported weight scales/zp can be found in # FusedMoeWeightScaleSupported diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index ebb959aba..18f3dea8d 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -308,7 +308,7 @@ def biased_grouped_topk_gpu( renormalize: bool, num_expert_group: int = 0, topk_group: int = 0, - compiled: bool = True, + compiled: bool = not _is_npu, num_fused_shared_experts: int = 0, routed_scaling_factor: Optional[float] = None, num_token_non_padded: Optional[torch.Tensor] = None, diff --git a/python/sglang/srt/layers/quantization/moe_wna16.py b/python/sglang/srt/layers/quantization/moe_wna16.py index 4f3bc716e..0bae43435 100644 --- a/python/sglang/srt/layers/quantization/moe_wna16.py +++ b/python/sglang/srt/layers/quantization/moe_wna16.py @@ -116,8 +116,7 @@ class MoeWNA16Config(QuantizationConfig): @classmethod def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: - can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg) - if can_convert and user_quant == "moe_wna16": + if user_quant == "moe_wna16" and cls.is_moe_wna16_compatible(hf_quant_cfg): return cls.get_name() return None diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py index db0351052..49e6f0e8c 100644 --- a/python/sglang/srt/layers/quantization/w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -1,21 +1,37 @@ -from typing import Any, Callable, Dict, List, Optional +import importlib +import sys +from types import MappingProxyType +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union import torch from torch.nn.parameter import Parameter -from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading -from sglang.srt.layers.linear import LinearMethodBase -from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter +from sglang.srt.layers.linear import ( + LinearMethodBase, + RowParallelLinear, + UnquantizedLinearMethod, +) +from sglang.srt.layers.parameter import ( + ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter, +) from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 from sglang.srt.utils import ( + apply_module_patch, cpu_has_amx_support, is_cpu, is_cuda, + is_npu, set_weight_attrs, use_intel_amx_backend, ) @@ -25,6 +41,134 @@ _is_cpu_amx_available = cpu_has_amx_support() _is_cpu = is_cpu() if _is_cuda: from sgl_kernel import int8_scaled_mm +_is_npu = is_npu() + +if _is_npu: + import torch_npu + + try: + from mindie_turbo import _ops as ops + from mindie_turbo.quantize.quant_utils import quant_per_tensor + except ImportError: + useMindIETurbo = False + else: + useMindIETurbo = True + + +# func refers to RMSNorm.__init__ +def npu_wrapper_rmsnorm_init(func): + def init(self, hidden_size: int, **extra_args) -> None: + func(self, hidden_size, **extra_args) + self.ignore_anti = True + # The Ascend w8a8_int8 quantization requires adding a bias in rmsnorm + self.bias = torch.nn.Parameter(torch.zeros(hidden_size), requires_grad=False) + + return init + + +# func refers to RMSNorm.forward_oot +def npu_wrapper_rmsnorm_forward(func): + def _rmsnorm_forward_oot( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if not x.is_contiguous(): + x = x.contiguous() + original_dtype = x.dtype + x = x.to(torch.float32) + if residual is not None: + x = x + residual.to(torch.float32) + residual = x.to(original_dtype) + + x = ( + torch_npu.npu_rms_norm( + x, self.weight.to(torch.float32), self.variance_epsilon + )[0] + + self.bias + ) + + if residual is None: + return x.to(original_dtype) + return x.to(original_dtype), residual + + return _rmsnorm_forward_oot + + +def npu_fused_experts( + hidden_states: torch.Tensor, + w13: torch.Tensor, + w13_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, +): + original_shape = hidden_states.shape + original_dtype = hidden_states.dtype + scale_dtype = original_dtype if original_dtype == torch.bfloat16 else torch.float32 + if len(original_shape) == 3: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + num_tokens = hidden_states.shape[0] + num_experts = w13.shape[0] + row_idx_len = num_tokens * top_k + row_idx = ( + torch.arange(0, row_idx_len, dtype=torch.int32, device=topk_weights.device) + .view(top_k, -1) + .permute(1, 0) + .contiguous() + ) + hidden_states, expanded_row_idx, expanded_expert_idx = ( + torch_npu.npu_moe_init_routing( + hidden_states, row_idx=row_idx, expert_idx=topk_ids, active_num=num_tokens + ) + ) + expert_tokens = torch_npu.npu_moe_compute_expert_tokens( + expanded_expert_idx, num_experts + ) + expert_tokens = expert_tokens.to(torch.int64) + # gmm1: gate_up_proj + hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states) + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w13], + scale=[w13_scale.to(scale_dtype)], + per_token_scale=[pertoken_scale], + split_item=2, + group_list_type=0, + group_type=0, + group_list=expert_tokens, + output_dtype=original_dtype, + )[0] + # act_fn: swiglu + hidden_states = torch_npu.npu_swiglu(hidden_states) + hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states) + # gmm2: down_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w2], + scale=[w2_scale.to(scale_dtype)], + per_token_scale=[pertoken_scale], + split_item=2, + group_list_type=0, + group_type=0, + group_list=expert_tokens, + output_dtype=original_dtype, + )[0] + + final_hidden_states = torch_npu.npu_moe_finalize_routing( + hidden_states, + skip1=None, + skip2=None, + bias=None, + scales=topk_weights, + expanded_src_to_dst_row=expanded_row_idx, + export_for_source_row=topk_ids, + ) + if len(original_shape) == 3: + final_hidden_states = final_hidden_states.view(original_shape) + return final_hidden_states class W8A8Int8Config(QuantizationConfig): @@ -34,16 +178,47 @@ class W8A8Int8Config(QuantizationConfig): - Activation: dynamic, per-token, symmetric """ - def __init__(self): - pass + def __init__(self, quant_config: Dict[str, Any]): + super().__init__() + self.quant_description = quant_config + self.is_dynamic = quant_config.get("is_dynamic", False) + if _is_npu: + if ( + "packed_modules_mapping" in quant_config + and quant_config["packed_modules_mapping"] is not None + ): + self.packed_modules_mapping = quant_config["packed_modules_mapping"] + + # Ascend w8a8_int8 quantization with bias, use wrappers to isolate the effects between models + for name in self.quant_description.keys(): + if "norm.bias" in name: + apply_module_patch( + "sglang.srt.layers.layernorm.RMSNorm", + "__init__", + [npu_wrapper_rmsnorm_init], + ) + apply_module_patch( + "sglang.srt.layers.layernorm.RMSNorm", + "forward_npu", + [npu_wrapper_rmsnorm_forward], + ) @classmethod def get_supported_act_dtypes(cls) -> List[torch.dtype]: - return [torch.float16, torch.bfloat16] + return ( + [torch.float16, torch.bfloat16] + if not _is_npu + else [torch.int8, torch.float16, torch.bfloat16] + ) @classmethod def get_min_capability(cls) -> int: - return 75 + if _is_npu: + raise NotImplementedError( + 'NPU hardware does not support "get_min_capability" feature.' + ) + else: + return 75 @classmethod def get_name(self) -> str: @@ -55,7 +230,7 @@ class W8A8Int8Config(QuantizationConfig): @classmethod def from_config(cls, config: Dict[str, Any]) -> "W8A8Int8Config": - return cls() + return cls(config) def get_quant_method( self, @@ -65,11 +240,65 @@ class W8A8Int8Config(QuantizationConfig): from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.moe.fused_moe_triton import FusedMoE - if isinstance(layer, LinearBase): - return W8A8Int8LinearMethod(self) - elif isinstance(layer, FusedMoE): - return W8A8Int8MoEMethod(self) - return None + if _is_npu: + if isinstance(layer, LinearBase): + prefix_in_quant_config = prefix + proj_name = prefix.split(".")[-1] + if proj_name in self.packed_modules_mapping: + prefix_in_quant_config = prefix.replace( + proj_name, self.packed_modules_mapping[proj_name][0] + ) + self.is_dynamic = ( + self.quant_description[prefix_in_quant_config + ".weight"] + == "W8A8_DYNAMIC" + ) + if self.is_layer_skipped(prefix, self.packed_modules_mapping): + return UnquantizedLinearMethod() + return ( + NPU_W8A8DynamicLinearMethod(self) + if self.is_dynamic + else NPU_W8A8LinearMethod(self) + ) + elif isinstance(layer, FusedMoE): + return NPU_W8A8MoEMethod(self) + return None + else: + if isinstance(layer, LinearBase): + return W8A8Int8LinearMethod(self) + elif isinstance(layer, FusedMoE): + return W8A8Int8MoEMethod(self) + return None + + def is_layer_skipped( + self, prefix: str, fused_mapping: Mapping[str, List[str]] = MappingProxyType({}) + ): + # adapted from vllm.model_executor.layers.quantization.utils.quant_utils.is_layer_skipped + proj_name = prefix.split(".")[-1] + if proj_name in fused_mapping: + shard_prefixes = [ + prefix.replace(proj_name, shard_proj_name) + for shard_proj_name in fused_mapping[proj_name] + ] + + is_skipped = None + for shard_prefix in shard_prefixes: + is_shard_skipped = ( + self.quant_description[shard_prefix + ".weight"] == "FLOAT" + ) + + if is_skipped is None: + is_skipped = is_shard_skipped + elif is_shard_skipped != is_skipped: + raise ValueError( + f"Detected some but not all shards of {prefix} " + "are quantized. All shards of fused layers " + "to have the same precision." + ) + else: + is_skipped = self.quant_description[prefix + ".weight"] == "FLOAT" + + assert is_skipped is not None + return is_skipped def get_scaled_act_names(self) -> List[str]: return [] @@ -321,3 +550,498 @@ class W8A8Int8MoEMethod: no_combine=no_combine, routed_scaling_factor=routed_scaling_factor, ) + + +class NPU_W8A8LinearMethodImpl: + """Linear method for NPU W8A8.""" + + def __init__(self) -> None: + # aclnn quant matmul requires to transpose matrix B, set to true by default. + self.transpose_weight = True + + @staticmethod + def get_weight( + input_size: int, + output_size: int, + params_dtype: torch.dtype = torch.bfloat16, + ) -> Dict[str, Any]: + params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)} + return params_dict + + @staticmethod + def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: + params_dict = {} + params_dict["input_scale"] = torch.empty(1, dtype=params_dtype) + params_dict["input_offset"] = torch.empty(1, dtype=torch.int8) + return params_dict + + @staticmethod + def get_perchannel_param( + output_size: int, + params_dtype: torch.dtype, + ) -> Dict[str, Any]: + params_dict = {} + params_dict["quant_bias"] = torch.empty(output_size, dtype=torch.int32) + if params_dtype == torch.bfloat16: + params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.float32) + elif params_dtype == torch.float16: + params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.int64) + params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype) + params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype) + return params_dict + + @staticmethod + def apply( + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + tp_rank: Optional[int] = 0, + ) -> torch.Tensor: + original_dtype = x.dtype + if original_dtype != torch.int8: + x = torch_npu.npu_quantize( + x, + layer.aclnn_input_scale, + layer.aclnn_input_offset, + torch.qint8, + -1, + True, + ) + + quant_bias = layer.quant_bias if tp_rank == 0 else None + return torch_npu.npu_quant_matmul( + x, + layer.weight, + layer.deq_scale, + bias=quant_bias, + output_dtype=original_dtype, + ) + + def process_weights_after_loading(self, layer): + expanding_factor = layer.weight.data.shape[1] + layer.aclnn_input_scale = torch.nn.Parameter( + layer.input_scale.data.repeat(expanding_factor).to(device="npu"), + requires_grad=False, + ) + layer.aclnn_input_offset = torch.nn.Parameter( + layer.input_offset.data.repeat(expanding_factor).to(device="npu"), + requires_grad=False, + ) + if self.transpose_weight: + layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() + layer.weight_scale.data = torch.flatten(layer.weight_scale.data) + layer.weight_offset.data = torch.flatten(layer.weight_offset.data) + + +class NPU_W8A8LinearMethodMTImpl: + """Linear method for NPU W8A8.""" + + def __init__(self) -> None: + self.transpose_weight = True + + @staticmethod + def get_weight( + input_size: int, + output_size: int, + params_dtype: torch.dtype = torch.bfloat16, + ) -> Dict[str, Any]: + params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)} + return params_dict + + @staticmethod + def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: + params_dict = {} + params_dict["input_scale"] = torch.empty(1, dtype=params_dtype) + params_dict["input_offset"] = torch.empty(1, dtype=torch.int8) + return params_dict + + @staticmethod + def get_perchannel_param( + output_size: int, + params_dtype: torch.dtype, + ) -> Dict[str, Any]: + params_dict = {} + params_dict["quant_bias"] = torch.empty(output_size, dtype=torch.int32) + if params_dtype == torch.bfloat16: + params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.float32) + elif params_dtype == torch.float16: + params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.int64) + params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype) + params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype) + return params_dict + + @staticmethod + def apply( + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + tp_rank: Optional[int] = 0, + ) -> torch.Tensor: + original_dtype = x.dtype + if original_dtype != torch.int8: + x = quant_per_tensor(x, layer.input_scale, layer.input_offset) + + quant_bias = layer.quant_bias if tp_rank == 0 else None + return ops.quant_matmul( + x=x, weight=layer.weight, deq_scale=layer.deq_scale, deq_bias=quant_bias + ) + + def process_weights_after_loading(self, layer): + layer.aclnn_deq_scale = torch.nn.Parameter( + torch_npu.npu_trans_quant_param(layer.deq_scale.npu()).to(device="npu"), + requires_grad=False, + ) + + +class NPU_W8A8LinearMethod(LinearMethodBase): + """Linear method for NPU quantization. + + This class search for specific quantization + implementation supported on NPU hardware for linear methods. + + Args: + quant_config: The NPU quantization config. + """ + + def __init__(self, quantization_config: W8A8Int8Config) -> None: + self.quantization_config = quantization_config + self.quant_method = ( + NPU_W8A8LinearMethodMTImpl() + if useMindIETurbo + else NPU_W8A8LinearMethodImpl() + ) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + weight_dict = self.quant_method.get_weight( + input_size_per_partition, output_size_per_partition, params_dtype + ) + for weight_name, weight_param in weight_dict.items(): + param = torch.nn.Parameter(weight_param, requires_grad=False) + set_weight_attrs(param, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter(weight_name, param) + set_weight_attrs(param, extra_weight_attrs) + + pertensor_dict = self.quant_method.get_pertensor_param(params_dtype) + for pertensor_name, pertensor_param in pertensor_dict.items(): + param = PerTensorScaleParameter( + data=pertensor_param, weight_loader=weight_loader + ) + # disable warning + param.ignore_warning = True + layer.register_parameter(pertensor_name, param) + + perchannel_dict = self.quant_method.get_perchannel_param( + output_size_per_partition, params_dtype + ) + for perchannel_name, perchannel_param in perchannel_dict.items(): + param = torch.nn.Parameter(perchannel_param, requires_grad=False) + set_weight_attrs(param, {"output_dim": 0}) + layer.register_parameter(perchannel_name, param) + set_weight_attrs(param, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if hasattr(self.quant_method, "process_weights_after_loading"): + self.quant_method.process_weights_after_loading(layer) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if isinstance(layer, RowParallelLinear): + tp_rank = get_tensor_model_parallel_rank() + return self.quant_method.apply(layer, x, bias, tp_rank) + return self.quant_method.apply(layer, x, bias) + + +class NPU_W8A8DynamicLinearMethodImpl: + """Linear method for NPU W8A8_DYNAMIC.""" + + def __init__(self): + self.transpose_weight = True + + @staticmethod + def get_weight( + input_size: int, output_size: int, params_dtype: torch.dtype + ) -> Dict[str, Any]: + params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)} + return params_dict + + @staticmethod + def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: + return {} + + @staticmethod + def get_perchannel_param( + output_size: int, + params_dtype: torch.dtype, + ) -> Dict[str, Any]: + params_dict = {} + params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype) + params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype) + return params_dict + + @staticmethod + def apply( + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + tp_rank: Optional[int] = 0, + ) -> torch.Tensor: + original_dtype = x.dtype + # use ATB quantize + quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x) + return torch_npu.npu_quant_matmul( + quant_out, + layer.weight, + layer.weight_scale, + pertoken_scale=dynamic_scale, + bias=bias, + output_dtype=original_dtype, + ) + + def process_weights_after_loading(self, layer): + if self.transpose_weight: + layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() + layer.weight_scale.data = layer.weight_scale.data.flatten() + layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32) + layer.weight_offset.data = layer.weight_offset.data.flatten() + + +class NPU_W8A8DynamicLinearMethod(LinearMethodBase): + """Linear method for NPU quantization. + + This class search for specific quantization + implementations supported on NPU hardware for linear methods. + + Args: + quant_config: The NPU quantization config. + """ + + def __init__(self, quantization_config: W8A8Int8Config) -> None: + self.quantization_config = quantization_config + self.quant_method = NPU_W8A8DynamicLinearMethodImpl() + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + weight_dict = self.quant_method.get_weight( + input_size_per_partition, output_size_per_partition, params_dtype + ) + for weight_name, weight_param in weight_dict.items(): + param = torch.nn.Parameter(weight_param, requires_grad=False) + set_weight_attrs(param, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter(weight_name, param) + set_weight_attrs(param, extra_weight_attrs) + + pertensor_dict = self.quant_method.get_pertensor_param(params_dtype) + for pertensor_name, pertensor_param in pertensor_dict.items(): + param = PerTensorScaleParameter( + data=pertensor_param, weight_loader=weight_loader + ) + # disable warning + param.ignore_warning = True + layer.register_parameter(pertensor_name, param) + + perchannel_dict = self.quant_method.get_perchannel_param( + output_size_per_partition, params_dtype + ) + for perchannel_name, perchannel_param in perchannel_dict.items(): + param = torch.nn.Parameter(perchannel_param, requires_grad=False) + set_weight_attrs(param, {"output_dim": 0}) + layer.register_parameter(perchannel_name, param) + set_weight_attrs(param, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if hasattr(self.quant_method, "process_weights_after_loading"): + self.quant_method.process_weights_after_loading(layer) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if isinstance(layer, RowParallelLinear): + tp_rank = get_tensor_model_parallel_rank() + return self.quant_method.apply(layer, x, bias, tp_rank) + return self.quant_method.apply(layer, x, bias) + + +class NPU_W8A8MoEMethod: + """MoE method for NPU quantization. + + This class search for specific quantization + implementations supported on NPU hardware for moe methods. + + Args: + quant_config: The NPU quantization config. + """ + + def __init__(self, quantization_config: W8A8Int8Config) -> None: + self.quantization_config = quantization_config + self.quant_method = self + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size: List[int], + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + + self.num_experts = num_experts + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} + ) + + # weight + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, 2 * intermediate_size, hidden_size, dtype=torch.int8 + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + w2_weight = torch.nn.Parameter( + torch.empty(num_experts, hidden_size, intermediate_size, dtype=torch.int8), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + # scale + w13_weight_scale = torch.nn.Parameter( + torch.empty(num_experts, 2 * intermediate_size, 1, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + w2_weight_scale = torch.nn.Parameter( + torch.empty(num_experts, hidden_size, 1, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + # offset + w13_weight_offset = torch.nn.Parameter( + torch.empty(num_experts, 2 * intermediate_size, 1, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w13_weight_offset", w13_weight_offset) + set_weight_attrs(w13_weight_offset, extra_weight_attrs) + w2_weight_offset = torch.nn.Parameter( + torch.empty(num_experts, hidden_size, 1, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_weight_offset", w2_weight_offset) + set_weight_attrs(w2_weight_offset, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.w13_weight = Parameter( + layer.w13_weight.data.transpose(1, 2).contiguous(), requires_grad=False + ) + layer.w2_weight = Parameter( + layer.w2_weight.data.transpose(1, 2).contiguous(), requires_grad=False + ) + layer.w13_weight_scale = Parameter( + layer.w13_weight_scale.data.squeeze(-1).contiguous(), requires_grad=False + ) + layer.w2_weight_scale = Parameter( + layer.w2_weight_scale.data.squeeze(-1).contiguous(), requires_grad=False + ) + layer.w13_weight_offset = Parameter( + layer.w13_weight_offset.data.squeeze(-1).contiguous(), requires_grad=False + ) + layer.w2_weight_offset = Parameter( + layer.w2_weight_offset.data.squeeze(-1).contiguous(), requires_grad=False + ) + + def apply( + self, + layer, + x, + router_logits, + top_k, + renormalize, + use_grouped_topk, + topk_group, + num_expert_group, + num_fused_shared_experts, + custom_routing_function, + correction_bias, + activation, + apply_router_weight_on_input, + routed_scaling_factor, + **kwargs, + ) -> torch.Tensor: + from sglang.srt.layers.moe.topk import select_experts + + global_num_experts = router_logits.shape[-1] + # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern + if global_num_experts == 256: + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( + router_logits, + k=top_k, + bias=correction_bias, + k_group=topk_group, + group_count=num_expert_group, + group_select_mode=1, + renorm=0, + norm_type=1, + routed_scaling_factor=1, + eps=float(1e-20), + ) + else: + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + num_fused_shared_experts=num_fused_shared_experts, + custom_routing_function=custom_routing_function, + correction_bias=correction_bias, + torch_native=True, + routed_scaling_factor=routed_scaling_factor, + ) + topk_ids = topk_ids.to(torch.int32) + topk_weights = topk_weights.to(x.dtype) + return npu_fused_experts( + hidden_states=x, + w13=layer.w13_weight, + w13_scale=layer.w13_weight_scale, + w2=layer.w2_weight, + w2_scale=layer.w2_weight_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + ) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index b8d280dfa..cc2d8e20c 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -34,16 +34,18 @@ import torch import torch.distributed as dist import triton import triton.language as tl -from sgl_kernel.kvcacheio import transfer_kv_per_layer, transfer_kv_per_layer_mla from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2 +from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2 logger = logging.getLogger(__name__) GB = 1024 * 1024 * 1024 _is_cuda = is_cuda() +_is_npu = is_npu() +if not _is_npu: + from sgl_kernel.kvcacheio import transfer_kv_per_layer, transfer_kv_per_layer_mla class ReqToTokenPool: diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 5b267caf2..733e6df9e 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -64,10 +64,13 @@ from sglang.srt.model_loader.weight_utils import ( from sglang.srt.utils import ( get_bool_env_var, get_device_capability, + is_npu, is_pin_memory_available, set_weight_attrs, ) +_is_npu = is_npu() + @contextmanager def device_loading_context(module: torch.nn.Module, target_device: torch.device): @@ -127,18 +130,19 @@ def _get_quantization_config( # (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3 if quant_config is None: return None - major, minor = get_device_capability() + if not _is_npu: + major, minor = get_device_capability() - if major is not None and minor is not None: - assert 0 <= minor < 10 - capability = major * 10 + minor - if capability < quant_config.get_min_capability(): - raise ValueError( - f"The quantization method {model_config.quantization} " - "is not supported for the current GPU. " - f"Minimum capability: {quant_config.get_min_capability()}. " - f"Current capability: {capability}." - ) + if major is not None and minor is not None: + assert 0 <= minor < 10 + capability = major * 10 + minor + if capability < quant_config.get_min_capability(): + raise ValueError( + f"The quantization method {model_config.quantization} " + "is not supported for the current GPU. " + f"Minimum capability: {quant_config.get_min_capability()}. " + f"Current capability: {capability}." + ) supported_dtypes = quant_config.get_supported_act_dtypes() if model_config.dtype not in supported_dtypes: raise ValueError( @@ -157,6 +161,13 @@ def _initialize_model( """Initialize a model with the given configurations.""" model_class, _ = get_model_architecture(model_config) packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {}) + if _is_npu: + packed_modules_mapping["fused_qkv_a_proj_with_mqa"] = [ + "q_a_proj", + "kv_a_proj_with_mqa", + ] + packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"] + packed_modules_mapping["gate_up_proj"] = ["gate_proj", "up_proj"] quant_config = _get_quantization_config( model_config, load_config, packed_modules_mapping ) diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 24a16bf21..f8cfe859b 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -575,6 +575,8 @@ class LlamaForCausalLM(nn.Module): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if name not in params_dict: + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) diff --git a/python/sglang/srt/models/mixtral_quant.py b/python/sglang/srt/models/mixtral_quant.py index cfb23cdf3..5b84c90dd 100644 --- a/python/sglang/srt/models/mixtral_quant.py +++ b/python/sglang/srt/models/mixtral_quant.py @@ -407,6 +407,8 @@ class QuantMixtralForCausalLM(nn.Module): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if name not in params_dict: + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -418,6 +420,8 @@ class QuantMixtralForCausalLM(nn.Module): # Skip experts that are not assigned to this worker. if "block_sparse_moe.experts." in name and name not in params_dict: continue + if name not in params_dict: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 987204d83..e3670bb55 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -538,6 +538,8 @@ class Qwen2ForCausalLM(nn.Module): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if name not in params_dict: + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 1c062c788..ce159a4da 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -197,7 +197,7 @@ def get_int_env_var(name: str, default: int = 0) -> int: def support_triton(backend: str) -> bool: - return backend not in ["torch_native", "intel_amx"] + return backend not in ["torch_native", "intel_amx", "ascend"] try: @@ -2782,3 +2782,101 @@ def lru_cache_frozenset(maxsize=128): return wrapper return decorator + + +def apply_module_patch(target_module, target_function, wrappers): + original_module, original_function = parse_module_path( + target_module, target_function, False + ) + + original_function_id = id(original_function) + + candidate = original_function + for wrapper in wrappers: + candidate = wrapper(candidate) + if target_function is not None: + setattr(original_module, target_function, candidate) + + for key, value in sys.modules.copy().items(): + if ( + target_function is not None + and hasattr(value, target_function) + and id(getattr(value, target_function)) == original_function_id + ): + setattr(value, target_function, candidate) + + +def parse_module_path(module_path, function_name, create_dummy): + from importlib.machinery import ModuleSpec + + def create_dummy_module(full_path, parent=None): + """Create and register a placeholder module""" + dummy = types.ModuleType(full_path) + dummy.__file__ = "vllm_ascend.dummy_module.py" + dummy.__spec__ = ModuleSpec(full_path, None) + sys.modules[full_path] = dummy + if parent: + setattr(parent, full_path.split(".")[-1], dummy) + return dummy + + def create_placeholder_function(func_name): + """Create dummy function that raises when called""" + + def placeholder(*args, **kwargs): + raise NotImplementedError(f"Function {func_name} is a placeholder") + + placeholder.__name__ = func_name + return placeholder + + modules = module_path.split(".") + current_module = None + processed_path = [] + + for idx, part in enumerate(modules): + current_path = ".".join(modules[: idx + 1]) + parent_path = ".".join(modules[:idx]) if idx > 0 else None + + try: + current_module = importlib.import_module(current_path) + except ModuleNotFoundError: + # Handle missing module + parent = importlib.import_module(parent_path) if parent_path else None + if parent and hasattr(parent, part): + # Use existing attribute from parent + current_module = getattr(parent, part) + # Check for early function resolution + if function_name and hasattr(current_module, function_name): + return current_module, getattr(current_module, function_name) + if function_name and create_dummy: + ph_func = create_placeholder_function(function_name) + setattr(current_module, function_name, ph_func) + return current_module, ph_func + if function_name: + raise AttributeError( + f"Function {function_name} missing in {current_path}" + ) + else: + if not create_dummy: + raise + # Create and register dummy module + current_module = create_dummy_module( + current_path, + parent=( + importlib.import_module(parent_path) if parent_path else None + ), + ) + + processed_path.append(part) + + # Final function handling + final_module = sys.modules[module_path] + if function_name is not None: + if not hasattr(final_module, function_name): + if create_dummy: + ph_func = create_placeholder_function(function_name) + setattr(final_module, function_name, ph_func) + else: + setattr(final_module, function_name, None) + return final_module, getattr(final_module, function_name) + + return final_module, None