[300I][Bugfix] fix unquant model weight nd2nz error (#6851)
### What this PR does / why we need it?
- This PR fixes an issue with weight format conversion for unquantized
models running on Ascend 310P devices.
- The changes refactor the logic for converting weights to the
FRACTAL_NZ format. Previously, this was handled in a 310P-specific
linear layer implementation (`AscendUnquantizedLinearMethod310`). This
implementation has been removed, and the logic is now centralized in the
`maybe_trans_nz` utility function. This function now checks if the
device is a 310P and applies the NZ format cast accordingly for
`float16`/`bfloat16` weights.
- This refactoring simplifies the code by removing platform-specific
duplication and ensures correct weight handling for unquantized models
on 310P.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
ut and local test
- vLLM version: v0.15.0
- vLLM main:
83b47f67b1
---------
Signed-off-by: Tflowers-0129 <2906339855@qq.com>
This commit is contained in:
@@ -134,22 +134,35 @@ def _unregister_print_streams_on_exit():
|
||||
atexit.register(_unregister_print_streams_on_exit)
|
||||
|
||||
|
||||
def maybe_trans_nz(weight: torch.Tensor):
|
||||
def _should_trans_nz(weight: torch.Tensor) -> bool:
|
||||
# FP32 cannot use NZ.
|
||||
if weight.dtype == torch.float32:
|
||||
return False
|
||||
|
||||
# 310P always converts to NZ.
|
||||
if is_310p():
|
||||
return True
|
||||
|
||||
# NZ is disabled on non-310P.
|
||||
if not envs_ascend.VLLM_ASCEND_ENABLE_NZ:
|
||||
# NZ is not enabled
|
||||
return False
|
||||
|
||||
# BF16/FP16 convert only when enable_nz == 2.
|
||||
if weight.dtype in {torch.bfloat16, torch.float16}:
|
||||
return envs_ascend.VLLM_ASCEND_ENABLE_NZ == 2
|
||||
|
||||
# Quantized or other supported dtypes convert by default.
|
||||
return True
|
||||
|
||||
|
||||
# NZ conversion policy:
|
||||
# - 310P: always convert supported weights to FRACTAL_NZ
|
||||
# - non-310P: follow VLLM_ASCEND_ENABLE_NZ
|
||||
# - FP32: never convert
|
||||
def maybe_trans_nz(weight: torch.Tensor) -> torch.Tensor:
|
||||
if not _should_trans_nz(weight):
|
||||
return weight
|
||||
if weight.dtype == torch.float:
|
||||
# fp32 can not support NZ
|
||||
return weight
|
||||
elif weight.dtype in {torch.bfloat16, torch.float16}:
|
||||
# bf16/fp16 will trans nz when VLLM_ASCEND_ENABLE_NZ is 2
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_NZ == 2:
|
||||
return torch_npu.npu_format_cast(weight, ACL_FORMAT_FRACTAL_NZ)
|
||||
else:
|
||||
return weight
|
||||
else:
|
||||
# quant weight will trans nz by default
|
||||
return torch_npu.npu_format_cast(weight, ACL_FORMAT_FRACTAL_NZ)
|
||||
return torch_npu.npu_format_cast(weight, ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
|
||||
def _round_up(x: int, align: int):
|
||||
@@ -631,6 +644,10 @@ def register_ascend_customop(vllm_config: VllmConfig | None = None):
|
||||
from vllm_ascend._310p.ops.activation import AscendSiluAndMul310
|
||||
from vllm_ascend._310p.ops.layernorm import AscendGemmaRMSNorm310, AscendRMSNorm310
|
||||
from vllm_ascend._310p.ops.rotary_embedding import AscendRotaryEmbedding310
|
||||
from vllm_ascend._310p.ops.vocab_parallel_embedding import (
|
||||
AscendParallelLMHead310,
|
||||
AscendVocabParallelEmbedding310,
|
||||
)
|
||||
|
||||
REGISTERED_ASCEND_OPS.update(
|
||||
{
|
||||
@@ -640,6 +657,8 @@ def register_ascend_customop(vllm_config: VllmConfig | None = None):
|
||||
"GemmaRMSNorm": AscendGemmaRMSNorm310,
|
||||
"FusedMoE": AscendFusedMoE310,
|
||||
"SharedFusedMoE": AscendSharedFusedMoE310,
|
||||
"ParallelLMHead": AscendParallelLMHead310,
|
||||
"VocabParallelEmbedding": AscendVocabParallelEmbedding310,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user