fix vl float model not support NZ format weight error (#3533)
### What this PR does / why we need it? fix vl float model not support nz mm op ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: shaopeng666 <shaopeng666@noreply.gitcode.com> Co-authored-by: shaopeng666 <shaopeng666@noreply.gitcode.com>
This commit is contained in:
@@ -370,6 +370,10 @@ class TestAscendQwen2_5_VisionTransformer(PytestBase):
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
mocker.patch(
|
||||
"torch_npu.npu_format_cast",
|
||||
return_value=torch.rand((384, 300)),
|
||||
)
|
||||
res = attention.pad_qkv_weight(torch.rand((300, 300)))
|
||||
assert res.shape == (384, 300)
|
||||
|
||||
@@ -378,6 +382,10 @@ class TestAscendQwen2_5_VisionTransformer(PytestBase):
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
mocker.patch(
|
||||
"torch_npu.npu_format_cast",
|
||||
return_value=torch.rand((300, 384)),
|
||||
)
|
||||
res = attention.pad_proj_weight(torch.rand((300, 300)))
|
||||
assert res.shape == (300, 384)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user