Files

489 lines
20 KiB
Python
Raw Permalink Normal View History

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from types import MappingProxyType
from typing import Any, Callable, Dict, List, Mapping, Optional
import torch
from vllm.config import get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
RowParallelLinear)
from vllm.model_executor.layers.quantization import \
register_quantization_config
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.vocab_parallel_embedding import (
UnquantizedEmbeddingMethod, VocabParallelEmbedding)
from vllm.model_executor.parameter import PerTensorScaleParameter
from vllm.model_executor.utils import set_weight_attrs
[feat]: oproj tensor parallelism in pure DP and graph-mode scenarios. (#2167) ### What this PR does / why we need it? This PR introduces Oproj matrix tensor model parallel to achieve decreasing of memory consumption. It only support graph mode in pure DP scenario. In deepseek r1 w8a8 PD disagregated Decode instance, using pure DP, with oproj_tensor_parallel_size = 8, we have 1 ms TPOT increasing, saved 5.8 GB NPU memory per RANK. We got best performance when oproj_tensor_parallel_size=4 without TPOT increasing. performance data: <img width="1442" height="442" alt="image" src="https://github.com/user-attachments/assets/83270fc5-868a-4387-b0a9-fac29b4a376d" /> ### Does this PR introduce _any_ user-facing change? This PR introduces one new config in `additional_config`. | Name | Effect | Required | Type | Constraints | | :---------------------------- | :--------------------------------------- | :------- | :--- | :----------------- | | oproj_tensor_parallel_size | Split the o_proj matrix along the row dimension (head num * head dim) into oproj_tensor_parallel_size pieces. | No | int | default value is None, once this value is set, the feature will be enabled, head num * head dim must be divisible by this value. | example `--additional_config={"oproj_tensor_parallel_size": 8}` ### How was this patch tested? - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/eddaafc1c77b0690194cbd1b73747d572793838c --------- Signed-off-by: zzhx1 <zzh_201018@outlook.com> Co-authored-by: zzh <zzh_201018@outlook.com>
2025-09-07 10:31:32 +08:00
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
get_otp_group)
from vllm_ascend.ops.common_fused_moe import AscendUnquantizedFusedMoEMethod
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
[feat]: oproj tensor parallelism in pure DP and graph-mode scenarios. (#2167) ### What this PR does / why we need it? This PR introduces Oproj matrix tensor model parallel to achieve decreasing of memory consumption. It only support graph mode in pure DP scenario. In deepseek r1 w8a8 PD disagregated Decode instance, using pure DP, with oproj_tensor_parallel_size = 8, we have 1 ms TPOT increasing, saved 5.8 GB NPU memory per RANK. We got best performance when oproj_tensor_parallel_size=4 without TPOT increasing. performance data: <img width="1442" height="442" alt="image" src="https://github.com/user-attachments/assets/83270fc5-868a-4387-b0a9-fac29b4a376d" /> ### Does this PR introduce _any_ user-facing change? This PR introduces one new config in `additional_config`. | Name | Effect | Required | Type | Constraints | | :---------------------------- | :--------------------------------------- | :------- | :--- | :----------------- | | oproj_tensor_parallel_size | Split the o_proj matrix along the row dimension (head num * head dim) into oproj_tensor_parallel_size pieces. | No | int | default value is None, once this value is set, the feature will be enabled, head num * head dim must be divisible by this value. | example `--additional_config={"oproj_tensor_parallel_size": 8}` ### How was this patch tested? - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/eddaafc1c77b0690194cbd1b73747d572793838c --------- Signed-off-by: zzhx1 <zzh_201018@outlook.com> Co-authored-by: zzh <zzh_201018@outlook.com>
2025-09-07 10:31:32 +08:00
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, mlp_tp_enable,
oproj_tp_enable)
from .utils import get_quant_method
@register_quantization_config(ASCEND_QUANTIZATION_METHOD)
class AscendQuantConfig(QuantizationConfig):
"""Config class for Ascend
[main][quantization] Adapt to the new format of ds w4a8 weight (#2392) ### What this PR does / why we need it? The deepseek w4a8 weights we supported before were in mindie-format format. It uses int8 to represent int4, so the weight size is similar to w8a8, and we need to do a few extra steps to make vllm-ascend load it normally. Now we can directly use the new weight format, which uses two int4 packs to save the weight, the weight size is reduced, and there is no need to do many extra operations to directly use it on vllm-ascend, but we are also compatible with the weights of the previous mindie format. The weight changes in the new version: 1. The weight is packed (2 int4 pack to int8) 2. The bias required in the apply method is directly generated by modelslim ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? Adding ut case in `tests/ut/quantization/test_w4a8_dynamic.py` #### 1.How to get weights using Modelslim ##### Installation steps we can use the branch br_release_MindStudio_8.1.RC2_TR5_20260624 git clone -b br_release_MindStudio_8.1.RC2_TR5_20260624 https://gitee.com/ascend/msit.git cd msit/msmodelslim bash install.sh ##### Generate w4a8 weights cd /example/DeepSeek Command reference: msmodelslim/example/DeepSeek/README.md Execute the [pre-check](https://gitee.com/ascend/msit/blob/br_release_MindStudio_8.1.RC2_TR5_20260624/msmodelslim/example/DeepSeek/README.md#%E8%BF%90%E8%A1%8C%E5%89%8D%E5%BF%85%E6%A3%80) and [DeepSeek-R1 w4a8 mix quantization](https://gitee.com/ascend/msit/blob/br_release_MindStudio_8.1.RC2_TR5_20260624/msmodelslim/example/DeepSeek/README.md#deepseek-r1-w4a8-%E6%B7%B7%E5%90%88%E9%87%8F%E5%8C%96%E5%89%8D%E4%B8%89%E5%B1%82-mlpw8a8-dynamic-%E9%87%8F%E5%8C%96mla%E5%85%B1%E4%BA%AB%E4%B8%93%E5%AE%B6w8a8%E9%87%8F%E5%8C%96%E8%B7%AF%E7%94%B1%E4%B8%93%E5%AE%B6w4a8-dynamic%E9%87%8F%E5%8C%96) chapter Reference command:python3 quant_deepseek_w4a8.py --model_path {Original weight path} --save_path {Generate weight path} ##### Adapt to vllm-ascend Modification in `config.json`:`"model_type":deepseekv2` is changed to `"model_type":deepseek_v3`; #### 2.How to run w4a8 ##### a.How to run eager mode export VLLM_ASCEND_MLA_PA=1 python -m vllm.entrypoints.openai.api_server --model=$1 --trust-remote-code -tp $2 -dp $3 --enable_expert_parallel --quantization ascend --port $4 --max-model-len $5 --max-num-seqs $6 --enforce-eager eg: python -m vllm.entrypoints.openai.api_server --model=/weightpath/w4a8_4_layer --trust-remote-code -tp 4 -dp 4 --enable_expert_parallel --quantization ascend --port 8002 --max-model-len 5120 --max-num-seqs 128 --enforce-eager ##### b.How to run graph mode export HCCL_BUFFSIZE=1024 python -m vllm.entrypoints.openai.api_server --model=$1 --trust-remote-code -tp $2 -dp $3 --enable_expert_parallel --quantization ascend --port $4 --max-model-len $5 --additional_config='{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":true}}' eg: python -m vllm.entrypoints.openai.api_server --model=/weight/dsr1_w4a8_vllm --trust-remote-code -tp 4 -dp 4 --enable_expert_parallel --quantization ascend --port 8002 --max-model-len 5120 --additional_config='{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":true}}' - vLLM version: v0.10.0 - vLLM main: https://github.com/vllm-project/vllm/commit/103f1ec8d348a5f336f11d972d6285c4fb4736d4 --------- Signed-off-by: Wang Kunpeng <1289706727@qq.com>
2025-08-20 20:25:18 +08:00
This class is a general class that parse quantization configs
that are supported on ascend hardware.
"""
def __init__(self, quant_config: Dict[str, Any]):
super().__init__()
self.quant_description = quant_config
# TODO(whx): remove this adaptation after adding "shared_head"
# to prefix of DeepSeekShareHead in vLLM.
extra_quant_dict = {}
for k in self.quant_description.keys():
if "shared_head" in k:
new_k = k.replace(".shared_head.", ".")
extra_quant_dict[new_k] = self.quant_description[k]
self.quant_description.update(extra_quant_dict)
def __repr__(self) -> str:
return "AscendQuantConfig:\n" + super().__repr__()
@classmethod
def get_name(cls) -> str:
return ASCEND_QUANTIZATION_METHOD
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.int8, torch.float16, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
raise NotImplementedError(
"Ascend hardware dose not support \"get_min_capability\" feature.")
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quant_model_description.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "AscendQuantConfig":
return cls(config)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
if torch.npu.is_available():
return ASCEND_QUANTIZATION_METHOD
return None
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
vllm_config = get_current_vllm_config()
model_type = vllm_config.model_config.hf_config.model_type
if model_type in packed_modules_model_mapping:
self.packed_modules_mapping = packed_modules_model_mapping[
model_type]
from vllm.attention.layer import Attention
if prefix.startswith("language_model"):
prefix = prefix.split('.', 1)[-1]
if isinstance(layer, LinearBase):
if self.is_layer_skipped_ascend(prefix,
self.packed_modules_mapping):
return AscendUnquantizedLinearMethod()
return AscendLinearMethod(self, prefix,
self.packed_modules_mapping)
elif isinstance(layer, Attention) and \
'fa_quant_type' in self.quant_description.keys() and \
self.quant_description['fa_quant_type'] is not None:
return AscendKVCacheMethod(self, prefix)
elif isinstance(layer, Attention) and self.quant_description.get(
'kv_quant_type') == 'C8':
return AscendKVCacheMethod(self, prefix)
elif isinstance(layer, FusedMoE):
if self.is_layer_skipped_ascend(prefix,
self.packed_modules_mapping):
return AscendUnquantizedFusedMoEMethod(layer.moe_config)
return AscendFusedMoEMethod(self, prefix,
self.packed_modules_mapping)
elif isinstance(layer, VocabParallelEmbedding):
if self.is_layer_skipped_ascend(prefix,
self.packed_modules_mapping):
return UnquantizedEmbeddingMethod()
return AscendEmbeddingMethod(self, prefix,
self.packed_modules_mapping)
return None
def is_layer_skipped_ascend(
self,
prefix: str,
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})):
# adapted from vllm.model_executor.layers.quantization.utils.quant_utils.is_layer_skipped
proj_name = prefix.split(".")[-1]
if proj_name in fused_mapping:
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in fused_mapping[proj_name]
]
is_skipped = None
for shard_prefix in shard_prefixes:
is_shard_skipped = self.quant_description[shard_prefix +
'.weight'] == "FLOAT"
if is_skipped is None:
is_skipped = is_shard_skipped
elif is_shard_skipped != is_skipped:
raise ValueError(
f"Detected some but not all shards of {prefix} "
"are quantized. All shards of fused layers "
"to have the same precision.")
elif "experts" in prefix:
# For the experts' prefix (e.g., "model.layers.3.mlp.experts")
# Assume all experts within the same MLP use the same quantization method
experts_quant_description = [
self.quant_description[layer]
for layer in self.quant_description if prefix in layer
]
is_skipped = any(quantization == "FLOAT"
for quantization in experts_quant_description)
else:
is_skipped = self.quant_description[prefix + '.weight'] == "FLOAT"
assert is_skipped is not None
return is_skipped
def get_scaled_act_names(self) -> List[str]:
return []
packed_modules_model_mapping = {
"qwen3_moe": {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
},
"deepseek_v2": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
},
"deepseek_v3": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
},
"kimi_k2": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
},
"deepseek_v32": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
},
# NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized;
# NOTE 2.The description file generated by the current msmodelslim tool does not have
# MTP layer info. Please manually add it and set the value to FLOAT.
"deepseek_mtp": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
},
"qwen3_next": {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": ["gate_proj", "up_proj"],
"in_proj": ["in_proj_qkvz", "in_proj_ba"],
},
"qwen2_5_vl": {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
},
"glm4_moe": {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
},
}
class AscendLinearMethod(LinearMethodBase):
"""Linear method for Ascend quantization.
Args:
quant_config: The Ascend quantization config.
"""
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
packed_modules_mapping: Dict[str, Any]) -> None:
self.quant_method = get_quant_method(quant_config.quant_description,
prefix, "linear",
packed_modules_mapping)
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
weight_dict = self.quant_method.get_weight(input_size_per_partition,
output_size_per_partition,
params_dtype)
[Feat][quantization] Support new version w4a8 dynamic quantization for Linear layers (#3311) ### What this PR does / why we need it? **Problem Description:** The existing implementation for the w4a8-dynamic linear method only supports the old quantization format from msmodelslim. When attempting to load models quantized with the new version, vLLM encounters errors due to mismatched tensor shapes and unprocessed quantization parameters. Relavant issues: - https://github.com/vllm-project/vllm-ascend/issues/3192 - https://github.com/vllm-project/vllm-ascend/issues/3152 **Proposed Changes:** 1. Add support for w4a8 dynamic(new format) in AscendW4A8DynamicLinearMethod and TorchairAscendW4A8DynamicLinearMethod 2. Add unit tests and e2e tests for w4a8 dynamic new and old format models <details> <summary><b>details</b></summary> 1. **Support for new w4a8-dynamic format:** * Detects quantization format by reading the "version" field in quant_description to ensure backward compatibility. * Handles the new pre-packed weight format (`2x int4` in an `int8`), which has a halved dimension. It tells the vLLM loader how to unpack it using `_packed_dim` and `_packed_factor`. * Supports the new `scale_bias` parameter, setting its shape based on the layer type, as required by msmodelslim. For api consistency and future use, the `layer_type` parameter was also added to other quantization methods. * Updates the weight processing logic: new format weights are handled with `.view(torch.int32)` since they're pre-packed, while old ones are processed with `npu_convert_weight_to_int4pack`. 2. **New unit and E2E tests:** * Added unit tests that verify the logic for both the old and new formats. * Split the distributed E2E test to confirm that both old and new format models work correctly. </details> Theoretically, these changes will provide support for all common new version w4a8(dynamic) models from msmodelslim. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? I implement relevant unit tests and e2e tests and test the changes with following commands: ```bash # unit tests python -m pytest tests/ut/quantization/test_w4a8_dynamic.py tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py -v # e2e tests pytest tests/e2e/singlecard/test_quantization.py -v -s pytest tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC_new_version -v -s pytest tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC_old_version -v -s pytest tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC -v -s ``` I also tested Hunyuan-1.8B-Instruct quantized with the new w4a8-dynamic format: ``` vllm serve ./models/Hunyuan-1.8B-Instruct-quantized --gpu-memory-utilization 0.96 --quantization ascend --max-model-len 9600 --seed 0 --max-num-batched-tokens 16384 ``` All tests mentioned passed locally. **NOTE: I use quantization model from my own repo in test_offline_inference_distributed.py**. Here is the description: [Anionex/Qwen3-1.7B-W4A8-V1](https://modelscope.cn/models/Anionex/Qwen3-1.7B-W4A8-V1/summary) (including quantization steps).This should be replaced by a model in vllm-ascend ci modelscope repo. Thanks for reading! - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: Anionex <1005128408@qq.com>
2025-10-21 20:18:39 +08:00
# Extract packing information (if present)
packed_dim = weight_dict.pop("_packed_dim", None)
packed_factor = weight_dict.pop("_packed_factor", None)
for weight_name, weight_param in weight_dict.items():
param = torch.nn.Parameter(weight_param, requires_grad=False)
set_weight_attrs(param, {"input_dim": 1, "output_dim": 0})
[Feat][quantization] Support new version w4a8 dynamic quantization for Linear layers (#3311) ### What this PR does / why we need it? **Problem Description:** The existing implementation for the w4a8-dynamic linear method only supports the old quantization format from msmodelslim. When attempting to load models quantized with the new version, vLLM encounters errors due to mismatched tensor shapes and unprocessed quantization parameters. Relavant issues: - https://github.com/vllm-project/vllm-ascend/issues/3192 - https://github.com/vllm-project/vllm-ascend/issues/3152 **Proposed Changes:** 1. Add support for w4a8 dynamic(new format) in AscendW4A8DynamicLinearMethod and TorchairAscendW4A8DynamicLinearMethod 2. Add unit tests and e2e tests for w4a8 dynamic new and old format models <details> <summary><b>details</b></summary> 1. **Support for new w4a8-dynamic format:** * Detects quantization format by reading the "version" field in quant_description to ensure backward compatibility. * Handles the new pre-packed weight format (`2x int4` in an `int8`), which has a halved dimension. It tells the vLLM loader how to unpack it using `_packed_dim` and `_packed_factor`. * Supports the new `scale_bias` parameter, setting its shape based on the layer type, as required by msmodelslim. For api consistency and future use, the `layer_type` parameter was also added to other quantization methods. * Updates the weight processing logic: new format weights are handled with `.view(torch.int32)` since they're pre-packed, while old ones are processed with `npu_convert_weight_to_int4pack`. 2. **New unit and E2E tests:** * Added unit tests that verify the logic for both the old and new formats. * Split the distributed E2E test to confirm that both old and new format models work correctly. </details> Theoretically, these changes will provide support for all common new version w4a8(dynamic) models from msmodelslim. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? I implement relevant unit tests and e2e tests and test the changes with following commands: ```bash # unit tests python -m pytest tests/ut/quantization/test_w4a8_dynamic.py tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py -v # e2e tests pytest tests/e2e/singlecard/test_quantization.py -v -s pytest tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC_new_version -v -s pytest tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC_old_version -v -s pytest tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC -v -s ``` I also tested Hunyuan-1.8B-Instruct quantized with the new w4a8-dynamic format: ``` vllm serve ./models/Hunyuan-1.8B-Instruct-quantized --gpu-memory-utilization 0.96 --quantization ascend --max-model-len 9600 --seed 0 --max-num-batched-tokens 16384 ``` All tests mentioned passed locally. **NOTE: I use quantization model from my own repo in test_offline_inference_distributed.py**. Here is the description: [Anionex/Qwen3-1.7B-W4A8-V1](https://modelscope.cn/models/Anionex/Qwen3-1.7B-W4A8-V1/summary) (including quantization steps).This should be replaced by a model in vllm-ascend ci modelscope repo. Thanks for reading! - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: Anionex <1005128408@qq.com>
2025-10-21 20:18:39 +08:00
# Set packing attributes if the weight is packed
if packed_dim is not None and packed_factor is not None:
set_weight_attrs(param, {
"packed_dim": packed_dim,
"packed_factor": packed_factor
})
layer.register_parameter(weight_name, param)
set_weight_attrs(param, extra_weight_attrs)
pertensor_dict = self.quant_method.get_pertensor_param(params_dtype)
for pertensor_name, pertensor_param in pertensor_dict.items():
param = PerTensorScaleParameter(data=pertensor_param,
weight_loader=weight_loader)
# disable warning
param.ignore_warning = True
layer.register_parameter(pertensor_name, param)
param.weight_loader = extra_weight_attrs.get("weight_loader")
perchannel_dict = self.quant_method.get_perchannel_param(
output_size_per_partition, params_dtype)
for perchannel_name, perchannel_param in perchannel_dict.items():
param = torch.nn.Parameter(perchannel_param, requires_grad=False)
set_weight_attrs(param, {"output_dim": 0})
layer.register_parameter(perchannel_name, param)
set_weight_attrs(param, extra_weight_attrs)
[Feat][quantization] Support new version w4a8 dynamic quantization for Linear layers (#3311) ### What this PR does / why we need it? **Problem Description:** The existing implementation for the w4a8-dynamic linear method only supports the old quantization format from msmodelslim. When attempting to load models quantized with the new version, vLLM encounters errors due to mismatched tensor shapes and unprocessed quantization parameters. Relavant issues: - https://github.com/vllm-project/vllm-ascend/issues/3192 - https://github.com/vllm-project/vllm-ascend/issues/3152 **Proposed Changes:** 1. Add support for w4a8 dynamic(new format) in AscendW4A8DynamicLinearMethod and TorchairAscendW4A8DynamicLinearMethod 2. Add unit tests and e2e tests for w4a8 dynamic new and old format models <details> <summary><b>details</b></summary> 1. **Support for new w4a8-dynamic format:** * Detects quantization format by reading the "version" field in quant_description to ensure backward compatibility. * Handles the new pre-packed weight format (`2x int4` in an `int8`), which has a halved dimension. It tells the vLLM loader how to unpack it using `_packed_dim` and `_packed_factor`. * Supports the new `scale_bias` parameter, setting its shape based on the layer type, as required by msmodelslim. For api consistency and future use, the `layer_type` parameter was also added to other quantization methods. * Updates the weight processing logic: new format weights are handled with `.view(torch.int32)` since they're pre-packed, while old ones are processed with `npu_convert_weight_to_int4pack`. 2. **New unit and E2E tests:** * Added unit tests that verify the logic for both the old and new formats. * Split the distributed E2E test to confirm that both old and new format models work correctly. </details> Theoretically, these changes will provide support for all common new version w4a8(dynamic) models from msmodelslim. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? I implement relevant unit tests and e2e tests and test the changes with following commands: ```bash # unit tests python -m pytest tests/ut/quantization/test_w4a8_dynamic.py tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py -v # e2e tests pytest tests/e2e/singlecard/test_quantization.py -v -s pytest tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC_new_version -v -s pytest tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC_old_version -v -s pytest tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC -v -s ``` I also tested Hunyuan-1.8B-Instruct quantized with the new w4a8-dynamic format: ``` vllm serve ./models/Hunyuan-1.8B-Instruct-quantized --gpu-memory-utilization 0.96 --quantization ascend --max-model-len 9600 --seed 0 --max-num-batched-tokens 16384 ``` All tests mentioned passed locally. **NOTE: I use quantization model from my own repo in test_offline_inference_distributed.py**. Here is the description: [Anionex/Qwen3-1.7B-W4A8-V1](https://modelscope.cn/models/Anionex/Qwen3-1.7B-W4A8-V1/summary) (including quantization steps).This should be replaced by a model in vllm-ascend ci modelscope repo. Thanks for reading! - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: Anionex <1005128408@qq.com>
2025-10-21 20:18:39 +08:00
# NOTE: In w4a8 quantization implementation,
# for down_proj and o_proj scale_bias shape is [output_size, 16],
# others are [output_size, 1]
layer_type = "row" if isinstance(layer,
RowParallelLinear) else "others"
pergroup_dict = self.quant_method.get_pergroup_param(
[Feat][quantization] Support new version w4a8 dynamic quantization for Linear layers (#3311) ### What this PR does / why we need it? **Problem Description:** The existing implementation for the w4a8-dynamic linear method only supports the old quantization format from msmodelslim. When attempting to load models quantized with the new version, vLLM encounters errors due to mismatched tensor shapes and unprocessed quantization parameters. Relavant issues: - https://github.com/vllm-project/vllm-ascend/issues/3192 - https://github.com/vllm-project/vllm-ascend/issues/3152 **Proposed Changes:** 1. Add support for w4a8 dynamic(new format) in AscendW4A8DynamicLinearMethod and TorchairAscendW4A8DynamicLinearMethod 2. Add unit tests and e2e tests for w4a8 dynamic new and old format models <details> <summary><b>details</b></summary> 1. **Support for new w4a8-dynamic format:** * Detects quantization format by reading the "version" field in quant_description to ensure backward compatibility. * Handles the new pre-packed weight format (`2x int4` in an `int8`), which has a halved dimension. It tells the vLLM loader how to unpack it using `_packed_dim` and `_packed_factor`. * Supports the new `scale_bias` parameter, setting its shape based on the layer type, as required by msmodelslim. For api consistency and future use, the `layer_type` parameter was also added to other quantization methods. * Updates the weight processing logic: new format weights are handled with `.view(torch.int32)` since they're pre-packed, while old ones are processed with `npu_convert_weight_to_int4pack`. 2. **New unit and E2E tests:** * Added unit tests that verify the logic for both the old and new formats. * Split the distributed E2E test to confirm that both old and new format models work correctly. </details> Theoretically, these changes will provide support for all common new version w4a8(dynamic) models from msmodelslim. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? I implement relevant unit tests and e2e tests and test the changes with following commands: ```bash # unit tests python -m pytest tests/ut/quantization/test_w4a8_dynamic.py tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py -v # e2e tests pytest tests/e2e/singlecard/test_quantization.py -v -s pytest tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC_new_version -v -s pytest tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC_old_version -v -s pytest tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC -v -s ``` I also tested Hunyuan-1.8B-Instruct quantized with the new w4a8-dynamic format: ``` vllm serve ./models/Hunyuan-1.8B-Instruct-quantized --gpu-memory-utilization 0.96 --quantization ascend --max-model-len 9600 --seed 0 --max-num-batched-tokens 16384 ``` All tests mentioned passed locally. **NOTE: I use quantization model from my own repo in test_offline_inference_distributed.py**. Here is the description: [Anionex/Qwen3-1.7B-W4A8-V1](https://modelscope.cn/models/Anionex/Qwen3-1.7B-W4A8-V1/summary) (including quantization steps).This should be replaced by a model in vllm-ascend ci modelscope repo. Thanks for reading! - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: Anionex <1005128408@qq.com>
2025-10-21 20:18:39 +08:00
input_size_per_partition,
output_size_per_partition,
params_dtype,
layer_type=layer_type)
for pergroup_name, pergroup_param in pergroup_dict.items():
param = torch.nn.Parameter(pergroup_param, requires_grad=False)
set_weight_attrs(param, {"output_dim": 0})
layer.register_parameter(pergroup_name, param)
set_weight_attrs(param, extra_weight_attrs)
if "weight_scale_second" in pergroup_name or "weight_offset_second" in pergroup_name:
setattr(param, "input_dim", 1)
param.input_dim = 1
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method, "process_weights_after_loading"):
self.quant_method.process_weights_after_loading(layer)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if isinstance(layer, RowParallelLinear):
[feat]: oproj tensor parallelism in pure DP and graph-mode scenarios. (#2167) ### What this PR does / why we need it? This PR introduces Oproj matrix tensor model parallel to achieve decreasing of memory consumption. It only support graph mode in pure DP scenario. In deepseek r1 w8a8 PD disagregated Decode instance, using pure DP, with oproj_tensor_parallel_size = 8, we have 1 ms TPOT increasing, saved 5.8 GB NPU memory per RANK. We got best performance when oproj_tensor_parallel_size=4 without TPOT increasing. performance data: <img width="1442" height="442" alt="image" src="https://github.com/user-attachments/assets/83270fc5-868a-4387-b0a9-fac29b4a376d" /> ### Does this PR introduce _any_ user-facing change? This PR introduces one new config in `additional_config`. | Name | Effect | Required | Type | Constraints | | :---------------------------- | :--------------------------------------- | :------- | :--- | :----------------- | | oproj_tensor_parallel_size | Split the o_proj matrix along the row dimension (head num * head dim) into oproj_tensor_parallel_size pieces. | No | int | default value is None, once this value is set, the feature will be enabled, head num * head dim must be divisible by this value. | example `--additional_config={"oproj_tensor_parallel_size": 8}` ### How was this patch tested? - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/eddaafc1c77b0690194cbd1b73747d572793838c --------- Signed-off-by: zzhx1 <zzh_201018@outlook.com> Co-authored-by: zzh <zzh_201018@outlook.com>
2025-09-07 10:31:32 +08:00
if layer.prefix.find("o_proj") != -1 and oproj_tp_enable():
tp_rank = get_otp_group().rank_in_group
elif layer.prefix.find("down_proj") != -1 and mlp_tp_enable():
tp_rank = get_mlp_tp_group().rank_in_group
else:
tp_rank = get_tensor_model_parallel_rank()
else:
tp_rank = 0
return self.quant_method.apply(layer, x, bias, tp_rank)
class AscendKVCacheMethod(BaseKVCacheMethod):
"""KVCache method for Ascend quantization.
Args:
quant_config: The Ascend quantization config.
"""
def __init__(self, quant_config: AscendQuantConfig, prefix: str) -> None:
self.quant_method = get_quant_method(quant_config.quant_description,
prefix, "attention")
def create_weights(self, layer: torch.nn.Module) -> None:
# Different from linear method, there are no weight processing/slicing
# steps for attention in vllm. So the whole process of create weights
# is hidden into the specific quant method.
self.quant_method.create_weights(layer)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method, "process_weights_after_loading"):
self.quant_method.process_weights_after_loading(layer)
def apply(self, layer: torch.nn.Module, query: torch.Tensor,
key: torch.Tensor, value: torch.Tensor, kv_cache, attn_metadata,
attn_type, scale, output) -> torch.Tensor:
return self.quant_method.apply(layer, query, key, value, kv_cache,
attn_metadata, attn_type, scale, output)
class AscendFusedMoEMethod(FusedMoEMethodBase):
"""FusedMoE method for Ascend quantization.
Args:
quant_config: The Ascend quantization config.
"""
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
packed_modules_mapping: Dict[str, Any]):
self.quant_method = get_quant_method(quant_config.quant_description,
prefix, "moe",
packed_modules_mapping)
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
weight_param = self.quant_method.get_weight(
num_experts, intermediate_size_per_partition, hidden_size,
params_dtype)
for param_key, param_value in weight_param.items():
param = torch.nn.Parameter(param_value, requires_grad=False)
layer.register_parameter(param_key, param)
set_weight_attrs(param, extra_weight_attrs)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
[main][quantization] Adapt to the new format of ds w4a8 weight (#2392) ### What this PR does / why we need it? The deepseek w4a8 weights we supported before were in mindie-format format. It uses int8 to represent int4, so the weight size is similar to w8a8, and we need to do a few extra steps to make vllm-ascend load it normally. Now we can directly use the new weight format, which uses two int4 packs to save the weight, the weight size is reduced, and there is no need to do many extra operations to directly use it on vllm-ascend, but we are also compatible with the weights of the previous mindie format. The weight changes in the new version: 1. The weight is packed (2 int4 pack to int8) 2. The bias required in the apply method is directly generated by modelslim ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? Adding ut case in `tests/ut/quantization/test_w4a8_dynamic.py` #### 1.How to get weights using Modelslim ##### Installation steps we can use the branch br_release_MindStudio_8.1.RC2_TR5_20260624 git clone -b br_release_MindStudio_8.1.RC2_TR5_20260624 https://gitee.com/ascend/msit.git cd msit/msmodelslim bash install.sh ##### Generate w4a8 weights cd /example/DeepSeek Command reference: msmodelslim/example/DeepSeek/README.md Execute the [pre-check](https://gitee.com/ascend/msit/blob/br_release_MindStudio_8.1.RC2_TR5_20260624/msmodelslim/example/DeepSeek/README.md#%E8%BF%90%E8%A1%8C%E5%89%8D%E5%BF%85%E6%A3%80) and [DeepSeek-R1 w4a8 mix quantization](https://gitee.com/ascend/msit/blob/br_release_MindStudio_8.1.RC2_TR5_20260624/msmodelslim/example/DeepSeek/README.md#deepseek-r1-w4a8-%E6%B7%B7%E5%90%88%E9%87%8F%E5%8C%96%E5%89%8D%E4%B8%89%E5%B1%82-mlpw8a8-dynamic-%E9%87%8F%E5%8C%96mla%E5%85%B1%E4%BA%AB%E4%B8%93%E5%AE%B6w8a8%E9%87%8F%E5%8C%96%E8%B7%AF%E7%94%B1%E4%B8%93%E5%AE%B6w4a8-dynamic%E9%87%8F%E5%8C%96) chapter Reference command:python3 quant_deepseek_w4a8.py --model_path {Original weight path} --save_path {Generate weight path} ##### Adapt to vllm-ascend Modification in `config.json`:`"model_type":deepseekv2` is changed to `"model_type":deepseek_v3`; #### 2.How to run w4a8 ##### a.How to run eager mode export VLLM_ASCEND_MLA_PA=1 python -m vllm.entrypoints.openai.api_server --model=$1 --trust-remote-code -tp $2 -dp $3 --enable_expert_parallel --quantization ascend --port $4 --max-model-len $5 --max-num-seqs $6 --enforce-eager eg: python -m vllm.entrypoints.openai.api_server --model=/weightpath/w4a8_4_layer --trust-remote-code -tp 4 -dp 4 --enable_expert_parallel --quantization ascend --port 8002 --max-model-len 5120 --max-num-seqs 128 --enforce-eager ##### b.How to run graph mode export HCCL_BUFFSIZE=1024 python -m vllm.entrypoints.openai.api_server --model=$1 --trust-remote-code -tp $2 -dp $3 --enable_expert_parallel --quantization ascend --port $4 --max-model-len $5 --additional_config='{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":true}}' eg: python -m vllm.entrypoints.openai.api_server --model=/weight/dsr1_w4a8_vllm --trust-remote-code -tp 4 -dp 4 --enable_expert_parallel --quantization ascend --port 8002 --max-model-len 5120 --additional_config='{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":true}}' - vLLM version: v0.10.0 - vLLM main: https://github.com/vllm-project/vllm/commit/103f1ec8d348a5f336f11d972d6285c4fb4736d4 --------- Signed-off-by: Wang Kunpeng <1289706727@qq.com>
2025-08-20 20:25:18 +08:00
per_group_param = [
"weight_scale_second", "weight_offset_second", "scale_bias"
]
dynamic_quant_param = self.quant_method.get_dynamic_quant_param(
num_experts, intermediate_size_per_partition, hidden_size,
params_dtype)
for param_key, param_value in dynamic_quant_param.items():
param = torch.nn.Parameter(param_value, requires_grad=False)
layer.register_parameter(param_key, param)
set_weight_attrs(param, extra_weight_attrs)
[main][quantization] Adapt to the new format of ds w4a8 weight (#2392) ### What this PR does / why we need it? The deepseek w4a8 weights we supported before were in mindie-format format. It uses int8 to represent int4, so the weight size is similar to w8a8, and we need to do a few extra steps to make vllm-ascend load it normally. Now we can directly use the new weight format, which uses two int4 packs to save the weight, the weight size is reduced, and there is no need to do many extra operations to directly use it on vllm-ascend, but we are also compatible with the weights of the previous mindie format. The weight changes in the new version: 1. The weight is packed (2 int4 pack to int8) 2. The bias required in the apply method is directly generated by modelslim ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? Adding ut case in `tests/ut/quantization/test_w4a8_dynamic.py` #### 1.How to get weights using Modelslim ##### Installation steps we can use the branch br_release_MindStudio_8.1.RC2_TR5_20260624 git clone -b br_release_MindStudio_8.1.RC2_TR5_20260624 https://gitee.com/ascend/msit.git cd msit/msmodelslim bash install.sh ##### Generate w4a8 weights cd /example/DeepSeek Command reference: msmodelslim/example/DeepSeek/README.md Execute the [pre-check](https://gitee.com/ascend/msit/blob/br_release_MindStudio_8.1.RC2_TR5_20260624/msmodelslim/example/DeepSeek/README.md#%E8%BF%90%E8%A1%8C%E5%89%8D%E5%BF%85%E6%A3%80) and [DeepSeek-R1 w4a8 mix quantization](https://gitee.com/ascend/msit/blob/br_release_MindStudio_8.1.RC2_TR5_20260624/msmodelslim/example/DeepSeek/README.md#deepseek-r1-w4a8-%E6%B7%B7%E5%90%88%E9%87%8F%E5%8C%96%E5%89%8D%E4%B8%89%E5%B1%82-mlpw8a8-dynamic-%E9%87%8F%E5%8C%96mla%E5%85%B1%E4%BA%AB%E4%B8%93%E5%AE%B6w8a8%E9%87%8F%E5%8C%96%E8%B7%AF%E7%94%B1%E4%B8%93%E5%AE%B6w4a8-dynamic%E9%87%8F%E5%8C%96) chapter Reference command:python3 quant_deepseek_w4a8.py --model_path {Original weight path} --save_path {Generate weight path} ##### Adapt to vllm-ascend Modification in `config.json`:`"model_type":deepseekv2` is changed to `"model_type":deepseek_v3`; #### 2.How to run w4a8 ##### a.How to run eager mode export VLLM_ASCEND_MLA_PA=1 python -m vllm.entrypoints.openai.api_server --model=$1 --trust-remote-code -tp $2 -dp $3 --enable_expert_parallel --quantization ascend --port $4 --max-model-len $5 --max-num-seqs $6 --enforce-eager eg: python -m vllm.entrypoints.openai.api_server --model=/weightpath/w4a8_4_layer --trust-remote-code -tp 4 -dp 4 --enable_expert_parallel --quantization ascend --port 8002 --max-model-len 5120 --max-num-seqs 128 --enforce-eager ##### b.How to run graph mode export HCCL_BUFFSIZE=1024 python -m vllm.entrypoints.openai.api_server --model=$1 --trust-remote-code -tp $2 -dp $3 --enable_expert_parallel --quantization ascend --port $4 --max-model-len $5 --additional_config='{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":true}}' eg: python -m vllm.entrypoints.openai.api_server --model=/weight/dsr1_w4a8_vllm --trust-remote-code -tp 4 -dp 4 --enable_expert_parallel --quantization ascend --port 8002 --max-model-len 5120 --additional_config='{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":true}}' - vLLM version: v0.10.0 - vLLM main: https://github.com/vllm-project/vllm/commit/103f1ec8d348a5f336f11d972d6285c4fb4736d4 --------- Signed-off-by: Wang Kunpeng <1289706727@qq.com>
2025-08-20 20:25:18 +08:00
if any(fields in param_key for fields in per_group_param):
[main][Feature] Support deepseek w4a8 quantization (#2172) ### What this PR does / why we need it? Supports Deepseek-R1 w4a8 quantization. Since R1 w4a8 uses mixed quantization, only the MOE layer uses w4a8_dynamic quantization, so we added the w4a8_dynamic.py file, which includes the AscendW4A8DynamicFusedMoEMethod class. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? Adding ut case in `tests/ut/quantization/test_w4a8_dynamic.py` and `tests/ut/quantization/test_quantizer.py` Adding e2e case in `tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC` to test deepseek w4a8_dynamic quantized model #### 1.How to get weights using Modelslim ##### Installation steps Use the branch master, the commit id is: 298e175d69b3b855111a1e09bbe2fcd12fdb4e24 git clone https://gitee.com/ascend/msit.git cd msit/msmodelslim bash install.sh ##### The required transformers environment transformers>=4.48.2 ##### Generate w4a8 weights cd /example/DeepSeek Command reference: msmodelslim/example/DeepSeek/README.md Execute the [pre-check](https://gitee.com/ascend/msit/blob/master/msmodelslim/example/DeepSeek/README.md#%E8%BF%90%E8%A1%8C%E5%89%8D%E5%BF%85%E6%A3%80) and [DeepSeek-R1 w4a8 mix quantization](https://gitee.com/ascend/msit/blob/master/msmodelslim/example/DeepSeek/README.md#deepseek-r1-w4a8-%E6%B7%B7%E5%90%88%E9%87%8F%E5%8C%96%E5%89%8D%E4%B8%89%E5%B1%82-mlpw8a8-dynamic-%E9%87%8F%E5%8C%96mla%E5%85%B1%E4%BA%AB%E4%B8%93%E5%AE%B6w8a8%E9%87%8F%E5%8C%96%E8%B7%AF%E7%94%B1%E4%B8%93%E5%AE%B6w4a8-dynamic%E9%87%8F%E5%8C%96) chapter Reference command:python3 quant_deepseek_w4a8.py --model_path {Original weight path} --save_path {Generate weight path} --mindie_format ##### Adapt to vllm-ascend Since mindie_format generates mindie format, some adaptation modifications are needed for vllm-ascend to use it: `quant_model_description_w8a8_dynamic.json` rename to `quant_model_description.json`, and add `"group_size": 256` Modification in `config.json`:`"model_type":deepseekv2` is changed to `"model_type":deepseek_v3`; `quantization_config` is removed; tips:The group_size and weights match. If the w4a8 weights are not generated using msmodelslim, you can check the group_size in quantization_config in config.json. #### 2.How to run w4a8 ##### a.How to run eager mode export VLLM_USE_V1=1 # v1 python -m vllm.entrypoints.openai.api_server --model=$1 --trust-remote-code -tp $2 -dp $3 --enable_expert_parallel --quantization ascend --port $4 --max-model-len $5 --max-num-seqs $6 --enforce-eager eg: python -m vllm.entrypoints.openai.api_server --model=/weightpath/w4a8_4_layer --trust-remote-code -tp 4 -dp 4 --enable_expert_parallel --quantization ascend --port 8002 --max-model-len 5120 --max-num-seqs 128 --enforce-eager ##### b.How to run graph mode export VLLM_USE_V1=1 # v1 export HCCL_BUFFSIZE=1024 python -m vllm.entrypoints.openai.api_server --model=$1 --trust-remote-code -tp $2 -dp $3 --enable_expert_parallel --quantization ascend --port $4 --max-model-len $5 --additional_config='{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":true}}' eg: python -m vllm.entrypoints.openai.api_server --model=/weight/dsr1_w4a8_vllm --trust-remote-code -tp 4 -dp 4 --enable_expert_parallel --quantization ascend --port 8002 --max-model-len 5120 --additional_config='{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":true}}' - vLLM version: v0.10.0 - vLLM main: https://github.com/vllm-project/vllm/commit/c494f96fbcf5e9f19f59e3dea6c2780aeb6c567f --------- Signed-off-by: Wang Kunpeng <1289706727@qq.com>
2025-08-06 10:17:44 +08:00
setattr(param, "quant_method",
FusedMoeWeightScaleSupported.GROUP.value)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
[quantization] Support w8a8 quantization (#580) ### What this PR does / why we need it? Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model has [quantize filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27). If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply, otherwise will use VLLMAscendQuantizer directly. - This patch fix installation docs to make installation work - This patch enable norm quantization by patch `RMSNorm.__init__`, `RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model` - Add `AscendW8A8LinearMethod` for W8A8 - Add `AscendW8A8DynamicLinearMethod` and `AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC - Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` ### Does this PR introduce _any_ user-facing change? Yes, support w8a8 quantization. After this patch supported, users can use below commands to run w8a8 models: ``` vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B" ``` ### How was this patch tested? 0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` 1. From @Yikun: I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls refer to https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613 2. From @dingdingchaomian : Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both models were quantized using Ascend's msmodelslim tool: - Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one for w8a8 dynamic. - Deepseek-v2-lite-chat were tested once because its quantization used both static and dynamic w8a8. Models were tested using both off line inference and online serving, and both work well. The inference codes are exactly the same with the examples in https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with model path and tensor parallel number changed. --------- Signed-off-by: dingdingchaomian <wangce21@huawei.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: dingdingchaomian <wangce21@huawei.com> Co-authored-by: Angazenn <zengyanjia@huawei.com> Co-authored-by: liujiaxu <liujiaxu4@huawei.com> Co-authored-by: ApsarasX <apsarax@outlook.com> Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
top_k: int,
renormalize: bool,
[quantization] Support w8a8 quantization (#580) ### What this PR does / why we need it? Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model has [quantize filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27). If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply, otherwise will use VLLMAscendQuantizer directly. - This patch fix installation docs to make installation work - This patch enable norm quantization by patch `RMSNorm.__init__`, `RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model` - Add `AscendW8A8LinearMethod` for W8A8 - Add `AscendW8A8DynamicLinearMethod` and `AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC - Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` ### Does this PR introduce _any_ user-facing change? Yes, support w8a8 quantization. After this patch supported, users can use below commands to run w8a8 models: ``` vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B" ``` ### How was this patch tested? 0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` 1. From @Yikun: I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls refer to https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613 2. From @dingdingchaomian : Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both models were quantized using Ascend's msmodelslim tool: - Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one for w8a8 dynamic. - Deepseek-v2-lite-chat were tested once because its quantization used both static and dynamic w8a8. Models were tested using both off line inference and online serving, and both work well. The inference codes are exactly the same with the examples in https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with model path and tensor parallel number changed. --------- Signed-off-by: dingdingchaomian <wangce21@huawei.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: dingdingchaomian <wangce21@huawei.com> Co-authored-by: Angazenn <zengyanjia@huawei.com> Co-authored-by: liujiaxu <liujiaxu4@huawei.com> Co-authored-by: ApsarasX <apsarax@outlook.com> Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
use_grouped_topk: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
[quantization] Support w8a8 quantization (#580) ### What this PR does / why we need it? Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model has [quantize filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27). If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply, otherwise will use VLLMAscendQuantizer directly. - This patch fix installation docs to make installation work - This patch enable norm quantization by patch `RMSNorm.__init__`, `RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model` - Add `AscendW8A8LinearMethod` for W8A8 - Add `AscendW8A8DynamicLinearMethod` and `AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC - Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` ### Does this PR introduce _any_ user-facing change? Yes, support w8a8 quantization. After this patch supported, users can use below commands to run w8a8 models: ``` vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B" ``` ### How was this patch tested? 0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` 1. From @Yikun: I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls refer to https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613 2. From @dingdingchaomian : Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both models were quantized using Ascend's msmodelslim tool: - Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one for w8a8 dynamic. - Deepseek-v2-lite-chat were tested once because its quantization used both static and dynamic w8a8. Models were tested using both off line inference and online serving, and both work well. The inference codes are exactly the same with the examples in https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with model path and tensor parallel number changed. --------- Signed-off-by: dingdingchaomian <wangce21@huawei.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: dingdingchaomian <wangce21@huawei.com> Co-authored-by: Angazenn <zengyanjia@huawei.com> Co-authored-by: liujiaxu <liujiaxu4@huawei.com> Co-authored-by: ApsarasX <apsarax@outlook.com> Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: bool = True,
enable_force_load_balance: bool = False,
log2phy: torch.Tensor = None,
global_redundant_expert_num=0,
[quantization] Support w8a8 quantization (#580) ### What this PR does / why we need it? Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model has [quantize filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27). If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply, otherwise will use VLLMAscendQuantizer directly. - This patch fix installation docs to make installation work - This patch enable norm quantization by patch `RMSNorm.__init__`, `RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model` - Add `AscendW8A8LinearMethod` for W8A8 - Add `AscendW8A8DynamicLinearMethod` and `AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC - Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` ### Does this PR introduce _any_ user-facing change? Yes, support w8a8 quantization. After this patch supported, users can use below commands to run w8a8 models: ``` vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B" ``` ### How was this patch tested? 0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` 1. From @Yikun: I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls refer to https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613 2. From @dingdingchaomian : Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both models were quantized using Ascend's msmodelslim tool: - Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one for w8a8 dynamic. - Deepseek-v2-lite-chat were tested once because its quantization used both static and dynamic w8a8. Models were tested using both off line inference and online serving, and both work well. The inference codes are exactly the same with the examples in https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with model path and tensor parallel number changed. --------- Signed-off-by: dingdingchaomian <wangce21@huawei.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: dingdingchaomian <wangce21@huawei.com> Co-authored-by: Angazenn <zengyanjia@huawei.com> Co-authored-by: liujiaxu <liujiaxu4@huawei.com> Co-authored-by: ApsarasX <apsarax@outlook.com> Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
**kwargs,
) -> torch.Tensor:
return self.quant_method.apply(
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
global_num_experts, expert_map, topk_group, num_expert_group,
custom_routing_function, scoring_func, e_score_correction_bias,
is_prefill, enable_force_load_balance, log2phy,
global_redundant_expert_num, **kwargs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method, "process_weights_after_loading"):
self.quant_method.process_weights_after_loading(layer)
def get_fused_moe_quant_config(self, layer: torch.nn.Module):
# TODO: implement this function
pass
class AscendEmbeddingMethod(AscendLinearMethod):
"""Embedding method for Ascend quantization.
Args:
quant_config: The Ascend quantization config.
"""
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
packed_modules_mapping: Dict[str, Any]) -> None:
self.quant_method = get_quant_method(quant_config.quant_description,
prefix, "linear",
packed_modules_mapping)