diff --git a/tests/ut/ops/test_layernorm.py b/tests/ut/ops/test_layernorm.py index 0623185..ed9dd44 100644 --- a/tests/ut/ops/test_layernorm.py +++ b/tests/ut/ops/test_layernorm.py @@ -7,6 +7,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm from tests.ut.base import PytestBase from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod +from vllm_ascend.utils import version_check def mock_rms_norm(x, weight, eps): @@ -26,6 +27,15 @@ def mock_add_rms_norm_quant(x, residual, weight, quant_scale, quant_offset, return x_out_quant, None, residual_out_quant +def mock_add_rms_norm_quant_with_bias(x, residual, weight, quant_scale, + quant_offset, beta, epsilon): + x_out = 2 * x + residual_out = 2 * residual + x_out_quant = x_out.to(torch.int8) + residual_out_quant = residual_out.to(torch.int8) + return x_out_quant, None, residual_out_quant + + class TestAscendRMSNorm(PytestBase): @pytest.fixture(autouse=True) @@ -33,8 +43,10 @@ class TestAscendRMSNorm(PytestBase): mocker.patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm) mocker.patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm) + torch_npu_check = version_check() + arnq_side_effect = mock_add_rms_norm_quant_with_bias if torch_npu_check else mock_add_rms_norm_quant mocker.patch("torch_npu.npu_add_rms_norm_quant", - side_effect=mock_add_rms_norm_quant) + side_effect=arnq_side_effect) mocker.patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None) @@ -70,8 +82,10 @@ class TestAscendRMSNorm(PytestBase): mock_model_instance = mocker.MagicMock() mock_forward_context.model_instance = mock_model_instance + torch_npu_check = version_check() + num_hidden_layers = 3 if torch_npu_check else 2 mock_model_instance.model.layers = [ - mocker.MagicMock() for _ in range(2) + mocker.MagicMock() for _ in range(num_hidden_layers) ] mock_layer_0 = mock_model_instance.model.layers[0] @@ -101,7 +115,7 @@ class TestAscendRMSNorm(PytestBase): mock_forward_context.addrmsnorm_quant_fusion_enabled = True mock_forward_context.prefetch_mlp_enabled = False mock_forward_context.layer_idx = 0 - mock_forward_context.num_hidden_layers = 2 + mock_forward_context.num_hidden_layers = num_hidden_layers mock_forward_context.fusion_linear = "gate_up_dense" # Ensure fusion and layer_idx increment are handled correctly @@ -121,18 +135,37 @@ class TestAscendRMSNorm(PytestBase): assert mock_forward_context.fusion_linear == "gate_up_dense" assert mock_forward_context.layer_idx == 1 + if torch_npu_check: + mock_forward_context.fusion_linear = "gate_moe" x_out, residual_out = layer.forward_oot(x, residual) assert mock_get_forward_context.call_count == 3 - assert mock_forward_context.fusion_linear == "qkv_dense" + fusion_linear_expected = "qkv_moe" if torch_npu_check else "qkv_dense" + assert mock_forward_context.fusion_linear == fusion_linear_expected assert mock_forward_context.layer_idx == 2 x_out, residual_out = layer.forward_oot(x, residual) assert mock_get_forward_context.call_count == 4 - assert mock_forward_context.fusion_linear == "qkv_dense" + fusion_linear_expected = "gate_moe" if torch_npu_check else "qkv_dense" + assert mock_forward_context.fusion_linear == fusion_linear_expected assert mock_forward_context.layer_idx == 2 + if not torch_npu_check: + return + # last layer returned directly + x_out, residual_out = layer.forward_oot(x, residual) + + assert mock_get_forward_context.call_count == 5 + assert mock_forward_context.fusion_linear == "qkv_moe" + assert mock_forward_context.layer_idx == 3 + + x_out, residual_out = layer.forward_oot(x, residual) + + assert mock_get_forward_context.call_count == 6 + assert mock_forward_context.fusion_linear == "qkv_moe" + assert mock_forward_context.layer_idx == 3 + if __name__ == '__main__': unittest.main() diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index d3402fe..47cc2e9 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -11,7 +11,7 @@ from vllm.forward_context import (BatchDescriptor, get_forward_context, set_forward_context) import vllm_ascend.envs as envs_ascend -from vllm_ascend.utils import enable_sp, is_moe_model +from vllm_ascend.utils import enable_sp, is_moe_model, version_check if TYPE_CHECKING: from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod @@ -160,13 +160,18 @@ def set_ascend_forward_context( # this optim now just support dense models due to the specific operators used. # Once the necessary conditions are met, support for MOE models will also be added. from vllm_ascend.quantization.quant_config import AscendQuantConfig + model_type_scope = ["llama", "qwen2", "qwen3"] + if version_check(): + model_type_scope.append("qwen3_moe") addrmsnorm_quant_fusion_enabled = isinstance(vllm_config.quant_config, AscendQuantConfig) and \ - vllm_config.model_config.hf_config.model_type in ["llama", "qwen2", "qwen3"] and \ + vllm_config.model_config.hf_config.model_type in model_type_scope and \ forward_context.layer_idx is not None if addrmsnorm_quant_fusion_enabled: forward_context.model_instance = model_instance forward_context.num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers forward_context.fusion_linear = "gate_up_dense" if forward_context.layer_idx == 0 else "qkv_dense" + if vllm_config.model_config.hf_config.model_type == "qwen3_moe": + forward_context.fusion_linear = "gate_moe" if forward_context.layer_idx == 0 else "qkv_moe" forward_context.addrmsnorm_quant_fusion_enabled = addrmsnorm_quant_fusion_enabled if num_tokens is None and attn_metadata is not None: diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 53c93a6..27c289b 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -33,13 +33,12 @@ from vllm.v1.kv_cache_interface import AttentionSpec from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, maybe_save_kv_layer_to_connector, - version_check, wait_for_kv_layer_from_connector) from vllm_ascend.compilation.acl_graph import (get_graph_params, update_graph_params_workspaces) from vllm_ascend.ops.attention import vanilla_chunked_prefill from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p, - nd_to_nz_2d, nd_to_nz_spec) + nd_to_nz_2d, nd_to_nz_spec, version_check) from ..utils import weak_ref_tensors diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index 007b055..61befaa 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -1,10 +1,8 @@ -import functools from dataclasses import dataclass from typing import Any, List import torch import torch.nn.functional as F -import torch_npu from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group, is_v1_kv_transfer_group) @@ -142,20 +140,6 @@ def maybe_save_kv_layer_to_connector( connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata) -@functools.cache -def version_check(): - import re - torch_npu_version = torch_npu.version.__version__ - date_pattern = r'dev(\d{8})' - - match = re.search(date_pattern, torch_npu_version) - if match: - full_date = match.group(1) - if full_date >= "20250919": - return True - return False - - def round_up(val: int, align: int) -> int: if align == 0: return 0 diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 0065dd4..60c7e39 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -18,7 +18,7 @@ from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.logger import logger from vllm.platforms import current_platform -from vllm_ascend.attention.utils import version_check +from vllm_ascend.utils import version_check from ..utils import weak_ref_tensors diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index fbe281f..55eab21 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -18,28 +18,43 @@ from typing import Optional, Tuple, Union, cast import torch +from vllm.config import get_current_vllm_config from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm +from vllm_ascend.utils import version_check + def _addrmsnorm_forward_oot( self, x: torch.Tensor, residual: torch.Tensor, layer: Optional[torch.nn.Module] = None, + bias: Optional[torch.nn.Parameter] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: import torch_npu from vllm_ascend.utils import is_310p + torch_npu_check = version_check() if layer is not None and not is_310p(): - x, _, residual = torch_npu.npu_add_rms_norm_quant( - x, - residual, - self.weight, - layer.aclnn_input_scale, - layer.aclnn_input_offset, - epsilon=self.variance_epsilon) + if torch_npu_check: + x, _, residual = torch_npu.npu_add_rms_norm_quant( + x, + residual, + self.weight, + layer.aclnn_input_scale, + layer.aclnn_input_offset, + beta=bias, + epsilon=self.variance_epsilon) + else: + x, _, residual = torch_npu.npu_add_rms_norm_quant( + x, + residual, + self.weight, + layer.aclnn_input_scale, + layer.aclnn_input_offset, + epsilon=self.variance_epsilon) else: if is_310p(): orig_dtype = residual.dtype @@ -50,12 +65,32 @@ def _addrmsnorm_forward_oot( else: x, _, residual = torch_npu.npu_add_rms_norm( x, residual, self.weight, self.variance_epsilon) + if torch_npu_check and bias is not None: + x.add_(bias) torch.ops.vllm.maybe_wait_prefetch_done(x) return x, residual class AscendRMSNorm(RMSNorm): + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + var_hidden_size: Optional[int] = None, + has_weight: bool = True, + dtype: Optional[torch.dtype] = None, + ) -> None: + super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype) + vllm_config = get_current_vllm_config() + self.bias = None + self.torch_npu_check = version_check() + # quantization with anti_method m4 will generate none-zero norm bias + if self.torch_npu_check and vllm_config.quant_config is not None and \ + any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()): + self.bias = torch.nn.Parameter(torch.zeros(hidden_size), + requires_grad=False) + def forward_oot( self, x: torch.Tensor, @@ -66,10 +101,13 @@ class AscendRMSNorm(RMSNorm): if residual is not None: assert x.size(0) == residual.size(0) x, residual = _addrmsnorm_forward_oot( - self, x, residual, self.next_need_quant_fusion_linear) + self, x, residual, self.next_need_quant_fusion_linear, + self.bias) return x, residual x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon) + if self.torch_npu_check and self.bias is not None: + x.add_(self.bias) return x @property @@ -99,6 +137,13 @@ class AscendRMSNorm(RMSNorm): # does not need to be repeated if not forward_context.prefetch_mlp_enabled: forward_context.layer_idx += 1 + elif fusion_linear == "qkv_moe": + next_linear = model_instance.model.layers[ + layer_idx].self_attn.qkv_proj + forward_context.fusion_linear = "gate_moe" + elif fusion_linear == "gate_moe": + forward_context.fusion_linear = "qkv_moe" + forward_context.layer_idx += 1 from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod if next_linear is not None and \ not isinstance(next_linear.quant_method.quant_method, AscendW8A8LinearMethod): diff --git a/vllm_ascend/ops/moe/moe_mlp.py b/vllm_ascend/ops/moe/moe_mlp.py index 05a1a2e..5ee7d70 100644 --- a/vllm_ascend/ops/moe/moe_mlp.py +++ b/vllm_ascend/ops/moe/moe_mlp.py @@ -177,7 +177,6 @@ def quant_apply_mlp(hidden_states: torch.Tensor, group_type=0, group_list=group_list, output_dtype=_output_dtype)[0] - return hidden_states diff --git a/vllm_ascend/ops/weight_prefetch.py b/vllm_ascend/ops/weight_prefetch.py index 36b3a18..c2548ba 100644 --- a/vllm_ascend/ops/weight_prefetch.py +++ b/vllm_ascend/ops/weight_prefetch.py @@ -7,6 +7,7 @@ from vllm.forward_context import get_forward_context from vllm_ascend.ascend_config import WeightPrefetchConfig from vllm_ascend.ops.linear import (AscendQKVParallelLinear, AscendRowParallelLinear) +from vllm_ascend.utils import version_check SUPPORTED_MODULES = ["attn", "mlp", "moe"] MOE_PREFETCH_TOKEN_THRESHOLD = 96 @@ -82,14 +83,15 @@ class WeightPrefetchMethod: if not self.moe.is_active_this_forward: return forward_context = get_forward_context() + if not version_check(): + forward_context.layer_idx += 1 weight = forward_context.model_instance.model.layers[ - forward_context.layer_idx].mlp.experts.w13_weight + forward_context.layer_idx - 1].mlp.experts.w13_weight weight_size = weight.data.element_size() * weight.data.numel( ) * self.moe.prefetch_ratio.get(prefix, 0) torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=None, max_weight_size=int(weight_size)) - forward_context.layer_idx += 1 def maybe_prefetch_moe_weight_postprocess(self, stop_flag: torch.Tensor): if not self.moe.is_active_this_forward: diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 6cc903a..7bc79d8 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -546,7 +546,8 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): if vllm_config is not None and \ vllm_config.quant_config is not None and \ - any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()): + any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()) and \ + not version_check(): REGISTERED_ASCEND_OPS["RMSNorm"] = AscendQuantRMSNorm for name, op_cls in REGISTERED_ASCEND_OPS.items(): @@ -725,3 +726,18 @@ def calculate_dp_buffer_size() -> int: def is_hierarchical_communication_enabled(): return (os.getenv("HCCL_INTRA_ROCE_ENABLE", "") == "0" and os.getenv("HCCL_INTRA_PCIE_ENABLE", "") == "1") + + +@functools.cache +def version_check(): + """check if torch_npu version >= dev20250919""" + import re + torch_npu_version = torch_npu.version.__version__ + date_pattern = r'dev(\d{8})' + + match = re.search(date_pattern, torch_npu_version) + if match: + full_date = match.group(1) + if full_date >= "20250919": + return True + return False