[Feat][BugFix]Support the Qwen3-Next-80B-A3B-Instruct quantization model&Fix the NZ issue (#4245)
### What this PR does / why we need it?
Support the Qwen3-Next-80B-A3B-Instruct quantization model and Fix the
NZ issue. Triton kernel doesn't support data format nz, thus we skip
converting weight to nz on layer `conv1d`
- vLLM version: v0.11.0
- vLLM main:
2918c1b49c
---------
Signed-off-by: IncSec <1790766300@qq.com>
This commit is contained in:
@@ -20,6 +20,12 @@
|
|||||||
|
|
||||||
Run `pytest tests/e2e/multicard/test_qwen3_next.py`.
|
Run `pytest tests/e2e/multicard/test_qwen3_next.py`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from modelscope import snapshot_download # type: ignore
|
||||||
|
|
||||||
from tests.e2e.conftest import VllmRunner
|
from tests.e2e.conftest import VllmRunner
|
||||||
|
|
||||||
|
|
||||||
@@ -106,3 +112,23 @@ def test_models_distributed_Qwen3_NEXT_MTP_TP4_SIMILARITY():
|
|||||||
print(f"spec_output: {spec_output[1]}")
|
print(f"spec_output: {spec_output[1]}")
|
||||||
|
|
||||||
assert matches > int(0.66 * len(ref_outputs))
|
assert matches > int(0.66 * len(ref_outputs))
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: will conduct accuracy verification after the subsequent version becomes stable
|
||||||
|
@patch.dict(os.environ, {"HCCL_BUFFSIZE": "1024"})
|
||||||
|
def test_models_distributed_Qwen3_NEXT_W8A8DYNAMIC_WITH_EP():
|
||||||
|
example_prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
]
|
||||||
|
max_tokens = 5
|
||||||
|
with VllmRunner(
|
||||||
|
snapshot_download(
|
||||||
|
"vllm-ascend/Qwen3-Next-80B-A3B-Instruct-W8A8-Pruning"),
|
||||||
|
max_model_len=4096,
|
||||||
|
tensor_parallel_size=2,
|
||||||
|
gpu_memory_utilization=0.4,
|
||||||
|
max_num_seqs=1,
|
||||||
|
enable_expert_parallel=True,
|
||||||
|
quantization="ascend",
|
||||||
|
) as vllm_model:
|
||||||
|
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
|||||||
@@ -797,7 +797,6 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
self.assertEqual(q_pe.shape[1], self.impl.num_heads)
|
self.assertEqual(q_pe.shape[1], self.impl.num_heads)
|
||||||
self.assertEqual(q_pe.shape[2], self.impl.qk_rope_head_dim)
|
self.assertEqual(q_pe.shape[2], self.impl.qk_rope_head_dim)
|
||||||
|
|
||||||
@patch('vllm_ascend.utils._ENABLE_NZ', True)
|
|
||||||
@patch('torch_npu.npu_format_cast')
|
@patch('torch_npu.npu_format_cast')
|
||||||
def test_process_weights_after_loading(self, mock_format_cast):
|
def test_process_weights_after_loading(self, mock_format_cast):
|
||||||
layer = MagicMock(spec=LinearBase)
|
layer = MagicMock(spec=LinearBase)
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -367,7 +365,6 @@ class TestAscendQwen2_5_VisionTransformer(PytestBase):
|
|||||||
res = attention.pad_qkv_bias(torch.rand((300)))
|
res = attention.pad_qkv_bias(torch.rand((300)))
|
||||||
assert res.shape[0] == 384
|
assert res.shape[0] == 384
|
||||||
|
|
||||||
@patch('vllm_ascend.utils._ENABLE_NZ', True)
|
|
||||||
def test_pad_qkv_weight(self, mocker: MockerFixture):
|
def test_pad_qkv_weight(self, mocker: MockerFixture):
|
||||||
attention = self.init_vision_transformer(mocker)
|
attention = self.init_vision_transformer(mocker)
|
||||||
mocker.patch("torch.nn.Module.__setattr__")
|
mocker.patch("torch.nn.Module.__setattr__")
|
||||||
@@ -380,7 +377,6 @@ class TestAscendQwen2_5_VisionTransformer(PytestBase):
|
|||||||
res = attention.pad_qkv_weight(torch.rand((300, 300)))
|
res = attention.pad_qkv_weight(torch.rand((300, 300)))
|
||||||
assert res.shape == (384, 300)
|
assert res.shape == (384, 300)
|
||||||
|
|
||||||
@patch('vllm_ascend.utils._ENABLE_NZ', True)
|
|
||||||
def test_pad_proj_weight(self, mocker: MockerFixture):
|
def test_pad_proj_weight(self, mocker: MockerFixture):
|
||||||
attention = self.init_vision_transformer(mocker)
|
attention = self.init_vision_transformer(mocker)
|
||||||
mocker.patch("torch.nn.Module.__setattr__")
|
mocker.patch("torch.nn.Module.__setattr__")
|
||||||
|
|||||||
@@ -260,7 +260,6 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
|
|||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
return layer
|
return layer
|
||||||
|
|
||||||
@patch('vllm_ascend.utils._ENABLE_NZ', False)
|
|
||||||
@patch('torch_npu.npu_format_cast')
|
@patch('torch_npu.npu_format_cast')
|
||||||
@patch('torch_npu.npu_quantize')
|
@patch('torch_npu.npu_quantize')
|
||||||
@patch('torch.Tensor.npu')
|
@patch('torch.Tensor.npu')
|
||||||
|
|||||||
@@ -46,19 +46,13 @@ class TestUtils(TestBase):
|
|||||||
self.assertFalse(utils.is_310p())
|
self.assertFalse(utils.is_310p())
|
||||||
|
|
||||||
def test_is_enable_nz(self):
|
def test_is_enable_nz(self):
|
||||||
# Case when _ENABLE_NZ is already set
|
with mock.patch("vllm_ascend.utils.envs_ascend.VLLM_ASCEND_ENABLE_NZ",
|
||||||
utils._ENABLE_NZ = True
|
1):
|
||||||
self.assertTrue(utils.is_enable_nz())
|
self.assertTrue(utils.is_enable_nz())
|
||||||
|
with mock.patch("vllm_ascend.utils.envs_ascend.VLLM_ASCEND_ENABLE_NZ",
|
||||||
utils._ENABLE_NZ = False
|
0):
|
||||||
self.assertFalse(utils.is_enable_nz())
|
self.assertFalse(utils.is_enable_nz())
|
||||||
|
|
||||||
# Case when _ENABLE_NZ is None and vllm_config is not provided
|
|
||||||
utils._ENABLE_NZ = None
|
|
||||||
with self.assertRaises(ValueError) as context:
|
|
||||||
utils.is_enable_nz()
|
|
||||||
self.assertIn("vllm_config must be provided", str(context.exception))
|
|
||||||
|
|
||||||
def test_sleep_mode_enabled(self):
|
def test_sleep_mode_enabled(self):
|
||||||
utils._SLEEP_MODE_ENABLED = None
|
utils._SLEEP_MODE_ENABLED = None
|
||||||
with mock.patch("vllm_ascend._build_info.__sleep_mode_enabled__",
|
with mock.patch("vllm_ascend._build_info.__sleep_mode_enabled__",
|
||||||
|
|||||||
@@ -281,9 +281,9 @@ class TestNPUWorker(TestBase):
|
|||||||
|
|
||||||
self.assertIn("Sleep mode is not enabled", str(cm.exception))
|
self.assertIn("Sleep mode is not enabled", str(cm.exception))
|
||||||
|
|
||||||
@patch('vllm_ascend.utils._ENABLE_NZ', False)
|
|
||||||
@patch("vllm_ascend.worker.worker_v1.sleep_mode_enabled")
|
@patch("vllm_ascend.worker.worker_v1.sleep_mode_enabled")
|
||||||
@patch("vllm_ascend.worker.worker_v1.CaMemAllocator")
|
@patch("vllm_ascend.worker.worker_v1.CaMemAllocator")
|
||||||
|
@patch.dict("os.environ", {"VLLM_ASCEND_ENABLE_NZ": "0"})
|
||||||
def test_wake_up_mode_enabled(self, mock_allocator_class,
|
def test_wake_up_mode_enabled(self, mock_allocator_class,
|
||||||
mock_sleep_mode_enabled):
|
mock_sleep_mode_enabled):
|
||||||
"""Test wake_up method when sleep mode is enabled"""
|
"""Test wake_up method when sleep mode is enabled"""
|
||||||
|
|||||||
@@ -45,7 +45,8 @@ class AscendUnquantizedLinearMethod(UnquantizedLinearMethod):
|
|||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
super().process_weights_after_loading(layer)
|
super().process_weights_after_loading(layer)
|
||||||
if (is_enable_nz() and layer.weight.data.dtype
|
if "conv1d" not in layer.prefix and (
|
||||||
|
is_enable_nz() and layer.weight.data.dtype
|
||||||
in [torch.float16, torch.bfloat16]):
|
in [torch.float16, torch.bfloat16]):
|
||||||
layer.weight.data = torch_npu.npu_format_cast(
|
layer.weight.data = torch_npu.npu_format_cast(
|
||||||
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
|
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||||
|
|||||||
@@ -222,6 +222,8 @@ packed_modules_model_mapping = {
|
|||||||
],
|
],
|
||||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||||
"in_proj": ["in_proj_qkvz", "in_proj_ba"],
|
"in_proj": ["in_proj_qkvz", "in_proj_ba"],
|
||||||
|
"experts":
|
||||||
|
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
|
||||||
},
|
},
|
||||||
"qwen2_5_vl": {
|
"qwen2_5_vl": {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
|
|||||||
@@ -59,7 +59,6 @@ _MIN_DP_BUFFER_SIZE = 50
|
|||||||
_IS_MOE_MODEL = None
|
_IS_MOE_MODEL = None
|
||||||
_ENABLE_SP = None
|
_ENABLE_SP = None
|
||||||
_HAS_LAYER_IDX = None
|
_HAS_LAYER_IDX = None
|
||||||
_ENABLE_NZ = None
|
|
||||||
_SUBSCRIBED_COMPUTE_STREAMS = set()
|
_SUBSCRIBED_COMPUTE_STREAMS = set()
|
||||||
_GRAPH_PRINT_STREAM = None
|
_GRAPH_PRINT_STREAM = None
|
||||||
_GRAPH_PRINT_STREAM_LOCK = Lock()
|
_GRAPH_PRINT_STREAM_LOCK = Lock()
|
||||||
@@ -129,14 +128,8 @@ def is_310p():
|
|||||||
return _IS_310P
|
return _IS_310P
|
||||||
|
|
||||||
|
|
||||||
def is_enable_nz(vllm_config: Optional[VllmConfig] = None) -> bool:
|
def is_enable_nz():
|
||||||
global _ENABLE_NZ
|
return envs_ascend.VLLM_ASCEND_ENABLE_NZ
|
||||||
if _ENABLE_NZ is None:
|
|
||||||
if not vllm_config:
|
|
||||||
raise ValueError(
|
|
||||||
"vllm_config must be provided when _ENABLE_NZ is None")
|
|
||||||
_ENABLE_NZ = envs_ascend.VLLM_ASCEND_ENABLE_NZ and vllm_config.model_config.hf_config.model_type != "qwen3_next"
|
|
||||||
return _ENABLE_NZ
|
|
||||||
|
|
||||||
|
|
||||||
def sleep_mode_enabled():
|
def sleep_mode_enabled():
|
||||||
|
|||||||
@@ -87,7 +87,6 @@ class NPUWorker(WorkerBase):
|
|||||||
# register patch for vllm
|
# register patch for vllm
|
||||||
from vllm_ascend.utils import adapt_patch
|
from vllm_ascend.utils import adapt_patch
|
||||||
adapt_patch()
|
adapt_patch()
|
||||||
is_enable_nz(vllm_config)
|
|
||||||
# Register ops when worker init.
|
# Register ops when worker init.
|
||||||
from vllm_ascend import ops
|
from vllm_ascend import ops
|
||||||
ops.register_dummy_fusion_op()
|
ops.register_dummy_fusion_op()
|
||||||
|
|||||||
Reference in New Issue
Block a user