[UT] Align input arguments with Ascend(Yarn)RotaryEmbedding with vLLM and add ut (#7358)
### What this PR does / why we need it?
This PR adds missing arguments in `AscendRotaryEmbedding`,
`AscendYarnRotaryEmbedding` to conform with vLLM. Besides, corresponding
ut is introduced.
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: Angazenn <supperccell@163.com>
This commit is contained in:
@@ -222,8 +222,9 @@ class AscendRotaryEmbedding(RotaryEmbedding):
|
||||
base: float,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
init_cache: bool = True,
|
||||
) -> None:
|
||||
super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype)
|
||||
super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, init_cache)
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.use_mtp = vllm_config.speculative_config and vllm_config.speculative_config.method == "mtp"
|
||||
_record_cos_sin_cache(self.cos_sin_cache)
|
||||
@@ -264,6 +265,7 @@ class AscendYaRNRotaryEmbedding(YaRNScalingRotaryEmbedding):
|
||||
attn_factor: float = 1,
|
||||
beta_fast: int = 32,
|
||||
beta_slow: int = 1,
|
||||
apply_yarn_scaling: bool = True,
|
||||
truncate: bool = False,
|
||||
) -> None:
|
||||
extra_kwargs = {
|
||||
@@ -271,6 +273,7 @@ class AscendYaRNRotaryEmbedding(YaRNScalingRotaryEmbedding):
|
||||
"attn_factor": attn_factor,
|
||||
"beta_fast": beta_fast,
|
||||
"beta_slow": beta_slow,
|
||||
"apply_yarn_scaling": apply_yarn_scaling,
|
||||
# TODO: current not support actual truncate,adaptation for extra parameters to be compatible with vllm
|
||||
"truncate": truncate,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user