fix vl float model not support NZ format weight error (#3533)

### What this PR does / why we need it?
fix vl float model not support nz mm op
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: shaopeng666 <shaopeng666@noreply.gitcode.com>
Co-authored-by: shaopeng666 <shaopeng666@noreply.gitcode.com>
This commit is contained in:
shaopeng-666
2025-10-21 22:23:17 +08:00
committed by GitHub
parent 6f04b467de
commit 0c83eee9b1
3 changed files with 42 additions and 0 deletions

View File

@@ -42,6 +42,8 @@ from vllm.model_executor.models.qwen2_5_vl import (
from vllm.model_executor.models.utils import maybe_prefix
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_ND, is_enable_nz
MIN_PAD_SIZE = 64 # min_size to pad weight
MAX_PAD_SIZE = 128 # max_size to pad weight
@@ -281,6 +283,14 @@ class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer):
[qkv_weight_first_half_padded, qkv_weight_second_half_padded],
dim=2)
qkv_weight_final = qkv_weight_padded.reshape(-1, self.hidden_size)
if is_enable_nz():
qkv_weight_final_copy = torch.empty_like(qkv_weight_final).copy_(
qkv_weight_final)
qkv_weight_final_copy = torch_npu.npu_format_cast(
qkv_weight_final_copy, ACL_FORMAT_FRACTAL_ND)
return qkv_weight_final_copy
return qkv_weight_final
def pad_proj_weight(self, data):
@@ -289,6 +299,13 @@ class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer):
self.half_origin_hidden_size_per_attention_head),
(0, self.half_pad_hidden_size_per_attention_head, 0, 0)).reshape(
self.hidden_size, -1)
if is_enable_nz():
out_weight_copy = torch.empty_like(out_weight).copy_(out_weight)
out_weight_copy = torch_npu.npu_format_cast(
out_weight_copy, ACL_FORMAT_FRACTAL_ND)
return out_weight_copy
return out_weight
def pad_qkv_weight_scale_offset(self, data):

View File

@@ -40,6 +40,8 @@ from vllm.model_executor.models.qwen2_vl import (
from vllm.model_executor.models.utils import maybe_prefix
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_ND, is_enable_nz
MIN_PAD_SIZE = 64 # min_size to pad weight
MAX_PAD_SIZE = 128 # max_size to pad weight
@@ -265,6 +267,14 @@ class AscendQwen2VisionTransformer(Qwen2VisionTransformer):
[qkv_weight_first_half_padded, qkv_weight_second_half_padded],
dim=2)
qkv_weight_final = qkv_weight_padded.reshape(-1, self.hidden_size)
if is_enable_nz():
qkv_weight_final_copy = torch.empty_like(qkv_weight_final).copy_(
qkv_weight_final)
qkv_weight_final_copy = torch_npu.npu_format_cast(
qkv_weight_final_copy, ACL_FORMAT_FRACTAL_ND)
return qkv_weight_final_copy
return qkv_weight_final
def pad_proj_weight(self, data):
@@ -273,6 +283,13 @@ class AscendQwen2VisionTransformer(Qwen2VisionTransformer):
self.half_origin_hidden_size_per_attention_head),
(0, self.half_pad_hidden_size_per_attention_head, 0, 0)).reshape(
self.hidden_size, -1)
if is_enable_nz():
out_weight_copy = torch.empty_like(out_weight).copy_(out_weight)
out_weight_copy = torch_npu.npu_format_cast(
out_weight_copy, ACL_FORMAT_FRACTAL_ND)
return out_weight_copy
return out_weight
def load_weights(self, weights: Iterable[Tuple[str,