[Bugfix] Remove ModelSlim-"M4 Quantization". (#4589)
The M4 quantization method in ModelSlim adds bias to model weights that originally do not have a linear bias. PR #4235 supported PD-MIX quantization and M4 quantization, adding bias to `w8a8.py` and `w8a8_dynamic.py`, and implementing adaptations in `ops/linear.py` to prevent it from being reset to `None` by `self.register_parameter("bias", None)`. However, this modification introduced an issue where the bias was still being reset to `None` in certain scenarios, causing errors during service startup. Therefore, support for M4 quantization is temporarily being reverted in this PR. ___ - vLLM version: v0.11.2 Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
This commit is contained in:
@@ -277,20 +277,18 @@ class AscendRowParallelLinear(RowParallelLinear):
|
||||
weight_loader=(
|
||||
self.weight_loader_v2 if self.quant_method.__class__.__name__
|
||||
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
|
||||
bias_initialized_by_quant = ("bias" in self._parameters
|
||||
and self._parameters["bias"] is not None)
|
||||
if not reduce_results and (bias and not skip_bias_add):
|
||||
raise ValueError("When not reduce the results, adding bias to the "
|
||||
"results can lead to incorrect results")
|
||||
|
||||
if bias and not bias_initialized_by_quant:
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size, dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
elif not bias and not bias_initialized_by_quant:
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
if self.custom_op is not None:
|
||||
@@ -368,9 +366,7 @@ class AscendColumnParallelLinear(ColumnParallelLinear):
|
||||
weight_loader=(
|
||||
self.weight_loader_v2 if self.quant_method.__class__.__name__
|
||||
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
|
||||
bias_initialized_by_quant = ("bias" in self._parameters
|
||||
and self._parameters["bias"] is not None)
|
||||
if bias and not bias_initialized_by_quant:
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size_per_partition,
|
||||
dtype=params_dtype))
|
||||
@@ -378,7 +374,7 @@ class AscendColumnParallelLinear(ColumnParallelLinear):
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
elif not bias and not bias_initialized_by_quant:
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
if self.custom_op is not None:
|
||||
@@ -449,16 +445,14 @@ class AscendReplicatedLinear(ReplicatedLinear):
|
||||
self.params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
|
||||
bias_initialized_by_quant = ("bias" in self._parameters
|
||||
and self._parameters["bias"] is not None)
|
||||
if bias and not bias_initialized_by_quant:
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size, dtype=self.params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
elif not bias and not bias_initialized_by_quant:
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
if self.custom_op is not None:
|
||||
|
||||
@@ -87,7 +87,6 @@ class AscendW8A8LinearMethod:
|
||||
params_dict["weight_offset"] = torch.empty(output_size,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
params_dict["bias"] = torch.zeros(output_size, dtype=torch.float32)
|
||||
return params_dict
|
||||
|
||||
def get_pergroup_param(self,
|
||||
@@ -199,13 +198,7 @@ class AscendW8A8LinearMethod:
|
||||
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
|
||||
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
|
||||
layer.bias.data = layer.bias.data.to(layer.weight_scale.data.dtype)
|
||||
|
||||
try:
|
||||
ascend_quant_method = getattr(layer, "ascend_quant_method")
|
||||
except AttributeError:
|
||||
ascend_quant_method = ""
|
||||
|
||||
ascend_quant_method = getattr(layer, "ascend_quant_method", "")
|
||||
if ascend_quant_method == COMPRESSED_TENSORS_METHOD:
|
||||
deq_scale = layer.input_scale.data * layer.weight_scale.data
|
||||
layer.deq_scale = torch.nn.Parameter(deq_scale,
|
||||
|
||||
@@ -60,7 +60,6 @@ class AscendW8A8DynamicLinearMethod:
|
||||
params_dict["weight_offset"] = torch.empty(output_size,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
params_dict["bias"] = torch.zeros(output_size, dtype=torch.float32)
|
||||
return params_dict
|
||||
|
||||
def get_pergroup_param(self,
|
||||
@@ -98,7 +97,6 @@ class AscendW8A8DynamicLinearMethod:
|
||||
layer.weight_scale.data = layer.weight_scale.data.flatten()
|
||||
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
|
||||
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
||||
layer.bias.data = layer.bias.data.to(layer.weight_scale.data.dtype)
|
||||
|
||||
|
||||
class AscendW8A8DynamicFusedMoEMethod:
|
||||
|
||||
Reference in New Issue
Block a user