[Bugfix] fix tensor not same device in qwen2_5_vl_without_padding (#2051)

bugfix cherry-pick from v0.9.1-dev
https://github.com/vllm-project/vllm-ascend/pull/2007
### What this PR does / why we need it?
Minimum reproducing code:
```python
# test.py
from vllm import LLM, SamplingParams
 
prompts = [
    "Hello, my name is",
    "The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(model="Qwen2.5-VL-7B-Instruct", max_model_len=26240)
 
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
    
```
```bash
export USE_OPTIMIZED_MODEL=0
python test.py
```
exception as follow:
```
[rank0]:   File "/home/xxx/vllm_ascend/models/qwen2_5_vl_without_padding.py", line 84, in forward
[rank0]:     q = torch_npu.npu_rotary_mul(q, cos, sin)
[rank0]:   File "/home/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/_ops.py", line 1116, in __call__
[rank0]:     return self._op(*args, **(kwargs or {}))
[rank0]: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, npu:0 and cpu! (when checking argument for argument r1 in method wrapper__npu_rotary_mul)
```

In `AscendQwen2_5_VisionAttention_Without_Padding`,
`torch_npu.npu_rotary_mul(q, cos, sin)`, `cos`/`sin` on cpu, but `q` on
npu, so there will be an error.

`qwen2_5_vl_without_padding.py` need this bugfix, because
`AscendQwen2_5_VisionTransformer_Without_Padding.rot_pos_emb` in
wen2_5_vl_without_padding.py is from vllm and `inv_freq` will create on
cpu.

40d86ee412/vllm/model_executor/models/qwen2_5_vl.py (L482)
```python
inv_freq = 1.0 / (theta**(torch.arange(0, dim, 2, dtype=torch.float, device='cpu') / dim))
```
`qwen2_5_vl.py` do not need, because
`AscendQwen2_5_VisionRotaryEmbedding` in qwen2_5_vl.py rewrite
`AscendQwen2_5_VisionRotaryEmbedding` and `inv_freq` will create on
device.
```python
inv_freq = 1.0 / (theta**(torch.arange(0, dim, 2, dtype=torch.float) / dim))
```

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
CI passed with new added/existing test.


- vLLM version: v0.10.0
- vLLM main:
18cc33dd60

Signed-off-by: pjgao <gaopengju3@huawei.com>
Co-authored-by: pjgao <gaopengju3@huawei.com>
This commit is contained in:
Joey Gao
2025-07-31 15:18:54 +08:00
committed by GitHub
parent 72eceff94d
commit 6192bc95c0
2 changed files with 12 additions and 1 deletions

View File

@@ -231,6 +231,8 @@ class TestAscendQwen2_5_VisionTransformer_Without_Padding(PytestBase):
vision_config.in_channels = 3
vision_config.hidden_act = "gelu"
vision_config.depth = 0
vision_config.hidden_size = 1280
vision_config.num_heads = 16
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
@@ -239,6 +241,10 @@ class TestAscendQwen2_5_VisionTransformer_Without_Padding(PytestBase):
"vllm.model_executor.models.qwen2_5_vl.Qwen2_5_VisionTransformer.__init__",
return_value=None,
)
mocker_vision_rotary_embedding = mocker.patch(
"vllm_ascend.models.qwen2_5_vl.AscendQwen2_5_VisionRotaryEmbedding.__init__",
return_value=None,
)
mocker.patch(
"vllm_ascend.models.qwen2_5_vl_without_padding.AscendQwen2_5_VisionBlock_Without_Padding.__init__",
return_value=None,
@@ -264,7 +270,7 @@ class TestAscendQwen2_5_VisionTransformer_Without_Padding(PytestBase):
args, kwargs = mocker_vit.call_args
assert args == (vision_config, norm_eps, None, "")
assert not kwargs
mocker_vision_rotary_embedding.assert_called_once()
return vision_transformer
def test_init_vision_transformer(self, mocker: MockerFixture):

View File

@@ -41,6 +41,8 @@ from vllm.model_executor.models.qwen2_5_vl import (
from vllm.model_executor.models.utils import maybe_prefix
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm_ascend.models.qwen2_5_vl import AscendQwen2_5_VisionRotaryEmbedding
class AscendQwen2_5_VisionAttention_Without_Padding(Qwen2_5_VisionAttention):
@@ -160,6 +162,9 @@ class AscendQwen2_5_VisionTransformer_Without_Padding(Qwen2_5_VisionTransformer
super().__init__(vision_config, norm_eps, quant_config, prefix)
norm_layer = partial(RMSNorm, eps=norm_eps)
self.interleaved = interleaved
head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = AscendQwen2_5_VisionRotaryEmbedding(head_dim //
2)
self.patch_embed = AscendQwen2_5_VisionPatchEmbed_Without_Padding(
patch_size=vision_config.patch_size,
temporal_patch_size=vision_config.temporal_patch_size,