diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index a181f2cb..3b780268 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -387,7 +387,7 @@ class AscendFusedMoE(FusedMoE): def transpose_weight(self, loaded_weight, expert_data, shard_dim): # Ensure training and inference weight shapes match during RL weight updates - if ( + if (len(loaded_weight.shape) >= 2 and len(expert_data.shape) >= 2 and \ loaded_weight.shape[1] != expert_data.shape[1] and \ loaded_weight.shape[0] != expert_data.shape[0] ): diff --git a/vllm_ascend/ops/linear.py b/vllm_ascend/ops/linear.py index 844cdcbd..3b6f7100 100644 --- a/vllm_ascend/ops/linear.py +++ b/vllm_ascend/ops/linear.py @@ -277,18 +277,20 @@ class AscendRowParallelLinear(RowParallelLinear): weight_loader=( self.weight_loader_v2 if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) + bias_initialized_by_quant = ("bias" in self._parameters + and self._parameters["bias"] is not None) if not reduce_results and (bias and not skip_bias_add): raise ValueError("When not reduce the results, adding bias to the " "results can lead to incorrect results") - if bias: + if bias and not bias_initialized_by_quant: self.bias = Parameter( torch.empty(self.output_size, dtype=params_dtype)) set_weight_attrs(self.bias, { "output_dim": 0, "weight_loader": self.weight_loader, }) - else: + elif not bias and not bias_initialized_by_quant: self.register_parameter("bias", None) if self.custom_op is not None: @@ -366,7 +368,9 @@ class AscendColumnParallelLinear(ColumnParallelLinear): weight_loader=( self.weight_loader_v2 if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) - if bias: + bias_initialized_by_quant = ("bias" in self._parameters + and self._parameters["bias"] is not None) + if bias and not bias_initialized_by_quant: self.bias = Parameter( torch.empty(self.output_size_per_partition, dtype=params_dtype)) @@ -374,7 +378,7 @@ class AscendColumnParallelLinear(ColumnParallelLinear): "output_dim": 0, "weight_loader": self.weight_loader, }) - else: + elif not bias and not bias_initialized_by_quant: self.register_parameter("bias", None) if self.custom_op is not None: @@ -445,14 +449,16 @@ class AscendReplicatedLinear(ReplicatedLinear): self.params_dtype, weight_loader=self.weight_loader) - if bias: + bias_initialized_by_quant = ("bias" in self._parameters + and self._parameters["bias"] is not None) + if bias and not bias_initialized_by_quant: self.bias = Parameter( torch.empty(self.output_size, dtype=self.params_dtype)) set_weight_attrs(self.bias, { "output_dim": 0, "weight_loader": self.weight_loader, }) - else: + elif not bias and not bias_initialized_by_quant: self.register_parameter("bias", None) if self.custom_op is not None: diff --git a/vllm_ascend/quantization/utils.py b/vllm_ascend/quantization/utils.py index eaaaee86..be43726e 100644 --- a/vllm_ascend/quantization/utils.py +++ b/vllm_ascend/quantization/utils.py @@ -12,6 +12,8 @@ from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod, AscendW8A8LinearMethod) from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod, AscendW8A8DynamicLinearMethod) +from .w8a8_pdmix import (AscendW8A8PDMixFusedMoeMethod, + AscendW8A8PDMixLinearMethod) ASCEND_QUANTIZATION_METHOD_MAP: Dict[str, Dict[str, Type[Any]]] = { "W4A8_DYNAMIC": { @@ -30,6 +32,10 @@ ASCEND_QUANTIZATION_METHOD_MAP: Dict[str, Dict[str, Type[Any]]] = { "linear": AscendW8A8DynamicLinearMethod, "moe": AscendW8A8DynamicFusedMoEMethod, }, + "W8A8_MIX": { + "linear": AscendW8A8PDMixLinearMethod, + "moe": AscendW8A8PDMixFusedMoeMethod, + }, "C8": { "attention": AscendC8KVCacheMethod, }, diff --git a/vllm_ascend/quantization/w8a8.py b/vllm_ascend/quantization/w8a8.py index 8a7bbfe7..ceb42c53 100644 --- a/vllm_ascend/quantization/w8a8.py +++ b/vllm_ascend/quantization/w8a8.py @@ -87,6 +87,7 @@ class AscendW8A8LinearMethod: params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype) + params_dict["bias"] = torch.zeros(output_size, dtype=torch.float32) return params_dict def get_pergroup_param(self, @@ -192,6 +193,7 @@ class AscendW8A8LinearMethod: layer.weight.data, ACL_FORMAT_FRACTAL_NZ) layer.weight_scale.data = torch.flatten(layer.weight_scale.data) layer.weight_offset.data = torch.flatten(layer.weight_offset.data) + layer.bias.data = layer.bias.data.to(layer.weight_scale.data.dtype) if getattr(layer, "ascend_quant_method", "") == COMPRESSED_TENSORS_METHOD: deq_scale = layer.input_scale.data * layer.weight_scale.data diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 6b7d6b08..e64814be 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -60,6 +60,7 @@ class AscendW8A8DynamicLinearMethod: params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype) + params_dict["bias"] = torch.zeros(output_size, dtype=torch.float32) return params_dict def get_pergroup_param(self, @@ -110,6 +111,7 @@ class AscendW8A8DynamicLinearMethod: layer.weight_scale.data = layer.weight_scale.data.flatten() layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32) layer.weight_offset.data = layer.weight_offset.data.flatten() + layer.bias.data = layer.bias.data.to(layer.weight_scale.data.dtype) class AscendW8A8DynamicFusedMoEMethod: diff --git a/vllm_ascend/quantization/w8a8_pdmix.py b/vllm_ascend/quantization/w8a8_pdmix.py new file mode 100644 index 00000000..0fa74f7e --- /dev/null +++ b/vllm_ascend/quantization/w8a8_pdmix.py @@ -0,0 +1,70 @@ +from typing import Any, Dict, cast + +import torch +from vllm.config import get_current_vllm_config + +from .w8a8 import AscendW8A8LinearMethod +from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod, + AscendW8A8DynamicLinearMethod) + + +class AscendW8A8PDMixLinearMethod(AscendW8A8DynamicLinearMethod): + + def __init__(self): + self.kv_transfer_config = get_current_vllm_config().kv_transfer_config + super().__init__() + + @staticmethod + def apply(layer, x, bias=None, tp_rank=0): + if layer.is_kv_consumer: + return AscendW8A8LinearMethod.apply(layer, x, bias, tp_rank) + else: + return AscendW8A8DynamicLinearMethod.apply(layer, x, bias, tp_rank) + + @staticmethod + def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: + return AscendW8A8LinearMethod.get_pertensor_param(params_dtype) + + @staticmethod + def get_perchannel_param( + output_size: int, + params_dtype: torch.dtype, + ) -> Dict[str, Any]: + return AscendW8A8LinearMethod.get_perchannel_param( + output_size, params_dtype) + + def process_weights_after_loading(self, layer): + AscendW8A8LinearMethod.process_weights_after_loading( + cast(AscendW8A8LinearMethod, self), layer) + layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32) + layer.is_kv_consumer = self.kv_transfer_config is not None and self.kv_transfer_config.is_kv_consumer + + +class AscendW8A8PDMixFusedMoeMethod(AscendW8A8DynamicFusedMoEMethod): + + def __init__(self): + super().__init__() + + @staticmethod + def get_dynamic_quant_param(num_experts: int, + intermediate_size_per_partition: int, + hidden_sizes: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + param_dict = AscendW8A8DynamicFusedMoEMethod.get_dynamic_quant_param( + num_experts, intermediate_size_per_partition, hidden_sizes, + params_dtype) + param_dict["w2_deq_scale"] = torch.empty(num_experts, + hidden_sizes, + dtype=torch.float32) + param_dict["w13_deq_scale"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + dtype=torch.float32) + param_dict["w2_input_offset"] = torch.empty(num_experts, + 1, + dtype=torch.int8) + param_dict["w13_input_offset"] = torch.empty(num_experts, + 1, + dtype=torch.int8) + + return param_dict