diff --git a/docs/source/installation.md b/docs/source/installation.md index 93c13a6..5dc99fd 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -61,6 +61,7 @@ docker run --rm \ -v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \ -v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \ -v /etc/ascend_install.info:/etc/ascend_install.info \ + -v /root/.cache:/root/.cache \ -it $IMAGE bash ``` @@ -123,7 +124,7 @@ First install system dependencies: ```bash apt update -y -apt install -y gcc g++ cmake libnuma-dev +apt install -y gcc g++ cmake libnuma-dev wget ``` Current version depends on a unreleased `torch-npu`, you need to install manually: @@ -144,6 +145,7 @@ cd pta wget https://pytorch-package.obs.cn-north-4.myhuaweicloud.com/pta/Daily/v2.5.1/20250320.3/pytorch_v2.5.1_py310.tar.gz tar -xvf pytorch_v2.5.1_py310.tar.gz pip install ./torch_npu-2.5.1.dev20250320-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl +cd .. ``` Then you can install `vllm` and `vllm-ascend` from **pre-built wheel**: @@ -152,6 +154,8 @@ Then you can install `vllm` and `vllm-ascend` from **pre-built wheel**: :substitutions: # Install vllm-project/vllm from pypi +# There was a vLLM v0.8.4 installation bug, please use "Build from source code" +# https://github.com/vllm-project/vllm-ascend/issues/581 pip install vllm==|pip_vllm_version| # Install vllm-project/vllm-ascend from pypi. @@ -168,11 +172,13 @@ or build from **source code**: git clone --depth 1 --branch |vllm_version| https://github.com/vllm-project/vllm cd vllm VLLM_TARGET_DEVICE=empty pip install . --extra-index https://download.pytorch.org/whl/cpu/ +cd .. # Install vLLM Ascend git clone --depth 1 --branch |vllm_ascend_version| https://github.com/vllm-project/vllm-ascend.git cd vllm-ascend pip install -e . --extra-index https://download.pytorch.org/whl/cpu/ +cd .. ``` ::: @@ -216,6 +222,7 @@ docker run --rm \ -v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \ -v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \ -v /etc/ascend_install.info:/etc/ascend_install.info \ + -v /root/.cache:/root/.cache \ -it $IMAGE bash ``` diff --git a/tests/singlecard/test_offline_inference.py b/tests/singlecard/test_offline_inference.py index efa205d..9249c33 100644 --- a/tests/singlecard/test_offline_inference.py +++ b/tests/singlecard/test_offline_inference.py @@ -30,7 +30,9 @@ from tests.conftest import VllmRunner MODELS = [ "Qwen/Qwen2.5-0.5B-Instruct", + "vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8", ] +os.environ["VLLM_USE_MODELSCOPE"] = "True" os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256" diff --git a/vllm_ascend/quantization/func_wrapper.py b/vllm_ascend/quantization/func_wrapper.py new file mode 100644 index 0000000..77ecca2 --- /dev/null +++ b/vllm_ascend/quantization/func_wrapper.py @@ -0,0 +1,151 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional, Tuple, Union + +import torch +import torch_npu +from vllm.logger import logger +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import UnquantizedLinearMethod + + +# func refers to RMSNorm.__init__ +def wrapper_rmsnorm_init(func): + + def init(self, hidden_size: int, **extra_args) -> None: + func(self, hidden_size, **extra_args) + self.ignore_anti = True + self.bias = torch.nn.Parameter(torch.zeros(hidden_size), + requires_grad=False) + + return init + + +# func refers to RMSNorm.forward_oot +def wrapper_rmsnorm_forward_oot(func): + + def _rmsnorm_forward_oot( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if not self.ignore_anti: + if residual is not None: + residual += x + out = torch_npu._npu_quant_rms_norm( + residual, + self.weight, + self.bias, + self.input_scale, + self.input_offset, + self.variance_epsilon, + ) + return out, residual + out = torch_npu._npu_quant_rms_norm( + x, + self.weight, + self.bias, + self.input_scale, + self.input_offset, + self.variance_epsilon, + ) + return out + + if residual is not None: + x, residual = func(self, x, residual) + return x.add_(self.bias), residual + + return func(self, x).add_(self.bias) + + return _rmsnorm_forward_oot + + +MODEL_LAYER_MAPPING = { + "LlamaModel": { + "attn": { + "layer_attr": "self_attn", + "proj_attr": "qkv_proj", + "norm_attr": "input_layernorm", + "unquantized_type": UnquantizedLinearMethod, + }, + "mlp": { + "layer_attr": "mlp", + "proj_attr": "gate_up_proj", + "norm_attr": "post_attention_layernorm", + "unquantized_type": UnquantizedLinearMethod, + }, + }, +} + + +def wrapper_load_model(func): + + def postprocess_loading(self) -> None: + func(self) + + def process_layer(layer, idx, mapping): + + def process_module(module_cfg, layer_obj): + if module_cfg is None: + return + + module_obj = getattr(layer_obj, module_cfg["layer_attr"], None) + if module_obj is None: + return + + proj_attr = module_cfg["proj_attr"] + if callable(proj_attr): + proj = proj_attr(module_obj, idx) + else: + proj = getattr(module_obj, proj_attr, None) + + norm = getattr(layer_obj, module_cfg["norm_attr"], None) + + if proj is None or norm is None: + return + + norm.ignore_anti = isinstance(proj.quant_method, + module_cfg["unquantized_type"]) + if not norm.ignore_anti: + for param_name in ["input_scale", "input_offset"]: + if hasattr(proj, param_name): + param = getattr(proj, param_name) + norm.register_parameter( + param_name, + torch.nn.Parameter(param.clone(), + requires_grad=False)) + + process_module(mapping.get("attn"), layer) + process_module(mapping.get("mlp"), layer) + + model_type = self.model.model.__class__.__name__ + mapping = MODEL_LAYER_MAPPING.get(model_type) + + if not mapping: + logger.info( + f"Warning: Model type '{model_type}' not found in MODEL_LAYER_MAPPING. Skipping layer mapping." + ) + return + + for idx, layer in enumerate(self.model.model.layers): + process_layer(layer, idx, mapping) + + if isinstance(self.model.model.norm, RMSNorm): + self.model.model.norm.ignore_anti = True + + return postprocess_loading diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index da5f8a9..702829e 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -306,23 +306,23 @@ class AscendFusedMoEMethod(FusedMoEMethodBase): self, layer: torch.nn.Module, x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, router_logits: torch.Tensor, + top_k: int, renormalize: bool, - global_num_experts: int, - expert_map: torch.Tensor, + use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, - is_prefill: bool = True, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None + e_score_correction_bias: Optional[torch.Tensor] = None, + **kwargs, ) -> torch.Tensor: - return self.quant_method.apply(layer, x, use_grouped_topk, top_k, - router_logits, renormalize, topk_group, - num_expert_group, global_num_experts, - expert_map, is_prefill, + return self.quant_method.apply(layer, x, router_logits, top_k, + renormalize, use_grouped_topk, + topk_group, num_expert_group, + global_num_experts, expert_map, custom_routing_function, scoring_func, e_score_correction_bias) diff --git a/vllm_ascend/quantization/quantizer.py b/vllm_ascend/quantization/quantizer.py index eee5159..ea1297b 100644 --- a/vllm_ascend/quantization/quantizer.py +++ b/vllm_ascend/quantization/quantizer.py @@ -16,8 +16,18 @@ # import importlib +import sys +import types from typing import Any, Dict, List, Optional +from vllm.logger import logger + +from .func_wrapper import (wrapper_load_model, wrapper_rmsnorm_forward_oot, + wrapper_rmsnorm_init) +from .w8a8 import AscendW8A8LinearMethod +from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod, + AscendW8A8DynamicLinearMethod) + CUSTOMIZED_QUANTIZER_TYPE: List[str] = [] @@ -39,12 +49,11 @@ class AscendQuantizer: try: module = importlib.import_module("mindie_turbo") MindIETurboQuantizer = module.MindIETurboQuantizer + return MindIETurboQuantizer.get_quantizer(quant_config, prefix, + packed_modules_mapping) except ImportError: - raise NotImplementedError( - "There is no available ascend quantizer.") - - return MindIETurboQuantizer.get_quantizer(quant_config, prefix, - packed_modules_mapping) + return VLLMAscendQuantizer.get_quantizer(quant_config, prefix, + packed_modules_mapping) def build_linear_method(self): raise NotImplementedError @@ -54,3 +63,225 @@ class AscendQuantizer: def build_attention_method(self): raise NotImplementedError + + +class VLLMAscendQuantizer: + _instance: Optional[object] = None + patched = False + + def __init__(self, quant_description): + if VLLMAscendQuantizer.patched: + return + for name in quant_description.keys(): + if "norm.bias" in name: + VLLMAscendQuantizer.apply_patch( + "vllm.model_executor.layers.layernorm.RMSNorm", "__init__", + [wrapper_rmsnorm_init]) + VLLMAscendQuantizer.apply_patch( + "vllm.model_executor.layers.layernorm.RMSNorm", + "forward_oot", [wrapper_rmsnorm_forward_oot]) + VLLMAscendQuantizer.apply_patch( + "vllm_ascend.worker.model_runner.NPUModelRunnerBase", + "load_model", [wrapper_load_model]) + break + VLLMAscendQuantizer.patched = True + logger.info("Using the vLLM Ascend Quantizer version now!") + + @staticmethod + def apply_patch(target_module, target_function, wrappers): + + original_module, original_function = VLLMAscendQuantizer.parse_path( + target_module, target_function, False) + + original_function_id = id(original_function) + + candidate = original_function + for wrapper in wrappers: + candidate = wrapper(candidate) + if target_function is not None: + setattr(original_module, target_function, candidate) + + for key, value in sys.modules.copy().items(): + if (target_function is not None + and hasattr(value, target_function) + and id(getattr(value, + target_function)) == original_function_id): + setattr(value, target_function, candidate) + + @staticmethod + def parse_path(module_path, function_name, create_dummy): + """ + Parse module path and resolve/create modules as needed. + + Args: + module_path: Dot-separated module path + function_name: Target function name (None for module only) + create_dummy: Create dummy modules/functions when missing + + Returns: + Tuple of (resolved module, target function/none) + + Raises: + ModuleNotFoundError: If module path is invalid and create_dummy=False + AttributeError: If function is missing and create_dummy=False + """ + from importlib.machinery import ModuleSpec + + def create_dummy_module(full_path, parent=None): + """Create and register a placeholder module""" + dummy = types.ModuleType(full_path) + dummy.__file__ = "vllm_ascend.dummy_module.py" + dummy.__spec__ = ModuleSpec(full_path, None) + sys.modules[full_path] = dummy + if parent: + setattr(parent, full_path.split(".")[-1], dummy) + return dummy + + def create_placeholder_function(func_name): + """Create dummy function that raises when called""" + + def placeholder(*args, **kwargs): + raise NotImplementedError( + f"Function {func_name} is a placeholder") + + placeholder.__name__ = func_name + return placeholder + + modules = module_path.split(".") + current_module = None + processed_path = [] + + for idx, part in enumerate(modules): + current_path = ".".join(modules[:idx + 1]) + parent_path = ".".join(modules[:idx]) if idx > 0 else None + + try: + current_module = importlib.import_module(current_path) + except ModuleNotFoundError: + # Handle missing module + parent = importlib.import_module( + parent_path) if parent_path else None + if parent and hasattr(parent, part): + # Use existing attribute from parent + current_module = getattr(parent, part) + # Check for early function resolution + if function_name and hasattr(current_module, + function_name): + return current_module, getattr(current_module, + function_name) + if function_name and create_dummy: + ph_func = create_placeholder_function(function_name) + setattr(current_module, function_name, ph_func) + return current_module, ph_func + if function_name: + raise AttributeError( + f"Function {function_name} missing in {current_path}" + ) + else: + if not create_dummy: + raise + # Create and register dummy module + current_module = create_dummy_module( + current_path, + parent=importlib.import_module(parent_path) + if parent_path else None) + + processed_path.append(part) + + # Final function handling + final_module = sys.modules[module_path] + if function_name is not None: + if not hasattr(final_module, function_name): + if create_dummy: + ph_func = create_placeholder_function(function_name) + setattr(final_module, function_name, ph_func) + else: + setattr(final_module, function_name, None) + return final_module, getattr(final_module, function_name) + + return final_module, None + + @staticmethod + def build_linear_method(): + raise NotImplementedError( + "Linear method is not implemented for the current quant type.") + + @staticmethod + def build_moe_method(): + raise NotImplementedError( + "MoE method is not implemented for the current quant type.") + + @staticmethod + def build_attention_method(): + raise NotImplementedError( + "Attention method is not implemented for the current quant type.") + + @staticmethod + def get_linear_quant_type(quant_description: Dict[str, Any], prefix: str, + packed_modules_mapping: Dict[str, Any]): + proj_name = prefix.split(".")[-1] + if proj_name in packed_modules_mapping: + quant_type = None + shard_prefixes = [ + prefix.replace(proj_name, shard_proj_name) + for shard_proj_name in packed_modules_mapping[proj_name] + ] + for shard_prefix in shard_prefixes: + shard_quant_type = quant_description[shard_prefix + '.weight'] + + if quant_type is None: + quant_type = shard_quant_type + elif shard_quant_type != quant_type: + raise ValueError( + f"Not all shards of {prefix} are quantized with same quant type." + f"Shard {proj_name} uses {shard_quant_type}, but another shard" + f"use {quant_type}. Please check quantization config.") + else: + quant_type = quant_description[prefix + '.weight'] + return quant_type + + @classmethod + def get_quantizer(cls, + quant_description: Dict[str, Any], + prefix: str, + packed_modules_mapping: Optional[Dict[str, Any]] = None): + if packed_modules_mapping is None: + packed_modules_mapping = dict() + # Attention + if '.attn' in prefix and 'fa_quant_type' in quant_description.keys(): + quant_type = quant_description['fa_quant_type'] + # Linear + else: + quant_type = cls.get_linear_quant_type(quant_description, prefix, + packed_modules_mapping) + if quant_type in SUPPORT_ASCEND_QUANTIZER_TYPE.keys(): + cls = SUPPORT_ASCEND_QUANTIZER_TYPE[quant_type] + if not cls._instance: + cls._instance = cls(quant_description) + return cls._instance + raise NotImplementedError("Currently, vLLM Ascend only supports following quant types:" \ + f"{list(SUPPORT_ASCEND_QUANTIZER_TYPE.keys())}") + + +class W8A8Quantizer(VLLMAscendQuantizer): + + @staticmethod + def build_linear_method(): + return AscendW8A8LinearMethod() + + +class W8A8DYNAMICQuantizer(VLLMAscendQuantizer): + + @staticmethod + def build_linear_method(): + return AscendW8A8DynamicLinearMethod() + + @staticmethod + def build_moe_method(): + return AscendW8A8DynamicFusedMoEMethod() + + +SUPPORT_ASCEND_QUANTIZER_TYPE = { + "W8A8": W8A8Quantizer, + "W8A8_DYNAMIC": W8A8DYNAMICQuantizer, +} diff --git a/vllm_ascend/quantization/w8a8.py b/vllm_ascend/quantization/w8a8.py new file mode 100644 index 0000000..b1e081d --- /dev/null +++ b/vllm_ascend/quantization/w8a8.py @@ -0,0 +1,105 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, Dict, Optional + +import torch +import torch_npu + + +def quant_per_tensor(in_tensor: torch.Tensor, input_scale: torch.Tensor, + input_offset: torch.Tensor): + out = torch.empty_like(in_tensor, dtype=torch.int8) + torch_npu._npu_quantize_per_tensor(in_tensor, input_scale, input_offset, + out) + return out + + +class AscendW8A8LinearMethod: + """Linear method for Ascend W8A8. + + Args: + w_sym: whether the linear weight is symmetrically quantized. + """ + + def __init__(self) -> None: + # aclnn quant matmul requires to transpose matrix B, set to true by default. + self.transpose_weight = True + + @staticmethod + def get_weight( + input_size: int, + output_size: int, + params_dtype: torch.dtype = torch.bfloat16, + ) -> Dict[str, Any]: + params_dict = { + "weight": torch.empty(output_size, input_size, dtype=torch.int8) + } + return params_dict + + @staticmethod + def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: + params_dict = {} + params_dict["input_scale"] = torch.empty(1, dtype=params_dtype) + params_dict["input_offset"] = torch.empty(1, dtype=torch.int8) + return params_dict + + @staticmethod + def get_perchannel_param( + output_size: int, + params_dtype: torch.dtype, + ) -> Dict[str, Any]: + params_dict = {} + params_dict["quant_bias"] = torch.empty(output_size, dtype=torch.int32) + if params_dtype == torch.bfloat16: + params_dict["deq_scale"] = torch.empty(output_size, + dtype=torch.float32) + elif params_dtype == torch.float16: + params_dict["deq_scale"] = torch.empty(output_size, + dtype=torch.int64) + params_dict["weight_scale"] = torch.empty(output_size, + 1, + dtype=params_dtype) + params_dict["weight_offset"] = torch.empty(output_size, + 1, + dtype=params_dtype) + return params_dict + + @staticmethod + def apply( + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + tp_rank: Optional[int] = 0, + ) -> torch.Tensor: + original_dtype = x.dtype + if original_dtype != torch.int8: + x = quant_per_tensor(x, layer.input_scale, layer.input_offset) + quant_bias = layer.quant_bias if tp_rank == 0 else None + return torch_npu.npu_quant_matmul( + x, + layer.weight, + layer.deq_scale, + bias=quant_bias, + output_dtype=original_dtype, + ) + + def process_weights_after_loading(self, layer): + if self.transpose_weight: + layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() + layer.weight_scale.data = torch.flatten(layer.weight_scale.data) + layer.weight_offset.data = torch.flatten(layer.weight_offset.data) diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py new file mode 100644 index 0000000..52796a8 --- /dev/null +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -0,0 +1,331 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, Callable, Dict, Optional + +import torch +import torch_npu + +from vllm_ascend.ops.fused_moe import select_experts + + +def fused_experts(hidden_states: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + expert_map: torch.Tensor = None): + original_shape = hidden_states.shape + if len(original_shape) == 3: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + num_tokens, _ = hidden_states.shape + num_experts = w1.shape[0] + dtype = hidden_states.dtype + device = hidden_states.device + + if expert_map is not None: + # Generate token indices and flatten + token_indices = (torch.arange(num_tokens, + device=device, + dtype=torch.int64).unsqueeze(1).expand( + -1, top_k).reshape(-1)) + + # Flatten token-to-expert mappings and map to local experts + weights_flat = topk_weights.view(-1) + experts_flat = topk_ids.view(-1) + local_experts_flat = expert_map[experts_flat] + + # Filter valid token-expert pairs + mask = local_experts_flat != -1 + filtered_weights = torch.where( + mask, weights_flat, torch.zeros_like(weights_flat)).to(dtype) + filtered_experts = torch.where( + mask, local_experts_flat, + torch.full_like(local_experts_flat, + num_experts)).to(topk_ids.dtype) + + # Sort by local expert IDs + sort_indices = torch.argsort(filtered_experts) + sorted_token_indices = token_indices[sort_indices] + sorted_weights = filtered_weights[sort_indices] + + # Compute token counts with minlength of num_experts + # This is equivalent to but faster than: + # >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1] + token_counts = torch.zeros(num_experts + 1, + device=device, + dtype=torch.int64) + ones = torch.ones_like(filtered_experts, dtype=torch.int64) + token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones) + token_counts = token_counts[:num_experts] + expert_tokens = torch.cumsum(token_counts, dim=0, dtype=torch.int64) + + # Rearrange hidden_states + sorted_hidden_states = hidden_states[sorted_token_indices] + else: + row_idx_len = num_tokens * top_k + row_idx = torch.arange(0, + row_idx_len, + dtype=torch.int32, + device=topk_weights.device).view( + top_k, -1).permute(1, 0).contiguous() + sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( + hidden_states, + row_idx=row_idx, + expert_idx=topk_ids, + active_num=num_tokens) + del hidden_states + + expert_tokens = torch_npu.npu_moe_compute_expert_tokens( + expanded_expert_idx, num_experts) + expert_tokens = expert_tokens.to(torch.int64) + + quant_x, x_dynamic_scale = torch_npu.npu_dynamic_quant( + sorted_hidden_states) + del sorted_hidden_states + output_dtype = torch.bfloat16 if w1_scale.dtype == torch.bfloat16 else torch.float16 + + gate_up_out_list = torch_npu.npu_grouped_matmul( + x=[quant_x], + weight=[w1], + scale=[w1_scale], + per_token_scale=[x_dynamic_scale], + split_item=2, + group_list_type=0, + group_type=0, + group_list=expert_tokens, + output_dtype=output_dtype) + del quant_x + + gate_up_out_list = gate_up_out_list[0] if len( + gate_up_out_list) == 1 else torch.cat(gate_up_out_list, dim=0) + gate_up_out_list = torch_npu.npu_swiglu(gate_up_out_list) + + quant_gate_up_out_list, gate_up_out_dynamic_scale = torch_npu.npu_dynamic_quant( + gate_up_out_list) + del gate_up_out_list + + down_out_list = torch_npu.npu_grouped_matmul( + x=[quant_gate_up_out_list], + weight=[w2], + scale=[w2_scale], + per_token_scale=[gate_up_out_dynamic_scale], + split_item=2, + group_list_type=0, + group_type=0, + group_list=expert_tokens, + output_dtype=output_dtype) + del quant_gate_up_out_list + + down_out_list = down_out_list[0] if len(down_out_list) == 1 else torch.cat( + down_out_list, dim=0) + + if expert_map is not None: + weighted_down_out = down_out_list * sorted_weights.unsqueeze(1) + + final_hidden_states = torch.zeros(*original_shape, + device=hidden_states.device, + dtype=dtype) + final_hidden_states.index_add_(0, sorted_token_indices, + weighted_down_out) + # TODO: This should not happen! Look into it! + # fill nan with 0.0 + final_hidden_states[torch.isnan(final_hidden_states)] = 0.0 + else: + final_hidden_states = torch_npu.npu_moe_finalize_routing( + down_out_list, + skip1=None, + skip2=None, + bias=None, + scales=topk_weights, + expanded_src_to_dst_row=expanded_row_idx, + export_for_source_row=topk_ids) + del down_out_list + if len(original_shape) == 3: + final_hidden_states = final_hidden_states.view(original_shape) + return final_hidden_states + + +class AscendW8A8DynamicLinearMethod: + """Linear method for Ascend W8A8_DYNAMIC. + """ + + def __init__(self): + self.transpose_weight = True + + @staticmethod + def get_weight(input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + params_dict = { + "weight": torch.empty(output_size, input_size, dtype=torch.int8) + } + return params_dict + + @staticmethod + def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: + return {} + + @staticmethod + def get_perchannel_param( + output_size: int, + params_dtype: torch.dtype, + ) -> Dict[str, Any]: + params_dict = {} + params_dict["weight_scale"] = torch.empty(output_size, + 1, + dtype=params_dtype) + params_dict["weight_offset"] = torch.empty(output_size, + 1, + dtype=params_dtype) + return params_dict + + @staticmethod + def apply( + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + tp_rank: Optional[int] = 0, + ) -> torch.Tensor: + original_dtype = x.dtype + # use ATB quantize + quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x) + return torch_npu.npu_quant_matmul( + quant_out, + layer.weight, + layer.weight_scale, + pertoken_scale=dynamic_scale, + bias=bias, + output_dtype=original_dtype, + ) + + def process_weights_after_loading(self, layer): + if self.transpose_weight: + layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() + layer.weight_scale.data = layer.weight_scale.data.flatten() + layer.weight_offset.data = layer.weight_offset.data.flatten() + + +class AscendW8A8DynamicFusedMoEMethod: + """FusedMoe method for Ascend W8A8_DYNAMIC. + """ + + def __init__(self): + self.transpose_weight = True + + @staticmethod + def get_weight(num_experts: int, intermediate_size_per_partition: int, + hidden_sizes: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + param_dict = {} + param_dict["w13_weight"] = torch.empty(num_experts, + 2 * + intermediate_size_per_partition, + hidden_sizes, + dtype=torch.int8) + param_dict["w2_weight"] = torch.empty(num_experts, + hidden_sizes, + intermediate_size_per_partition, + dtype=torch.int8) + return param_dict + + @staticmethod + def get_dynamic_quant_param(num_experts: int, + intermediate_size_per_partition: int, + hidden_sizes: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + param_dict = {} + param_dict["w13_weight_scale"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=params_dtype) + param_dict["w13_weight_offset"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=params_dtype) + param_dict["w2_weight_scale"] = torch.empty(num_experts, + hidden_sizes, + 1, + dtype=params_dtype) + param_dict["w2_weight_offset"] = torch.empty(num_experts, + hidden_sizes, + 1, + dtype=params_dtype) + return param_dict + + @staticmethod + def apply( + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + assert router_logits.shape[ + 1] == global_num_experts, "Number of global experts mismatch" + + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + + return fused_experts(hidden_states=x, + w1=layer.w13_weight, + w1_scale=layer.w13_weight_scale, + w2=layer.w2_weight, + w2_scale=layer.w2_weight_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map) + + def process_weights_after_loading(self, layer): + if self.transpose_weight: + layer.w13_weight.data = layer.w13_weight.data.transpose( + 1, 2).contiguous() + layer.w2_weight.data = layer.w2_weight.data.transpose( + 1, 2).contiguous() + layer.w13_weight_scale.data = layer.w13_weight_scale.data.view( + layer.w13_weight_scale.data.shape[0], -1) + layer.w13_weight_offset.data = layer.w13_weight_offset.data.view( + layer.w13_weight_offset.data.shape[0], -1) + layer.w2_weight_scale.data = layer.w2_weight_scale.data.view( + layer.w2_weight_scale.data.shape[0], -1) + layer.w2_weight_offset.data = layer.w2_weight_offset.data.view( + layer.w2_weight_offset.data.shape[0], -1)