[UT] Align input arguments with Ascend(Yarn)RotaryEmbedding with vLLM and add ut (#7358)

### What this PR does / why we need it?
This PR adds missing arguments in `AscendRotaryEmbedding`,
`AscendYarnRotaryEmbedding` to conform with vLLM. Besides, corresponding
ut is introduced.

- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: Angazenn <supperccell@163.com>
This commit is contained in:
Angazenn
2026-03-24 16:02:56 +08:00
committed by GitHub
parent 568b6d0601
commit bdb65319a9
2 changed files with 344 additions and 1 deletions

View File

@@ -222,8 +222,9 @@ class AscendRotaryEmbedding(RotaryEmbedding):
base: float,
is_neox_style: bool,
dtype: torch.dtype,
init_cache: bool = True,
) -> None:
super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype)
super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, init_cache)
vllm_config = get_current_vllm_config()
self.use_mtp = vllm_config.speculative_config and vllm_config.speculative_config.method == "mtp"
_record_cos_sin_cache(self.cos_sin_cache)
@@ -264,6 +265,7 @@ class AscendYaRNRotaryEmbedding(YaRNScalingRotaryEmbedding):
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
apply_yarn_scaling: bool = True,
truncate: bool = False,
) -> None:
extra_kwargs = {
@@ -271,6 +273,7 @@ class AscendYaRNRotaryEmbedding(YaRNScalingRotaryEmbedding):
"attn_factor": attn_factor,
"beta_fast": beta_fast,
"beta_slow": beta_slow,
"apply_yarn_scaling": apply_yarn_scaling,
# TODO: current not support actual truncateadaptation for extra parameters to be compatible with vllm
"truncate": truncate,
}