[refactor] refactor weight trans nz and transpose (#4878)

### What this PR does / why we need it?

Now `VLLM_ASCEND_ENABLE_NZ` will have three options:
0: disable nz;
1: only quant case enable nz;
2: enable nz as long as possible;

And `VLLM_ASCEND_ENABLE_NZ`=1 by default.

All cases are shown in the table below:

|  | W4A4 | W4A8 | W8A8 | fp16/bf16 | fp32 |
|---|---|---|---|---|---|
| trans nz | can't support nz | trans nz by default | trans nz by
default | trans nz when VLLM_ASCEND_ENABLE_NZ is 2 | can't support nz |
| transpose | only support not transpose case | only support transpose
case | only support transpose case | linear: only support not transpose
case<br>gmm: only support transpose case | same to fp16/bf16 |

Some exceptional cases:
1. MLAPO op need to do some additional processing on the weights,
including trans nz. If use MLAPO op, some weight will be transformed to
nz forcely;
2. MLA/SFA's weight `W_UV` will be used by op
`torch.ops._C_ascend.batch_matmul_transpose`, and this op can't support
nz currently;

### Does this PR introduce _any_ user-facing change?
Now fp16/bf16 weight will not trans nz by default.

### How was this patch tested?

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
zzzzwwjj
2025-12-19 14:27:24 +08:00
committed by GitHub
parent ea8f544ce7
commit cc23067f1e
19 changed files with 156 additions and 255 deletions

View File

@@ -27,7 +27,7 @@ from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
from vllm_ascend.utils import maybe_trans_nz
class AscendW4A8DynamicLinearMethod:
@@ -35,8 +35,6 @@ class AscendW4A8DynamicLinearMethod:
"""
def __init__(self):
self.transpose_weight = True
vllm_config = get_current_vllm_config()
self.group_size = vllm_config.quant_config.quant_description.get(
"group_size", 256)
@@ -170,8 +168,8 @@ class AscendW4A8DynamicLinearMethod:
)
def process_weights_after_loading(self, layer: torch.nn.Module):
if self.transpose_weight:
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight.data = maybe_trans_nz(layer.weight.data)
layer.weight_scale.data = layer.weight_scale.data.flatten().to(
torch.float32)
layer.weight_offset.data = layer.weight_offset.data.flatten()
@@ -214,8 +212,6 @@ class AscendW4A8DynamicFusedMoEMethod:
"""
def __init__(self):
self.transpose_weight = True
self.ep_group = get_ep_group()
vllm_config = get_current_vllm_config()
@@ -462,11 +458,10 @@ class AscendW4A8DynamicFusedMoEMethod:
torch.quint4x2, -1, False)
def process_weights_after_loading(self, layer):
if self.transpose_weight:
layer.w13_weight.data = layer.w13_weight.data.transpose(
1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose(
1, 2).contiguous()
layer.w13_weight.data = layer.w13_weight.data.transpose(
1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose(1,
2).contiguous()
w13_weight_scale_second = layer.w13_weight_scale_second.data if hasattr(
layer, "w13_weight_scale_second") else None
@@ -487,10 +482,7 @@ class AscendW4A8DynamicFusedMoEMethod:
self.update_bias(layer, w13_bias, w2_bias)
if is_enable_nz():
layer.w13_weight.data = torch_npu.npu_format_cast(
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.w2_weight.data = torch_npu.npu_format_cast(
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.w13_weight.data = maybe_trans_nz(layer.w13_weight.data)
layer.w2_weight.data = maybe_trans_nz(layer.w2_weight.data)
layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data)
layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data)