diff --git a/vllm_ascend/quantization/methods/w4a4_laos_dynamic.py b/vllm_ascend/quantization/methods/w4a4_laos_dynamic.py index 204cf95e..455b57fa 100644 --- a/vllm_ascend/quantization/methods/w4a4_laos_dynamic.py +++ b/vllm_ascend/quantization/methods/w4a4_laos_dynamic.py @@ -35,20 +35,6 @@ class AscendW4A4LaosDynamicLinearMethod(AscendLinearScheme): def __init__(self): self.transpose_weight = True - self.rotation_type = None - - def set_rotation_config(self, prefix: str, metadata: dict) -> str | None: - """Set rotation config based on prefix and metadata.""" - layer_idx = prefix.split(".")[2] - if prefix.endswith("o_proj"): - layers = metadata["quarot"]["heads_rotation"]["layers"] - if layer_idx in layers: - return "heads_rotation" - if prefix.endswith("down_proj"): - layers = metadata["quarot"]["kronecker_rotation"]["layers"] - if layer_idx in layers: - return "kronecker_rotation" - return None def get_weight(self, input_size: int, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]: params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)} @@ -58,32 +44,8 @@ class AscendW4A4LaosDynamicLinearMethod(AscendLinearScheme): params_dict = {} params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=torch.float32) params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=torch.float32) - if self.rotation_type == "heads_rotation": - params_dict["heads_rotation"] = torch.zeros((64, 64), dtype=torch.float32) - if self.rotation_type == "kronecker_rotation": - params_dict["kronecker_rotation_n"] = torch.zeros((160, 160), dtype=torch.float32) - params_dict["kronecker_rotation_m"] = torch.zeros((160, 160), dtype=torch.float32) return params_dict - def apply_rotation(self, layer: torch.nn.Module, x: torch.Tensor) -> torch.Tensor: - """Apply rotation transformation to input tensor.""" - init_shape = x.shape - dtype = x.dtype - if self.rotation_type == "heads_rotation": - Q1 = layer.heads_rotation - scaled_x = x.reshape(-1, Q1.shape[1], 128) - scaled_x = torch.matmul(Q1.T, scaled_x).reshape(init_shape) - return scaled_x.to(dtype) - if self.rotation_type == "kronecker_rotation": - Q1 = layer.kronecker_rotation_m - Q2 = layer.kronecker_rotation_n - scaled_x = x.reshape(-1, Q1.shape[0], Q2.shape[0]) - scaled_x = torch.matmul(scaled_x, Q2) - scaled_x = torch.matmul(Q1.T, scaled_x) - scaled_x = scaled_x.reshape(init_shape) - return scaled_x.to(dtype) - return x - def apply( self, layer: torch.nn.Module,