For nz unset in bf16&fp16 (#4495)

<!--  Thanks for sending a pull request!

BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html

-->
### What this PR does / why we need it?
disable NZ for float weight case. This is only a quick fix for dev
branch.

For main branch, we'll consider more case to make it more common.


### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->

### How was this patch tested?
qwen2.5 32B
<img width="441" height="221" alt="image"
src="https://github.com/user-attachments/assets/7ae18ffd-1ce2-43d9-9960-be45250ad0da"
/>

---------

Signed-off-by: 刘哲续 <liuzhexu1@huawei.com>
Co-authored-by: 刘哲续 <liuzhexu1@huawei.com>
This commit is contained in:
henryxuxu0716
2025-11-28 17:32:25 +08:00
committed by GitHub
parent 96c362361e
commit 71acc8ddeb
10 changed files with 16 additions and 14 deletions

View File

@@ -842,7 +842,7 @@ class AscendSFATorchairImpl(MLAAttentionImpl):
wd_qkv = wd_qkv.t().contiguous()
wd_qkv = transdata(wd_qkv,
block_size=(16, 32)).unsqueeze(0).contiguous()
if is_enable_nz():
if is_enable_nz(wd_qkv.dtype):
self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29)
kv_a_proj_deq_scl = self.kv_a_proj_with_mqa.deq_scale.clone()
@@ -876,7 +876,7 @@ class AscendSFATorchairImpl(MLAAttentionImpl):
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim),
-1)
wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous()
if is_enable_nz():
if is_enable_nz(wu_q.dtype):
self.wu_q = torch_npu.npu_format_cast(wu_q, 29)
qb_deq_scl = self.q_proj.deq_scale.data.clone()

View File

@@ -146,7 +146,7 @@ def converting_weight_acl_format(model, format):
if torch_npu.get_npu_format(module.w13_weight.data) == format:
return
if format == ACL_FORMAT_FRACTAL_NZ \
and not is_enable_nz():
and not is_enable_nz(module.w13_weight.data.dtype):
return
module.w13_weight.data = torch_npu.npu_format_cast(
module.w13_weight.data, format)