[Quantization] Support compressed tensors w8a8 static and w8a8 dynamic weight (#4036)

### What this PR does / why we need it?

While using the LLM Compressor quantization tool from the VLLM community
to generate quantized weights, the VLLM Ascend engine needs to be
adapted to support the compressed tensors quantization format.

1. Add AscendCompressedTensorsConfig to replace CompressedTensorsConfig
in vllm.
2. Support CompressedTensorsW8A8 static weight.
- weight: per-channel, int8, symmetric; activation: per-tensor, int8,
symmetric.
4. Support CompressedTensorsW8A8Dynamic weight.
- weight: per-channel, int8, symmetric; activation: per-token, int8,
symmetric, dynamic.
5. Modify the override_quantization_method in AscendQuantConfig.

Co-authored-by: taoqun110 taoqun@huawei.com
Co-authored-by: chenxi-hh chen464822955@163.com

- vLLM version: v0.11.2

---------

Signed-off-by: LHXuuu <scut_xlh@163.com>
Signed-off-by: chenxi-hh <chen464822955@163.com>
Signed-off-by: chenxi-hh <32731611+chenxi-hh@users.noreply.github.com>
Co-authored-by: chenxi-hh <chen464822955@163.com>
Co-authored-by: chenxi-hh <32731611+chenxi-hh@users.noreply.github.com>
This commit is contained in:
LHXuuu
2025-11-28 14:09:39 +08:00
committed by GitHub
parent ab37a7d5ae
commit bdc66972db
18 changed files with 707 additions and 32 deletions

View File

@@ -179,6 +179,7 @@ jobs:
VLLM_USE_MODELSCOPE: True
if: ${{ inputs.type == 'full' }}
run: |
pytest -sv tests/e2e/multicard/test_quantization.py
pytest -sv tests/e2e/multicard/test_aclgraph_capture_replay.py
pytest -sv tests/e2e/multicard/test_torchair_graph_mode.py
pytest -sv tests/e2e/multicard/test_full_graph_mode.py

View File

@@ -7,6 +7,7 @@ This section provides a detailed usage guide of vLLM Ascend features.
:maxdepth: 1
graph_mode
quantization
quantization-llm-compressor
sleep_mode
structured_output
lora

View File

@@ -0,0 +1,65 @@
# llm-compressor Quantization Guide
Model quantization is a technique that reduces the size and computational requirements of a model by lowering the data precision of the weights and activation values in the model, thereby saving the memory and improving the inference speed.
## Supported llm-compressor Quantization Types
Support CompressedTensorsW8A8 static weight
weight: per-channel, int8, symmetric; activation: per-tensor, int8, symmetric.
Support CompressedTensorsW8A8Dynamic weight
weight: per-channel, int8, symmetric; activation: per-token, int8, symmetric, dynamic.
## Install llm-compressor
To quantize a model, you should install [llm-compressor](https://github.com/vllm-project/llm-compressor/blob/main/README.md). It is a unified library for creating compressed models for faster inference with vLLM.
Install llm-compressor
```bash
pip install llmcompressor
```
### Generate the W8A8 weights
```bash
cd examples/quantization/llm-compressor
python3 w8a8_int8_dynamic.py
```
for more details, see the [Official Sample](https://github.com/vllm-project/llm-compressor/tree/main/examples).
## Run the model
Now, you can run the quantized model with vLLM Ascend. Examples for online and offline inference are provided as follows:
### Offline inference
```python
import torch
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.6, top_p=0.95, top_k=40)
llm = LLM(model="{quantized_model_save_path}",
max_model_len=2048,
trust_remote_code=True)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
### Online inference
Start the quantized model using vLLM Ascend; no modifications to the startup command are required.

View File

@@ -0,0 +1,160 @@
import os
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration, \
AutoTokenizer, AutoProcessor, AutoConfig, AutoImageProcessor
from llmcompressor import oneshot
from llmcompressor.modifiers.awq import AWQModifier
from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier
from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme, QuantizationType, QuantizationStrategy
W8A8_W_cha_A_ten_static_symmetric = {
"group_0": QuantizationScheme(
targets=["Linear"],
weights=QuantizationArgs(
num_bits=8,
type=QuantizationType.INT,
strategy=QuantizationStrategy.CHANNEL,
symmetric=True,
dynamic=False
),
input_activations=QuantizationArgs(
num_bits=8,
type=QuantizationType.INT,
strategy=QuantizationStrategy.TENSOR,
symmetric=True,
dynamic=False
),
),
}
# supported modifiers
MODIFIER_DICT = {
"PTQ": QuantizationModifier,
"AWQ": AWQModifier,
"GPTQ": GPTQModifier,
}
# supported schemes
SCHEMES_DICT = {
"W8A8_W_cha_A_ten_static_symmetric": W8A8_W_cha_A_ten_static_symmetric,
}
MODEL_DICT = {
"qwen3": AutoModelForCausalLM,
}
TOKENIZER_DICT = {
"qwen3": AutoTokenizer,
}
def load_environment_variables():
env_vars = {
'model_path': "Qwen/Qwen3-32B",
'export_path': "/llm-compressor/export/GPTQ/W8A8_W_cha_A_ten_static_symmetric",
'modifier': "GPTQ",
'schemes': "W8A8_W_cha_A_ten_static_symmetric",
'calib_prompt_path': "HuggingFaceH4/ultrachat_200k"
}
# verify export model path
if env_vars['export_path'] is None:
env_vars['export_path'] = env_vars['model_path'].rstrip("/") + "-" + env_vars['modifier']
if env_vars['schemes'] is not None:
env_vars['export_path'] += "-" + env_vars['schemes']
os.makedirs(env_vars['export_path'], exist_ok=True)
return env_vars
def load_calibration_text_dataset(calib_prompt_path, tokenizer):
# Load dataset
for f in os.listdir(calib_prompt_path):
print(f)
if any(f.lower().endswith('.jsonl') for f in os.listdir(calib_prompt_path)):
ds = load_dataset('json', data_dir=calib_prompt_path, split='validation')
elif any(f.lower().endswith('.parquet') for f in os.listdir(calib_prompt_path)):
ds = load_dataset("parquet", data_dir=calib_prompt_path, split="train[:512]")
else:
raise ValueError("Unsupported calibration file format: {}".format(
calib_prompt_path.split('.')[-1]))
# Preprocess dataset
def preprocess(example):
if tokenizer.chat_template is not None:
return {"text": tokenizer.apply_chat_template(
example["messages"], tokenize=False)}
else:
return {"text": example["messages"]}
# Tokenize inputs
def tokenize(sample):
return tokenizer(
sample["text"],
add_special_tokens=False,
)
ds = ds.map(preprocess)
ds = ds.map(tokenize, remove_columns=ds.column_names)
return ds
# Define a oneshot data collator for multimodal inputs.
def data_collator(batch):
assert len(batch) == 1
return {
key: torch.tensor(value, dtype=torch.bfloat16 if key == "pixel_values" else torch.long)
for key, value in batch[0].items()
}
def quantize_model(model, env_vars, dataset_dict=None):
# since the MoE gate layers are sensitive to quantization, we add them to the ignore
# list so they remain at full precision
ignore = ["lm_head", "re:.*mlp.down_proj"]
# define a llmcompressor recipe
recipe = [
MODIFIER_DICT[env_vars['modifier']](
config_groups=SCHEMES_DICT[env_vars['schemes']],
ignore=ignore,
),
]
# quantize the model
oneshot(
model=model,
dataset=dataset_dict,
recipe=recipe,
trust_remote_code_model=True,
)
def save_quantized_model(model, tokenizer, save_path, save_compressed=False):
model.save_pretrained(save_path, save_compressed=save_compressed)
tokenizer.save_pretrained(save_path)
if __name__ == '__main__':
# get environment variables
env_vars = load_environment_variables()
# support model type list
config = AutoConfig.from_pretrained(env_vars['model_path'], trust_remote_code=True)
model_type = config.model_type
model = MODEL_DICT[model_type].from_pretrained(
env_vars['model_path'], torch_dtype="auto", trust_remote_code=True
)
tokenizer = TOKENIZER_DICT[model_type].from_pretrained(env_vars['model_path'], trust_remote_code=True)
ds = load_calibration_text_dataset(env_vars["calib_prompt_path"], tokenizer)
# Quantize the model
quantize_model(model, env_vars, ds)
# save the quantized model
save_quantized_model(model, tokenizer, env_vars['export_path'], True)

View File

@@ -0,0 +1,83 @@
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
from llmcompressor.utils import dispatch_for_generation
# Select model and load it.
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# Select calibration dataset.
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"
# Select number of samples. 512 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048
# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
ds = ds.shuffle(seed=42)
def preprocess(example):
return {
"text": tokenizer.apply_chat_template(
example["messages"],
tokenize=False,
)
}
ds = ds.map(preprocess)
# Tokenize inputs.
def tokenize(sample):
return tokenizer(
sample["text"],
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
add_special_tokens=False,
)
ds = ds.map(tokenize, remove_columns=ds.column_names)
# Configure algorithms. In this case, we:
# * apply SmoothQuant to make the activations easier to quantize
# * quantize the weights to int8 with GPTQ (static per channel)
# * quantize the activations to int8 (dynamic per token)
recipe = [
SmoothQuantModifier(smoothing_strength=0.8),
GPTQModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"]),
]
# Apply algorithms and save to output_dir
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)
# Confirm generations of the quantized model look sane.
print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("npu")
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")
# Save to disk compressed.
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W8A8-Dynamic-Per-Token"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)

View File

@@ -15,6 +15,15 @@ ignore_missing_imports = True
[mypy-lm_eval.*]
ignore_missing_imports = True
[mypy-compressed_tensors.*]
ignore_missing_imports = True
[mypy-datasets.*]
ignore_missing_imports = True
[mypy-llmcompressor.*]
ignore_missing_imports = True
[mypy-msprobe.*]
ignore_missing_imports = True
allow_untyped_imports = True
allow_untyped_imports = True

View File

@@ -23,6 +23,7 @@ requires = [
"quart",
"numba",
"opencv-python-headless<=4.11.0.86", # Required to avoid numpy version conflict with vllm
"compressed_tensors>=0.11.0"
]
build-backend = "setuptools.build_meta"

View File

@@ -16,6 +16,7 @@ torchvision
wheel
pandas-stubs
opencv-python-headless<=4.11.0.86 # Required to avoid numpy version conflict with vllm
compressed_tensors>=0.11.0
# requirements for disaggregated prefill
msgpack

View File

@@ -0,0 +1,46 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
#
"""Compare the short outputs of HF and vLLM when using greedy sampling.
Run `pytest tests/e2e/multicard/test_quantization.py`.
"""
from modelscope import snapshot_download # type: ignore
from tests.e2e.conftest import VllmRunner
def test_models_distributed_quantized_W8A8():
example_prompts = [
"The president of the United States is",
]
max_tokens = 5
with VllmRunner(snapshot_download("neuralmagic/Qwen2.5-3B-quantized.w8a8"),
tensor_parallel_size=2,
max_model_len=4096,
gpu_memory_utilization=0.8,
enforce_eager=False) as vllm_model:
vllm_output = vllm_model.generate_greedy(example_prompts, max_tokens)
golden_results = [
'The president of the United States is the head of state and',
]
for i in range(len(vllm_output)):
assert golden_results[i] == vllm_output[i][1]
print(f"Generated text: {vllm_output[i][1]!r}")

View File

@@ -65,7 +65,7 @@ class TestAscendQuantConfig(TestBase):
# Test when NPU is available
mock_is_available.return_value = True
result = AscendQuantConfig.override_quantization_method(None, None)
self.assertEqual(result, ASCEND_QUANTIZATION_METHOD)
self.assertIsNone(result)
# Test when NPU is not available
mock_is_available.return_value = False
@@ -93,7 +93,7 @@ class TestAscendQuantConfig(TestBase):
self.assertIs(method, mock_ascend_linear.return_value)
mock_ascend_linear.assert_called_once_with(
self.ascend_config, ".attn",
self.ascend_config.packed_modules_mapping)
self.ascend_config.packed_modules_mapping, linear_layer)
def test_get_quant_method_for_attention(self):
attention_layer = MagicMock(spec=Attention)

View File

@@ -9,7 +9,8 @@ from vllm.platforms import PlatformEnum
from tests.ut.base import TestBase
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD, AscendDeviceType
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD,
COMPRESSED_TENSORS_METHOD, AscendDeviceType)
class TestNPUPlatform(TestBase):
@@ -47,8 +48,9 @@ class TestNPUPlatform(TestBase):
self.assertEqual(NPUPlatform.device_control_env_var,
"ASCEND_RT_VISIBLE_DEVICES")
self.assertEqual(NPUPlatform.dispatch_key, "PrivateUse1")
self.assertEqual(NPUPlatform.supported_quantization,
[ASCEND_QUANTIZATION_METHOD])
self.assertEqual(
NPUPlatform.supported_quantization,
[ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD])
def test_is_sleep_mode_available(self):
self.assertTrue(self.platform.is_sleep_mode_available())

View File

@@ -30,12 +30,13 @@ from vllm_ascend.ascend_config import (check_ascend_config, get_ascend_config,
init_ascend_config)
from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
delete_torchair_cache_file)
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, AscendDeviceType,
enable_sp, get_ascend_device_type, is_vl_model,
prefill_context_parallel_enable,
update_aclgraph_sizes,
update_cudagraph_capture_sizes,
update_default_aclgraph_sizes)
# isort: off
from vllm_ascend.utils import (
ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD, AscendDeviceType,
enable_sp, get_ascend_device_type, is_vl_model,
prefill_context_parallel_enable, update_aclgraph_sizes,
update_cudagraph_capture_sizes, update_default_aclgraph_sizes)
if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig
@@ -56,7 +57,9 @@ class NPUPlatform(Platform):
device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES"
dispatch_key: str = "PrivateUse1"
supported_quantization: list[str] = [ASCEND_QUANTIZATION_METHOD]
supported_quantization: list[str] = [
ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD
]
def is_sleep_mode_available(self) -> bool:
return True
@@ -79,6 +82,8 @@ class NPUPlatform(Platform):
if ASCEND_QUANTIZATION_METHOD not in quant_action.choices:
quant_action.choices.append(ASCEND_QUANTIZATION_METHOD)
from vllm_ascend.quantization.compressed_tensors.compressed_tensors import \
AscendCompressedTensorsConfig # noqa: F401
from vllm_ascend.quantization.quant_config import \
AscendQuantConfig # noqa: F401

View File

@@ -0,0 +1,252 @@
from typing import TYPE_CHECKING, Any, Optional, cast
import torch
from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy)
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import (
QUANTIZATION_METHODS, register_quantization_config)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import \
CompressedTensorsScheme
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
find_matched_target, is_activation_quantization_format,
should_ignore_layer)
from vllm_ascend.quantization.quant_config import (AscendLinearMethod,
AscendQuantConfig)
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD
if TYPE_CHECKING:
from vllm.model_executor.models.utils import WeightsMapper
logger = init_logger(__name__)
QUANTIZATION_SCHEME_MAP_TYPE = dict[str, Optional[dict[str, QuantizationArgs]]]
def remove_quantization_method():
if COMPRESSED_TENSORS_METHOD in QUANTIZATION_METHODS:
QUANTIZATION_METHODS.remove(COMPRESSED_TENSORS_METHOD)
remove_quantization_method()
@register_quantization_config(COMPRESSED_TENSORS_METHOD)
class AscendCompressedTensorsConfig(QuantizationConfig):
def __init__(
self,
target_scheme_map: dict[str, Any],
ignore: list[str],
quant_format: str,
config: Optional[dict[str, Any]] = None,
):
super().__init__()
self.ignore = ignore
self.quant_format = quant_format
# Map from [target -> scheme]
self.target_scheme_map = target_scheme_map
self.quant_description = config
def get_name(self) -> str:
return "compressed-tensors"
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.int8, torch.float16, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
raise NotImplementedError(
"Ascend hardware dose not support \"get_min_capability\" feature.")
@classmethod
def get_config_filenames(cls) -> list[str]:
return []
@classmethod
def from_config(cls, config: dict[str,
Any]) -> "AscendCompressedTensorsConfig":
ignore: list[str] = cast(list[str], config.get("ignore", []))
quant_format = cast(str, config.get("format"))
target_scheme_map = cls._quantization_scheme_map_from_config(
config=config)
return cls(
target_scheme_map=target_scheme_map,
ignore=ignore,
quant_format=quant_format,
config=config,
)
@classmethod
def _quantization_scheme_map_from_config(
cls, config: dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE:
"""
:param config: The `quantization_config` dictionary from config.json
:return: A dictionary mapping target layer names to their corresponding
quantization_args for weights and input activations
"""
target_scheme_map: dict[str, Any] = dict()
quant_format = cast(str, config.get("format"))
# The quant_config has multiple config_groups, each containing
# an input_activations key with details about how the activations are
# quantized, a weights key indicating how the weights are quantized,
# and a list of targets under the `targets` key, dictating which
# layers are impacted by the quantization details. The quantization
# details follow the structure defined by the QuantizationArgs
# pydantic model, which is used to verify the structure of the
# quant_config and also store the details for later use.
config_groups = config.get("config_groups", dict())
for _, quant_config in config_groups.items():
targets = quant_config.get("targets")
for target in targets:
target_scheme_map[target] = {}
target_scheme_map[target][
"weights"] = QuantizationArgs.model_validate(
quant_config.get("weights"))
target_scheme_map[target]["input_activations"] = None
target_scheme_map[target]["format"] = quant_config.get(
"format")
format = target_scheme_map[target].get("format")
# If no per-config format defined, use global format in config
act_quant_format = (
is_activation_quantization_format(format)
if format is not None else
is_activation_quantization_format(quant_format))
input_activations = quant_config.get("input_activations")
if act_quant_format and input_activations is not None:
target_scheme_map[target]["input_activations"] = (
QuantizationArgs.model_validate(
quant_config.get("input_activations")))
return target_scheme_map
def get_quant_method(
self,
layer: torch.nn.Module,
prefix: str,
) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
layer.ascend_quant_method = COMPRESSED_TENSORS_METHOD
# collect schemes
quant_scheme = self.get_scheme(layer=layer, layer_name=prefix)
# choose quantization method
quant_method: LinearMethodBase = UnquantizedLinearMethod()
if quant_scheme is not None:
layer.scheme = quant_scheme
ascend_quant_config = AscendQuantConfig(self.quant_description
or {})
quant_method = AscendLinearMethod(ascend_quant_config, prefix,
None, layer)
return quant_method
return None
def get_scheme(self,
layer: torch.nn.Module,
layer_name: Optional[str] = None
) -> Optional["CompressedTensorsScheme"]:
"""
compressed-tensors supports non uniform in the following way:
targets of config_groups: There can be N config_groups which each
have a quantization scheme. Each config_group has a list of targets
which can be a full layer_name, a regex for a layer_name, or
an nn.Module name.
Detect whether a layer_name is found in any target and
use the quantization scheme corresponding to the matched target
to select the CompressedTensorsScheme used for inference.
"""
# Find the "target" in the compressed-tensors config
# that our layer conforms to.
if should_ignore_layer(layer_name,
ignore=self.ignore,
fused_mapping=self.packed_modules_mapping):
return None
# Will be empty for models with only sparsity
weight_quant = input_quant = None
if self.target_scheme_map:
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=self.target_scheme_map.keys(),
fused_mapping=self.packed_modules_mapping,
)
scheme_dict = self.target_scheme_map[matched_target]
weight_quant = scheme_dict.get("weights")
input_quant = scheme_dict.get("input_activations")
if weight_quant is None:
logger.warning_once("Acceleration for non-quantized schemes is "
"not supported by Compressed Tensors. "
"Falling back to UnquantizedLinearMethod")
return None
else:
# Find the quant_scheme
scheme = self._get_scheme_from_parts(
weight_quant=weight_quant,
input_quant=input_quant,
)
return scheme
def _get_scheme_from_parts(
self, weight_quant: QuantizationArgs,
input_quant: QuantizationArgs) -> "CompressedTensorsScheme":
act_quant_format = is_activation_quantization_format(self.quant_format)
if act_quant_format and input_quant is not None:
if self._is_static_tensor_w8a8(weight_quant, input_quant):
return AscendW8A8LinearMethod()
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
return AscendW8A8DynamicLinearMethod()
raise NotImplementedError(
"No compressed-tensors compatible scheme was found.")
def _is_static_tensor_w8a8(self, weight_quant: QuantizationArgs,
input_quant: QuantizationArgs) -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
weight_strategy = (
weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
is_tensor = (weight_strategy and input_quant.strategy
== QuantizationStrategy.TENSOR.value)
is_static = not weight_quant.dynamic and not input_quant.dynamic
is_symmetric = weight_quant.symmetric and input_quant.symmetric
# Only symmetric input quantization supported.
# Only symmetric weight quantization supported.
return is_8_bits and is_tensor and is_symmetric and is_static
def _is_dynamic_token_w8a8(self, weight_quant: QuantizationArgs,
input_quant: QuantizationArgs) -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
weight_strategy = (
weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
is_token = (weight_strategy and input_quant.strategy
== QuantizationStrategy.TOKEN.value)
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
is_symmetric = weight_quant.symmetric and input_quant.symmetric
# Only symmetric input quantization supported.
# Only symmetric weight quantization supported.
return is_8_bits and is_token and is_symmetric and is_dynamic
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
self.target_scheme_map = hf_to_vllm_mapper.apply_dict(
self.target_scheme_map)
self.ignore = hf_to_vllm_mapper.apply_list(self.ignore)

View File

@@ -94,8 +94,10 @@ class AscendQuantConfig(QuantizationConfig):
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
if torch.npu.is_available():
return ASCEND_QUANTIZATION_METHOD
if hf_quant_cfg is not None:
quant_method = hf_quant_cfg.get("quant_method", None)
if quant_method is None and torch.npu.is_available():
return ASCEND_QUANTIZATION_METHOD
return None
def get_quant_method(self, layer: torch.nn.Module,
@@ -113,7 +115,7 @@ class AscendQuantConfig(QuantizationConfig):
self.packed_modules_mapping):
return AscendUnquantizedLinearMethod()
return AscendLinearMethod(self, prefix,
self.packed_modules_mapping)
self.packed_modules_mapping, layer)
elif isinstance(layer, Attention) and \
'fa_quant_type' in self.quant_description.keys() and \
self.quant_description['fa_quant_type'] is not None:
@@ -126,13 +128,13 @@ class AscendQuantConfig(QuantizationConfig):
self.packed_modules_mapping):
return AscendUnquantizedFusedMoEMethod(layer.moe_config)
return AscendFusedMoEMethod(self, prefix,
self.packed_modules_mapping)
self.packed_modules_mapping, layer)
elif isinstance(layer, VocabParallelEmbedding):
if self.is_layer_skipped_ascend(prefix,
self.packed_modules_mapping):
return UnquantizedEmbeddingMethod()
return AscendEmbeddingMethod(self, prefix,
self.packed_modules_mapping)
self.packed_modules_mapping, layer)
return None
def is_layer_skipped_ascend(
@@ -259,11 +261,16 @@ class AscendLinearMethod(LinearMethodBase):
quant_config: The Ascend quantization config.
"""
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
packed_modules_mapping: Dict[str, Any]) -> None:
def __init__(self,
quant_config: AscendQuantConfig,
prefix: str,
packed_modules_mapping: Dict[str, Any] | None,
layer: torch.nn.Module = None) -> None:
self.quant_method = get_quant_method(quant_config.quant_description,
prefix, "linear",
packed_modules_mapping)
prefix,
"linear",
packed_modules_mapping,
layer=layer)
def create_weights(
self,
@@ -401,11 +408,16 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
quant_config: The Ascend quantization config.
"""
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
packed_modules_mapping: Dict[str, Any]):
def __init__(self,
quant_config: AscendQuantConfig,
prefix: str,
packed_modules_mapping: Dict[str, Any],
layer: torch.nn.Module = None):
self.quant_method = get_quant_method(quant_config.quant_description,
prefix, "moe",
packed_modules_mapping)
prefix,
"moe",
packed_modules_mapping,
layer=layer)
def create_weights(
self,
@@ -485,7 +497,10 @@ class AscendEmbeddingMethod(AscendLinearMethod):
"""
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
packed_modules_mapping: Dict[str, Any]) -> None:
packed_modules_mapping: Dict[str, Any],
layer: torch.nn.Module) -> None:
self.quant_method = get_quant_method(quant_config.quant_description,
prefix, "linear",
packed_modules_mapping)
prefix,
"linear",
packed_modules_mapping,
layer=layer)

View File

@@ -1,7 +1,10 @@
from typing import Any, Dict, Optional, Type
import torch
from vllm.logger import logger
from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD
from .w4a4_flatquant_dynamic import AscendW4A4FlatQuantDynamicLinearMethod
from .w4a8_dynamic import (AscendW4A8DynamicFusedMoEMethod,
AscendW4A8DynamicLinearMethod)
@@ -60,8 +63,28 @@ def get_linear_quant_type(quant_description: Dict[str, Any], prefix: str,
def get_quant_method(quant_description: Dict[str, Any],
prefix: str,
layer_type: str,
packed_modules_mapping: Optional[Dict[str, Any]] = None):
logger.info_once("Using the vLLM Ascend Quantization now!")
packed_modules_mapping: Optional[Dict[str, Any]] = None,
layer: torch.nn.Module = None):
if quant_description.get("quant_method") == COMPRESSED_TENSORS_METHOD:
return get_quant_method_llmcompressor(layer)
return get_quant_method_modelslim(quant_description, prefix, layer_type,
packed_modules_mapping)
def get_quant_method_llmcompressor(layer: torch.nn.Module):
logger.info_once("Using the vLLM Ascend llmcompressor Quantization now!")
if layer.scheme is None:
raise ValueError("A scheme must be defined for each layer")
return layer.scheme
def get_quant_method_modelslim(
quant_description: Dict[str, Any],
prefix: str,
layer_type: str,
packed_modules_mapping: Optional[Dict[str, Any]] = None):
logger.info_once("Using the vLLM Ascend modelslim Quantization now!")
if packed_modules_mapping is None:
packed_modules_mapping = dict()
# Attention

View File

@@ -25,7 +25,8 @@ from vllm.forward_context import get_forward_context
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendDeviceType,
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ,
COMPRESSED_TENSORS_METHOD, AscendDeviceType,
get_ascend_device_type, is_enable_nz)
@@ -149,6 +150,10 @@ class AscendW8A8LinearMethod:
)
quant_bias = layer.quant_bias if tp_rank == 0 else None
if getattr(layer, "ascend_quant_method",
"") == COMPRESSED_TENSORS_METHOD:
quant_bias = bias
if get_ascend_device_type() == AscendDeviceType._310P:
# On 300I Duo platform, we need transpose again if
# using nz. This transpose can be skipped in torchair.
@@ -187,6 +192,11 @@ class AscendW8A8LinearMethod:
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
if getattr(layer, "ascend_quant_method",
"") == COMPRESSED_TENSORS_METHOD:
deq_scale = layer.input_scale.data * layer.weight_scale.data
layer.deq_scale = torch.nn.Parameter(deq_scale,
requires_grad=False)
class AscendW8A8FusedMoEMethod:

View File

@@ -41,6 +41,7 @@ else:
VllmConfig = None
ASCEND_QUANTIZATION_METHOD = "ascend"
COMPRESSED_TENSORS_METHOD = "compressed-tensors"
SOC_VERSION_INFERENCE_SERIES = ["Ascend310P3"]
REGISTERED_ASCEND_OPS = {}