### What this PR does / why we need it? **Problem Description:** The existing implementation for the w4a8-dynamic linear method only supports the old quantization format from msmodelslim. When attempting to load models quantized with the new version, vLLM encounters errors due to mismatched tensor shapes and unprocessed quantization parameters. Relavant issues: - https://github.com/vllm-project/vllm-ascend/issues/3192 - https://github.com/vllm-project/vllm-ascend/issues/3152 **Proposed Changes:** 1. Add support for w4a8 dynamic(new format) in AscendW4A8DynamicLinearMethod and TorchairAscendW4A8DynamicLinearMethod 2. Add unit tests and e2e tests for w4a8 dynamic new and old format models <details> <summary><b>details</b></summary> 1. **Support for new w4a8-dynamic format:** * Detects quantization format by reading the "version" field in quant_description to ensure backward compatibility. * Handles the new pre-packed weight format (`2x int4` in an `int8`), which has a halved dimension. It tells the vLLM loader how to unpack it using `_packed_dim` and `_packed_factor`. * Supports the new `scale_bias` parameter, setting its shape based on the layer type, as required by msmodelslim. For api consistency and future use, the `layer_type` parameter was also added to other quantization methods. * Updates the weight processing logic: new format weights are handled with `.view(torch.int32)` since they're pre-packed, while old ones are processed with `npu_convert_weight_to_int4pack`. 2. **New unit and E2E tests:** * Added unit tests that verify the logic for both the old and new formats. * Split the distributed E2E test to confirm that both old and new format models work correctly. </details> Theoretically, these changes will provide support for all common new version w4a8(dynamic) models from msmodelslim. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? I implement relevant unit tests and e2e tests and test the changes with following commands: ```bash # unit tests python -m pytest tests/ut/quantization/test_w4a8_dynamic.py tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py -v # e2e tests pytest tests/e2e/singlecard/test_quantization.py -v -s pytest tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC_new_version -v -s pytest tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC_old_version -v -s pytest tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC -v -s ``` I also tested Hunyuan-1.8B-Instruct quantized with the new w4a8-dynamic format: ``` vllm serve ./models/Hunyuan-1.8B-Instruct-quantized --gpu-memory-utilization 0.96 --quantization ascend --max-model-len 9600 --seed 0 --max-num-batched-tokens 16384 ``` All tests mentioned passed locally. **NOTE: I use quantization model from my own repo in test_offline_inference_distributed.py**. Here is the description: [Anionex/Qwen3-1.7B-W4A8-V1](https://modelscope.cn/models/Anionex/Qwen3-1.7B-W4A8-V1/summary) (including quantization steps).This should be replaced by a model in vllm-ascend ci modelscope repo. Thanks for reading! - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: Anionex <1005128408@qq.com>
229 lines
6.8 KiB
Python
229 lines
6.8 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# Copyright 2023 The vLLM team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# This file is a part of the vllm-ascend project.
|
|
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
|
|
#
|
|
"""Compare the short outputs of HF and vLLM when using greedy sampling.
|
|
|
|
Run `pytest tests/test_offline_inference.py`.
|
|
"""
|
|
import os
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
from modelscope import snapshot_download # type: ignore
|
|
from vllm import SamplingParams
|
|
|
|
from tests.e2e.conftest import VllmRunner
|
|
|
|
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
|
|
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
|
|
|
QWEN_DENSE_MODELS = [
|
|
"vllm-ascend/Qwen3-8B-W8A8", "vllm-ascend/Qwen2.5-0.5B-Instruct-W8A8"
|
|
]
|
|
|
|
QWEN_W4A8_OLD_VERSION_MODELS = [
|
|
"vllm-ascend/Qwen3-8B-W4A8",
|
|
]
|
|
|
|
QWEN_W4A8_NEW_VERSION_MODELS = [
|
|
"vllm-ascend/Qwen3-1.7B-W4A8-V1",
|
|
]
|
|
|
|
DEEPSEEK_W4A8_MODELS = [
|
|
"vllm-ascend/DeepSeek-V3-W4A8-Pruing",
|
|
"vllm-ascend/DeepSeek-V3.1-W4A8-puring"
|
|
]
|
|
|
|
|
|
def test_models_distributed_QwQ():
|
|
example_prompts = [
|
|
"Hello, my name is",
|
|
]
|
|
dtype = "half"
|
|
max_tokens = 5
|
|
with VllmRunner(
|
|
"Qwen/QwQ-32B",
|
|
dtype=dtype,
|
|
tensor_parallel_size=2,
|
|
distributed_executor_backend="mp",
|
|
enforce_eager=False,
|
|
) as vllm_model:
|
|
vllm_model.generate_greedy(example_prompts, max_tokens)
|
|
|
|
|
|
def test_models_distributed_DeepSeek_multistream_moe():
|
|
example_prompts = [
|
|
"Hello, my name is",
|
|
]
|
|
dtype = "half"
|
|
max_tokens = 5
|
|
with VllmRunner(
|
|
"vllm-ascend/DeepSeek-V3-Pruning",
|
|
dtype=dtype,
|
|
tensor_parallel_size=2,
|
|
distributed_executor_backend="mp",
|
|
additional_config={
|
|
"torchair_graph_config": {
|
|
"enabled": True,
|
|
},
|
|
"enable_multistream_moe": True,
|
|
"ascend_scheduler_config": {
|
|
"enabled": True,
|
|
},
|
|
"refresh": True,
|
|
},
|
|
) as vllm_model:
|
|
vllm_model.generate_greedy(example_prompts, max_tokens)
|
|
|
|
|
|
def test_models_distributed_Qwen3_W8A8():
|
|
example_prompts = [
|
|
"Hello, my name is",
|
|
]
|
|
max_tokens = 5
|
|
|
|
with VllmRunner(
|
|
snapshot_download("vllm-ascend/Qwen3-8B-W8A8"),
|
|
max_model_len=8192,
|
|
dtype="auto",
|
|
tensor_parallel_size=2,
|
|
quantization="ascend",
|
|
) as vllm_model:
|
|
vllm_model.generate_greedy(example_prompts, max_tokens)
|
|
|
|
|
|
@pytest.mark.parametrize("model", QWEN_W4A8_OLD_VERSION_MODELS)
|
|
def test_models_distributed_Qwen3_W4A8DYNAMIC_old_version(model):
|
|
prompts = [
|
|
"Hello, my name is",
|
|
]
|
|
max_tokens = 5
|
|
with VllmRunner(
|
|
snapshot_download(model),
|
|
max_model_len=8192,
|
|
dtype="auto",
|
|
tensor_parallel_size=2,
|
|
quantization="ascend",
|
|
) as vllm_model:
|
|
vllm_model.generate_greedy(prompts, max_tokens)
|
|
|
|
|
|
@pytest.mark.parametrize("model", QWEN_W4A8_NEW_VERSION_MODELS)
|
|
def test_models_distributed_Qwen3_W4A8DYNAMIC_new_version(model):
|
|
prompts = [
|
|
"Hello, my name is",
|
|
]
|
|
max_tokens = 5
|
|
with VllmRunner(
|
|
snapshot_download(model),
|
|
max_model_len=8192,
|
|
dtype="auto",
|
|
tensor_parallel_size=2,
|
|
quantization="ascend",
|
|
) as vllm_model:
|
|
vllm_model.generate_greedy(prompts, max_tokens)
|
|
|
|
|
|
@pytest.mark.parametrize("model", DEEPSEEK_W4A8_MODELS)
|
|
@patch.dict(os.environ, {"VLLM_ASCEND_MLA_PA": "1"})
|
|
def test_models_distributed_DeepSeek_W4A8DYNAMIC(model):
|
|
prompts = [
|
|
"Hello, my name is",
|
|
]
|
|
max_tokens = 5
|
|
with VllmRunner(
|
|
snapshot_download(model),
|
|
dtype="auto",
|
|
tensor_parallel_size=2,
|
|
quantization="ascend",
|
|
enforce_eager=True,
|
|
enable_expert_parallel=True,
|
|
additional_config={
|
|
"torchair_graph_config": {
|
|
"enabled": False,
|
|
},
|
|
"ascend_scheduler_config": {
|
|
"enabled": True,
|
|
}
|
|
},
|
|
) as vllm_model:
|
|
vllm_model.generate_greedy(prompts, max_tokens)
|
|
|
|
|
|
def test_sp_for_qwen3_moe() -> None:
|
|
example_prompts = [
|
|
"Hello, my name is",
|
|
]
|
|
sampling_params = SamplingParams(max_tokens=5,
|
|
temperature=0.0,
|
|
top_k=50,
|
|
top_p=0.9)
|
|
|
|
with VllmRunner(snapshot_download("Qwen/Qwen3-30B-A3B"),
|
|
dtype="auto",
|
|
tensor_parallel_size=2,
|
|
distributed_executor_backend="mp",
|
|
compilation_config={
|
|
"pass_config": {
|
|
"enable_sequence_parallelism": True
|
|
}
|
|
},
|
|
enable_expert_parallel=True,
|
|
enforce_eager=True) as vllm_model:
|
|
vllm_model.generate(example_prompts, sampling_params)
|
|
|
|
|
|
@pytest.mark.parametrize("model", QWEN_DENSE_MODELS)
|
|
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE": "1"})
|
|
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
|
|
def test_models_distributed_Qwen_Dense_with_flashcomm_v1(model):
|
|
example_prompts = [
|
|
"Hello, my name is",
|
|
]
|
|
max_tokens = 5
|
|
|
|
with VllmRunner(
|
|
snapshot_download(model),
|
|
max_model_len=8192,
|
|
enforce_eager=False,
|
|
dtype="auto",
|
|
tensor_parallel_size=2,
|
|
quantization="ascend",
|
|
) as vllm_model:
|
|
vllm_model.generate_greedy(example_prompts, max_tokens)
|
|
|
|
|
|
@pytest.mark.parametrize("model", QWEN_DENSE_MODELS)
|
|
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE": "1"})
|
|
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_PREFETCH_MLP": "1"})
|
|
def test_models_distributed_Qwen_Dense_with_prefetch_mlp_weight(model):
|
|
example_prompts = [
|
|
"Hello, my name is",
|
|
]
|
|
max_tokens = 5
|
|
|
|
with VllmRunner(
|
|
snapshot_download(model),
|
|
max_model_len=8192,
|
|
enforce_eager=False,
|
|
dtype="auto",
|
|
tensor_parallel_size=2,
|
|
quantization="ascend",
|
|
) as vllm_model:
|
|
vllm_model.generate_greedy(example_prompts, max_tokens)
|