v0.10.1rc1

This commit is contained in:
2025-09-09 09:40:35 +08:00
parent d6f6ef41fe
commit 9149384e03
432 changed files with 84698 additions and 1 deletions

View File

View File

@@ -0,0 +1,184 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional, Tuple, Union
import torch
import torch_npu
from vllm.logger import logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, QuantizationConfig)
# func refers to vocabParallelEmbedding.__init__
def wrapper_vocab_parallel_embedding_init(func):
def init(
self,
num_embeddings: int,
embedding_dim: int,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
func(
self,
num_embeddings,
embedding_dim,
params_dtype,
org_num_embeddings,
padding_size,
quant_config,
prefix,
)
# TODO: Contact vLLM maintainers to add a `params_dtype` attribute to the `VocabParallelEmbedding` class.
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
return init
# func refers to RMSNorm.__init__
def wrapper_rmsnorm_init(func):
def init(self, hidden_size: int, **extra_args) -> None:
func(self, hidden_size, **extra_args)
self.ignore_anti = True
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
requires_grad=False)
return init
# func refers to RMSNorm.forward_oot
def wrapper_rmsnorm_forward_oot(func):
def _rmsnorm_forward_oot(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if not self.ignore_anti:
if residual is not None:
residual += x
out = torch_npu._npu_quant_rms_norm(
residual,
self.weight,
self.bias,
self.input_scale,
self.input_offset,
self.variance_epsilon,
)
return out, residual
out = torch_npu._npu_quant_rms_norm(
x,
self.weight,
self.bias,
self.input_scale,
self.input_offset,
self.variance_epsilon,
)
return out
if residual is not None:
x, residual = func(self, x, residual)
return x.add_(self.bias), residual
return func(self, x).add_(self.bias)
return _rmsnorm_forward_oot
MODEL_LAYER_MAPPING = {
"LlamaModel": {
"attn": {
"layer_attr": "self_attn",
"proj_attr": "qkv_proj",
"norm_attr": "input_layernorm",
"unquantized_type": UnquantizedLinearMethod,
},
"mlp": {
"layer_attr": "mlp",
"proj_attr": "gate_up_proj",
"norm_attr": "post_attention_layernorm",
"unquantized_type": UnquantizedLinearMethod,
},
},
}
def wrapper_load_model(func):
def postprocess_loading(self) -> None:
func(self)
def process_layer(layer, idx, mapping):
def process_module(module_cfg, layer_obj):
if module_cfg is None:
return
module_obj = getattr(layer_obj, module_cfg["layer_attr"], None)
if module_obj is None:
return
proj_attr = module_cfg["proj_attr"]
if callable(proj_attr):
proj = proj_attr(module_obj, idx)
else:
proj = getattr(module_obj, proj_attr, None)
norm = getattr(layer_obj, module_cfg["norm_attr"], None)
if proj is None or norm is None:
return
norm.ignore_anti = isinstance(proj.quant_method,
module_cfg["unquantized_type"])
if not norm.ignore_anti:
for param_name in ["input_scale", "input_offset"]:
if hasattr(proj, param_name):
param = getattr(proj, param_name)
norm.register_parameter(
param_name,
torch.nn.Parameter(param.clone(),
requires_grad=False))
process_module(mapping.get("attn"), layer)
process_module(mapping.get("mlp"), layer)
model_type = self.model.model.__class__.__name__
mapping = MODEL_LAYER_MAPPING.get(model_type)
if not mapping:
logger.info(
f"Warning: Model type '{model_type}' not found in MODEL_LAYER_MAPPING. Skipping layer mapping."
)
return
for idx, layer in enumerate(self.model.model.layers):
process_layer(layer, idx, mapping)
if isinstance(self.model.model.norm, RMSNorm):
self.model.model.norm.ignore_anti = True
return postprocess_loading

View File

@@ -0,0 +1,357 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from types import MappingProxyType
from typing import Any, Callable, Dict, List, Mapping, Optional
import torch
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
RowParallelLinear,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import \
register_quantization_config
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.vocab_parallel_embedding import (
UnquantizedEmbeddingMethod, VocabParallelEmbedding)
from vllm.model_executor.parameter import PerTensorScaleParameter
from vllm.model_executor.utils import set_weight_attrs
from vllm_ascend.ops.fused_moe import AscendUnquantizedFusedMoEMethod
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD
from .quantizer import AscendQuantizer
@register_quantization_config(ASCEND_QUANTIZATION_METHOD)
class AscendQuantConfig(QuantizationConfig):
"""Config class for Ascend
This class is a general class that parse quantization configs
that are supported on ascend hardware.
"""
def __init__(self, quant_config: Dict[str, Any]):
self.quant_description = quant_config
def __repr__(self) -> str:
return "AscendQuantConfig:\n" + super().__repr__()
@classmethod
def get_name(cls) -> str:
return ASCEND_QUANTIZATION_METHOD
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.int8, torch.float16, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
raise NotImplementedError(
"Ascend hardware dose not support \"get_min_capability\" feature.")
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quant_model_description.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "AscendQuantConfig":
return cls(config)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
if torch.npu.is_available():
return ASCEND_QUANTIZATION_METHOD
return None
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention
if isinstance(layer, LinearBase):
if self.is_layer_skipped_ascend(prefix,
self.packed_modules_mapping):
return UnquantizedLinearMethod()
return AscendLinearMethod(self, prefix,
self.packed_modules_mapping)
elif isinstance(layer, Attention) and \
'fa_quant_type' in self.quant_description.keys() and \
self.quant_description['fa_quant_type'] is not None:
return AscendKVCacheMethod(self, prefix)
elif isinstance(layer, Attention) and self.quant_description.get(
'kv_quant_type') == 'C8':
return AscendKVCacheMethod(self, prefix)
elif isinstance(layer, FusedMoE):
if self.is_layer_skipped_ascend(prefix,
self.packed_modules_mapping):
return AscendUnquantizedFusedMoEMethod(layer.moe_config)
return AscendFusedMoEMethod(self, prefix,
self.packed_modules_mapping)
elif isinstance(layer, VocabParallelEmbedding):
if self.is_layer_skipped_ascend(prefix,
self.packed_modules_mapping):
return UnquantizedEmbeddingMethod()
return AscendEmbeddingMethod(self, prefix,
self.packed_modules_mapping)
return None
def is_layer_skipped_ascend(
self,
prefix: str,
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})):
# adapted from vllm.model_executor.layers.quantization.utils.quant_utils.is_layer_skipped
proj_name = prefix.split(".")[-1]
if proj_name in fused_mapping:
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in fused_mapping[proj_name]
]
is_skipped = None
for shard_prefix in shard_prefixes:
is_shard_skipped = self.quant_description[shard_prefix +
'.weight'] == "FLOAT"
if is_skipped is None:
is_skipped = is_shard_skipped
elif is_shard_skipped != is_skipped:
raise ValueError(
f"Detected some but not all shards of {prefix} "
"are quantized. All shards of fused layers "
"to have the same precision.")
else:
is_skipped = self.quant_description[prefix + '.weight'] == "FLOAT"
assert is_skipped is not None
return is_skipped
def get_scaled_act_names(self) -> List[str]:
return []
class AscendLinearMethod(LinearMethodBase):
"""Linear method for Ascend quantization.
This class calls AscendQuantizer to search a specific quantization
implementations supported on ascend hardware for linear methods.
Args:
quant_config: The Ascend quantization config.
"""
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
packed_modules_mapping: Dict[str, Any]) -> None:
self.quantizer = AscendQuantizer.get_quantizer(
quant_config.quant_description, prefix, packed_modules_mapping)
self.quant_method = self.quantizer.build_linear_method()
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
weight_dict = self.quant_method.get_weight(input_size_per_partition,
output_size_per_partition,
params_dtype)
for weight_name, weight_param in weight_dict.items():
param = torch.nn.Parameter(weight_param, requires_grad=False)
set_weight_attrs(param, {"input_dim": 1, "output_dim": 0})
layer.register_parameter(weight_name, param)
set_weight_attrs(param, extra_weight_attrs)
pertensor_dict = self.quant_method.get_pertensor_param(params_dtype)
for pertensor_name, pertensor_param in pertensor_dict.items():
param = PerTensorScaleParameter(data=pertensor_param,
weight_loader=weight_loader)
# disable warning
param.ignore_warning = True
layer.register_parameter(pertensor_name, param)
perchannel_dict = self.quant_method.get_perchannel_param(
output_size_per_partition, params_dtype)
for perchannel_name, perchannel_param in perchannel_dict.items():
param = torch.nn.Parameter(perchannel_param, requires_grad=False)
set_weight_attrs(param, {"output_dim": 0})
layer.register_parameter(perchannel_name, param)
set_weight_attrs(param, extra_weight_attrs)
pergroup_dict = self.quant_method.get_pergroup_param(
input_size_per_partition, output_size_per_partition, params_dtype)
for pergroup_name, pergroup_param in pergroup_dict.items():
param = torch.nn.Parameter(pergroup_param, requires_grad=False)
set_weight_attrs(param, {"output_dim": 0})
layer.register_parameter(pergroup_name, param)
set_weight_attrs(param, extra_weight_attrs)
if "weight_scale_second" in pergroup_name or "weight_offset_second" in pergroup_name:
setattr(param, "input_dim", 1)
param.input_dim = 1
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method, "process_weights_after_loading"):
self.quant_method.process_weights_after_loading(layer)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if isinstance(layer, RowParallelLinear):
tp_rank = get_tensor_model_parallel_rank()
return self.quant_method.apply(layer, x, bias, tp_rank)
return self.quant_method.apply(layer, x, bias)
class AscendKVCacheMethod(BaseKVCacheMethod):
"""KVCache method for Ascend quantization.
This class calls AscendQuantizer to search a specific quantization
implementations supported on ascend hardware for kvcache methods.
Args:
quant_config: The Ascend quantization config.
"""
def __init__(self, quant_config: AscendQuantConfig, prefix: str) -> None:
self.quantizer = AscendQuantizer.get_quantizer(
quant_config.quant_description, prefix)
self.quant_method = self.quantizer.build_attention_method()
def create_weights(self, layer: torch.nn.Module) -> None:
# Different from linear method, there are no weight processing/slicing
# steps for attention in vllm. So the whole process of create weights
# is hidden into the specific quant method.
self.quant_method.create_weights(layer)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method, "process_weights_after_loading"):
self.quant_method.process_weights_after_loading(layer)
def apply(self, layer: torch.nn.Module, query: torch.Tensor,
key: torch.Tensor, value: torch.Tensor, kv_cache, attn_metadata,
attn_type, scale, output) -> torch.Tensor:
return self.quant_method.apply(layer, query, key, value, kv_cache,
attn_metadata, attn_type, scale, output)
class AscendFusedMoEMethod(FusedMoEMethodBase):
"""FusedMoE method for Ascend quantization.
This class calls AscendQuantizer to search a specific quantization
implementations supported on ascend hardware for kvcache methods.
Args:
quant_config: The Ascend quantization config.
"""
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
packed_modules_mapping: Dict[str, Any]):
self.quantizer = AscendQuantizer.get_quantizer(
quant_config.quant_description, prefix, packed_modules_mapping)
self.quant_method = self.quantizer.build_moe_method()
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
weight_param = self.quant_method.get_weight(
num_experts, intermediate_size_per_partition, hidden_size,
params_dtype)
for param_key, param_value in weight_param.items():
param = torch.nn.Parameter(param_value, requires_grad=False)
layer.register_parameter(param_key, param)
set_weight_attrs(param, extra_weight_attrs)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
per_group_param = [
"weight_scale_second", "weight_offset_second", "scale_bias"
]
dynamic_quant_param = self.quant_method.get_dynamic_quant_param(
num_experts, intermediate_size_per_partition, hidden_size,
params_dtype)
for param_key, param_value in dynamic_quant_param.items():
param = torch.nn.Parameter(param_value, requires_grad=False)
layer.register_parameter(param_key, param)
set_weight_attrs(param, extra_weight_attrs)
if any(fields in param_key for fields in per_group_param):
setattr(param, "quant_method",
FusedMoeWeightScaleSupported.GROUP.value)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: bool = True,
enable_force_load_balance: bool = False,
log2phy: torch.Tensor = None,
global_redundant_expert_num=0,
**kwargs,
) -> torch.Tensor:
return self.quant_method.apply(
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
global_num_experts, expert_map, topk_group, num_expert_group,
custom_routing_function, scoring_func, e_score_correction_bias,
is_prefill, enable_force_load_balance, log2phy,
global_redundant_expert_num, **kwargs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method, "process_weights_after_loading"):
self.quant_method.process_weights_after_loading(layer)
class AscendEmbeddingMethod(AscendLinearMethod):
"""Embedding method for Ascend quantization.
This class calls AscendQuantizer to search a specific quantization
implementations supported on ascend hardware for Embedding methods.
Args:
quant_config: The Ascend quantization config.
"""
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
packed_modules_mapping: Dict[str, Any]) -> None:
self.quantizer = AscendQuantizer.get_quantizer(
quant_config.quant_description, prefix, packed_modules_mapping)
self.quant_method = self.quantizer.build_linear_method()

View File

@@ -0,0 +1,311 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import importlib
import sys
import types
from typing import Any, Dict, List, Optional
from vllm.logger import logger
from .func_wrapper import (wrapper_rmsnorm_forward_oot, wrapper_rmsnorm_init,
wrapper_vocab_parallel_embedding_init)
from .w4a8_dynamic import (AscendW4A8DynamicFusedMoEMethod,
AscendW4A8DynamicLinearMethod)
from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod,
AscendW8A8LinearMethod)
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
AscendW8A8DynamicLinearMethod)
CUSTOMIZED_QUANTIZER_TYPE: List[str] = []
class AscendQuantizer:
"""An interface to different quantization implementations for ascend hardwares."""
@classmethod
def get_quantizer(cls,
quant_config: Dict[str, Any],
prefix: str,
packed_modules_mapping: Optional[Dict[str,
Any]] = dict()):
# TODO: Need a param to choose quantization algorithms.
quantization_algorithm = ''
if quantization_algorithm in CUSTOMIZED_QUANTIZER_TYPE:
return
return VLLMAscendQuantizer.get_quantizer(quant_config, prefix,
packed_modules_mapping)
def build_linear_method(self):
raise NotImplementedError
def build_moe_method(self):
raise NotImplementedError
def build_attention_method(self):
raise NotImplementedError
class VLLMAscendQuantizer:
_instance: Optional[object] = None
patched = False
def __init__(self, quant_description):
if VLLMAscendQuantizer.patched:
return
for name in quant_description.keys():
if "norm.bias" in name:
VLLMAscendQuantizer.apply_patch(
"vllm.model_executor.layers.layernorm.RMSNorm", "__init__",
[wrapper_rmsnorm_init])
VLLMAscendQuantizer.apply_patch(
"vllm_ascend.ops.layernorm.AscendRMSNorm", "forward_oot",
[wrapper_rmsnorm_forward_oot])
VLLMAscendQuantizer.apply_patch(
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding",
"__init__", [wrapper_vocab_parallel_embedding_init])
break
VLLMAscendQuantizer.patched = True
logger.info("Using the vLLM Ascend Quantizer version now!")
@staticmethod
def apply_patch(target_module, target_function, wrappers):
original_module, original_function = VLLMAscendQuantizer.parse_path(
target_module, target_function, False)
original_function_id = id(original_function)
candidate = original_function
for wrapper in wrappers:
candidate = wrapper(candidate)
if target_function is not None:
setattr(original_module, target_function, candidate)
for _, value in sys.modules.copy().items():
if target_function is None:
continue
try:
attr = getattr(value, target_function, None)
if attr is not None and id(attr) == original_function_id:
setattr(value, target_function, candidate)
except ImportError:
continue
@staticmethod
def parse_path(module_path, function_name, create_dummy):
"""
Parse module path and resolve/create modules as needed.
Args:
module_path: Dot-separated module path
function_name: Target function name (None for module only)
create_dummy: Create dummy modules/functions when missing
Returns:
Tuple of (resolved module, target function/none)
Raises:
ModuleNotFoundError: If module path is invalid and create_dummy=False
AttributeError: If function is missing and create_dummy=False
"""
from importlib.machinery import ModuleSpec
def create_dummy_module(full_path, parent=None):
"""Create and register a placeholder module"""
dummy = types.ModuleType(full_path)
dummy.__file__ = "vllm_ascend.dummy_module.py"
dummy.__spec__ = ModuleSpec(full_path, None)
sys.modules[full_path] = dummy
if parent:
setattr(parent, full_path.split(".")[-1], dummy)
return dummy
def create_placeholder_function(func_name):
"""Create dummy function that raises when called"""
def placeholder(*args, **kwargs):
raise NotImplementedError(
f"Function {func_name} is a placeholder")
placeholder.__name__ = func_name
return placeholder
modules = module_path.split(".")
current_module = None
processed_path = []
for idx, part in enumerate(modules):
current_path = ".".join(modules[:idx + 1])
parent_path = ".".join(modules[:idx]) if idx > 0 else None
try:
current_module = importlib.import_module(current_path)
except ModuleNotFoundError:
# Handle missing module
parent = importlib.import_module(
parent_path) if parent_path else None
if parent and hasattr(parent, part):
# Use existing attribute from parent
current_module = getattr(parent, part)
# Check for early function resolution
if function_name and hasattr(current_module,
function_name):
return current_module, getattr(current_module,
function_name)
if function_name and create_dummy:
ph_func = create_placeholder_function(function_name)
setattr(current_module, function_name, ph_func)
return current_module, ph_func
if function_name:
raise AttributeError(
f"Function {function_name} missing in {current_path}"
)
else:
if not create_dummy:
raise
# Create and register dummy module
current_module = create_dummy_module(
current_path,
parent=importlib.import_module(parent_path)
if parent_path else None)
processed_path.append(part)
# Final function handling
final_module = sys.modules[module_path]
if function_name is not None:
if not hasattr(final_module, function_name):
if create_dummy:
ph_func = create_placeholder_function(function_name)
setattr(final_module, function_name, ph_func)
else:
setattr(final_module, function_name, None)
return final_module, getattr(final_module, function_name)
return final_module, None
@staticmethod
def build_linear_method():
raise NotImplementedError(
"Linear method is not implemented for the current quant type.")
@staticmethod
def build_moe_method():
raise NotImplementedError(
"MoE method is not implemented for the current quant type.")
@staticmethod
def build_attention_method():
raise NotImplementedError(
"Attention method is not implemented for the current quant type.")
@staticmethod
def get_linear_quant_type(quant_description: Dict[str, Any], prefix: str,
packed_modules_mapping: Dict[str, Any]):
proj_name = prefix.split(".")[-1]
if proj_name in packed_modules_mapping:
quant_type = None
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in packed_modules_mapping[proj_name]
]
for shard_prefix in shard_prefixes:
shard_quant_type = quant_description[shard_prefix + '.weight']
if quant_type is None:
quant_type = shard_quant_type
elif shard_quant_type != quant_type:
raise ValueError(
f"Not all shards of {prefix} are quantized with same quant type."
f"Shard {proj_name} uses {shard_quant_type}, but another shard"
f"use {quant_type}. Please check quantization config.")
else:
quant_type = quant_description[prefix + '.weight']
return quant_type
@classmethod
def get_quantizer(cls,
quant_description: Dict[str, Any],
prefix: str,
packed_modules_mapping: Optional[Dict[str, Any]] = None):
if packed_modules_mapping is None:
packed_modules_mapping = dict()
# Attention
if '.attn' in prefix and 'fa_quant_type' in quant_description.keys():
quant_type = quant_description['fa_quant_type']
# Use KVCache int8
elif '.attn' in prefix and 'kv_quant_type' in quant_description.keys():
quant_type = quant_description['kv_quant_type']
# Linear
else:
quant_type = cls.get_linear_quant_type(quant_description, prefix,
packed_modules_mapping)
if quant_type in SUPPORT_ASCEND_QUANTIZER_TYPE.keys():
cls = SUPPORT_ASCEND_QUANTIZER_TYPE[quant_type]
if not cls._instance:
cls._instance = cls(quant_description)
return cls._instance
raise NotImplementedError("Currently, vLLM Ascend only supports following quant types:" \
f"{list(SUPPORT_ASCEND_QUANTIZER_TYPE.keys())}")
class W4A8DYNAMICQuantizer(VLLMAscendQuantizer):
@staticmethod
def build_linear_method():
return AscendW4A8DynamicLinearMethod()
@staticmethod
def build_moe_method():
return AscendW4A8DynamicFusedMoEMethod()
class W8A8Quantizer(VLLMAscendQuantizer):
@staticmethod
def build_linear_method():
return AscendW8A8LinearMethod()
@staticmethod
def build_moe_method():
return AscendW8A8FusedMoEMethod()
@staticmethod
def build_attention_method():
return AscendC8KVCacheMethod()
class W8A8DYNAMICQuantizer(VLLMAscendQuantizer):
@staticmethod
def build_linear_method():
return AscendW8A8DynamicLinearMethod()
@staticmethod
def build_moe_method():
return AscendW8A8DynamicFusedMoEMethod()
SUPPORT_ASCEND_QUANTIZER_TYPE = {
"W4A8_DYNAMIC": W4A8DYNAMICQuantizer,
"W8A8": W8A8Quantizer,
"W8A8_DYNAMIC": W8A8DYNAMICQuantizer,
"C8": W8A8Quantizer,
}

View File

@@ -0,0 +1,394 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Any, Callable, Dict, Optional
import numpy as np
import torch
import torch_npu
from vllm.config import get_current_vllm_config
from vllm.distributed import get_ep_group
from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_forward_context import FusedMoEState
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.fused_moe import unified_fused_experts_eager
from vllm_ascend.ops.layers.experts_selector import select_experts
class AscendW4A8DynamicLinearMethod:
"""Linear method for Ascend W4A8_DYNAMIC
"""
def __init__(self):
self.transpose_weight = True
try:
self.group_size = get_current_vllm_config(
).quant_config.quant_description.get("group_size", 256)
except AttributeError:
self.group_size = 256
@staticmethod
def get_weight(input_size: int, output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
params_dict = {
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
}
return params_dict
@staticmethod
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
return {}
@staticmethod
def get_perchannel_param(output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
return {}
def get_pergroup_param(self, input_size: int, output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
params_dict = {}
params_dict["weight_scale"] = torch.empty(output_size,
1,
dtype=params_dtype)
params_dict["weight_offset"] = torch.empty(output_size,
1,
dtype=params_dtype)
params_dict["weight_scale_second"] = torch.empty(output_size,
input_size //
self.group_size,
dtype=params_dtype)
params_dict["weight_offset_second"] = torch.empty(output_size,
input_size //
self.group_size,
dtype=params_dtype)
return params_dict
@staticmethod
def process_scale_second(weight: torch.Tensor, scale: torch.Tensor,
per_group_scale: torch.Tensor):
k, n = weight.shape
group_num, n = per_group_scale.shape
weight_high = weight.to(torch.float32).reshape(
group_num, -1, n) * per_group_scale.reshape(group_num, 1, n)
weight_high = weight_high.reshape(k, n)
bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0)
antiquant_scale = (scale * per_group_scale).reshape(group_num, n)
return antiquant_scale.npu(), bias
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = None,
) -> torch.Tensor:
return torch_npu.npu_weight_quant_batchmatmul(
x,
layer.weight,
antiquant_scale=layer.weight_scale_second.to(x.dtype),
antiquant_group_size=self.group_size,
)
def process_weights_after_loading(self, layer: torch.nn.Module):
if self.transpose_weight:
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight_scale.data = layer.weight_scale.data.flatten().to(
torch.float32)
layer.weight_offset.data = layer.weight_offset.data.flatten()
layer.weight_scale_second.data, scale_bias = self.process_scale_second(
layer.weight.data,
layer.weight_scale.data,
layer.weight_scale_second.data.transpose(0, 1).contiguous(),
)
param = torch.nn.Parameter(scale_bias, requires_grad=False)
layer.register_parameter("weight_scale_bias", param)
layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(
layer.weight.data.to(torch.int32))
class AscendW4A8DynamicFusedMoEMethod:
"""FusedMoe method for Ascend W4A8_DYNAMIC.
"""
def __init__(self):
self.transpose_weight = True
self.ep_group = get_ep_group()
vllm_config = get_current_vllm_config()
self.group_size = vllm_config.quant_config.quant_description.get(
"group_size", 256)
quant_version = vllm_config.quant_config.quant_description.get(
"version", "0")
# NOTE: new quantize weights: 2 int4 pack into int8
self.new_quant_version = quant_version == "1.0.0"
self.tp_size = 1 if vllm_config.parallel_config.enable_expert_parallel else self.ep_group.world_size
if self.new_quant_version and self.tp_size > 16:
raise ValueError(
"The current weight does not support moe part tp>16.")
try:
device_group = get_mc2_group().device_group
# TODO: Try local_rank = ep_group.rank_in_group
local_rank = torch.distributed.get_rank(group=device_group)
backend = device_group._get_backend(torch.device("npu"))
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
local_rank)
except AttributeError:
self.moe_all_to_all_group_name = ""
def get_weight(self, num_experts: int,
intermediate_size_per_partition: int, hidden_sizes: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
param_dict = {}
if self.new_quant_version:
w13_output_size = intermediate_size_per_partition
w2_output_size = hidden_sizes // 2
else:
w13_output_size = 2 * intermediate_size_per_partition
w2_output_size = hidden_sizes
param_dict["w13_weight"] = torch.empty(num_experts,
w13_output_size,
hidden_sizes,
dtype=torch.int8)
param_dict["w2_weight"] = torch.empty(num_experts,
w2_output_size,
intermediate_size_per_partition,
dtype=torch.int8)
return param_dict
def get_dynamic_quant_param(self, num_experts: int,
intermediate_size_per_partition: int,
hidden_sizes: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
param_dict = {}
param_dict["w13_weight_scale"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=params_dtype)
param_dict["w13_weight_offset"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=params_dtype)
param_dict["w13_weight_scale_second"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_sizes // self.group_size,
dtype=params_dtype)
param_dict["w13_weight_offset_second"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_sizes // self.group_size,
dtype=params_dtype)
param_dict["w2_weight_scale"] = torch.empty(num_experts,
hidden_sizes,
1,
dtype=params_dtype)
param_dict["w2_weight_offset"] = torch.empty(num_experts,
hidden_sizes,
1,
dtype=params_dtype)
param_dict["w2_weight_scale_second"] = torch.empty(
num_experts,
hidden_sizes,
intermediate_size_per_partition // self.group_size,
dtype=params_dtype)
param_dict["w2_weight_offset_second"] = torch.empty(
num_experts,
hidden_sizes,
intermediate_size_per_partition // self.group_size,
dtype=params_dtype)
if self.new_quant_version:
param_dict["w13_scale_bias"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=torch.float32)
param_dict["w2_scale_bias"] = torch.empty(num_experts,
hidden_sizes,
16 // self.tp_size,
dtype=torch.float32)
return param_dict
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: bool = True,
enable_force_load_balance: bool = True,
log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
**kwargs,
) -> torch.Tensor:
assert router_logits.shape[
1] == global_num_experts, "Number of global experts mismatch"
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
topk_weights, topk_ids, row_idx = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts)
fused_moe_state = get_forward_context().fused_moe_state
shared_gate_up, shared_dequant_scale = None, None
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
share_up_out, _ = shared_experts.gate_up_proj(
(quantized_x_for_share, dynamic_scale_for_share))
shared_gate_up, shared_dequant_scale = share_up_out[
0], share_up_out[1]
# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
topk_weights = topk_weights.to(x.dtype)
return unified_fused_experts_eager(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
w1_scale=layer.w13_weight_scale_second,
w2_scale=layer.w2_weight_scale_second,
w1_scale_bias=layer.w13_scale_bias,
w2_scale_bias=layer.w2_scale_bias,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
expert_map=expert_map,
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num,
shared_experts=shared_experts,
shared_gate_up=shared_gate_up,
shared_dequant_scale=shared_dequant_scale,
mc2_mask=kwargs.get("mc2_mask", None),
with_quant=True)
def process_scale(self, weight: torch.Tensor, scale, per_group_scale):
group_num, k, n = weight.shape
# the weight of the new version is reduced by half by pack n, so it needs to be restored
if self.new_quant_version:
n = n * 2
per_group_scale = per_group_scale.reshape(group_num, -1, n)
group_num, quantgroup_num, n = per_group_scale.shape
bias = None
if not self.new_quant_version:
weight_high = weight.to(torch.float32).reshape([group_num, quantgroup_num, -1, n]) * \
per_group_scale.reshape([group_num, quantgroup_num, 1, n])
weight_high = weight_high.reshape([group_num, k, n])
bias = 8 * (weight_high.to(torch.float32) * scale).sum(axis=1)
scale_fp32 = (scale * per_group_scale).to(torch.float16).to(
torch.float32)
scale_fp32_np = scale_fp32.cpu().numpy()
scale_fp32_np.dtype = np.uint32
sscale_uint64 = np.zeros((group_num, quantgroup_num, n * 2),
dtype=np.uint32)
sscale_uint64[..., ::2] = scale_fp32_np
sscale_uint64_buffer = np.frombuffer(sscale_uint64.tobytes(),
dtype=np.int64).copy()
sscale_uint64_tensor = torch.from_numpy(sscale_uint64_buffer).reshape(
group_num, quantgroup_num, n)
sscale_uint64_tensor = sscale_uint64_tensor.npu()
return sscale_uint64_tensor, bias
def update_bias(self, layer, w13_bias, w2_bias):
if self.new_quant_version:
layer.w13_scale_bias.data = layer.w13_scale_bias.data.transpose(
1, 2).contiguous().sum(axis=1)
layer.w2_scale_bias.data = layer.w2_scale_bias.data.transpose(
1, 2).contiguous().sum(axis=1)
else:
w13_scale_bias = torch.nn.Parameter(w13_bias, requires_grad=False)
layer.register_parameter("w13_scale_bias", w13_scale_bias)
w2_scale_bias = torch.nn.Parameter(w2_bias, requires_grad=False)
layer.register_parameter("w2_scale_bias", w2_scale_bias)
def pack_to_int32(self, weight: torch.Tensor):
if self.new_quant_version:
group_num, k, n = weight.shape
assert n % 4 == 0, "the last dim of weight needs to be divided by 4"
packed_n = n // 4
# pack 4 int8(int4*2) to int32, because in pytorch, we need to use int32 to represent int4
packed_weight = torch.from_numpy(
np.frombuffer(weight.cpu().numpy().tobytes(), dtype=np.int32))
return packed_weight.reshape(group_num, k, packed_n).npu()
else:
return torch_npu.npu_quantize(weight.to(torch.float32),
torch.tensor([1.]).npu(), None,
torch.quint4x2, -1, False)
def process_weights_after_loading(self, layer):
if self.transpose_weight:
layer.w13_weight.data = layer.w13_weight.data.transpose(
1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose(
1, 2).contiguous()
layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose(
1, 2).contiguous()
layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose(
1, 2).contiguous()
layer.w13_weight_scale_second.data = layer.w13_weight_scale_second.data.transpose(
1, 2).contiguous()
layer.w2_weight_scale_second.data = layer.w2_weight_scale_second.data.transpose(
1, 2).contiguous()
layer.w13_weight_scale_second.data, w13_bias = self.process_scale(
layer.w13_weight, layer.w13_weight_scale.data,
layer.w13_weight_scale_second.data)
layer.w2_weight_scale_second.data, w2_bias = self.process_scale(
layer.w2_weight, layer.w2_weight_scale.data,
layer.w2_weight_scale_second.data)
self.update_bias(layer, w13_bias, w2_bias)
layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data)
layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data)

View File

@@ -0,0 +1,647 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Any, Callable, Dict, Optional
import torch
import torch_npu
from vllm.attention.backends.abstract import AttentionType
from vllm.distributed.parallel_state import get_ep_group
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.ops.layers.experts_selector import select_experts
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
def quant_per_tensor(in_tensor: torch.Tensor,
input_scale: torch.Tensor,
input_offset: torch.Tensor,
function=False):
return torch_npu.npu_quantize(in_tensor, input_scale, input_offset,
torch.qint8, -1, function)
class AscendW8A8LinearMethod:
"""Linear method for Ascend W8A8.
Args:
w_sym: whether the linear weight is symmetrically quantized.
"""
def __init__(self) -> None:
# aclnn quant matmul requires to transpose matrix B, set to true by default.
self.transpose_weight = not is_310p()
@staticmethod
def get_weight(
input_size: int,
output_size: int,
params_dtype: torch.dtype = torch.bfloat16,
) -> Dict[str, Any]:
params_dict = {
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
}
return params_dict
@staticmethod
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
params_dict = {}
params_dict["input_scale"] = torch.empty(1, dtype=params_dtype)
params_dict["input_offset"] = torch.empty(1, dtype=torch.int8)
return params_dict
@staticmethod
def get_perchannel_param(
output_size: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
params_dict = {}
params_dict["quant_bias"] = torch.empty(output_size, dtype=torch.int32)
if params_dtype == torch.bfloat16:
params_dict["deq_scale"] = torch.empty(output_size,
dtype=torch.float32)
elif params_dtype == torch.float16:
params_dict["deq_scale"] = torch.empty(output_size,
dtype=torch.int64)
params_dict["weight_scale"] = torch.empty(output_size,
1,
dtype=params_dtype)
params_dict["weight_offset"] = torch.empty(output_size,
1,
dtype=params_dtype)
return params_dict
def get_pergroup_param(self, input_size: int, output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
return {}
@staticmethod
def apply(
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = 0,
) -> torch.Tensor:
if x.dtype != torch.int8:
x = quant_per_tensor(
x,
layer.aclnn_input_scale_reciprocal,
layer.aclnn_input_offset,
)
quant_bias = layer.quant_bias if tp_rank == 0 else None
if is_310p():
# On 300I Duo platform, we need transpose again if
# using nz. This transpose can be skipped in torchair.
output = torch_npu.npu_quant_matmul(
x,
layer.weight.data.transpose(1, 0),
layer.deq_scale,
bias=quant_bias,
output_dtype=layer.params_dtype,
)
else:
output = torch_npu.npu_quant_matmul(
x,
layer.weight,
layer.deq_scale,
bias=quant_bias,
output_dtype=layer.params_dtype,
)
return output
def process_weights_after_loading(self, layer):
expanding_factor = layer.weight.data.shape[1]
layer.aclnn_input_scale = torch.nn.Parameter(
layer.input_scale.data.repeat(expanding_factor),
requires_grad=False)
layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter(
layer.input_scale.data.repeat(expanding_factor),
requires_grad=False)
layer.aclnn_input_offset = torch.nn.Parameter(
layer.input_offset.data.repeat(expanding_factor),
requires_grad=False).to(layer.aclnn_input_scale.dtype)
if self.transpose_weight:
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data,
ACL_FORMAT_FRACTAL_NZ)
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
class AscendW8A8FusedMoEMethod:
"""FusedMoe method for Ascend W8A8.
"""
def __init__(self):
self.transpose_weight = True
@staticmethod
def get_weight(num_experts: int, intermediate_size_per_partition: int,
hidden_sizes: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
param_dict = {}
param_dict["w13_weight"] = torch.empty(num_experts,
2 *
intermediate_size_per_partition,
hidden_sizes,
dtype=torch.int8,
requires_grad=False)
param_dict["w2_weight"] = torch.empty(num_experts,
hidden_sizes,
intermediate_size_per_partition,
dtype=torch.int8,
requires_grad=False)
return param_dict
@staticmethod
def get_dynamic_quant_param(num_experts: int,
intermediate_size_per_partition: int,
hidden_sizes: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
param_dict = {}
param_dict["w13_weight_scale"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=torch.float32)
param_dict["w13_weight_offset"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=torch.float16)
param_dict["w2_weight_scale"] = torch.empty(num_experts,
hidden_sizes,
1,
dtype=torch.float32)
param_dict["w2_weight_offset"] = torch.empty(num_experts,
hidden_sizes,
1,
dtype=torch.float16)
param_dict["w2_deq_scale"] = torch.empty(num_experts,
hidden_sizes,
dtype=torch.float32)
param_dict["w13_deq_scale"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
dtype=torch.float32)
param_dict["w2_input_scale"] = torch.empty(num_experts,
1,
dtype=torch.float32)
param_dict["w13_input_scale"] = torch.empty(num_experts,
1,
dtype=torch.float32)
param_dict["w2_input_offset"] = torch.empty(num_experts,
1,
dtype=torch.int8)
param_dict["w13_input_offset"] = torch.empty(num_experts,
1,
dtype=torch.int8)
param_dict["quant_bias"] = torch.empty(num_experts,
hidden_sizes,
dtype=torch.int32)
return param_dict
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: bool = True,
enable_force_load_balance: bool = False,
log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[Any] = None,
**kwargs,
) -> torch.Tensor:
assert router_logits.shape[
1] == global_num_experts, "Number of global experts mismatch"
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts)
if is_310p():
return fused_experts_310p(hidden_states=x,
w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale,
w1_input_scale=layer.w13_input_scale,
w2=layer.w2_weight,
w2_scale=layer.w2_weight_scale,
w2_input_scale=layer.w2_input_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
global_num_experts=global_num_experts,
expert_map=expert_map)
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale,
w1_input_scale=layer.w13_input_scale,
w1_input_offset=layer.w13_input_offset,
w2=layer.w2_weight,
w2_scale=layer.w2_weight_scale,
w2_input_scale=layer.w2_input_scale,
w2_input_offset=layer.w2_input_offset,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
global_num_experts=global_num_experts,
expert_map=expert_map)
def process_weights_after_loading(self, layer):
if not is_310p():
layer.w13_weight.data = layer.w13_weight.data.transpose(
1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose(
1, 2).contiguous()
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
layer.w13_weight_scale.data.shape[0], -1)
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
layer.w13_weight_offset.data.shape[0], -1)
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
layer.w2_weight_scale.data.shape[0], -1)
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
layer.w2_weight_offset.data.shape[0], -1)
expanding_factor_w13 = layer.w13_weight.data.shape[1]
expanding_factor_w2 = layer.w2_weight.data.shape[1]
if is_310p():
layer.w13_input_scale.data = torch.nn.Parameter(
layer.w13_input_scale.data.max())
layer.w2_input_scale.data = torch.nn.Parameter(
layer.w2_input_scale.data.max())
else:
layer.w13_input_scale.data = torch.nn.Parameter(
layer.w13_input_scale.data.repeat(1,
expanding_factor_w13)[0:1])
layer.w2_input_scale.data = torch.nn.Parameter(
layer.w2_input_scale.data.repeat(1, expanding_factor_w2)[0:1])
layer.w13_input_offset.data = torch.nn.Parameter(
layer.w13_input_scale.data.repeat(1, expanding_factor_w13)[0:1])
layer.w2_input_offset.data = torch.nn.Parameter(
layer.w2_input_scale.data.repeat(1, expanding_factor_w2)[0:1])
# converting ACL_FORMAT_FRACTAL_NZ.
# npu_quant_grouped_matmul_dequant in eager mode does not accept
# ACL_FORMAT_FRACTAL_NZ.
if not is_310p():
layer.w13_weight.data = torch_npu.npu_format_cast(
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ).contiguous()
layer.w2_weight.data = torch_npu.npu_format_cast(
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ).contiguous()
class AscendC8KVCacheMethod:
def __init__(self) -> None:
self.antiquant_scale_comb = None
@staticmethod
def create_weights(layer) -> None:
param_dict = {} # num_kv_heads * head_size
param_dict["key_antiquant_scale"] = torch.empty(layer.num_kv_heads *
layer.head_size,
dtype=torch.float16,
requires_grad=False)
param_dict["value_antiquant_scale"] = torch.empty(layer.num_kv_heads *
layer.head_size,
dtype=torch.float16,
requires_grad=False)
for weight_name, weight_param in param_dict.items():
param = torch.nn.Parameter(weight_param, requires_grad=False)
layer.register_parameter(weight_name, param)
def process_weights_after_loading(self, layer):
self.antiquant_scale_comb = torch.cat(
(layer.key_antiquant_scale.data.unsqueeze(0),
layer.value_antiquant_scale.data.unsqueeze(0)),
dim=0).to(torch.float16).contiguous()
def apply(self, layer, query, key, value, kv_cache, attn_metadata,
attn_type, scale, output) -> torch.Tensor:
num_tokens = query.shape[0]
if attn_metadata is None:
return output.view(num_tokens, layer.num_heads * layer.head_size)
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl")
# C8
quant_key = quant_per_tensor(
key.view(-1, layer.num_kv_heads * layer.head_size),
layer.key_antiquant_scale.data.view(-1), None, True)
quant_value = quant_per_tensor(
value.view(-1, layer.num_kv_heads * layer.head_size),
layer.value_antiquant_scale.data.view(-1), None, True)
# View q k v to BSH.
query = query.view(-1, layer.num_heads, layer.head_size)
key = key.view(-1, layer.num_kv_heads, layer.head_size)
value = value.view(-1, layer.num_kv_heads, layer.head_size)
# TODO: Remove this contiguous in the future.
value = value.contiguous()
if kv_cache[0].numel() > 0:
# if key_cache is None:
key_cache, value_cache = kv_cache[0], kv_cache[1]
slots = attn_metadata.slot_mapping
block_size = key_cache.shape[1]
slots_indices = slots.reshape(-1, 1)
block_indices = slots_indices // block_size
slots_indices = slots_indices % block_size
indices = torch.cat((block_indices, slots_indices), dim=1)
# C8
torch_npu.npu_scatter_nd_update_(key_cache, indices, quant_key)
torch_npu.npu_scatter_nd_update_(value_cache, indices, quant_value)
# V0-Style scheduler situation.
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
assert attn_metadata is not None
assert attn_metadata.attn_mask is not None
mask = attn_metadata.attn_mask
torch_npu._npu_flash_attention(query=query,
key=key,
value=value,
mask=mask,
seq_len=attn_metadata.seq_lens,
scale_value=scale,
num_heads=layer.num_heads,
num_kv_heads=layer.num_kv_heads,
out=output.reshape(query.shape))
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
raise NotImplementedError("kv cache int8 are not "
"implemented for "
"PrefillCacheHit")
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: # changed attn_metadata.attn_state == AscendAttentionState.DecodeOnly
if hasattr(attn_metadata, "decode"):
# torch_air
decode_meta = attn_metadata.decode
seq_lens = decode_meta.seq_lens_list
else:
seq_lens = attn_metadata.seq_lens
block_size = key_cache.shape[1]
query = query.view(num_tokens, 1, layer.num_heads *
layer.head_size).contiguous() # changed
# [num_blocks, block_size, N, D] --> [num_blocks, N, block_size, D]
key = key_cache
value = value_cache
output = torch_npu.npu_incre_flash_attention(
query,
key,
value,
num_key_value_heads=layer.num_kv_heads,
num_heads=layer.num_heads,
actual_seq_lengths=seq_lens,
scale_value=scale,
input_layout='BSH',
block_size=block_size,
block_table=attn_metadata.block_tables,
antiquant_scale=self.antiquant_scale_comb,
)
# Normal V1 situation.
else:
raise NotImplementedError("kv cache int8 are not "
"implemented for "
"other case")
return output
def fused_experts_310p(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
w1_input_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
w2_input_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
global_num_experts: int,
expert_map: torch.Tensor = None,
) -> torch.Tensor:
ep_size = get_ep_group().world_size
local_num_experts = global_num_experts // ep_size
local_num_group = top_k // ep_size
bsz, _ = hidden_states.shape
flatten_topk_ids = topk_ids.view(-1)
sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
sorted_topk_ids = sorted_topk_ids.to(torch.int32)
sorted_hidden_states = hidden_states.index_select(
0, sorted_topk_ids // local_num_group)
experts_id = torch.arange(0,
local_num_experts,
dtype=topk_ids.dtype,
device=topk_ids.device)
num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to(
torch.float32).sum(0)
topk_scales = topk_weights.view(-1).index_select(
0, sorted_topk_ids).unsqueeze(-1)
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
gate_up_out = torch_npu.npu_quant_grouped_matmul_dequant(
x=sorted_hidden_states,
quantized_weight=w1,
weight_scale=w1_scale,
group_list=group_list,
x_scale=w1_input_scale,
quant_mode="pertensor")
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
torch.float16)
gate_up_out *= topk_scales
down_out = torch_npu.npu_quant_grouped_matmul_dequant(
x=gate_up_out,
quantized_weight=w2,
weight_scale=w2_scale,
group_list=group_list,
x_scale=w2_input_scale,
quant_mode="pertensor")
unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32)
unsorted_hidden_states = down_out.index_select(0, unsorted_topk_ids)
final_hidden_states = unsorted_hidden_states.reshape(
bsz, top_k // ep_size, -1).sum(1)
return final_hidden_states
def fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
w1_input_scale: torch.Tensor,
w1_input_offset: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
w2_input_scale: torch.Tensor,
w2_input_offset: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
global_num_experts: int,
expert_map: torch.Tensor = None,
) -> torch.Tensor:
"""
Fused experts with top-k routing.
Args:
hidden_states: Hidden states of shape (num_tokens, hidden_size).
w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
topk_weights: Routing weights of shape (num_tokens, top_k).
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
top_k: Number of experts to select.
expert_map: Expert mapping of shape (num_experts,).
Returns:
hidden_states: Hidden states after routing.
"""
"""
# Check constraints.
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
"""
original_dtype = hidden_states.dtype
ep_size = get_ep_group().world_size
local_num_experts = global_num_experts // ep_size
w1_input_scale, _ = w1_input_scale.max(0)
quant_sorted_hidden_states = quant_per_tensor(
hidden_states,
w1_input_scale,
None,
True,
)
if expert_map is not None:
expanded_x, expanded_row_idx, expert_token_count, expanded_scale = torch_npu.npu_moe_init_routing_v2(
quant_sorted_hidden_states,
topk_ids,
scale=None,
active_num=topk_ids.numel(),
expert_capacity=-1,
expert_num=local_num_experts,
drop_pad_mode=0,
expert_tokens_num_type=1,
expert_tokens_num_flag=True,
quant_mode=-1,
active_expert_range=[0, local_num_experts],
row_idx_type=0,
)
else:
raise NotImplementedError(
"The quantified version of MOE class models "
"currently does not support tensor parallelism")
if expanded_x.dtype != w1.dtype:
w1_input_scale, _ = w1_input_scale.max(0)
quant_sorted_hidden_states = quant_per_tensor(
expanded_x,
w1_input_scale,
None,
True,
)
else:
quant_sorted_hidden_states = expanded_x
gate_up_out = torch_npu.npu_grouped_matmul(
x=[quant_sorted_hidden_states],
weight=[w1],
scale=[w1_scale * w1_input_scale[0]],
split_item=2,
group_list_type=1,
group_type=0,
group_list=expert_token_count,
output_dtype=original_dtype,
)[0]
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
if gate_up_out.dtype != w2.dtype:
w2_input_scale, _ = w2_input_scale.max(0)
quant_gate_up_out = quant_per_tensor(
gate_up_out,
w2_input_scale,
None,
True,
)
else:
quant_gate_up_out = gate_up_out
down_out = torch_npu.npu_grouped_matmul(
x=[quant_gate_up_out],
weight=[w2],
scale=[w2_scale * w2_input_scale[0]],
split_item=2,
group_list_type=1,
group_type=0,
group_list=expert_token_count,
output_dtype=original_dtype,
)[0]
if expert_map is not None:
final_hidden_states = torch_npu.npu_moe_finalize_routing(
down_out,
skip1=None,
skip2=None,
bias=None,
scales=topk_weights.to(down_out.dtype),
expanded_src_to_dst_row=expanded_row_idx,
export_for_source_row=topk_ids,
drop_pad_mode=2,
)
else:
raise NotImplementedError(
"The quantified version of MOE class models "
"currently does not support tensor parallelism")
return final_hidden_states

View File

@@ -0,0 +1,453 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch
import torch_npu
from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.distributed import get_ep_group
from vllm.forward_context import get_forward_context
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import FusedMoEState
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.common_fused_moe import \
fused_experts as unified_fused_experts
from vllm_ascend.ops.fused_moe import unified_fused_experts_eager
from vllm_ascend.ops.layers.experts_selector import select_experts
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, dispose_tensor
def apply_mlp_decode(hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
group_list: torch.Tensor,
dynamic_scale: torch.Tensor = None,
group_list_type: int = 1) -> torch.Tensor:
"""
apply MLP: gate_up_proj -> swiglu -> down_proj
Args:
hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size).
w1: expert weights1 with shape
(num_experts, hidden_size, intermediate_size * 2)
w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
w2: expert weights2 with shape
(num_experts, intermediate_size, hidden_size)
w2_scale: weights2 scale with shape (num_experts, hidden_size)
group_list: number of tokens for each expert, follow cumsum mode, and
with shape (num_experts).
transpose_weight:
w1: (num_experts, intermediate_size * 2, hidden_size) ->
(num_experts, hidden_size, intermediate_size * 2)
w2: (num_experts, hidden_size, intermediate_size) ->
(num_experts, intermediate_size, hidden_size)
Returns:
hidden_states: output hidden states after MLP.
"""
if dynamic_scale is None:
unquantized_hidden_states = hidden_states
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
hidden_states)
# Dispose the original unquantized hidden states
# to save npu memory because they're no longer used.
dispose_tensor(unquantized_hidden_states)
else:
pertoken_scale = dynamic_scale
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
split_item=3,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=torch.int32)[0]
# act_fn: swiglu
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
x=hidden_states,
weight_scale=w1_scale,
activation_scale=pertoken_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=group_list,
activate_left=True,
quant_mode=1,
)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w2],
scale=[w2_scale],
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=w2_scale.dtype)[0]
return hidden_states
def apply_mlp(hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
group_list: torch.Tensor,
dynamic_scale: torch.Tensor = None,
group_list_type: int = 1,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None) -> torch.Tensor:
"""
apply MLP: gate_up_proj -> swiglu -> down_proj
Args:
hidden_states: input hidden states with shape (num_tokens, hidden_size).
w1: expert weights1 with shape
(num_experts, hidden_size, intermediate_size * 2)
w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
w2: expert weights2 with shape
(num_experts, intermediate_size, hidden_size)
w2_scale: weights2 scale with shape (num_experts, hidden_size)
group_list: number of tokens for each expert, follow cumsum mode, and
with shape (num_experts).
transpose_weight:
w1: (num_experts, intermediate_size * 2, hidden_size) ->
(num_experts, hidden_size, intermediate_size * 2)
w2: (num_experts, hidden_size, intermediate_size) ->
(num_experts, intermediate_size, hidden_size)
Returns:
hidden_states: output hidden states after MLP.
"""
if dynamic_scale is None:
unquantized_hidden_states = hidden_states
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
hidden_states)
# Dispose the original unquantized hidden states
# to save npu memory because they're no longer used.
dispose_tensor(unquantized_hidden_states)
else:
pertoken_scale = dynamic_scale
bias1, bias2 = None, None
_output_dtype = w2_scale.dtype
if w1_scale_bias is not None:
if group_list_type == 0:
group_list = torch.cat(
[group_list[:1], torch.diff(group_list, dim=0)])
group_list_type = 1
bias1 = [w1_scale_bias]
bias2 = [w2_scale_bias]
# TODO w4a8 scene: dynamic acquisition of dtype in the future
_output_dtype = torch.bfloat16
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
scale=[w1_scale],
bias=bias1,
per_token_scale=[pertoken_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=_output_dtype)[0]
# act_fn: swiglu
hidden_states = torch_npu.npu_swiglu(hidden_states)
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
hidden_states)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w2],
scale=[w2_scale],
bias=bias2,
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=_output_dtype)[0]
return hidden_states
class AscendW8A8DynamicLinearMethod:
"""Linear method for Ascend W8A8_DYNAMIC.
"""
def __init__(self):
self.transpose_weight = True
@staticmethod
def get_weight(input_size: int, output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
params_dict = {
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
}
return params_dict
@staticmethod
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
return {}
@staticmethod
def get_perchannel_param(
output_size: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
params_dict = {}
params_dict["weight_scale"] = torch.empty(output_size,
1,
dtype=params_dtype)
params_dict["weight_offset"] = torch.empty(output_size,
1,
dtype=params_dtype)
return params_dict
def get_pergroup_param(self, input_size: int, output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
return {}
@staticmethod
def apply(
layer: torch.nn.Module,
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = 0,
) -> torch.Tensor:
config = getattr(layer, "_ascend_quant_config", {})
if not isinstance(x, tuple):
output_dtype = config.get("output_dtype", x.dtype)
quantized_x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
else:
assert "output_dtype" in config.keys(), (
f"DynamicLinearMethod needs explicitly specified `output_dtype`"
f"for pre-quantized input, got config [{config}]")
output_dtype = config["output_dtype"]
quantized_x, dynamic_scale = x
pertoken_scale = (dynamic_scale
if config.get("pertoken_scale", True) else None)
output = torch_npu.npu_quant_matmul(
quantized_x,
layer.weight,
layer.weight_scale,
pertoken_scale=pertoken_scale,
bias=bias,
output_dtype=output_dtype,
)
return ((output, dynamic_scale)
if config.get("return_scale", False) else output)
def process_weights_after_loading(self, layer):
if self.transpose_weight:
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
# cast quantized weight tensors in NZ format (29) for higher inference speed
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
layer.weight_scale.data = layer.weight_scale.data.flatten()
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
layer.weight_offset.data = layer.weight_offset.data.flatten()
class AscendW8A8DynamicFusedMoEMethod:
"""FusedMoe method for Ascend W8A8_DYNAMIC.
"""
def __init__(self):
self.transpose_weight = True
self.ep_group = get_ep_group()
vllm_config = get_current_vllm_config()
ascend_config = get_ascend_config()
self.use_aclgraph = (
vllm_config.compilation_config.level == CompilationLevel.PIECEWISE
and not vllm_config.model_config.enforce_eager
and not ascend_config.torchair_graph_config.enabled)
try:
device_group = get_mc2_group().device_group
# TODO: Try local_rank = ep_group.rank_in_group
local_rank = torch.distributed.get_rank(group=device_group)
backend = device_group._get_backend(torch.device("npu"))
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
local_rank)
except AttributeError:
self.moe_all_to_all_group_name = ""
@staticmethod
def get_weight(num_experts: int, intermediate_size_per_partition: int,
hidden_sizes: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
param_dict = {}
param_dict["w13_weight"] = torch.empty(num_experts,
2 *
intermediate_size_per_partition,
hidden_sizes,
dtype=torch.int8)
param_dict["w2_weight"] = torch.empty(num_experts,
hidden_sizes,
intermediate_size_per_partition,
dtype=torch.int8)
return param_dict
@staticmethod
def get_dynamic_quant_param(num_experts: int,
intermediate_size_per_partition: int,
hidden_sizes: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
param_dict = {}
param_dict["w13_weight_scale"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=params_dtype)
param_dict["w13_weight_offset"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=params_dtype)
param_dict["w2_weight_scale"] = torch.empty(num_experts,
hidden_sizes,
1,
dtype=params_dtype)
param_dict["w2_weight_offset"] = torch.empty(num_experts,
hidden_sizes,
1,
dtype=params_dtype)
return param_dict
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: bool = True,
enable_force_load_balance: bool = True,
log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
**kwargs,
) -> torch.Tensor:
assert router_logits.shape[
1] == global_num_experts, "Number of global experts mismatch"
topk_weights, topk_ids, row_idx = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts)
if self.use_aclgraph:
return unified_fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
use_int8_w8a8=True,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
expert_map=expert_map,
)
fused_moe_state = get_forward_context().fused_moe_state
shared_gate_up, shared_dequant_scale = None, None
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
share_up_out, _ = shared_experts.gate_up_proj(
(quantized_x_for_share, dynamic_scale_for_share))
shared_gate_up, shared_dequant_scale = share_up_out[
0], share_up_out[1]
# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
topk_weights = topk_weights.to(x.dtype)
return unified_fused_experts_eager(
hidden_states=x,
w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale,
w2=layer.w2_weight,
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
expert_map=expert_map,
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num,
shared_experts=shared_experts,
shared_gate_up=shared_gate_up,
shared_dequant_scale=shared_dequant_scale,
mc2_mask=kwargs.get("mc2_mask", None),
with_quant=True)
def process_weights_after_loading(self, layer):
if self.transpose_weight:
layer.w13_weight.data = layer.w13_weight.data.transpose(
1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose(
1, 2).contiguous()
if envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP:
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
layer.w13_weight_scale.data.shape[0], -1)
layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to(
torch.float32)
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
layer.w13_weight_offset.data.shape[0], -1)
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
layer.w2_weight_scale.data.shape[0], -1)
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
layer.w2_weight_offset.data.shape[0], -1)