fix qwen3vl mrope op (#4484)
### What this PR does / why we need it? Qwen2.5-VL mrope precision problem would been solved once this pr is merged ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Test on G8600 with textVQA dataset - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: 李少鹏 <lishaopeng21@huawei.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -24,7 +24,6 @@ from vllm.forward_context import get_forward_context
|
|||||||
from vllm.model_executor.layers.rotary_embedding import (
|
from vllm.model_executor.layers.rotary_embedding import (
|
||||||
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding,
|
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding,
|
||||||
YaRNScalingRotaryEmbedding)
|
YaRNScalingRotaryEmbedding)
|
||||||
from vllm.platforms import CpuArchEnum
|
|
||||||
|
|
||||||
from vllm_ascend.platform import NPUPlatform
|
from vllm_ascend.platform import NPUPlatform
|
||||||
from vllm_ascend.utils import (AscendDeviceType, enable_custom_op,
|
from vllm_ascend.utils import (AscendDeviceType, enable_custom_op,
|
||||||
@@ -411,10 +410,7 @@ class AscendMRotaryEmbedding(MRotaryEmbedding):
|
|||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
):
|
):
|
||||||
# TODO: This judgment will be removed once the mrope precision issue is fixed
|
if self.mrope_section != [16, 24, 24]:
|
||||||
if self.mrope_section != [
|
|
||||||
16, 24, 24
|
|
||||||
] or NPUPlatform.get_cpu_architecture() == CpuArchEnum.X86:
|
|
||||||
return super().forward_oot(positions, query, key)
|
return super().forward_oot(positions, query, key)
|
||||||
|
|
||||||
import torch_npu
|
import torch_npu
|
||||||
@@ -429,7 +425,7 @@ class AscendMRotaryEmbedding(MRotaryEmbedding):
|
|||||||
self.cos_sin_cache = self.cos_sin_cache.to( # type: ignore
|
self.cos_sin_cache = self.cos_sin_cache.to( # type: ignore
|
||||||
query.dtype) # type: ignore
|
query.dtype) # type: ignore
|
||||||
|
|
||||||
query, key = torch_npu.npu_mrope(positions,
|
query, key = torch_npu.npu_mrope(positions.contiguous(),
|
||||||
query.contiguous(),
|
query.contiguous(),
|
||||||
key.contiguous(),
|
key.contiguous(),
|
||||||
self.cos_sin_cache.contiguous(),
|
self.cos_sin_cache.contiguous(),
|
||||||
|
|||||||
Reference in New Issue
Block a user