2025-02-21 17:07:37 +08:00
|
|
|
#
|
|
|
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
2025-04-17 14:59:56 +08:00
|
|
|
# This file is a part of the vllm-ascend project.
|
2025-02-21 17:07:37 +08:00
|
|
|
#
|
|
|
|
|
|
2025-04-19 17:38:18 +08:00
|
|
|
import math
|
2025-02-21 17:07:37 +08:00
|
|
|
from typing import Optional, Tuple
|
|
|
|
|
|
|
|
|
|
import torch
|
qwen3_moe/qwen25 support torchair graph (#2403)
### What this PR does / why we need it?
Added support for the TorchAir graph mode in qwen3_moe and qwen2.5
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
```bash
llm = LLM(
model=model,
tensor_parallel_size=GPUs_per_dp_rank,
enforce_eager=False,
enable_expert_parallel=True,
max_model_len=4096,
max_num_seqs=16,
trust_remote_code=trust_remote_code,
gpu_memory_utilization=0.4,
additional_config={
"torchair_graph_config": {
"enabled": True,
"use_cached_graph": False,
"graph_batch_sizes_init": False,
"graph_batch_sizes": [16]
},
"ascend_scheduler_config": {
"enabled": True,
"chunked_prefill_enabled":True,
},
"refresh": True,
},
)
```
- vLLM version: v0.10.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/b87cb97a53bcff92a90308528b3f313e43aff102
Signed-off-by: taoyuxiang <oui.nicholas.tao@gmail.com>
2025-08-20 11:23:50 +08:00
|
|
|
import torch_npu
|
2025-09-09 14:28:14 +08:00
|
|
|
from vllm.forward_context import get_forward_context
|
2025-02-22 17:43:42 +08:00
|
|
|
from vllm.model_executor.layers.rotary_embedding import (
|
2025-10-25 11:41:23 +08:00
|
|
|
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding,
|
2025-10-11 08:36:20 +08:00
|
|
|
YaRNScalingRotaryEmbedding)
|
2025-02-21 17:07:37 +08:00
|
|
|
|
2025-08-25 09:32:35 +08:00
|
|
|
from vllm_ascend.platform import NPUPlatform
|
[Platform] Add initial experimental support for Altlas 300I series (#1333)
### What this PR does / why we need it?
Add initial experimental support for Ascend 310P, this patch squash
below PR into one to help validation:
- https://github.com/vllm-project/vllm-ascend/pull/914
- https://github.com/vllm-project/vllm-ascend/pull/1318
- https://github.com/vllm-project/vllm-ascend/pull/1327
### Does this PR introduce _any_ user-facing change?
User can run vLLM on Altlas 300I DUO series
### How was this patch tested?
CI passed with:
- E2E image build for 310P
- CI test on A2 with e2e test and longterm test
- Unit test missing because need a real 310P image to have the test,
will add in a separate PR later.
- Manually e2e test:
- Qwen2.5-7b-instruct, Qwen2.5-0.5b, Qwen3-0.6B, Qwen3-4B, Qwen3-8B:
https://github.com/vllm-project/vllm-ascend/pull/914#issuecomment-2942989322
- Pangu MGoE 72B
The patch has been tested locally on Ascend 310P hardware to ensure that
the changes do not break existing functionality and that the new
features work as intended.
#### ENV information
CANN, NNAL version: 8.1.RC1
> [!IMPORTANT]
> PTA 2.5.1 version >= torch_npu-2.5.1.post1.dev20250528 to support NZ
format and calling NNAL operators on 310P
#### Code example
##### Build vllm-ascend from source code
```shell
# download source code as vllm-ascend
cd vllm-ascend
export SOC_VERSION=Ascend310P3
pip install -v -e .
cd ..
```
##### Run offline inference
```python
from vllm import LLM, SamplingParams
prompts = ["水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。",
"水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。"]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=10)
# Create an LLM.
llm = LLM(
model="Qwen/Qwen2.5-7B-Instruct",
max_model_len=4096,
max_num_seqs=4,
dtype="float16", # IMPORTANT cause some ATB ops cannot support bf16 on 310P
disable_custom_all_reduce=True,
trust_remote_code=True,
tensor_parallel_size=2,
compilation_config={"custom_ops":['none', "+rms_norm", "+rotary_embedding"]},
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
---------
Signed-off-by: Vincent Yuan <farawayboat@gmail.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: Vincent Yuan <farawayboat@gmail.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: shen-shanshan <467638484@qq.com>
2025-06-21 09:00:16 +08:00
|
|
|
from vllm_ascend.utils import enable_custom_op, is_310p
|
2025-04-18 08:56:05 +08:00
|
|
|
|
2025-02-21 17:07:37 +08:00
|
|
|
|
2025-09-02 17:25:33 +08:00
|
|
|
def _custom_rotary_embedding_enabled(query, neox_style, head_size):
|
2025-06-16 21:03:16 +08:00
|
|
|
return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and enable_custom_op(
|
|
|
|
|
)
|
2025-04-29 17:12:03 +08:00
|
|
|
|
|
|
|
|
|
2025-09-02 17:25:33 +08:00
|
|
|
def _rope_forward_oot(
|
2025-02-21 17:07:37 +08:00
|
|
|
self,
|
|
|
|
|
positions: torch.Tensor,
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
2025-09-09 14:28:14 +08:00
|
|
|
is_neox_style: bool,
|
|
|
|
|
offsets: Optional[torch.Tensor] = None
|
2025-02-21 17:07:37 +08:00
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
2025-04-29 17:12:03 +08:00
|
|
|
query_shape, key_shape = query.shape, key.shape
|
2025-02-21 17:07:37 +08:00
|
|
|
if self.cos_sin_cache.device != query.device:
|
|
|
|
|
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
|
|
|
|
|
if self.cos_sin_cache.dtype != query.dtype:
|
|
|
|
|
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
|
2025-04-18 08:56:05 +08:00
|
|
|
# adopt custom kernel path for rotary_embedding
|
2025-09-09 14:28:14 +08:00
|
|
|
if _custom_rotary_embedding_enabled(query, is_neox_style,
|
2025-09-02 17:25:33 +08:00
|
|
|
self.head_size) and not is_310p():
|
2025-09-13 11:58:52 +08:00
|
|
|
query, key = torch.ops._C_ascend.rotary_embedding(
|
2025-02-21 17:07:37 +08:00
|
|
|
positions,
|
|
|
|
|
query,
|
|
|
|
|
key,
|
|
|
|
|
self.head_size,
|
|
|
|
|
self.cos_sin_cache,
|
2025-09-09 14:28:14 +08:00
|
|
|
is_neox_style,
|
2025-02-21 17:07:37 +08:00
|
|
|
)
|
2025-04-29 17:12:03 +08:00
|
|
|
return query.view(query_shape), key.view(key_shape)
|
2025-02-22 17:43:42 +08:00
|
|
|
if offsets is not None:
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
"Batched rotary embedding is currently not supported on NPU.")
|
|
|
|
|
else:
|
2025-09-09 14:28:14 +08:00
|
|
|
if self.cos is not None and \
|
|
|
|
|
self.sin is not None:
|
|
|
|
|
# If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation.
|
|
|
|
|
# This method requires head_size and rotary_dim equal 128 and neox_style is True
|
|
|
|
|
query = query.contiguous().view(1, query.shape[0], -1,
|
|
|
|
|
self.head_size)
|
|
|
|
|
key = key.contiguous().view(1, key.shape[0], -1, self.head_size)
|
|
|
|
|
torch_npu.npu_apply_rotary_pos_emb(query, key, self.cos, self.sin)
|
|
|
|
|
elif self.rotary_dim < self.head_size:
|
[Bugfix] Fix long context seq accuracy problem for `GLM4.5` (#2601)
### What this PR does / why we need it?
Fix long context seq accuracy problem for `GLM4.5`.
When `max_tokens=1000`, there is cyclic output problem like:
```bash
00 00 00 00 00 00 00 00 00 00 00 00 00 00
```
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
```python
import os
os.environ["VLLM_USE_MODELSCOPE"] = "True"
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
from vllm import LLM, SamplingParams
def main():
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=1000, temperature=0.0)
# Create an LLM.
llm = LLM(model="/root/.cache/modelscope/hub/models/ZhipuAI/GLM-4___5",
tensor_parallel_size=8,
enforce_eager=True,
trust_remote_code=True,
max_model_len=1024)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
if __name__ == "__main__":
main()
```
- vLLM version: v0.10.1.1
- vLLM main:
https://github.com/vllm-project/vllm/commit/0235103cbbdb511e6708aae600f759060a797c16
---------
Signed-off-by: Shanshan Shen <87969357+shen-shanshan@users.noreply.github.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
2025-09-03 09:18:44 +08:00
|
|
|
num_tokens = query.shape[0]
|
|
|
|
|
query = query.view(num_tokens, -1, self.head_size)
|
|
|
|
|
key = key.view(num_tokens, -1, self.head_size)
|
|
|
|
|
q_rot = query[..., :self.rotary_dim]
|
|
|
|
|
q_pass = query[..., self.rotary_dim:]
|
|
|
|
|
k_rot = key[..., :self.rotary_dim]
|
|
|
|
|
k_pass = key[..., self.rotary_dim:]
|
|
|
|
|
q_rot = q_rot.contiguous().view(num_tokens, -1)
|
|
|
|
|
k_rot = k_rot.contiguous().view(num_tokens, -1)
|
|
|
|
|
torch_npu._npu_rotary_embedding(
|
|
|
|
|
positions,
|
|
|
|
|
q_rot,
|
|
|
|
|
k_rot,
|
|
|
|
|
self.head_size,
|
|
|
|
|
self.cos_sin_cache,
|
2025-09-09 14:28:14 +08:00
|
|
|
is_neox_style,
|
[Bugfix] Fix long context seq accuracy problem for `GLM4.5` (#2601)
### What this PR does / why we need it?
Fix long context seq accuracy problem for `GLM4.5`.
When `max_tokens=1000`, there is cyclic output problem like:
```bash
00 00 00 00 00 00 00 00 00 00 00 00 00 00
```
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
```python
import os
os.environ["VLLM_USE_MODELSCOPE"] = "True"
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
from vllm import LLM, SamplingParams
def main():
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=1000, temperature=0.0)
# Create an LLM.
llm = LLM(model="/root/.cache/modelscope/hub/models/ZhipuAI/GLM-4___5",
tensor_parallel_size=8,
enforce_eager=True,
trust_remote_code=True,
max_model_len=1024)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
if __name__ == "__main__":
main()
```
- vLLM version: v0.10.1.1
- vLLM main:
https://github.com/vllm-project/vllm/commit/0235103cbbdb511e6708aae600f759060a797c16
---------
Signed-off-by: Shanshan Shen <87969357+shen-shanshan@users.noreply.github.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
2025-09-03 09:18:44 +08:00
|
|
|
)
|
|
|
|
|
q_rot = q_rot.view(num_tokens, -1, self.rotary_dim)
|
|
|
|
|
k_rot = k_rot.view(num_tokens, -1, self.rotary_dim)
|
|
|
|
|
q = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape)
|
|
|
|
|
k = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape)
|
|
|
|
|
return q, k
|
2025-09-09 14:28:14 +08:00
|
|
|
else:
|
|
|
|
|
# TODO: Remove the contiguous in the future.
|
|
|
|
|
query = query.contiguous().view(query.shape[0], -1)
|
|
|
|
|
key = key.contiguous().view(key.shape[0], -1)
|
|
|
|
|
torch_npu._npu_rotary_embedding(
|
|
|
|
|
positions,
|
|
|
|
|
query,
|
|
|
|
|
key,
|
|
|
|
|
self.head_size,
|
|
|
|
|
self.cos_sin_cache,
|
|
|
|
|
is_neox_style,
|
|
|
|
|
)
|
|
|
|
|
return query.view(query_shape), key.view(key_shape)
|
2025-02-21 17:07:37 +08:00
|
|
|
|
|
|
|
|
|
2025-08-25 09:32:35 +08:00
|
|
|
class AscendRotaryEmbedding(RotaryEmbedding):
|
qwen3_moe/qwen25 support torchair graph (#2403)
### What this PR does / why we need it?
Added support for the TorchAir graph mode in qwen3_moe and qwen2.5
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
```bash
llm = LLM(
model=model,
tensor_parallel_size=GPUs_per_dp_rank,
enforce_eager=False,
enable_expert_parallel=True,
max_model_len=4096,
max_num_seqs=16,
trust_remote_code=trust_remote_code,
gpu_memory_utilization=0.4,
additional_config={
"torchair_graph_config": {
"enabled": True,
"use_cached_graph": False,
"graph_batch_sizes_init": False,
"graph_batch_sizes": [16]
},
"ascend_scheduler_config": {
"enabled": True,
"chunked_prefill_enabled":True,
},
"refresh": True,
},
)
```
- vLLM version: v0.10.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/b87cb97a53bcff92a90308528b3f313e43aff102
Signed-off-by: taoyuxiang <oui.nicholas.tao@gmail.com>
2025-08-20 11:23:50 +08:00
|
|
|
|
2025-08-25 09:32:35 +08:00
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
head_size: int,
|
|
|
|
|
rotary_dim: int,
|
|
|
|
|
max_position_embeddings: int,
|
|
|
|
|
base: float,
|
|
|
|
|
is_neox_style: bool,
|
|
|
|
|
dtype: torch.dtype,
|
|
|
|
|
) -> None:
|
2025-09-09 14:28:14 +08:00
|
|
|
self.cos = None
|
|
|
|
|
self.sin = None
|
2025-08-25 09:32:35 +08:00
|
|
|
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
2025-04-29 17:12:03 +08:00
|
|
|
is_neox_style, dtype)
|
2025-08-25 09:32:35 +08:00
|
|
|
|
|
|
|
|
def forward_oot(
|
|
|
|
|
self,
|
|
|
|
|
positions: torch.Tensor,
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
offsets: Optional[torch.Tensor] = None,
|
|
|
|
|
is_neox_style_override: Optional[bool] = None,
|
|
|
|
|
):
|
2025-09-09 14:28:14 +08:00
|
|
|
is_neox_style = self.is_neox_style
|
|
|
|
|
if is_neox_style_override is not None:
|
|
|
|
|
is_neox_style = is_neox_style_override
|
|
|
|
|
forward_context = get_forward_context()
|
|
|
|
|
is_first_layer = forward_context.is_first_layer
|
|
|
|
|
# Generate cos and sin outside layers to avoid repeated calculation.
|
2025-09-12 09:49:36 +08:00
|
|
|
if is_neox_style and self.head_size == 128 and self.cos_sin_cache.shape[
|
|
|
|
|
-1] == 128:
|
2025-09-09 14:28:14 +08:00
|
|
|
if is_first_layer:
|
|
|
|
|
cos_sin = self.cos_sin_cache.index_select(0, positions)
|
|
|
|
|
last_dim = cos_sin.size()[-1]
|
|
|
|
|
cos, sin = cos_sin.reshape(-1, 2, last_dim // 2).repeat(
|
|
|
|
|
1, 1, 2).chunk(2, dim=-2)
|
|
|
|
|
# BSNH
|
|
|
|
|
self.cos = cos.view(1, -1, 1, last_dim).contiguous()
|
|
|
|
|
self.sin = sin.view(1, -1, 1, last_dim).contiguous()
|
|
|
|
|
forward_context.is_first_layer = False
|
|
|
|
|
return _rope_forward_oot(self, positions, query, key, is_neox_style,
|
|
|
|
|
offsets)
|
2025-08-25 09:32:35 +08:00
|
|
|
|
|
|
|
|
|
2025-10-11 08:36:20 +08:00
|
|
|
class AscendYaRNRotaryEmbedding(YaRNScalingRotaryEmbedding):
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
head_size: int,
|
|
|
|
|
rotary_dim: int,
|
|
|
|
|
max_position_embeddings: int,
|
|
|
|
|
base: float,
|
|
|
|
|
is_neox_style: bool,
|
|
|
|
|
scaling_factor: float,
|
|
|
|
|
dtype: torch.dtype,
|
|
|
|
|
*,
|
|
|
|
|
extrapolation_factor: float = 1,
|
|
|
|
|
attn_factor: float = 1,
|
|
|
|
|
beta_fast: int = 32,
|
|
|
|
|
beta_slow: int = 1,
|
|
|
|
|
) -> None:
|
|
|
|
|
self.cos = None
|
|
|
|
|
self.sin = None
|
|
|
|
|
extra_kwargs = {
|
|
|
|
|
"extrapolation_factor": extrapolation_factor,
|
|
|
|
|
"attn_factor": attn_factor,
|
|
|
|
|
"beta_fast": beta_fast,
|
|
|
|
|
"beta_slow": beta_slow
|
|
|
|
|
}
|
|
|
|
|
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
|
|
|
|
is_neox_style, scaling_factor, dtype, **extra_kwargs)
|
|
|
|
|
|
|
|
|
|
def forward_oot(
|
|
|
|
|
self,
|
|
|
|
|
positions: torch.Tensor,
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
offsets: Optional[torch.Tensor] = None,
|
|
|
|
|
is_neox_style_override: Optional[bool] = None,
|
|
|
|
|
):
|
|
|
|
|
return AscendRotaryEmbedding.forward_oot(self, positions, query, key,
|
|
|
|
|
offsets,
|
|
|
|
|
is_neox_style_override)
|
|
|
|
|
|
|
|
|
|
|
2025-08-25 09:32:35 +08:00
|
|
|
class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
head_size: int,
|
|
|
|
|
rotary_dim: int,
|
|
|
|
|
max_position_embeddings: int,
|
|
|
|
|
base: int,
|
|
|
|
|
is_neox_style: bool,
|
|
|
|
|
scaling_factor: float,
|
|
|
|
|
dtype: torch.dtype,
|
|
|
|
|
*,
|
|
|
|
|
extrapolation_factor: float = 1,
|
|
|
|
|
attn_factor: float = 1,
|
|
|
|
|
beta_fast: int = 32,
|
|
|
|
|
beta_slow: int = 1,
|
|
|
|
|
mscale: float = 1,
|
|
|
|
|
mscale_all_dim: float = 0,
|
|
|
|
|
) -> None:
|
|
|
|
|
# Note: we adopt the native huggingface deepseek rope initialization code from
|
|
|
|
|
# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py for
|
|
|
|
|
# its more ascend compute friendly
|
|
|
|
|
self.scaling_factor = scaling_factor
|
|
|
|
|
self.extrapolation_factor = extrapolation_factor
|
|
|
|
|
self.attn_factor = attn_factor
|
|
|
|
|
self.beta_fast = beta_fast
|
|
|
|
|
self.beta_slow = beta_slow
|
|
|
|
|
# Get n-d magnitude scaling corrected for interpolation.
|
|
|
|
|
self.mscale = float(
|
|
|
|
|
self._yarn_get_mscale(self.scaling_factor, float(mscale)) /
|
|
|
|
|
self._yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
|
|
|
|
|
attn_factor)
|
|
|
|
|
super(DeepseekScalingRotaryEmbedding,
|
|
|
|
|
self).__init__(head_size, rotary_dim, max_position_embeddings,
|
|
|
|
|
base, is_neox_style, dtype)
|
2025-09-08 22:03:34 +08:00
|
|
|
|
|
|
|
|
# NOTE: For ascend friendly computing, reorder sin and cos cache
|
|
|
|
|
self.max_seq_len = math.ceil(max_position_embeddings * scaling_factor)
|
|
|
|
|
self._set_cos_sin_cache(self.max_seq_len,
|
2025-08-25 09:32:35 +08:00
|
|
|
device=NPUPlatform.device_type,
|
|
|
|
|
dtype=dtype)
|
|
|
|
|
|
|
|
|
|
def _yarn_get_mscale(self, scale: float = 1, mscale: float = 1) -> float:
|
|
|
|
|
if scale <= 1:
|
|
|
|
|
return 1.0
|
|
|
|
|
return 0.1 * mscale * math.log(scale) + 1.0
|
|
|
|
|
|
|
|
|
|
def _rotate_half(self, x):
|
|
|
|
|
"""Rotates half the hidden dims of the input."""
|
|
|
|
|
x1 = x[..., :x.shape[-1] // 2]
|
|
|
|
|
x2 = x[..., x.shape[-1] // 2:]
|
|
|
|
|
return torch.cat((-x2, x1), dim=-1)
|
|
|
|
|
|
|
|
|
|
def _yarn_linear_ramp_mask(self, min_value, max_value, dim):
|
|
|
|
|
# Note: The if conditional branch is not used here
|
|
|
|
|
# to solve MTP compilation error.
|
|
|
|
|
max_value += (min_value == max_value).float() * 0.001
|
|
|
|
|
linear_func = (torch.arange(dim, dtype=torch.float32) -
|
|
|
|
|
min_value) / (max_value - min_value)
|
|
|
|
|
ramp_func = torch.clamp(linear_func, 0, 1)
|
|
|
|
|
return ramp_func
|
|
|
|
|
|
|
|
|
|
# Inverse dim formula to find dim based on number of rotations
|
|
|
|
|
def _yarn_find_correction_dim(self,
|
|
|
|
|
num_rotations,
|
|
|
|
|
dim,
|
|
|
|
|
base=10000,
|
|
|
|
|
max_position_embeddings=2048):
|
|
|
|
|
# Note: use torch instead of math to solve MTP compilation error.
|
|
|
|
|
return (dim * torch.log(
|
|
|
|
|
torch.tensor(max_position_embeddings) /
|
|
|
|
|
(num_rotations * 2 * torch.pi))) / (2 *
|
|
|
|
|
torch.log(torch.tensor(base)))
|
|
|
|
|
|
|
|
|
|
# Find dim range bounds based on rotations
|
|
|
|
|
def _yarn_find_correction_range(self,
|
|
|
|
|
low_rot,
|
|
|
|
|
high_rot,
|
|
|
|
|
dim,
|
|
|
|
|
base=10000,
|
|
|
|
|
max_position_embeddings=2048):
|
|
|
|
|
# Note: use torch instead of math to solve MTP compilation error.
|
|
|
|
|
low = torch.floor(
|
|
|
|
|
self._yarn_find_correction_dim(low_rot, dim, base,
|
|
|
|
|
max_position_embeddings))
|
|
|
|
|
high = torch.ceil(
|
|
|
|
|
self._yarn_find_correction_dim(high_rot, dim, base,
|
|
|
|
|
max_position_embeddings))
|
|
|
|
|
# Note: use torch instead of max/min to solve MTP compilation error.
|
|
|
|
|
return torch.clamp(low, min=0), torch.clamp(high, max=dim - 1)
|
|
|
|
|
|
|
|
|
|
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
|
|
|
|
def _apply_rotary_pos_emb(self,
|
|
|
|
|
q,
|
|
|
|
|
k,
|
|
|
|
|
cos,
|
|
|
|
|
sin,
|
|
|
|
|
position_ids,
|
|
|
|
|
unsqueeze_dim=1):
|
|
|
|
|
"""Applies Rotary Position Embedding to the query and key tensors.
|
|
|
|
|
Args:
|
|
|
|
|
q (`torch.Tensor`): The query tensor.
|
|
|
|
|
k (`torch.Tensor`): The key tensor.
|
|
|
|
|
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
|
|
|
|
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
|
|
|
|
position_ids (`torch.Tensor`):
|
|
|
|
|
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
|
|
|
|
used to pass offsetted position ids when working with a KV-cache.
|
|
|
|
|
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
|
|
|
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
|
|
|
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
|
|
|
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
|
|
|
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
|
|
|
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
|
|
|
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
|
|
|
|
Returns:
|
|
|
|
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
|
|
|
|
"""
|
|
|
|
|
cos = cos[position_ids]
|
|
|
|
|
sin = sin[position_ids]
|
|
|
|
|
cos = cos[:, None, None, :]
|
|
|
|
|
sin = sin[:, None, None, :]
|
|
|
|
|
|
|
|
|
|
if len(q.shape) == 3:
|
|
|
|
|
q = q[:, :, None, :]
|
|
|
|
|
if len(k.shape) == 2:
|
|
|
|
|
k = k[:, None, None, :]
|
|
|
|
|
elif len(k.shape) == 3:
|
|
|
|
|
k = k[:, :, None, :]
|
|
|
|
|
|
|
|
|
|
b, h_q, s, d = q.shape
|
|
|
|
|
q = q.view(b, h_q, s, d // 2, 2).transpose(4, 3).reshape(b, h_q, s, d)
|
|
|
|
|
|
|
|
|
|
b, h_k, s, d = k.shape
|
|
|
|
|
k = k.view(b, h_k, s, d // 2, 2).transpose(4, 3).reshape(b, h_k, s, d)
|
|
|
|
|
|
|
|
|
|
q_embed = (q * cos) + (self._rotate_half(q) * sin)
|
|
|
|
|
k_embed = (k * cos) + (self._rotate_half(k) * sin)
|
|
|
|
|
|
|
|
|
|
q_embed = q_embed.view(b, h_q, d)
|
|
|
|
|
k_embed = k_embed.view(b, h_k, d)
|
|
|
|
|
|
|
|
|
|
return q_embed, k_embed
|
|
|
|
|
|
2025-09-08 22:03:34 +08:00
|
|
|
def _set_cos_sin_cache(self, max_seq_len, device, dtype):
|
2025-08-25 09:32:35 +08:00
|
|
|
dim = self.rotary_dim
|
|
|
|
|
|
|
|
|
|
freq_extra = 1.0 / (self.base**(
|
|
|
|
|
torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
|
|
|
|
|
freq_inter = 1.0 / (self.scaling_factor * self.base**(
|
|
|
|
|
torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
|
|
|
|
|
|
|
|
|
|
low, high = self._yarn_find_correction_range(
|
|
|
|
|
self.beta_fast,
|
|
|
|
|
self.beta_slow,
|
|
|
|
|
dim,
|
|
|
|
|
self.base,
|
|
|
|
|
self.max_position_embeddings,
|
|
|
|
|
)
|
|
|
|
|
inv_freq_mask = 1.0 - self._yarn_linear_ramp_mask(
|
|
|
|
|
low, high, dim // 2).to(device=device, dtype=torch.float32)
|
|
|
|
|
inv_freq = freq_inter * (1 -
|
|
|
|
|
inv_freq_mask) + freq_extra * inv_freq_mask
|
|
|
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
|
|
|
|
2025-09-08 22:03:34 +08:00
|
|
|
t = torch.arange(max_seq_len, device=device, dtype=torch.float32)
|
2025-08-25 09:32:35 +08:00
|
|
|
|
|
|
|
|
freqs = torch.outer(t, inv_freq)
|
|
|
|
|
cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale
|
|
|
|
|
sin_cached = torch.cat([freqs, freqs], dim=-1).sin() * self.mscale
|
|
|
|
|
cos_cached = cos_cached.to(dtype)
|
|
|
|
|
sin_cached = sin_cached.to(dtype)
|
|
|
|
|
cache = torch.cat(
|
|
|
|
|
[freqs.cos() * self.mscale,
|
|
|
|
|
freqs.sin() * self.mscale], dim=-1).to(dtype)
|
|
|
|
|
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
|
|
|
|
self.register_buffer("cos_cached", cos_cached, persistent=False)
|
|
|
|
|
self.register_buffer("sin_cached", sin_cached, persistent=False)
|
|
|
|
|
|
|
|
|
|
def forward(self,
|
|
|
|
|
positions: torch.Tensor,
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
2025-09-08 22:03:34 +08:00
|
|
|
offsets: Optional[torch.Tensor] = None):
|
2025-08-25 09:32:35 +08:00
|
|
|
if len(key.shape) == 2:
|
|
|
|
|
key = key[:, None, :]
|
|
|
|
|
# Note: we implement the non neox_style method with shuffle the last dim and neox style
|
|
|
|
|
# calculation method which is also more compute friendly to the ascend machine
|
|
|
|
|
# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py
|
2025-09-09 14:28:14 +08:00
|
|
|
is_neox_style = True
|
2025-08-25 09:32:35 +08:00
|
|
|
if self.is_neox_style is False:
|
|
|
|
|
b, h_q, d = query.shape
|
|
|
|
|
query = query.view(b, h_q, d // 2,
|
|
|
|
|
2).transpose(3, 2).reshape(b, h_q, d)
|
|
|
|
|
b, h_k, d = key.shape
|
|
|
|
|
key = key.view(b, h_k, d // 2, 2).transpose(3,
|
|
|
|
|
2).reshape(b, h_k, d)
|
2025-09-09 14:28:14 +08:00
|
|
|
q_pe, k_pe = _rope_forward_oot(self, positions, query, key,
|
|
|
|
|
is_neox_style, offsets)
|
2025-08-25 09:32:35 +08:00
|
|
|
return q_pe, k_pe
|
2025-10-25 11:41:23 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class AscendMRotaryEmbedding(MRotaryEmbedding):
|
|
|
|
|
|
|
|
|
|
def forward_oot(
|
|
|
|
|
self,
|
|
|
|
|
positions: torch.Tensor,
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
):
|
2025-12-09 11:07:32 +08:00
|
|
|
if self.mrope_section != [16, 24, 24]:
|
2025-10-25 11:41:23 +08:00
|
|
|
return super().forward_oot(positions, query, key)
|
|
|
|
|
|
|
|
|
|
import torch_npu
|
|
|
|
|
mrope_section = [0, 0, 0
|
|
|
|
|
] if positions.ndim == 1 else self.mrope_section
|
|
|
|
|
|
|
|
|
|
if self.cos_sin_cache.device != query.device: # type: ignore
|
|
|
|
|
self.cos_sin_cache = self.cos_sin_cache.to( # type: ignore
|
|
|
|
|
query.device) # type: ignore
|
|
|
|
|
|
|
|
|
|
if self.cos_sin_cache.dtype != query.dtype: # type: ignore
|
|
|
|
|
self.cos_sin_cache = self.cos_sin_cache.to( # type: ignore
|
|
|
|
|
query.dtype) # type: ignore
|
|
|
|
|
|
2025-12-09 11:07:32 +08:00
|
|
|
query, key = torch_npu.npu_mrope(positions.contiguous(),
|
2025-10-25 11:41:23 +08:00
|
|
|
query.contiguous(),
|
|
|
|
|
key.contiguous(),
|
|
|
|
|
self.cos_sin_cache.contiguous(),
|
|
|
|
|
self.head_size,
|
|
|
|
|
mrope_section=mrope_section,
|
|
|
|
|
rotary_mode='half')
|
|
|
|
|
|
|
|
|
|
return query, key
|