2025-02-21 17:07:37 +08:00
|
|
|
#
|
|
|
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
|
|
|
# Copyright 2023 The vLLM team.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
2025-04-17 14:59:56 +08:00
|
|
|
# This file is a part of the vllm-ascend project.
|
2025-02-21 17:07:37 +08:00
|
|
|
#
|
|
|
|
|
from types import MappingProxyType
|
2025-04-07 10:56:12 +08:00
|
|
|
from typing import Any, Callable, Dict, List, Mapping, Optional
|
2025-02-21 17:07:37 +08:00
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from vllm.distributed import get_tensor_model_parallel_rank
|
2025-04-07 10:56:12 +08:00
|
|
|
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
|
|
|
|
FusedMoeWeightScaleSupported)
|
2025-02-21 17:07:37 +08:00
|
|
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
|
|
|
|
RowParallelLinear,
|
|
|
|
|
UnquantizedLinearMethod)
|
|
|
|
|
from vllm.model_executor.layers.quantization import \
|
|
|
|
|
register_quantization_config
|
|
|
|
|
from vllm.model_executor.layers.quantization.base_config import (
|
|
|
|
|
QuantizationConfig, QuantizeMethodBase)
|
|
|
|
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
2025-07-29 18:51:57 +08:00
|
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
|
|
|
UnquantizedEmbeddingMethod, VocabParallelEmbedding)
|
2025-04-19 17:38:18 +08:00
|
|
|
from vllm.model_executor.parameter import PerTensorScaleParameter
|
2025-04-07 10:56:12 +08:00
|
|
|
from vllm.model_executor.utils import set_weight_attrs
|
2025-02-21 17:07:37 +08:00
|
|
|
|
2025-04-21 19:25:51 +08:00
|
|
|
from vllm_ascend.ops.fused_moe import AscendUnquantizedFusedMoEMethod
|
2025-05-17 17:36:04 +08:00
|
|
|
from vllm_ascend.utils import ASCEND_QUATIZATION_METHOD
|
2025-04-21 19:25:51 +08:00
|
|
|
|
2025-02-21 17:07:37 +08:00
|
|
|
from .quantizer import AscendQuantizer
|
|
|
|
|
|
|
|
|
|
|
2025-05-17 17:36:04 +08:00
|
|
|
@register_quantization_config(ASCEND_QUATIZATION_METHOD)
|
2025-02-21 17:07:37 +08:00
|
|
|
class AscendQuantConfig(QuantizationConfig):
|
2025-03-06 15:17:25 +08:00
|
|
|
"""Config class for Ascend
|
|
|
|
|
|
|
|
|
|
This class is a general class that parse quantization configs
|
|
|
|
|
that are supported on ascend hardware.
|
|
|
|
|
"""
|
2025-02-21 17:07:37 +08:00
|
|
|
|
|
|
|
|
def __init__(self, quant_config: Dict[str, Any]):
|
|
|
|
|
self.quant_description = quant_config
|
|
|
|
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
|
|
return "AscendQuantConfig:\n" + super().__repr__()
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_name(cls) -> str:
|
2025-05-17 17:36:04 +08:00
|
|
|
return ASCEND_QUATIZATION_METHOD
|
2025-02-21 17:07:37 +08:00
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
|
|
|
|
return [torch.int8, torch.float16, torch.bfloat16]
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_min_capability(cls) -> int:
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
"Ascend hardware dose not support \"get_min_capability\" feature.")
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_config_filenames(cls) -> List[str]:
|
2025-04-30 16:51:56 +08:00
|
|
|
return ["quant_model_description.json"]
|
2025-02-21 17:07:37 +08:00
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_config(cls, config: Dict[str, Any]) -> "AscendQuantConfig":
|
|
|
|
|
return cls(config)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def override_quantization_method(cls, hf_quant_cfg,
|
|
|
|
|
user_quant) -> Optional[str]:
|
|
|
|
|
if torch.npu.is_available():
|
2025-05-17 17:36:04 +08:00
|
|
|
return ASCEND_QUATIZATION_METHOD
|
2025-02-21 17:07:37 +08:00
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def get_quant_method(self, layer: torch.nn.Module,
|
|
|
|
|
prefix: str) -> Optional["QuantizeMethodBase"]:
|
|
|
|
|
from vllm.attention.layer import Attention
|
|
|
|
|
if isinstance(layer, LinearBase):
|
|
|
|
|
if self.is_layer_skipped_ascend(prefix,
|
|
|
|
|
self.packed_modules_mapping):
|
|
|
|
|
return UnquantizedLinearMethod()
|
2025-03-12 11:33:21 +08:00
|
|
|
return AscendLinearMethod(self, prefix,
|
|
|
|
|
self.packed_modules_mapping)
|
2025-04-07 10:56:12 +08:00
|
|
|
elif isinstance(layer, Attention) and \
|
|
|
|
|
'fa_quant_type' in self.quant_description.keys() and \
|
|
|
|
|
self.quant_description['fa_quant_type'] is not None:
|
2025-03-06 15:17:25 +08:00
|
|
|
return AscendKVCacheMethod(self, prefix)
|
2025-06-28 18:51:07 +08:00
|
|
|
elif isinstance(layer, Attention) and self.quant_description.get(
|
|
|
|
|
'kv_quant_type') == 'C8':
|
|
|
|
|
return AscendKVCacheMethod(self, prefix)
|
2025-04-07 10:56:12 +08:00
|
|
|
elif isinstance(layer, FusedMoE):
|
|
|
|
|
if self.is_layer_skipped_ascend(prefix,
|
|
|
|
|
self.packed_modules_mapping):
|
[V1] MTP supports torchair (#2145)
### What this PR does / why we need it?
Support MTP with:
- [x] V0 Scheduler
- [x] TorchAir
- [x] Single DP
- [x] Multi DP
- [x] Disaggregate PD
Known issues:
- [ ] Not support V1 Scheduler (chunked prefill), will be supported in a
few weeks
- [ ] vllm v0.10.0 does not support metrics with `DP > 1` right now,
need to comment out the line 171-175 in file
`vllm/vllm/v1/metrics/loggers.py`
```
if (len(self.engine_indexes) > 1
and vllm_config.speculative_config is not None):
raise NotImplementedError("Prometheus metrics with Spec Decoding "
"with >1 EngineCore per AsyncLLM is not "
"supported yet.")
```
To start an online server with torchair enabled, here is an example:
```
python -m vllm.entrypoints.openai.api_server \
--model="/weights/DeepSeek-R1_w8a8/" \
--trust-remote-code \
--max-model-len 40000 \
--tensor-parallel-size 4 \
--data_parallel_size 4 \
--max-num-seqs 16 \
--no-enable-prefix-caching \
--enable_expert_parallel \
--served-model-name deepseekr1 \
--speculative-config '{"num_speculative_tokens": 1, "method":"deepseek_mtp"}' \
--quantization ascend \
--host 0.0.0.0 \
--port 1234 \
--additional-config '{"ascend_scheduler_config":{"enabled":true,"enable_chunked_prefill":false},"torchair_graph_config":{"enabled":true,"graph_batch_sizes":[16]},"enable_weight_nz_layout":true}' \
--gpu_memory_utilization 0.9
```
offline example with torchair enabled
```
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=16, temperature=0)
# Create an LLM.
llm = LLM(
model="/home/data/DeepSeek-R1_w8a8/",
tensor_parallel_size=16,
max_num_seqs=16,
gpu_memory_utilization=0.9,
distributed_executor_backend="mp",
enable_expert_parallel=True,
speculative_config={
"method": "deepseek_mtp",
"num_speculative_tokens": 1,
},
trust_remote_code=True,
enforce_eager=False,
max_model_len=2000,
additional_config = {
'torchair_graph_config': {
'enabled': True,
"graph_batch_sizes": [16],
'enable_multistream_shared_expert': False,
},
"ascend_scheduler_config": {
"enabled": True
},
# 'expert_tensor_parallel_size': 16,
}
)
# Generate texts from the prompts.
# llm.start_profile()
outputs = llm.generate(prompts, sampling_params)
# llm.stop_profile()
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
- vLLM version: v0.10.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/302962e806e9820643ae25987e8e38ed035e05d3
---------
Signed-off-by: xuyexiong <xuyexiong@huawei.com>
2025-08-06 19:37:43 +08:00
|
|
|
return AscendUnquantizedFusedMoEMethod(layer.moe)
|
2025-04-07 10:56:12 +08:00
|
|
|
return AscendFusedMoEMethod(self, prefix,
|
|
|
|
|
self.packed_modules_mapping)
|
2025-07-29 18:51:57 +08:00
|
|
|
elif isinstance(layer, VocabParallelEmbedding):
|
|
|
|
|
if self.is_layer_skipped_ascend(prefix,
|
|
|
|
|
self.packed_modules_mapping):
|
|
|
|
|
return UnquantizedEmbeddingMethod()
|
|
|
|
|
return AscendEmbeddingMethod(self, prefix,
|
|
|
|
|
self.packed_modules_mapping)
|
2025-02-21 17:07:37 +08:00
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def is_layer_skipped_ascend(
|
|
|
|
|
self,
|
|
|
|
|
prefix: str,
|
|
|
|
|
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})):
|
|
|
|
|
# adapted from vllm.model_executor.layers.quantization.utils.quant_utils.is_layer_skipped
|
|
|
|
|
proj_name = prefix.split(".")[-1]
|
|
|
|
|
if proj_name in fused_mapping:
|
|
|
|
|
shard_prefixes = [
|
|
|
|
|
prefix.replace(proj_name, shard_proj_name)
|
|
|
|
|
for shard_proj_name in fused_mapping[proj_name]
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
is_skipped = None
|
|
|
|
|
for shard_prefix in shard_prefixes:
|
|
|
|
|
is_shard_skipped = self.quant_description[shard_prefix +
|
|
|
|
|
'.weight'] == "FLOAT"
|
|
|
|
|
|
|
|
|
|
if is_skipped is None:
|
|
|
|
|
is_skipped = is_shard_skipped
|
|
|
|
|
elif is_shard_skipped != is_skipped:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Detected some but not all shards of {prefix} "
|
|
|
|
|
"are quantized. All shards of fused layers "
|
|
|
|
|
"to have the same precision.")
|
|
|
|
|
else:
|
|
|
|
|
is_skipped = self.quant_description[prefix + '.weight'] == "FLOAT"
|
|
|
|
|
|
|
|
|
|
assert is_skipped is not None
|
|
|
|
|
return is_skipped
|
|
|
|
|
|
|
|
|
|
def get_scaled_act_names(self) -> List[str]:
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AscendLinearMethod(LinearMethodBase):
|
|
|
|
|
"""Linear method for Ascend quantization.
|
|
|
|
|
|
2025-03-06 15:17:25 +08:00
|
|
|
This class calls AscendQuantizer to search a specific quantization
|
|
|
|
|
implementations supported on ascend hardware for linear methods.
|
|
|
|
|
|
2025-02-21 17:07:37 +08:00
|
|
|
Args:
|
|
|
|
|
quant_config: The Ascend quantization config.
|
|
|
|
|
"""
|
|
|
|
|
|
2025-03-12 11:33:21 +08:00
|
|
|
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
|
|
|
|
|
packed_modules_mapping: Dict[str, Any]) -> None:
|
2025-02-21 17:07:37 +08:00
|
|
|
self.quantizer = AscendQuantizer.get_quantizer(
|
2025-03-12 11:33:21 +08:00
|
|
|
quant_config.quant_description, prefix, packed_modules_mapping)
|
2025-02-21 17:07:37 +08:00
|
|
|
self.quant_method = self.quantizer.build_linear_method()
|
|
|
|
|
|
|
|
|
|
def create_weights(
|
|
|
|
|
self,
|
|
|
|
|
layer: torch.nn.Module,
|
|
|
|
|
input_size_per_partition: int,
|
|
|
|
|
output_partition_sizes: List[int],
|
|
|
|
|
input_size: int,
|
|
|
|
|
output_size: int,
|
|
|
|
|
params_dtype: torch.dtype,
|
|
|
|
|
**extra_weight_attrs,
|
|
|
|
|
) -> None:
|
|
|
|
|
output_size_per_partition = sum(output_partition_sizes)
|
|
|
|
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
|
|
|
|
|
2025-03-06 15:17:25 +08:00
|
|
|
weight_dict = self.quant_method.get_weight(input_size_per_partition,
|
2025-02-21 17:07:37 +08:00
|
|
|
output_size_per_partition,
|
|
|
|
|
params_dtype)
|
2025-03-06 15:17:25 +08:00
|
|
|
for weight_name, weight_param in weight_dict.items():
|
2025-04-19 17:38:18 +08:00
|
|
|
param = torch.nn.Parameter(weight_param, requires_grad=False)
|
|
|
|
|
set_weight_attrs(param, {"input_dim": 1, "output_dim": 0})
|
|
|
|
|
layer.register_parameter(weight_name, param)
|
|
|
|
|
set_weight_attrs(param, extra_weight_attrs)
|
2025-03-06 15:17:25 +08:00
|
|
|
|
|
|
|
|
pertensor_dict = self.quant_method.get_pertensor_param(params_dtype)
|
|
|
|
|
for pertensor_name, pertensor_param in pertensor_dict.items():
|
|
|
|
|
param = PerTensorScaleParameter(data=pertensor_param,
|
|
|
|
|
weight_loader=weight_loader)
|
|
|
|
|
# disable warning
|
|
|
|
|
param.ignore_warning = True
|
|
|
|
|
layer.register_parameter(pertensor_name, param)
|
|
|
|
|
|
|
|
|
|
perchannel_dict = self.quant_method.get_perchannel_param(
|
|
|
|
|
output_size_per_partition, params_dtype)
|
|
|
|
|
for perchannel_name, perchannel_param in perchannel_dict.items():
|
2025-04-19 17:38:18 +08:00
|
|
|
param = torch.nn.Parameter(perchannel_param, requires_grad=False)
|
|
|
|
|
set_weight_attrs(param, {"output_dim": 0})
|
|
|
|
|
layer.register_parameter(perchannel_name, param)
|
|
|
|
|
set_weight_attrs(param, extra_weight_attrs)
|
2025-02-21 17:07:37 +08:00
|
|
|
|
2025-07-30 14:57:14 +08:00
|
|
|
pergroup_dict = self.quant_method.get_pergroup_param(
|
|
|
|
|
input_size_per_partition, output_size_per_partition, params_dtype)
|
|
|
|
|
for pergroup_name, pergroup_param in pergroup_dict.items():
|
|
|
|
|
param = torch.nn.Parameter(pergroup_param, requires_grad=False)
|
|
|
|
|
set_weight_attrs(param, {"output_dim": 0})
|
|
|
|
|
layer.register_parameter(pergroup_name, param)
|
|
|
|
|
set_weight_attrs(param, extra_weight_attrs)
|
|
|
|
|
if "weight_scale_second" in pergroup_name or "weight_offset_second" in pergroup_name:
|
|
|
|
|
setattr(param, "input_dim", 1)
|
|
|
|
|
param.input_dim = 1
|
|
|
|
|
|
2025-02-21 17:07:37 +08:00
|
|
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
2025-03-06 15:17:25 +08:00
|
|
|
if hasattr(self.quant_method, "process_weights_after_loading"):
|
|
|
|
|
self.quant_method.process_weights_after_loading(layer)
|
2025-02-21 17:07:37 +08:00
|
|
|
|
|
|
|
|
def apply(
|
|
|
|
|
self,
|
|
|
|
|
layer: torch.nn.Module,
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
bias: Optional[torch.Tensor] = None,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
if isinstance(layer, RowParallelLinear):
|
|
|
|
|
tp_rank = get_tensor_model_parallel_rank()
|
|
|
|
|
return self.quant_method.apply(layer, x, bias, tp_rank)
|
|
|
|
|
return self.quant_method.apply(layer, x, bias)
|
|
|
|
|
|
|
|
|
|
|
2025-03-06 15:17:25 +08:00
|
|
|
class AscendKVCacheMethod(BaseKVCacheMethod):
|
|
|
|
|
"""KVCache method for Ascend quantization.
|
|
|
|
|
|
|
|
|
|
This class calls AscendQuantizer to search a specific quantization
|
|
|
|
|
implementations supported on ascend hardware for kvcache methods.
|
2025-02-21 17:07:37 +08:00
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
quant_config: The Ascend quantization config.
|
|
|
|
|
"""
|
|
|
|
|
|
2025-03-06 15:17:25 +08:00
|
|
|
def __init__(self, quant_config: AscendQuantConfig, prefix: str) -> None:
|
2025-02-21 17:07:37 +08:00
|
|
|
self.quantizer = AscendQuantizer.get_quantizer(
|
2025-03-06 15:17:25 +08:00
|
|
|
quant_config.quant_description, prefix)
|
2025-02-21 17:07:37 +08:00
|
|
|
self.quant_method = self.quantizer.build_attention_method()
|
|
|
|
|
|
|
|
|
|
def create_weights(self, layer: torch.nn.Module) -> None:
|
2025-03-06 15:17:25 +08:00
|
|
|
# Different from linear method, there are no weight processing/slicing
|
|
|
|
|
# steps for attention in vllm. So the whole process of create weights
|
|
|
|
|
# is hidden into the specific quant method.
|
|
|
|
|
self.quant_method.create_weights(layer)
|
2025-02-21 17:07:37 +08:00
|
|
|
|
|
|
|
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
|
|
|
if hasattr(self.quant_method, "process_weights_after_loading"):
|
|
|
|
|
self.quant_method.process_weights_after_loading(layer)
|
|
|
|
|
|
2025-06-28 18:51:07 +08:00
|
|
|
def apply(self, layer: torch.nn.Module, query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor, value: torch.Tensor, kv_cache, attn_metadata,
|
|
|
|
|
attn_type, scale, output) -> torch.Tensor:
|
|
|
|
|
return self.quant_method.apply(layer, query, key, value, kv_cache,
|
|
|
|
|
attn_metadata, attn_type, scale, output)
|
2025-04-07 10:56:12 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class AscendFusedMoEMethod(FusedMoEMethodBase):
|
|
|
|
|
"""FusedMoE method for Ascend quantization.
|
|
|
|
|
|
|
|
|
|
This class calls AscendQuantizer to search a specific quantization
|
|
|
|
|
implementations supported on ascend hardware for kvcache methods.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
quant_config: The Ascend quantization config.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
|
|
|
|
|
packed_modules_mapping: Dict[str, Any]):
|
|
|
|
|
self.quantizer = AscendQuantizer.get_quantizer(
|
|
|
|
|
quant_config.quant_description, prefix, packed_modules_mapping)
|
|
|
|
|
self.quant_method = self.quantizer.build_moe_method()
|
|
|
|
|
|
|
|
|
|
def create_weights(
|
|
|
|
|
self,
|
|
|
|
|
layer: torch.nn.Module,
|
|
|
|
|
num_experts: int,
|
|
|
|
|
hidden_size: int,
|
|
|
|
|
intermediate_size_per_partition: int,
|
|
|
|
|
params_dtype: torch.dtype,
|
|
|
|
|
**extra_weight_attrs,
|
|
|
|
|
) -> None:
|
|
|
|
|
weight_param = self.quant_method.get_weight(
|
|
|
|
|
num_experts, intermediate_size_per_partition, hidden_size,
|
|
|
|
|
params_dtype)
|
|
|
|
|
for param_key, param_value in weight_param.items():
|
|
|
|
|
param = torch.nn.Parameter(param_value, requires_grad=False)
|
|
|
|
|
layer.register_parameter(param_key, param)
|
|
|
|
|
set_weight_attrs(param, extra_weight_attrs)
|
|
|
|
|
|
|
|
|
|
extra_weight_attrs.update(
|
|
|
|
|
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
|
|
|
|
|
dynamic_quant_param = self.quant_method.get_dynamic_quant_param(
|
|
|
|
|
num_experts, intermediate_size_per_partition, hidden_size,
|
|
|
|
|
params_dtype)
|
|
|
|
|
for param_key, param_value in dynamic_quant_param.items():
|
|
|
|
|
param = torch.nn.Parameter(param_value, requires_grad=False)
|
|
|
|
|
layer.register_parameter(param_key, param)
|
|
|
|
|
set_weight_attrs(param, extra_weight_attrs)
|
2025-08-06 10:17:44 +08:00
|
|
|
if "weight_scale_second" in param_key or "weight_offset_second" in param_key:
|
|
|
|
|
setattr(param, "quant_method",
|
|
|
|
|
FusedMoeWeightScaleSupported.GROUP.value)
|
2025-04-07 10:56:12 +08:00
|
|
|
|
|
|
|
|
def apply(
|
|
|
|
|
self,
|
|
|
|
|
layer: torch.nn.Module,
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
router_logits: torch.Tensor,
|
[quantization] Support w8a8 quantization (#580)
### What this PR does / why we need it?
Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on
linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model
has [quantize
filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27).
If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply,
otherwise will use VLLMAscendQuantizer directly.
- This patch fix installation docs to make installation work
- This patch enable norm quantization by patch `RMSNorm.__init__`,
`RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model`
- Add `AscendW8A8LinearMethod` for W8A8
- Add `AscendW8A8DynamicLinearMethod` and
`AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC
- Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8`
### Does this PR introduce _any_ user-facing change?
Yes, support w8a8 quantization. After this patch supported, users can
use below commands to run w8a8 models:
```
vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B"
```
### How was this patch tested?
0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8`
1. From @Yikun:
I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls
refer to
https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613
2. From @dingdingchaomian :
Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both
models were quantized using Ascend's msmodelslim tool:
- Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one
for w8a8 dynamic.
- Deepseek-v2-lite-chat were tested once because its quantization used
both static and dynamic w8a8.
Models were tested using both off line inference and online serving, and
both work well. The inference codes are exactly the same with the
examples in
https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with
model path and tensor parallel number changed.
---------
Signed-off-by: dingdingchaomian <wangce21@huawei.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Co-authored-by: dingdingchaomian <wangce21@huawei.com>
Co-authored-by: Angazenn <zengyanjia@huawei.com>
Co-authored-by: liujiaxu <liujiaxu4@huawei.com>
Co-authored-by: ApsarasX <apsarax@outlook.com>
Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
|
|
|
top_k: int,
|
2025-04-07 10:56:12 +08:00
|
|
|
renormalize: bool,
|
[quantization] Support w8a8 quantization (#580)
### What this PR does / why we need it?
Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on
linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model
has [quantize
filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27).
If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply,
otherwise will use VLLMAscendQuantizer directly.
- This patch fix installation docs to make installation work
- This patch enable norm quantization by patch `RMSNorm.__init__`,
`RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model`
- Add `AscendW8A8LinearMethod` for W8A8
- Add `AscendW8A8DynamicLinearMethod` and
`AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC
- Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8`
### Does this PR introduce _any_ user-facing change?
Yes, support w8a8 quantization. After this patch supported, users can
use below commands to run w8a8 models:
```
vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B"
```
### How was this patch tested?
0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8`
1. From @Yikun:
I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls
refer to
https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613
2. From @dingdingchaomian :
Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both
models were quantized using Ascend's msmodelslim tool:
- Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one
for w8a8 dynamic.
- Deepseek-v2-lite-chat were tested once because its quantization used
both static and dynamic w8a8.
Models were tested using both off line inference and online serving, and
both work well. The inference codes are exactly the same with the
examples in
https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with
model path and tensor parallel number changed.
---------
Signed-off-by: dingdingchaomian <wangce21@huawei.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Co-authored-by: dingdingchaomian <wangce21@huawei.com>
Co-authored-by: Angazenn <zengyanjia@huawei.com>
Co-authored-by: liujiaxu <liujiaxu4@huawei.com>
Co-authored-by: ApsarasX <apsarax@outlook.com>
Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
|
|
|
use_grouped_topk: bool = False,
|
|
|
|
|
global_num_experts: int = -1,
|
|
|
|
|
expert_map: Optional[torch.Tensor] = None,
|
2025-04-23 16:23:25 +08:00
|
|
|
topk_group: Optional[int] = None,
|
|
|
|
|
num_expert_group: Optional[int] = None,
|
2025-04-07 10:56:12 +08:00
|
|
|
custom_routing_function: Optional[Callable] = None,
|
|
|
|
|
scoring_func: str = "softmax",
|
[quantization] Support w8a8 quantization (#580)
### What this PR does / why we need it?
Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on
linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model
has [quantize
filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27).
If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply,
otherwise will use VLLMAscendQuantizer directly.
- This patch fix installation docs to make installation work
- This patch enable norm quantization by patch `RMSNorm.__init__`,
`RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model`
- Add `AscendW8A8LinearMethod` for W8A8
- Add `AscendW8A8DynamicLinearMethod` and
`AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC
- Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8`
### Does this PR introduce _any_ user-facing change?
Yes, support w8a8 quantization. After this patch supported, users can
use below commands to run w8a8 models:
```
vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B"
```
### How was this patch tested?
0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8`
1. From @Yikun:
I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls
refer to
https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613
2. From @dingdingchaomian :
Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both
models were quantized using Ascend's msmodelslim tool:
- Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one
for w8a8 dynamic.
- Deepseek-v2-lite-chat were tested once because its quantization used
both static and dynamic w8a8.
Models were tested using both off line inference and online serving, and
both work well. The inference codes are exactly the same with the
examples in
https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with
model path and tensor parallel number changed.
---------
Signed-off-by: dingdingchaomian <wangce21@huawei.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Co-authored-by: dingdingchaomian <wangce21@huawei.com>
Co-authored-by: Angazenn <zengyanjia@huawei.com>
Co-authored-by: liujiaxu <liujiaxu4@huawei.com>
Co-authored-by: ApsarasX <apsarax@outlook.com>
Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
|
|
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
2025-04-23 16:23:25 +08:00
|
|
|
is_prefill: bool = True,
|
2025-05-15 09:19:55 +08:00
|
|
|
enable_force_load_balance: bool = False,
|
2025-06-09 19:28:11 +08:00
|
|
|
log2phy: torch.Tensor = None,
|
|
|
|
|
global_redundant_expert_num=0,
|
[quantization] Support w8a8 quantization (#580)
### What this PR does / why we need it?
Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on
linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model
has [quantize
filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27).
If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply,
otherwise will use VLLMAscendQuantizer directly.
- This patch fix installation docs to make installation work
- This patch enable norm quantization by patch `RMSNorm.__init__`,
`RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model`
- Add `AscendW8A8LinearMethod` for W8A8
- Add `AscendW8A8DynamicLinearMethod` and
`AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC
- Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8`
### Does this PR introduce _any_ user-facing change?
Yes, support w8a8 quantization. After this patch supported, users can
use below commands to run w8a8 models:
```
vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B"
```
### How was this patch tested?
0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8`
1. From @Yikun:
I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls
refer to
https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613
2. From @dingdingchaomian :
Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both
models were quantized using Ascend's msmodelslim tool:
- Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one
for w8a8 dynamic.
- Deepseek-v2-lite-chat were tested once because its quantization used
both static and dynamic w8a8.
Models were tested using both off line inference and online serving, and
both work well. The inference codes are exactly the same with the
examples in
https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with
model path and tensor parallel number changed.
---------
Signed-off-by: dingdingchaomian <wangce21@huawei.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Co-authored-by: dingdingchaomian <wangce21@huawei.com>
Co-authored-by: Angazenn <zengyanjia@huawei.com>
Co-authored-by: liujiaxu <liujiaxu4@huawei.com>
Co-authored-by: ApsarasX <apsarax@outlook.com>
Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
|
|
|
**kwargs,
|
2025-04-07 10:56:12 +08:00
|
|
|
) -> torch.Tensor:
|
2025-05-15 09:19:55 +08:00
|
|
|
return self.quant_method.apply(
|
|
|
|
|
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
|
|
|
|
|
global_num_experts, expert_map, topk_group, num_expert_group,
|
|
|
|
|
custom_routing_function, scoring_func, e_score_correction_bias,
|
2025-06-09 19:28:11 +08:00
|
|
|
is_prefill, enable_force_load_balance, log2phy,
|
|
|
|
|
global_redundant_expert_num, **kwargs)
|
2025-04-08 09:15:56 +08:00
|
|
|
|
|
|
|
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
|
|
|
if hasattr(self.quant_method, "process_weights_after_loading"):
|
|
|
|
|
self.quant_method.process_weights_after_loading(layer)
|
2025-07-29 18:51:57 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class AscendEmbeddingMethod(AscendLinearMethod):
|
|
|
|
|
"""Embedding method for Ascend quantization.
|
|
|
|
|
This class calls AscendQuantizer to search a specific quantization
|
|
|
|
|
implementations supported on ascend hardware for Embedding methods.
|
|
|
|
|
Args:
|
|
|
|
|
quant_config: The Ascend quantization config.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
|
|
|
|
|
packed_modules_mapping: Dict[str, Any]) -> None:
|
|
|
|
|
self.quantizer = AscendQuantizer.get_quantizer(
|
|
|
|
|
quant_config.quant_description, prefix, packed_modules_mapping)
|
2025-08-06 10:17:44 +08:00
|
|
|
self.quant_method = self.quantizer.build_linear_method()
|