[Feat] Unquantized linear nz support (#2619)

### What this PR does / why we need it?
Currently, when executing to the Linear layer of the model in
vLLM-Ascend, the weights input format is ND in unquantized case and
skipped ascend case, which is slower than FRACTAL_NZ.
This PR supplements the execution logic for Linear layer. When
VLLM_ASCEND_ENABLE_MLP_OPTIMIZE=1 and CANN version is 8.3, the weights
of the Linear layer will be converted to FRACTAL_NZ, in both unquantized
case and skipped ascend case.

- vLLM version: main
- vLLM main:
267c80d31f

Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
This commit is contained in:
anon189Ty
2025-09-11 11:40:00 +08:00
committed by GitHub
parent 5691104249
commit 7b2ecc1e9a
4 changed files with 111 additions and 10 deletions

View File

@@ -36,12 +36,36 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
get_otp_group)
from vllm_ascend.utils import (dense_optim_enable, matmul_allreduce_enable,
mlp_tp_enable, oproj_tp_enable)
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, dense_optim_enable,
matmul_allreduce_enable, mlp_tp_enable,
oproj_tp_enable)
_HCOMM_INFO = None
class AscendUnquantizedLinearMethod(UnquantizedLinearMethod):
"""Linear method without quantization."""
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
if torch.version.cann.startswith("8.3"):
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight.data = torch_npu.npu_format_cast(
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if torch.version.cann.startswith("8.3"):
if bias is None:
return torch.matmul(x, layer.weight)
else:
return torch.matmul(x, layer.weight) + bias
else:
return torch.nn.functional.linear(x, layer.weight, bias)
class AscendColumnParallelLinear(ColumnParallelLinear):
"""Linear layer with column parallelism.
@@ -617,7 +641,7 @@ class AscendLinearBase(LinearBase):
self.prefix = prefix
if quant_config is None:
self.quant_method: Optional[
QuantizeMethodBase] = UnquantizedLinearMethod()
QuantizeMethodBase] = AscendUnquantizedLinearMethod()
else:
self.quant_method = quant_config.get_quant_method(self,
prefix=prefix)

View File

@@ -23,8 +23,7 @@ from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
RowParallelLinear,
UnquantizedLinearMethod)
RowParallelLinear)
from vllm.model_executor.layers.quantization import \
register_quantization_config
from vllm.model_executor.layers.quantization.base_config import (
@@ -38,6 +37,7 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
get_otp_group)
from vllm_ascend.ops.fused_moe import AscendUnquantizedFusedMoEMethod
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, mlp_tp_enable,
oproj_tp_enable)
@@ -92,7 +92,7 @@ class AscendQuantConfig(QuantizationConfig):
if isinstance(layer, LinearBase):
if self.is_layer_skipped_ascend(prefix,
self.packed_modules_mapping):
return UnquantizedLinearMethod()
return AscendUnquantizedLinearMethod()
return AscendLinearMethod(self, prefix,
self.packed_modules_mapping)
elif isinstance(layer, Attention) and \