diff --git a/.github/workflows/misc/model_list.json b/.github/workflows/misc/model_list.json index e7dfa038..a045d308 100644 --- a/.github/workflows/misc/model_list.json +++ b/.github/workflows/misc/model_list.json @@ -204,6 +204,7 @@ "vllm-ascend/Qwen3-30B-A3B-Puring", "vllm-ascend/Qwen3-30B-A3B-W8A8", "vllm-ascend/Qwen3-30B-A3B-W8A8-Pruning", + "vllm-ascend/Qwen3-30B-A3B-Instruct-2507-quantized.w8a8", "vllm-ascend/Qwen3-32B-W4A4", "vllm-ascend/Qwen3-32B-W8A8", "vllm-ascend/Qwen3-8B", diff --git a/docs/source/user_guide/feature_guide/quantization.md b/docs/source/user_guide/feature_guide/quantization.md index 3bdfacc7..90a2cf90 100644 --- a/docs/source/user_guide/feature_guide/quantization.md +++ b/docs/source/user_guide/feature_guide/quantization.md @@ -72,7 +72,11 @@ pip install llmcompressor #### Model Quantization -`LLM-Compressor` provides various quantization scheme examples. To generate W8A8 dynamic quantized weights: +`LLM-Compressor` provides various quantization scheme examples. + +##### Dense Quantization + +An example to generate W8A8 dynamic quantized weights for dense model: ```bash # Navigate to LLM-Compressor examples directory @@ -82,6 +86,18 @@ cd examples/quantization/llm-compressor python3 w8a8_int8_dynamic.py ``` +##### MoE Quantization + +An example to generate W8A8 dynamic quantized weights for MoE model: + +```bash +# Navigate to LLM-Compressor examples directory +cd examples/quantization/llm-compressor + +# Run quantization script +python3 w8a8_int8_dynamic_moe.py +``` + For more content, refer to the [official examples](https://github.com/vllm-project/llm-compressor/tree/main/examples). Currently supported quantization types by LLM-Compressor: `W8A8` and `W8A8_DYNAMIC`. diff --git a/examples/quantization/llm-compressor/w8a8_int8_dynamic_moe.py b/examples/quantization/llm-compressor/w8a8_int8_dynamic_moe.py new file mode 100644 index 00000000..d5f5d887 --- /dev/null +++ b/examples/quantization/llm-compressor/w8a8_int8_dynamic_moe.py @@ -0,0 +1,29 @@ +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import QuantizationModifier + +MODEL_ID = "Qwen/Qwen3-30B-A3B-Instruct-2507" + +model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, dtype=torch.bfloat16, trust_remote_code=True +) +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + +recipe = QuantizationModifier( + targets="Linear", + scheme="INT8", + ignore=["lm_head", "re:.*mlp.gate$"], +) + +oneshot( + model=model, + recipe=recipe, + trust_remote_code_model=True, +) + +# Save to disk in compressed-tensors format. +SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-INT8_W8A8" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) diff --git a/tests/e2e/multicard/2-cards/test_quantization.py b/tests/e2e/multicard/2-cards/test_quantization.py index c9b87a61..1a3f11ad 100644 --- a/tests/e2e/multicard/2-cards/test_quantization.py +++ b/tests/e2e/multicard/2-cards/test_quantization.py @@ -42,3 +42,26 @@ def test_qwen2_5_w8a8_external_quantized_tp2(): for i in range(len(vllm_output)): assert golden_results[i] == vllm_output[i][1] print(f"Generated text: {vllm_output[i][1]!r}") + + +def test_qwen3_moe_w8a8_dynamic_llm_compressor(): + example_prompts = [ + "The president of the United States is", + ] + max_tokens = 5 + with VllmRunner( + snapshot_download( + "vllm-ascend/Qwen3-30B-A3B-Instruct-2507-quantized.w8a8"), + tensor_parallel_size=2, + max_model_len=4096, + gpu_memory_utilization=0.8, + ) as vllm_model: + vllm_output = vllm_model.generate_greedy(example_prompts, max_tokens) + + golden_results = [ + 'The president of the United States is the head of state and', + ] + + for i in range(len(vllm_output)): + assert golden_results[i] == vllm_output[i][1] + print(f"Generated text: {vllm_output[i][1]!r}") diff --git a/vllm_ascend/quantization/compressed_tensors/compressed_tensors.py b/vllm_ascend/quantization/compressed_tensors/compressed_tensors.py index 774bb006..6266eea4 100644 --- a/vllm_ascend/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm_ascend/quantization/compressed_tensors/compressed_tensors.py @@ -2,7 +2,8 @@ from typing import TYPE_CHECKING, Any, Optional, cast import torch from compressed_tensors.quantization import (QuantizationArgs, - QuantizationStrategy) + QuantizationStrategy, + QuantizationType) from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import (LinearBase, @@ -23,7 +24,8 @@ from vllm_ascend.quantization.quant_config import (AscendFusedMoEMethod, AscendQuantConfig) from vllm_ascend.quantization.w4a16 import AscendW4A16FusedMoEMethod from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod -from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod +from vllm_ascend.quantization.w8a8_dynamic import ( + AscendW8A8DynamicFusedMoEMethod, AscendW8A8DynamicLinearMethod) from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD if TYPE_CHECKING: @@ -75,6 +77,18 @@ class AscendCompressedTensorsConfig(QuantizationConfig): def get_config_filenames(cls) -> list[str]: return [] + def _add_fused_moe_to_target_scheme_map(self): + """ + Helper function to update target_scheme_map + since linear layers get fused into FusedMoE + targeting 'Linear' needs to also match + FusedMoE modules. + """ + if ("Linear" not in self.target_scheme_map + or "FusedMoE" in self.target_scheme_map): + return + self.target_scheme_map["FusedMoE"] = self.target_scheme_map["Linear"] + @classmethod def from_config(cls, config: dict[str, Any]) -> "AscendCompressedTensorsConfig": @@ -155,20 +169,48 @@ class AscendCompressedTensorsConfig(QuantizationConfig): None, layer) return quant_method if isinstance(layer, FusedMoE): - layer.ascend_quant_method = COMPRESSED_TENSORS_METHOD - # collect schemes - quant_scheme = self.get_scheme(layer=layer, layer_name=prefix) + self._add_fused_moe_to_target_scheme_map() + unfused_names = [ + prefix + proj_name for proj_name in + [".0.gate_proj", ".0.up_proj", ".0.down_proj"] + ] + # TODO: refactor this to use expert_mapping and check all layer numbers + all_scheme_dicts = [ + self.get_scheme_dict(layer, name) for name in unfused_names + ] + scheme_dict = all_scheme_dicts.pop() - # choose quantization method - quant_method = AscendUnquantizedFusedMoEMethod(layer.moe_config) - if quant_scheme is not None: - layer.scheme = quant_scheme - ascend_quant_config = AscendQuantConfig(self.quant_description - or {}) - quant_method = AscendFusedMoEMethod( - ascend_quant_config, prefix, - ascend_quant_config.packed_modules_mapping, layer) - return quant_method + # multiple schemes found + if not all( + [cur_dict == scheme_dict for cur_dict in all_scheme_dicts]): + raise ValueError("All MoE projections need to have same " + "quantization scheme but found multiple") + + if scheme_dict is None: + return AscendUnquantizedFusedMoEMethod(layer.moe_config) + + weight_quant = scheme_dict.get("weights") + input_quant = scheme_dict.get("input_activations") + + quant_scheme = None + act_quant_format = is_activation_quantization_format(self.quant_format) + if act_quant_format: + if self._is_dynamic_token_w8a8(weight_quant, input_quant): + quant_scheme = AscendW8A8DynamicFusedMoEMethod() + else: + if self._is_w4a16(weight_quant, input_quant): + quant_scheme = AscendW4A16FusedMoEMethod() + if quant_scheme is None: + raise RuntimeError( + f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}" + ) + layer.scheme = quant_scheme + layer.ascend_quant_method = COMPRESSED_TENSORS_METHOD + + ascend_quant_config = AscendQuantConfig(self.quant_description + or {}) + return AscendFusedMoEMethod(ascend_quant_config, prefix, + self.packed_modules_mapping, layer) return None def get_scheme(self, @@ -188,26 +230,14 @@ class AscendCompressedTensorsConfig(QuantizationConfig): to select the CompressedTensorsScheme used for inference. """ - # Find the "target" in the compressed-tensors config - # that our layer conforms to. - if should_ignore_layer(layer_name, - ignore=self.ignore, - fused_mapping=self.packed_modules_mapping): - return None - - # Will be empty for models with only sparsity - weight_quant = input_quant = None - if self.target_scheme_map: - matched_target = find_matched_target( - layer_name=layer_name, - module=layer, - targets=self.target_scheme_map.keys(), - fused_mapping=self.packed_modules_mapping, - ) - - scheme_dict = self.target_scheme_map[matched_target] + scheme_dict = self.get_scheme_dict(layer, layer_name) + weight_quant = None + input_quant = None + format = None + if scheme_dict: weight_quant = scheme_dict.get("weights") input_quant = scheme_dict.get("input_activations") + format = scheme_dict.get("format") if weight_quant is None: logger.warning_once("Acceleration for non-quantized schemes is " @@ -220,13 +250,54 @@ class AscendCompressedTensorsConfig(QuantizationConfig): scheme = self._get_scheme_from_parts( weight_quant=weight_quant, input_quant=input_quant, + format=format, ) return scheme + def get_scheme_dict( + self, + layer: torch.nn.Module, + layer_name: str | None = None + ) -> dict[str, QuantizationArgs | str | None] | None: + """ + Extract the QuantizationArgs for a given layer. + + Returns: + dict with { + "weights": QuantizationArgs, + "input_activations": QuantizationArgs | None, + "format": str | None + } | None + """ + if should_ignore_layer(layer_name, + ignore=self.ignore, + fused_mapping=self.packed_modules_mapping): + return None + + if self.target_scheme_map: + matched_target = find_matched_target( + layer_name=layer_name, + module=layer, + targets=self.target_scheme_map.keys(), + fused_mapping=self.packed_modules_mapping, + ) + scheme_dict = self.target_scheme_map[matched_target] + if scheme_dict.get("format") is None: + scheme_dict["format"] = self.quant_format + return scheme_dict + + return None + def _get_scheme_from_parts( - self, weight_quant: QuantizationArgs, - input_quant: QuantizationArgs) -> "CompressedTensorsScheme": - act_quant_format = is_activation_quantization_format(self.quant_format) + self, + weight_quant: QuantizationArgs, + input_quant: QuantizationArgs, + format: str | None = None, + ) -> "CompressedTensorsScheme": + # use the per-layer format if defined, otherwise, use global format + format = format if format is not None else self.quant_format + + act_quant_format = is_activation_quantization_format(format) if act_quant_format and input_quant is not None: if self._is_static_tensor_w8a8(weight_quant, input_quant): return AscendW8A8LinearMethod() @@ -234,10 +305,6 @@ class AscendCompressedTensorsConfig(QuantizationConfig): if self._is_dynamic_token_w8a8(weight_quant, input_quant): return AscendW8A8DynamicLinearMethod() - if weight_quant is not None: - if self._is_w4a16(weight_quant): - return AscendW4A16FusedMoEMethod() - raise NotImplementedError( "No compressed-tensors compatible scheme was found.") @@ -269,9 +336,22 @@ class AscendCompressedTensorsConfig(QuantizationConfig): # Only symmetric weight quantization supported. return is_8_bits and is_token and is_symmetric and is_dynamic - def _is_w4a16(self, weight_quant: QuantizationArgs) -> bool: + def _is_w4a16(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs) -> bool: + # Confirm weights quantized. + if weight_quant is None: + return False + + # Confirm we have floating points. + if weight_quant.type != QuantizationType.INT: + return False + + input_quant_none = input_quant is None is_4_bits = weight_quant.num_bits == 4 - return is_4_bits + is_group = (weight_quant.strategy == QuantizationStrategy.GROUP.value) + is_static = not weight_quant.dynamic + + return input_quant_none and is_4_bits and is_group and is_static def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): self.target_scheme_map = hf_to_vllm_mapper.apply_dict(