From 5a4e8cdebabd8293e9ff61b7014d758b71ebf32a Mon Sep 17 00:00:00 2001 From: InSec <158599047+InSec@users.noreply.github.com> Date: Fri, 21 Nov 2025 10:42:56 +0800 Subject: [PATCH] [Feat][BugFix]Support the Qwen3-Next-80B-A3B-Instruct quantization model&Fix the NZ issue (#4245) ### What this PR does / why we need it? Support the Qwen3-Next-80B-A3B-Instruct quantization model and Fix the NZ issue. Triton kernel doesn't support data format nz, thus we skip converting weight to nz on layer `conv1d` - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379 --------- Signed-off-by: IncSec <1790766300@qq.com> --- tests/e2e/multicard/test_qwen3_next.py | 26 ++++++++++++++++++++++ tests/ut/attention/test_mla_v1.py | 1 - tests/ut/models/test_qwen2_5_vl.py | 4 ---- tests/ut/quantization/test_w4a8_dynamic.py | 1 - tests/ut/test_utils.py | 18 +++++---------- tests/ut/worker/test_worker_v1.py | 2 +- vllm_ascend/ops/linear.py | 3 ++- vllm_ascend/quantization/quant_config.py | 2 ++ vllm_ascend/utils.py | 11 ++------- vllm_ascend/worker/worker_v1.py | 1 - 10 files changed, 39 insertions(+), 30 deletions(-) diff --git a/tests/e2e/multicard/test_qwen3_next.py b/tests/e2e/multicard/test_qwen3_next.py index c17eed95..6492da75 100644 --- a/tests/e2e/multicard/test_qwen3_next.py +++ b/tests/e2e/multicard/test_qwen3_next.py @@ -20,6 +20,12 @@ Run `pytest tests/e2e/multicard/test_qwen3_next.py`. """ + +import os +from unittest.mock import patch + +from modelscope import snapshot_download # type: ignore + from tests.e2e.conftest import VllmRunner @@ -106,3 +112,23 @@ def test_models_distributed_Qwen3_NEXT_MTP_TP4_SIMILARITY(): print(f"spec_output: {spec_output[1]}") assert matches > int(0.66 * len(ref_outputs)) + + +# TODO: will conduct accuracy verification after the subsequent version becomes stable +@patch.dict(os.environ, {"HCCL_BUFFSIZE": "1024"}) +def test_models_distributed_Qwen3_NEXT_W8A8DYNAMIC_WITH_EP(): + example_prompts = [ + "Hello, my name is", + ] + max_tokens = 5 + with VllmRunner( + snapshot_download( + "vllm-ascend/Qwen3-Next-80B-A3B-Instruct-W8A8-Pruning"), + max_model_len=4096, + tensor_parallel_size=2, + gpu_memory_utilization=0.4, + max_num_seqs=1, + enable_expert_parallel=True, + quantization="ascend", + ) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 10d6528c..5c58cf7b 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -797,7 +797,6 @@ class TestAscendMLAImpl(TestBase): self.assertEqual(q_pe.shape[1], self.impl.num_heads) self.assertEqual(q_pe.shape[2], self.impl.qk_rope_head_dim) - @patch('vllm_ascend.utils._ENABLE_NZ', True) @patch('torch_npu.npu_format_cast') def test_process_weights_after_loading(self, mock_format_cast): layer = MagicMock(spec=LinearBase) diff --git a/tests/ut/models/test_qwen2_5_vl.py b/tests/ut/models/test_qwen2_5_vl.py index b4f06803..7111aaed 100644 --- a/tests/ut/models/test_qwen2_5_vl.py +++ b/tests/ut/models/test_qwen2_5_vl.py @@ -1,5 +1,3 @@ -from unittest.mock import patch - import pytest import torch import torch.nn.functional as F @@ -367,7 +365,6 @@ class TestAscendQwen2_5_VisionTransformer(PytestBase): res = attention.pad_qkv_bias(torch.rand((300))) assert res.shape[0] == 384 - @patch('vllm_ascend.utils._ENABLE_NZ', True) def test_pad_qkv_weight(self, mocker: MockerFixture): attention = self.init_vision_transformer(mocker) mocker.patch("torch.nn.Module.__setattr__") @@ -380,7 +377,6 @@ class TestAscendQwen2_5_VisionTransformer(PytestBase): res = attention.pad_qkv_weight(torch.rand((300, 300))) assert res.shape == (384, 300) - @patch('vllm_ascend.utils._ENABLE_NZ', True) def test_pad_proj_weight(self, mocker: MockerFixture): attention = self.init_vision_transformer(mocker) mocker.patch("torch.nn.Module.__setattr__") diff --git a/tests/ut/quantization/test_w4a8_dynamic.py b/tests/ut/quantization/test_w4a8_dynamic.py index 42c3c933..2116b0c1 100644 --- a/tests/ut/quantization/test_w4a8_dynamic.py +++ b/tests/ut/quantization/test_w4a8_dynamic.py @@ -260,7 +260,6 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase): requires_grad=False) return layer - @patch('vllm_ascend.utils._ENABLE_NZ', False) @patch('torch_npu.npu_format_cast') @patch('torch_npu.npu_quantize') @patch('torch.Tensor.npu') diff --git a/tests/ut/test_utils.py b/tests/ut/test_utils.py index 8d34547b..147e8378 100644 --- a/tests/ut/test_utils.py +++ b/tests/ut/test_utils.py @@ -46,18 +46,12 @@ class TestUtils(TestBase): self.assertFalse(utils.is_310p()) def test_is_enable_nz(self): - # Case when _ENABLE_NZ is already set - utils._ENABLE_NZ = True - self.assertTrue(utils.is_enable_nz()) - - utils._ENABLE_NZ = False - self.assertFalse(utils.is_enable_nz()) - - # Case when _ENABLE_NZ is None and vllm_config is not provided - utils._ENABLE_NZ = None - with self.assertRaises(ValueError) as context: - utils.is_enable_nz() - self.assertIn("vllm_config must be provided", str(context.exception)) + with mock.patch("vllm_ascend.utils.envs_ascend.VLLM_ASCEND_ENABLE_NZ", + 1): + self.assertTrue(utils.is_enable_nz()) + with mock.patch("vllm_ascend.utils.envs_ascend.VLLM_ASCEND_ENABLE_NZ", + 0): + self.assertFalse(utils.is_enable_nz()) def test_sleep_mode_enabled(self): utils._SLEEP_MODE_ENABLED = None diff --git a/tests/ut/worker/test_worker_v1.py b/tests/ut/worker/test_worker_v1.py index bd3192b5..48a4242a 100644 --- a/tests/ut/worker/test_worker_v1.py +++ b/tests/ut/worker/test_worker_v1.py @@ -281,9 +281,9 @@ class TestNPUWorker(TestBase): self.assertIn("Sleep mode is not enabled", str(cm.exception)) - @patch('vllm_ascend.utils._ENABLE_NZ', False) @patch("vllm_ascend.worker.worker_v1.sleep_mode_enabled") @patch("vllm_ascend.worker.worker_v1.CaMemAllocator") + @patch.dict("os.environ", {"VLLM_ASCEND_ENABLE_NZ": "0"}) def test_wake_up_mode_enabled(self, mock_allocator_class, mock_sleep_mode_enabled): """Test wake_up method when sleep mode is enabled""" diff --git a/vllm_ascend/ops/linear.py b/vllm_ascend/ops/linear.py index eab312d5..844cdcbd 100644 --- a/vllm_ascend/ops/linear.py +++ b/vllm_ascend/ops/linear.py @@ -45,7 +45,8 @@ class AscendUnquantizedLinearMethod(UnquantizedLinearMethod): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: super().process_weights_after_loading(layer) - if (is_enable_nz() and layer.weight.data.dtype + if "conv1d" not in layer.prefix and ( + is_enable_nz() and layer.weight.data.dtype in [torch.float16, torch.bfloat16]): layer.weight.data = torch_npu.npu_format_cast( layer.weight.data, ACL_FORMAT_FRACTAL_NZ) diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index c0760c80..d6696304 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -222,6 +222,8 @@ packed_modules_model_mapping = { ], "gate_up_proj": ["gate_proj", "up_proj"], "in_proj": ["in_proj_qkvz", "in_proj_ba"], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] }, "qwen2_5_vl": { "qkv_proj": [ diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 38151080..7fd73826 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -59,7 +59,6 @@ _MIN_DP_BUFFER_SIZE = 50 _IS_MOE_MODEL = None _ENABLE_SP = None _HAS_LAYER_IDX = None -_ENABLE_NZ = None _SUBSCRIBED_COMPUTE_STREAMS = set() _GRAPH_PRINT_STREAM = None _GRAPH_PRINT_STREAM_LOCK = Lock() @@ -129,14 +128,8 @@ def is_310p(): return _IS_310P -def is_enable_nz(vllm_config: Optional[VllmConfig] = None) -> bool: - global _ENABLE_NZ - if _ENABLE_NZ is None: - if not vllm_config: - raise ValueError( - "vllm_config must be provided when _ENABLE_NZ is None") - _ENABLE_NZ = envs_ascend.VLLM_ASCEND_ENABLE_NZ and vllm_config.model_config.hf_config.model_type != "qwen3_next" - return _ENABLE_NZ +def is_enable_nz(): + return envs_ascend.VLLM_ASCEND_ENABLE_NZ def sleep_mode_enabled(): diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 58ac27a0..db50bceb 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -87,7 +87,6 @@ class NPUWorker(WorkerBase): # register patch for vllm from vllm_ascend.utils import adapt_patch adapt_patch() - is_enable_nz(vllm_config) # Register ops when worker init. from vllm_ascend import ops ops.register_dummy_fusion_op()