From 71acc8ddebec5c4e22095c4fca0ee3b61c3af930 Mon Sep 17 00:00:00 2001 From: henryxuxu0716 <59076975+henryxuxu0716@users.noreply.github.com> Date: Fri, 28 Nov 2025 17:32:25 +0800 Subject: [PATCH] For nz unset in bf16&fp16 (#4495) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? disable NZ for float weight case. This is only a quick fix for dev branch. For main branch, we'll consider more case to make it more common. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? qwen2.5 32B image --------- Signed-off-by: 刘哲续 Co-authored-by: 刘哲续 --- vllm_ascend/attention/mla_v1.py | 2 +- vllm_ascend/models/qwen2_5_vl.py | 4 ++-- vllm_ascend/models/qwen2_vl.py | 4 ++-- vllm_ascend/ops/common_fused_moe.py | 2 +- vllm_ascend/ops/linear.py | 3 +-- vllm_ascend/torchair/torchair_sfa.py | 4 ++-- vllm_ascend/torchair/utils.py | 2 +- vllm_ascend/utils.py | 5 ++++- vllm_ascend/worker/model_runner_v1.py | 2 +- vllm_ascend/worker/worker_v1.py | 2 +- 10 files changed, 16 insertions(+), 14 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index c0e175c..5dcac6c 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -652,7 +652,7 @@ class AscendMLAImpl(MLAAttentionImpl): # Function `get_and_maybe_dequant_weights` will cast the weights to # FRACTAL_AND. So we need to cast to FRACTAL_NZ again. - if is_enable_nz(): + if is_enable_nz(self.kv_b_proj.weight.data.dtype): self.kv_b_proj.weight.data = torch_npu.npu_format_cast( self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_NZ) diff --git a/vllm_ascend/models/qwen2_5_vl.py b/vllm_ascend/models/qwen2_5_vl.py index 35ac58d..ec39b96 100644 --- a/vllm_ascend/models/qwen2_5_vl.py +++ b/vllm_ascend/models/qwen2_5_vl.py @@ -284,7 +284,7 @@ class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer): dim=2) qkv_weight_final = qkv_weight_padded.reshape(-1, self.hidden_size) - if is_enable_nz(): + if is_enable_nz(qkv_weight_final.dtype): qkv_weight_final_copy = torch.empty_like(qkv_weight_final).copy_( qkv_weight_final) qkv_weight_final_copy = torch_npu.npu_format_cast( @@ -300,7 +300,7 @@ class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer): (0, self.half_pad_hidden_size_per_attention_head, 0, 0)).reshape( self.hidden_size, -1) - if is_enable_nz(): + if is_enable_nz(out_weight.dtype): out_weight_copy = torch.empty_like(out_weight).copy_(out_weight) out_weight_copy = torch_npu.npu_format_cast( out_weight_copy, ACL_FORMAT_FRACTAL_ND) diff --git a/vllm_ascend/models/qwen2_vl.py b/vllm_ascend/models/qwen2_vl.py index ccd4616..bd48283 100644 --- a/vllm_ascend/models/qwen2_vl.py +++ b/vllm_ascend/models/qwen2_vl.py @@ -268,7 +268,7 @@ class AscendQwen2VisionTransformer(Qwen2VisionTransformer): dim=2) qkv_weight_final = qkv_weight_padded.reshape(-1, self.hidden_size) - if is_enable_nz(): + if is_enable_nz(qkv_weight_final.dtype): qkv_weight_final_copy = torch.empty_like(qkv_weight_final).copy_( qkv_weight_final) qkv_weight_final_copy = torch_npu.npu_format_cast( @@ -284,7 +284,7 @@ class AscendQwen2VisionTransformer(Qwen2VisionTransformer): (0, self.half_pad_hidden_size_per_attention_head, 0, 0)).reshape( self.hidden_size, -1) - if is_enable_nz(): + if is_enable_nz(out_weight.dtype): out_weight_copy = torch.empty_like(out_weight).copy_(out_weight) out_weight_copy = torch_npu.npu_format_cast( out_weight_copy, ACL_FORMAT_FRACTAL_ND) diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index cae6b89..cc2a377 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -76,7 +76,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): w2_data = self._maybe_pad_weight(layer.w2_weight.data) layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False) - if not is_310p() and is_enable_nz(): + if not is_310p() and is_enable_nz(layer.w13_weight.data.dtype): layer.w13_weight.data = torch_npu.npu_format_cast( layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ) layer.w2_weight.data = torch_npu.npu_format_cast( diff --git a/vllm_ascend/ops/linear.py b/vllm_ascend/ops/linear.py index eab312d..69889b7 100644 --- a/vllm_ascend/ops/linear.py +++ b/vllm_ascend/ops/linear.py @@ -45,8 +45,7 @@ class AscendUnquantizedLinearMethod(UnquantizedLinearMethod): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: super().process_weights_after_loading(layer) - if (is_enable_nz() and layer.weight.data.dtype - in [torch.float16, torch.bfloat16]): + if (is_enable_nz(layer.weight.data.dtype)): layer.weight.data = torch_npu.npu_format_cast( layer.weight.data, ACL_FORMAT_FRACTAL_NZ) diff --git a/vllm_ascend/torchair/torchair_sfa.py b/vllm_ascend/torchair/torchair_sfa.py index 36c3224..751cd2f 100644 --- a/vllm_ascend/torchair/torchair_sfa.py +++ b/vllm_ascend/torchair/torchair_sfa.py @@ -842,7 +842,7 @@ class AscendSFATorchairImpl(MLAAttentionImpl): wd_qkv = wd_qkv.t().contiguous() wd_qkv = transdata(wd_qkv, block_size=(16, 32)).unsqueeze(0).contiguous() - if is_enable_nz(): + if is_enable_nz(wd_qkv.dtype): self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29) kv_a_proj_deq_scl = self.kv_a_proj_with_mqa.deq_scale.clone() @@ -876,7 +876,7 @@ class AscendSFATorchairImpl(MLAAttentionImpl): self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim), -1) wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous() - if is_enable_nz(): + if is_enable_nz(wu_q.dtype): self.wu_q = torch_npu.npu_format_cast(wu_q, 29) qb_deq_scl = self.q_proj.deq_scale.data.clone() diff --git a/vllm_ascend/torchair/utils.py b/vllm_ascend/torchair/utils.py index 1936703..7acabd8 100644 --- a/vllm_ascend/torchair/utils.py +++ b/vllm_ascend/torchair/utils.py @@ -146,7 +146,7 @@ def converting_weight_acl_format(model, format): if torch_npu.get_npu_format(module.w13_weight.data) == format: return if format == ACL_FORMAT_FRACTAL_NZ \ - and not is_enable_nz(): + and not is_enable_nz(module.w13_weight.data.dtype): return module.w13_weight.data = torch_npu.npu_format_cast( module.w13_weight.data, format) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 55b5136..d19453e 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -71,13 +71,16 @@ def is_310p(): return _IS_310P -def is_enable_nz(vllm_config: Optional[VllmConfig] = None) -> bool: +def is_enable_nz(dtype: Optional[torch.dtype] = torch.int8, + vllm_config: Optional[VllmConfig] = None) -> bool: global _ENABLE_NZ if _ENABLE_NZ is None: if not vllm_config: raise ValueError( "vllm_config must be provided when _ENABLE_NZ is None") _ENABLE_NZ = envs_ascend.VLLM_ASCEND_ENABLE_NZ and vllm_config.model_config.hf_config.model_type != "qwen3_next" + if dtype in [torch.float16, torch.bfloat16]: + return False return _ENABLE_NZ diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index ebbd608..c336c1c 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2676,7 +2676,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): def _convert_torch_format(self, tensor): if ACL_FORMAT == ACL_FORMAT_FRACTAL_NZ \ - and not is_enable_nz(): + and not is_enable_nz(tensor.dtype): return tensor tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT) return tensor diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index b6b6008..0816f26 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -81,7 +81,7 @@ class NPUWorker(WorkerBase): # register patch for vllm from vllm_ascend.utils import adapt_patch adapt_patch() - is_enable_nz(vllm_config) + is_enable_nz(vllm_config=vllm_config) # Register ops when worker init. from vllm_ascend import ops ops.register_dummy_fusion_op()