[refactor] refactor weight trans nz and transpose (#4878)
### What this PR does / why we need it?
Now `VLLM_ASCEND_ENABLE_NZ` will have three options:
0: disable nz;
1: only quant case enable nz;
2: enable nz as long as possible;
And `VLLM_ASCEND_ENABLE_NZ`=1 by default.
All cases are shown in the table below:
| | W4A4 | W4A8 | W8A8 | fp16/bf16 | fp32 |
|---|---|---|---|---|---|
| trans nz | can't support nz | trans nz by default | trans nz by
default | trans nz when VLLM_ASCEND_ENABLE_NZ is 2 | can't support nz |
| transpose | only support not transpose case | only support transpose
case | only support transpose case | linear: only support not transpose
case<br>gmm: only support transpose case | same to fp16/bf16 |
Some exceptional cases:
1. MLAPO op need to do some additional processing on the weights,
including trans nz. If use MLAPO op, some weight will be transformed to
nz forcely;
2. MLA/SFA's weight `W_UV` will be used by op
`torch.ops._C_ascend.batch_matmul_transpose`, and this op can't support
nz currently;
### Does this PR introduce _any_ user-facing change?
Now fp16/bf16 weight will not trans nz by default.
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
@@ -86,7 +86,6 @@ class AscendW4A4FlatQuantDynamicLinearMethod:
|
||||
input_size = 0
|
||||
|
||||
def __init__(self):
|
||||
self.transpose_weight = False
|
||||
self.sym = True
|
||||
|
||||
@staticmethod
|
||||
@@ -176,9 +175,8 @@ class AscendW4A4FlatQuantDynamicLinearMethod:
|
||||
return output
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
# NOTE: Currently, w4a4 can't support weight nz
|
||||
weight_packed = pack_int4_weights(layer.weight.data)
|
||||
if self.transpose_weight:
|
||||
weight_packed = weight_packed.transpose(0, 1).contiguous()
|
||||
layer.register_parameter(
|
||||
'weight_packed',
|
||||
torch.nn.Parameter(weight_packed, requires_grad=False))
|
||||
|
||||
@@ -27,7 +27,7 @@ from vllm.forward_context import get_forward_context
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
|
||||
from vllm_ascend.utils import maybe_trans_nz
|
||||
|
||||
|
||||
class AscendW4A8DynamicLinearMethod:
|
||||
@@ -35,8 +35,6 @@ class AscendW4A8DynamicLinearMethod:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.transpose_weight = True
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.group_size = vllm_config.quant_config.quant_description.get(
|
||||
"group_size", 256)
|
||||
@@ -170,8 +168,8 @@ class AscendW4A8DynamicLinearMethod:
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
if self.transpose_weight:
|
||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||
layer.weight.data = maybe_trans_nz(layer.weight.data)
|
||||
layer.weight_scale.data = layer.weight_scale.data.flatten().to(
|
||||
torch.float32)
|
||||
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
||||
@@ -214,8 +212,6 @@ class AscendW4A8DynamicFusedMoEMethod:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.transpose_weight = True
|
||||
|
||||
self.ep_group = get_ep_group()
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
@@ -462,11 +458,10 @@ class AscendW4A8DynamicFusedMoEMethod:
|
||||
torch.quint4x2, -1, False)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if self.transpose_weight:
|
||||
layer.w13_weight.data = layer.w13_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w13_weight.data = layer.w13_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight.data = layer.w2_weight.data.transpose(1,
|
||||
2).contiguous()
|
||||
|
||||
w13_weight_scale_second = layer.w13_weight_scale_second.data if hasattr(
|
||||
layer, "w13_weight_scale_second") else None
|
||||
@@ -487,10 +482,7 @@ class AscendW4A8DynamicFusedMoEMethod:
|
||||
|
||||
self.update_bias(layer, w13_bias, w2_bias)
|
||||
|
||||
if is_enable_nz():
|
||||
layer.w13_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.w2_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.w13_weight.data = maybe_trans_nz(layer.w13_weight.data)
|
||||
layer.w2_weight.data = maybe_trans_nz(layer.w2_weight.data)
|
||||
layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data)
|
||||
layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data)
|
||||
|
||||
@@ -21,9 +21,8 @@ import torch
|
||||
import torch_npu
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ,
|
||||
COMPRESSED_TENSORS_METHOD, AscendDeviceType,
|
||||
get_ascend_device_type, is_enable_nz)
|
||||
from vllm_ascend.utils import (COMPRESSED_TENSORS_METHOD, AscendDeviceType,
|
||||
get_ascend_device_type, maybe_trans_nz)
|
||||
|
||||
|
||||
def quant_per_tensor(in_tensor: torch.Tensor,
|
||||
@@ -42,9 +41,7 @@ class AscendW8A8LinearMethod:
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# aclnn quant matmul requires to transpose matrix B, set to true by default.
|
||||
self.transpose_weight = get_ascend_device_type(
|
||||
) != AscendDeviceType._310P
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def get_weight(
|
||||
@@ -189,11 +186,9 @@ class AscendW8A8LinearMethod:
|
||||
layer.aclnn_input_offset = torch.nn.Parameter(
|
||||
layer.input_offset.data.repeat(expanding_factor),
|
||||
requires_grad=False).to(layer.aclnn_input_scale.dtype)
|
||||
if self.transpose_weight:
|
||||
if get_ascend_device_type() != AscendDeviceType._310P:
|
||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||
if is_enable_nz():
|
||||
layer.weight.data = torch_npu.npu_format_cast(
|
||||
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.weight.data = maybe_trans_nz(layer.weight.data)
|
||||
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
|
||||
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
|
||||
ascend_quant_method = getattr(layer, "ascend_quant_method", "")
|
||||
|
||||
@@ -29,7 +29,7 @@ from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.flash_common3_context import get_flash_common3_context
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
|
||||
from vllm_ascend.utils import maybe_trans_nz
|
||||
|
||||
|
||||
class AscendW8A8DynamicLinearMethod:
|
||||
@@ -37,7 +37,7 @@ class AscendW8A8DynamicLinearMethod:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.transpose_weight = True
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def get_weight(input_size: int, output_size: int,
|
||||
@@ -91,12 +91,9 @@ class AscendW8A8DynamicLinearMethod:
|
||||
return output
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if self.transpose_weight:
|
||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||
# cast quantized weight tensors in NZ format for higher inference speed
|
||||
if is_enable_nz():
|
||||
layer.weight.data = torch_npu.npu_format_cast(
|
||||
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.weight.data = maybe_trans_nz(layer.weight.data)
|
||||
layer.weight_scale.data = layer.weight_scale.data.flatten()
|
||||
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
|
||||
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
||||
@@ -107,8 +104,6 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.transpose_weight = True
|
||||
|
||||
self.ep_group = get_ep_group()
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
@@ -270,14 +265,12 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
mc2_mask=kwargs.get("mc2_mask", None))
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if self.transpose_weight:
|
||||
layer.w13_weight.data = layer.w13_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
if is_enable_nz():
|
||||
torch_npu.npu_format_cast_(layer.w13_weight, ACL_FORMAT_FRACTAL_NZ)
|
||||
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.w13_weight.data = layer.w13_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight.data = layer.w2_weight.data.transpose(1,
|
||||
2).contiguous()
|
||||
layer.w13_weight.data = maybe_trans_nz(layer.w13_weight.data)
|
||||
layer.w2_weight.data = maybe_trans_nz(layer.w2_weight.data)
|
||||
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
|
||||
layer.w13_weight_scale.data.shape[0], -1)
|
||||
layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to(
|
||||
|
||||
Reference in New Issue
Block a user