diff --git a/.github/workflows/misc/model_list.json b/.github/workflows/misc/model_list.json index 001c831d..ee709ac2 100644 --- a/.github/workflows/misc/model_list.json +++ b/.github/workflows/misc/model_list.json @@ -206,6 +206,7 @@ "vllm-ascend/Qwen3-30B-A3B-W8A8", "vllm-ascend/Qwen3-30B-A3B-W8A8-Pruning", "vllm-ascend/Qwen3-30B-A3B-Instruct-2507-quantized.w8a8", + "vllm-ascend/Qwen3-30B-A3B-Instruct-2507-quantized.w4a8", "vllm-ascend/Qwen3-32B-W4A4", "vllm-ascend/Qwen3-32B-W8A8", "vllm-ascend/Qwen3-8B", diff --git a/examples/quantization/llm-compressor/w4a8_dynamic_moe.py b/examples/quantization/llm-compressor/w4a8_dynamic_moe.py new file mode 100644 index 00000000..04ff8657 --- /dev/null +++ b/examples/quantization/llm-compressor/w4a8_dynamic_moe.py @@ -0,0 +1,58 @@ +from llmcompressor import oneshot +from transformers import AutoModelForCausalLM, AutoTokenizer + +MODEL_ID = "Qwen/Qwen3-30B-A3B-Instruct-2507" + +# Load model. +model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + +recipe = """ +quant_stage: + quant_modifiers: + QuantizationModifier: + ignore: ["lm_head", "re:.*mlp.gate$"] + config_groups: + group_0: + weights: + num_bits: 8 + type: int + strategy: channel + dynamic: false + symmetric: true + input_activations: + num_bits: 8 + type: int + strategy: token + dynamic: true + symmetric: true + targets: ["re:.*self_attn.k_proj.*", "re:.*self_attn.o_proj.*", + "re:.*self_attn.q_proj.*", "re:.*self_attn.v_proj.*"] + group_1: + weights: + num_bits: 4 + type: int + strategy: group + group_size: 128 + dynamic: false + symmetric: true + input_activations: + num_bits: 8 + type: int + strategy: token + dynamic: true + symmetric: true + targets: ["re:.*down_proj.*", "re:.*gate_proj.*", "re:.*up_proj.*"] +""" + +# Apply quantization. +oneshot( + model=model, + recipe=recipe, + trust_remote_code_model=True, +) + +# Save to disk in compressed-tensors format. +SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W4A8" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) diff --git a/tests/e2e/multicard/2-cards/test_quantization.py b/tests/e2e/multicard/2-cards/test_quantization.py index b356ba3b..da45628b 100644 --- a/tests/e2e/multicard/2-cards/test_quantization.py +++ b/tests/e2e/multicard/2-cards/test_quantization.py @@ -64,3 +64,24 @@ def test_qwen3_moe_w8a8_dynamic_llm_compressor(): for i in range(len(vllm_output)): assert golden_results[i] == vllm_output[i][1] print(f"Generated text: {vllm_output[i][1]!r}") + +def test_qwen3_moe_w4a8_dynamic_llm_compressor(): + example_prompts = [ + "The president of the United States is", + ] + max_tokens = 5 + with VllmRunner( + "vllm-ascend/Qwen3-30B-A3B-Instruct-2507-quantized.w4a8", + tensor_parallel_size=2, + max_model_len=4096, + gpu_memory_utilization=0.8, + ) as vllm_model: + vllm_output = vllm_model.generate_greedy(example_prompts, max_tokens) + + golden_results = [ + 'The president of the United States is the head of state and', + ] + + for i in range(len(vllm_output)): + assert golden_results[i] == vllm_output[i][1] + print(f"Generated text: {vllm_output[i][1]!r}") diff --git a/vllm_ascend/quantization/compressed_tensors_config.py b/vllm_ascend/quantization/compressed_tensors_config.py index 30834e70..7896d1b4 100644 --- a/vllm_ascend/quantization/compressed_tensors_config.py +++ b/vllm_ascend/quantization/compressed_tensors_config.py @@ -187,8 +187,9 @@ class AscendCompressedTensorsConfig(QuantizationConfig): AscendUnquantizedFusedMoEMethod layer.ascend_quant_method = COMPRESSED_TENSORS_METHOD + layer_name = prefix + ".0.gate_proj" # Get the scheme for this layer - moe_scheme = self._get_moe_scheme(layer=layer, layer_name=prefix) + moe_scheme = self._get_moe_scheme(layer=layer, layer_name=layer_name) # Return unquantized method if no scheme found if moe_scheme is None: @@ -382,6 +383,9 @@ class AscendCompressedTensorsConfig(QuantizationConfig): if self._is_dynamic_token_w8a8(weight_quant, input_quant): return "W8A8_DYNAMIC" + if self._is_dynamic_token_w4a8(weight_quant, input_quant): + return "W4A8_DYNAMIC" + if self._is_w4a16(weight_quant, input_quant): return "W4A16" @@ -415,6 +419,30 @@ class AscendCompressedTensorsConfig(QuantizationConfig): # Only symmetric input quantization supported. # Only symmetric weight quantization supported. return is_8_bits and is_token and is_symmetric and is_dynamic + + def _is_dynamic_token_w4a8(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs) -> bool: + is_4_bits = weight_quant.num_bits == 4 + is_8_bits = input_quant.num_bits == 8 + weight_strategy = ( + weight_quant.strategy == QuantizationStrategy.CHANNEL.value) or (weight_quant.strategy == QuantizationStrategy.GROUP.value) + is_token = (weight_strategy and input_quant.strategy + == QuantizationStrategy.TOKEN.value) + is_dynamic = not weight_quant.dynamic and input_quant.dynamic + is_symmetric = weight_quant.symmetric and input_quant.symmetric + + # Adapt for AscendW4A8DynamicFusedMoEMethod + assert self.quant_description is not None, "quant_description should not be None" + if weight_strategy: + self.quant_description["group_size"] = weight_quant.group_size if weight_quant.group_size else 0 + + self.quant_description["version"] = "0" + self.quant_description["ascend_quant_method"] = COMPRESSED_TENSORS_METHOD + self.quant_description["weight_strategy"] = str(weight_quant.strategy) + + # Only symmetric input quantization supported. + # Only symmetric weight quantization supported. + return is_4_bits and is_8_bits and is_token and is_symmetric and is_dynamic def _is_w4a16(self, weight_quant: "QuantizationArgs", input_quant: Optional["QuantizationArgs"]) -> bool: diff --git a/vllm_ascend/quantization/methods/w4a8.py b/vllm_ascend/quantization/methods/w4a8.py index 4510b44a..a5fc3afa 100644 --- a/vllm_ascend/quantization/methods/w4a8.py +++ b/vllm_ascend/quantization/methods/w4a8.py @@ -27,7 +27,7 @@ from vllm.forward_context import get_forward_context from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.fused_moe.experts_selector import select_experts -from vllm_ascend.utils import maybe_trans_nz +from vllm_ascend.utils import maybe_trans_nz, COMPRESSED_TENSORS_METHOD from .base import AscendLinearScheme, AscendMoEScheme, QuantType from .registry import register_scheme @@ -217,6 +217,13 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme): "version", "0") # NOTE: new quantize weights: 2 int4 pack into int8 self.new_quant_version = quant_version == "1.0.0" + + self.quant_method = vllm_config.quant_config.quant_description.get( + "ascend_quant_method", "") + if self.quant_method == COMPRESSED_TENSORS_METHOD: + self.weight_strategy = vllm_config.quant_config.quant_description.get( + "weight_strategy", "group") + self.tp_size = 1 if vllm_config.parallel_config.enable_expert_parallel else self.ep_group.world_size self.dynamic_eplb = get_ascend_config().eplb_config.dynamic_eplb if self.new_quant_version and self.tp_size > 16: @@ -236,6 +243,35 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme): def get_weight(self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype) -> Dict[str, Any]: + if self.quant_method == COMPRESSED_TENSORS_METHOD: + return self.get_weight_compressed_tensors( + num_experts, intermediate_size_per_partition, + hidden_sizes, params_dtype) + else: + return self.get_weight_modelslim( + num_experts, intermediate_size_per_partition, + hidden_sizes, params_dtype) + + def get_weight_compressed_tensors(self, num_experts: int, + intermediate_size_per_partition: int, hidden_sizes: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + + param_dict = {} + E = num_experts + H = hidden_sizes + IN = intermediate_size_per_partition + g = self.group_size + + param_dict["w13_weight"] = torch.empty(E, 2 * IN, H, + dtype=torch.int8) + param_dict["w2_weight"] = torch.empty(E, H, IN, + dtype=torch.int8) + return param_dict + + + def get_weight_modelslim(self, num_experts: int, + intermediate_size_per_partition: int, hidden_sizes: int, + params_dtype: torch.dtype) -> Dict[str, Any]: param_dict = {} if self.new_quant_version: w13_output_size = intermediate_size_per_partition @@ -258,6 +294,42 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme): intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype) -> Dict[str, Any]: + if self.quant_method == COMPRESSED_TENSORS_METHOD: + return self.get_dynamic_quant_param_compressed_tensors( + num_experts, intermediate_size_per_partition, + hidden_sizes, params_dtype) + else: + return self.get_dynamic_quant_param_modelslim( + num_experts, intermediate_size_per_partition, + hidden_sizes, params_dtype) + + def get_dynamic_quant_param_compressed_tensors(self, num_experts: int, + intermediate_size_per_partition: int, + hidden_sizes: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + param_dict = {} + + E = num_experts + H = hidden_sizes + IN = intermediate_size_per_partition + g = self.group_size + + # Per-row scale columns + def _n_scale_cols(in_features: int) -> int: + return 1 if g <= 0 else (in_features // g) + + param_dict["w13_weight_scale"] = torch.empty( + E, 2 * IN, _n_scale_cols(H), dtype=torch.bfloat16) + + param_dict["w2_weight_scale"] = torch.empty(E, H, _n_scale_cols(IN), + dtype=torch.bfloat16) + + return param_dict + + def get_dynamic_quant_param_modelslim(self, num_experts: int, + intermediate_size_per_partition: int, + hidden_sizes: int, + params_dtype: torch.dtype) -> Dict[str, Any]: param_dict = {} param_dict["w13_weight_scale"] = torch.empty( num_experts, @@ -374,8 +446,10 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme): w2=[layer.w2_weight], w1_scale=[layer.w13_weight_scale], w2_scale=[layer.w2_weight_scale], - w1_scale_bias=layer.w13_scale_bias, - w2_scale_bias=layer.w2_scale_bias, + w1_scale_bias=layer.w13_scale_bias if hasattr( + layer, "w13_scale_bias") else None, + w2_scale_bias=layer.w2_scale_bias if hasattr( + layer, "w2_scale_bias") else None, topk_weights=topk_weights, topk_ids=topk_ids, use_int4_w4a8=True, @@ -445,6 +519,70 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme): torch.quint4x2, -1, False) def process_weights_after_loading(self, layer): + if self.quant_method == COMPRESSED_TENSORS_METHOD: + self.process_weights_after_loading_compressed_tensors(layer) + else: + self.process_weights_after_loading_modelslim(layer) + + + def process_weights_after_loading_compressed_tensors(self, layer): + layer.w13_weight.data = layer.w13_weight.data.transpose( + 1, 2).contiguous() + layer.w2_weight.data = layer.w2_weight.data.transpose(1, + 2).contiguous() + + def process_scale_compressed_tensors(scale: torch.Tensor): + scale = scale.transpose(1, 2).to(torch.float32).contiguous() + scale_np = scale.cpu().numpy() + scale_np.dtype = np.uint32 + scale_uint64_tensor = torch.from_numpy(scale_np.astype( + np.int64)).npu() + return scale_uint64_tensor + + def update_bias_compressed_tensors(weight: torch.Tensor, + scale: torch.Tensor, strategy:str): + group_num, k, n = weight.shape + scale = scale.transpose(1, 2).contiguous() + scale = scale.reshape(group_num, -1, n) + group_num, quantgroup_num, n = scale.shape + + bias = None + if strategy == "group": + tmp = weight.to(torch.float32).reshape([group_num, quantgroup_num, -1, n]) * \ + scale.reshape([group_num, quantgroup_num, 1, n]) + tmp = tmp.reshape([group_num, k, n]) + bias = 8 * tmp.sum(axis=1) + elif strategy == "channel": + bias = 8 * (weight.to(torch.float32) * scale).sum(axis=1) + else: + raise ValueError(f"Unsupported weight strategy: {strategy}") + return bias + + w13_bias = update_bias_compressed_tensors(layer.w13_weight.data, + layer.w13_weight_scale.data, + self.weight_strategy) + w2_bias = update_bias_compressed_tensors(layer.w2_weight.data, + layer.w2_weight_scale.data, + self.weight_strategy) + + layer.w13_weight_scale.data = process_scale_compressed_tensors( + layer.w13_weight_scale.data) + layer.w2_weight_scale.data = process_scale_compressed_tensors( + layer.w2_weight_scale.data) + + + w13_scale_bias = torch.nn.Parameter(w13_bias, requires_grad=False) + layer.register_parameter("w13_scale_bias", w13_scale_bias) + w2_scale_bias = torch.nn.Parameter(w2_bias, requires_grad=False) + layer.register_parameter("w2_scale_bias", w2_scale_bias) + + # Accuracy problem in nz format + # layer.w13_weight.data = maybe_trans_nz(layer.w13_weight.data) + # layer.w2_weight.data = maybe_trans_nz(layer.w2_weight.data) + layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data) + layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data) + + def process_weights_after_loading_modelslim(self, layer): layer.w13_weight.data = layer.w13_weight.data.transpose( 1, 2).contiguous() layer.w2_weight.data = layer.w2_weight.data.transpose(1,