From 27e0f2c0355a45a8dbe897564edbe4e414a83dbb Mon Sep 17 00:00:00 2001 From: Angazenn <92204292+Angazenn@users.noreply.github.com> Date: Sat, 11 Oct 2025 08:36:20 +0800 Subject: [PATCH] [Perf]Add YaRN custom op (#3355) ### What this PR does / why we need it? YaRN scaling is used to improve long seq accuracy for models like Qwen3. In vLLM, YaRN scaling refers to `YaRNScalingRotaryEmbedding` class which inherits from original `RotaryEmbedding`. Although `YaRNScalingRotaryEmbedding` does not rewrite the `forward` function of `RotaryEmbedding` , using YaRN on npu still run into the native implementation of foward in `RotaryEmbedding`, rather than forward_oot in vLLM-Ascend. Thus I register another custom op here to enable the oot implementation for YaRN in vLLM-Ascend, similar to #3151 . ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: Angazenn --- vllm_ascend/ops/rotary_embedding.py | 44 ++++++++++++++++++++++++++++- vllm_ascend/utils.py | 4 ++- 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 9ddf280..69102f3 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -22,7 +22,8 @@ import torch import torch_npu from vllm.forward_context import get_forward_context from vllm.model_executor.layers.rotary_embedding import ( - DeepseekScalingRotaryEmbedding, RotaryEmbedding) + DeepseekScalingRotaryEmbedding, RotaryEmbedding, + YaRNScalingRotaryEmbedding) from vllm_ascend.platform import NPUPlatform from vllm_ascend.utils import enable_custom_op, is_310p @@ -153,6 +154,47 @@ class AscendRotaryEmbedding(RotaryEmbedding): offsets) +class AscendYaRNRotaryEmbedding(YaRNScalingRotaryEmbedding): + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + ) -> None: + self.cos = None + self.sin = None + extra_kwargs = { + "extrapolation_factor": extrapolation_factor, + "attn_factor": attn_factor, + "beta_fast": beta_fast, + "beta_slow": beta_slow + } + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, scaling_factor, dtype, **extra_kwargs) + + def forward_oot( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + is_neox_style_override: Optional[bool] = None, + ): + return AscendRotaryEmbedding.forward_oot(self, positions, query, key, + offsets, + is_neox_style_override) + + class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding): def __init__( diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 62faa93..b591413 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -508,7 +508,8 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): AscendQKVParallelLinear, AscendRowParallelLinear) from vllm_ascend.ops.rotary_embedding import ( - AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding) + AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding, + AscendYaRNRotaryEmbedding) from vllm_ascend.ops.vocab_parallel_embedding import ( AscendLogitsProcessor, AscendParallelLMHead, AscendVocabParallelEmbedding) @@ -520,6 +521,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): "RotaryEmbedding": AscendRotaryEmbedding, "ColumnParallelLinear": AscendColumnParallelLinear, "RowParallelLinear": AscendRowParallelLinear, + "YaRNScalingRotaryEmbedding": AscendYaRNRotaryEmbedding, "MergedColumnParallelLinear": AscendMergedColumnParallelLinear, "QKVParallelLinear": AscendQKVParallelLinear, "DeepseekScalingRotaryEmbedding": AscendDeepseekScalingRotaryEmbedding,