[refactor] refactor weight trans nz and transpose (#4878)
### What this PR does / why we need it?
Now `VLLM_ASCEND_ENABLE_NZ` will have three options:
0: disable nz;
1: only quant case enable nz;
2: enable nz as long as possible;
And `VLLM_ASCEND_ENABLE_NZ`=1 by default.
All cases are shown in the table below:
| | W4A4 | W4A8 | W8A8 | fp16/bf16 | fp32 |
|---|---|---|---|---|---|
| trans nz | can't support nz | trans nz by default | trans nz by
default | trans nz when VLLM_ASCEND_ENABLE_NZ is 2 | can't support nz |
| transpose | only support not transpose case | only support transpose
case | only support transpose case | linear: only support not transpose
case<br>gmm: only support transpose case | same to fp16/bf16 |
Some exceptional cases:
1. MLAPO op need to do some additional processing on the weights,
including trans nz. If use MLAPO op, some weight will be transformed to
nz forcely;
2. MLA/SFA's weight `W_UV` will be used by op
`torch.ops._C_ascend.batch_matmul_transpose`, and this op can't support
nz currently;
### Does this PR introduce _any_ user-facing change?
Now fp16/bf16 weight will not trans nz by default.
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
@@ -9,7 +9,7 @@ from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.linear import (LinearBase, ReplicatedLinear,
|
||||
from vllm.model_executor.layers.linear import (ReplicatedLinear,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
@@ -29,9 +29,8 @@ from vllm_ascend.ops.shared_weight_layer import (
|
||||
from vllm_ascend.ops.triton.rope import rope_forward_triton
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
_round_up, dispose_layer, enable_sp,
|
||||
is_enable_nz, replace_layer)
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, _round_up, dispose_layer,
|
||||
enable_sp, maybe_trans_nz, replace_layer)
|
||||
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -404,40 +403,11 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
self.cp_size = 1
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
|
||||
def get_layer_weight(layer):
|
||||
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
|
||||
for attr in WEIGHT_NAMES:
|
||||
try:
|
||||
return getattr(layer, attr)
|
||||
except AttributeError:
|
||||
pass
|
||||
raise AttributeError(
|
||||
f"Layer '{layer}' has no recognized weight attribute:"
|
||||
f" {WEIGHT_NAMES}.")
|
||||
|
||||
def get_and_maybe_dequant_weights(layer: LinearBase):
|
||||
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
|
||||
# NOTE: This should only be used offline, since it's O(N^3)
|
||||
eye = torch.eye(layer.input_size_per_partition,
|
||||
dtype=act_dtype,
|
||||
device=get_layer_weight(layer).device)
|
||||
dequant_weights = layer.quant_method.apply(layer,
|
||||
eye,
|
||||
bias=None)
|
||||
del eye
|
||||
# standardize to (output, input)
|
||||
return dequant_weights.T
|
||||
# Weight will be reshaped next. To be on the safe side, the format
|
||||
# of the weight should be reverted to FRACTAL_AND.
|
||||
layer.weight.data = torch_npu.npu_format_cast(
|
||||
layer.weight.data, ACL_FORMAT_FRACTAL_ND)
|
||||
return layer.weight
|
||||
|
||||
# we currently do not have quantized bmm's which are needed for
|
||||
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
|
||||
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
|
||||
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
|
||||
# NOTE: We currently do not support quant kv_b_proj.
|
||||
assert isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod)
|
||||
# NOTE: Weight will be reshaped next, we need to revert and transpose it.
|
||||
kv_b_proj_weight = torch_npu.npu_format_cast(
|
||||
self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_ND).T
|
||||
assert kv_b_proj_weight.shape == (
|
||||
self.kv_lora_rank, self.local_num_heads *
|
||||
(self.qk_nope_head_dim + self.v_head_dim)), (
|
||||
@@ -460,15 +430,9 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
# Convert from (L, N, P) to (N, P, L)
|
||||
self.W_UK_T = W_UK.permute(1, 2, 0).contiguous()
|
||||
|
||||
# Function `get_and_maybe_dequant_weights` will cast the weights to
|
||||
# FRACTAL_AND. So we need to cast to FRACTAL_NZ again.
|
||||
if is_enable_nz():
|
||||
self.kv_b_proj.weight.data = torch_npu.npu_format_cast(
|
||||
self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
# TODO(zzzzwwjj): Currently, torch.ops._C_ascend.batch_matmul_transpose cannot support weight nz
|
||||
# self.W_UV = maybe_trans_nz(self.W_UV)
|
||||
|
||||
# Waiting for BMM NZ support
|
||||
# self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
|
||||
# self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
|
||||
# Dispose kv_b_proj since it is replaced by W_UV and W_UK_T to save memory
|
||||
dispose_layer(self.kv_b_proj)
|
||||
|
||||
@@ -502,6 +466,9 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
logger.warning_once(msg)
|
||||
else:
|
||||
self._process_weights_for_fused_mlapo(act_dtype)
|
||||
if not self.enable_mlapo:
|
||||
# if mlapo, W_UK_T can't trans nz
|
||||
self.W_UK_T = maybe_trans_nz(self.W_UK_T)
|
||||
|
||||
def _v_up_proj(self, x):
|
||||
forward_context = get_forward_context()
|
||||
|
||||
Reference in New Issue
Block a user