From 71acc8ddebec5c4e22095c4fca0ee3b61c3af930 Mon Sep 17 00:00:00 2001
From: henryxuxu0716 <59076975+henryxuxu0716@users.noreply.github.com>
Date: Fri, 28 Nov 2025 17:32:25 +0800
Subject: [PATCH] For nz unset in bf16&fp16 (#4495)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
### What this PR does / why we need it?
disable NZ for float weight case. This is only a quick fix for dev
branch.
For main branch, we'll consider more case to make it more common.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
qwen2.5 32B
---------
Signed-off-by: 刘哲续
Co-authored-by: 刘哲续
---
vllm_ascend/attention/mla_v1.py | 2 +-
vllm_ascend/models/qwen2_5_vl.py | 4 ++--
vllm_ascend/models/qwen2_vl.py | 4 ++--
vllm_ascend/ops/common_fused_moe.py | 2 +-
vllm_ascend/ops/linear.py | 3 +--
vllm_ascend/torchair/torchair_sfa.py | 4 ++--
vllm_ascend/torchair/utils.py | 2 +-
vllm_ascend/utils.py | 5 ++++-
vllm_ascend/worker/model_runner_v1.py | 2 +-
vllm_ascend/worker/worker_v1.py | 2 +-
10 files changed, 16 insertions(+), 14 deletions(-)
diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py
index c0e175c..5dcac6c 100644
--- a/vllm_ascend/attention/mla_v1.py
+++ b/vllm_ascend/attention/mla_v1.py
@@ -652,7 +652,7 @@ class AscendMLAImpl(MLAAttentionImpl):
# Function `get_and_maybe_dequant_weights` will cast the weights to
# FRACTAL_AND. So we need to cast to FRACTAL_NZ again.
- if is_enable_nz():
+ if is_enable_nz(self.kv_b_proj.weight.data.dtype):
self.kv_b_proj.weight.data = torch_npu.npu_format_cast(
self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_NZ)
diff --git a/vllm_ascend/models/qwen2_5_vl.py b/vllm_ascend/models/qwen2_5_vl.py
index 35ac58d..ec39b96 100644
--- a/vllm_ascend/models/qwen2_5_vl.py
+++ b/vllm_ascend/models/qwen2_5_vl.py
@@ -284,7 +284,7 @@ class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer):
dim=2)
qkv_weight_final = qkv_weight_padded.reshape(-1, self.hidden_size)
- if is_enable_nz():
+ if is_enable_nz(qkv_weight_final.dtype):
qkv_weight_final_copy = torch.empty_like(qkv_weight_final).copy_(
qkv_weight_final)
qkv_weight_final_copy = torch_npu.npu_format_cast(
@@ -300,7 +300,7 @@ class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer):
(0, self.half_pad_hidden_size_per_attention_head, 0, 0)).reshape(
self.hidden_size, -1)
- if is_enable_nz():
+ if is_enable_nz(out_weight.dtype):
out_weight_copy = torch.empty_like(out_weight).copy_(out_weight)
out_weight_copy = torch_npu.npu_format_cast(
out_weight_copy, ACL_FORMAT_FRACTAL_ND)
diff --git a/vllm_ascend/models/qwen2_vl.py b/vllm_ascend/models/qwen2_vl.py
index ccd4616..bd48283 100644
--- a/vllm_ascend/models/qwen2_vl.py
+++ b/vllm_ascend/models/qwen2_vl.py
@@ -268,7 +268,7 @@ class AscendQwen2VisionTransformer(Qwen2VisionTransformer):
dim=2)
qkv_weight_final = qkv_weight_padded.reshape(-1, self.hidden_size)
- if is_enable_nz():
+ if is_enable_nz(qkv_weight_final.dtype):
qkv_weight_final_copy = torch.empty_like(qkv_weight_final).copy_(
qkv_weight_final)
qkv_weight_final_copy = torch_npu.npu_format_cast(
@@ -284,7 +284,7 @@ class AscendQwen2VisionTransformer(Qwen2VisionTransformer):
(0, self.half_pad_hidden_size_per_attention_head, 0, 0)).reshape(
self.hidden_size, -1)
- if is_enable_nz():
+ if is_enable_nz(out_weight.dtype):
out_weight_copy = torch.empty_like(out_weight).copy_(out_weight)
out_weight_copy = torch_npu.npu_format_cast(
out_weight_copy, ACL_FORMAT_FRACTAL_ND)
diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py
index cae6b89..cc2a377 100644
--- a/vllm_ascend/ops/common_fused_moe.py
+++ b/vllm_ascend/ops/common_fused_moe.py
@@ -76,7 +76,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
w2_data = self._maybe_pad_weight(layer.w2_weight.data)
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
- if not is_310p() and is_enable_nz():
+ if not is_310p() and is_enable_nz(layer.w13_weight.data.dtype):
layer.w13_weight.data = torch_npu.npu_format_cast(
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.w2_weight.data = torch_npu.npu_format_cast(
diff --git a/vllm_ascend/ops/linear.py b/vllm_ascend/ops/linear.py
index eab312d..69889b7 100644
--- a/vllm_ascend/ops/linear.py
+++ b/vllm_ascend/ops/linear.py
@@ -45,8 +45,7 @@ class AscendUnquantizedLinearMethod(UnquantizedLinearMethod):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
- if (is_enable_nz() and layer.weight.data.dtype
- in [torch.float16, torch.bfloat16]):
+ if (is_enable_nz(layer.weight.data.dtype)):
layer.weight.data = torch_npu.npu_format_cast(
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
diff --git a/vllm_ascend/torchair/torchair_sfa.py b/vllm_ascend/torchair/torchair_sfa.py
index 36c3224..751cd2f 100644
--- a/vllm_ascend/torchair/torchair_sfa.py
+++ b/vllm_ascend/torchair/torchair_sfa.py
@@ -842,7 +842,7 @@ class AscendSFATorchairImpl(MLAAttentionImpl):
wd_qkv = wd_qkv.t().contiguous()
wd_qkv = transdata(wd_qkv,
block_size=(16, 32)).unsqueeze(0).contiguous()
- if is_enable_nz():
+ if is_enable_nz(wd_qkv.dtype):
self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29)
kv_a_proj_deq_scl = self.kv_a_proj_with_mqa.deq_scale.clone()
@@ -876,7 +876,7 @@ class AscendSFATorchairImpl(MLAAttentionImpl):
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim),
-1)
wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous()
- if is_enable_nz():
+ if is_enable_nz(wu_q.dtype):
self.wu_q = torch_npu.npu_format_cast(wu_q, 29)
qb_deq_scl = self.q_proj.deq_scale.data.clone()
diff --git a/vllm_ascend/torchair/utils.py b/vllm_ascend/torchair/utils.py
index 1936703..7acabd8 100644
--- a/vllm_ascend/torchair/utils.py
+++ b/vllm_ascend/torchair/utils.py
@@ -146,7 +146,7 @@ def converting_weight_acl_format(model, format):
if torch_npu.get_npu_format(module.w13_weight.data) == format:
return
if format == ACL_FORMAT_FRACTAL_NZ \
- and not is_enable_nz():
+ and not is_enable_nz(module.w13_weight.data.dtype):
return
module.w13_weight.data = torch_npu.npu_format_cast(
module.w13_weight.data, format)
diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py
index 55b5136..d19453e 100644
--- a/vllm_ascend/utils.py
+++ b/vllm_ascend/utils.py
@@ -71,13 +71,16 @@ def is_310p():
return _IS_310P
-def is_enable_nz(vllm_config: Optional[VllmConfig] = None) -> bool:
+def is_enable_nz(dtype: Optional[torch.dtype] = torch.int8,
+ vllm_config: Optional[VllmConfig] = None) -> bool:
global _ENABLE_NZ
if _ENABLE_NZ is None:
if not vllm_config:
raise ValueError(
"vllm_config must be provided when _ENABLE_NZ is None")
_ENABLE_NZ = envs_ascend.VLLM_ASCEND_ENABLE_NZ and vllm_config.model_config.hf_config.model_type != "qwen3_next"
+ if dtype in [torch.float16, torch.bfloat16]:
+ return False
return _ENABLE_NZ
diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py
index ebbd608..c336c1c 100644
--- a/vllm_ascend/worker/model_runner_v1.py
+++ b/vllm_ascend/worker/model_runner_v1.py
@@ -2676,7 +2676,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
def _convert_torch_format(self, tensor):
if ACL_FORMAT == ACL_FORMAT_FRACTAL_NZ \
- and not is_enable_nz():
+ and not is_enable_nz(tensor.dtype):
return tensor
tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT)
return tensor
diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py
index b6b6008..0816f26 100644
--- a/vllm_ascend/worker/worker_v1.py
+++ b/vllm_ascend/worker/worker_v1.py
@@ -81,7 +81,7 @@ class NPUWorker(WorkerBase):
# register patch for vllm
from vllm_ascend.utils import adapt_patch
adapt_patch()
- is_enable_nz(vllm_config)
+ is_enable_nz(vllm_config=vllm_config)
# Register ops when worker init.
from vllm_ascend import ops
ops.register_dummy_fusion_op()