2025-02-21 17:07:37 +08:00
|
|
|
#
|
|
|
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
2025-04-17 14:59:56 +08:00
|
|
|
# This file is a part of the vllm-ascend project.
|
2025-02-21 17:07:37 +08:00
|
|
|
#
|
|
|
|
|
|
2025-04-19 17:38:18 +08:00
|
|
|
import math
|
2025-02-21 17:07:37 +08:00
|
|
|
from typing import Optional, Tuple
|
|
|
|
|
|
[CustomOp] Register AscendApplyRotaryEmb CustomOp and remove related patch (#4667)
### What this PR does / why we need it?
Following https://github.com/vllm-project/vllm/pull/29873, register
`AscendApplyRotaryEmb` CustomOp and remove related patch.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
#### ✅ Test Qwen2.5-VL
Run:
```bash
vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct \
--max_model_len 16384
```
Output:
```
{"id":"chatcmpl-b02c1ff3415d2462","object":"chat.completion","created":1766129265,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-In struct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the illustration is \"TONGYI Qwen.\" The word \"TONGYI\" is writ ten in blue, and \"Qwen\" is written in gray. The text appears to be part of a logo or branding design.","refusal":null,"annotations":null,"audio": null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"stop","stop_reason":null,"tok en_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":78,"total_tokens":129,"completion_tokens":51,"prompt_tokens_d
```
#### ✅ Test Qwen3-VL
Run:
```bash
vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \
--max_model_len 16384
```
Output:
```
{"id":"chatcmpl-a3a7de5a900a9321","object":"chat.completion","created":1766129586,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the illustration is **“TONGYI Qwen”**.\n\n### How it looks:\n- **“TONGYI”** is written in **uppercase letters** in a **bold, modern sans-serif font**, colored **blue**.\n- **“Qwen”** is written in **lowercase letters** in a **slightly thinner, elegant sans-serif font**, colored **dark gray**.\n- The two lines of text are stacked vertically, with “TONG","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":112,"total_tokens":212,"completion_tokens":100,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null}
```
- vLLM version: v0.12.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9
---------
Signed-off-by: shen-shanshan <467638484@qq.com>
2025-12-23 10:04:37 +08:00
|
|
|
import einops
|
2025-02-21 17:07:37 +08:00
|
|
|
import torch
|
qwen3_moe/qwen25 support torchair graph (#2403)
### What this PR does / why we need it?
Added support for the TorchAir graph mode in qwen3_moe and qwen2.5
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
```bash
llm = LLM(
model=model,
tensor_parallel_size=GPUs_per_dp_rank,
enforce_eager=False,
enable_expert_parallel=True,
max_model_len=4096,
max_num_seqs=16,
trust_remote_code=trust_remote_code,
gpu_memory_utilization=0.4,
additional_config={
"torchair_graph_config": {
"enabled": True,
"use_cached_graph": False,
"graph_batch_sizes_init": False,
"graph_batch_sizes": [16]
},
"ascend_scheduler_config": {
"enabled": True,
"chunked_prefill_enabled":True,
},
"refresh": True,
},
)
```
- vLLM version: v0.10.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/b87cb97a53bcff92a90308528b3f313e43aff102
Signed-off-by: taoyuxiang <oui.nicholas.tao@gmail.com>
2025-08-20 11:23:50 +08:00
|
|
|
import torch_npu
|
2025-12-17 08:53:44 +08:00
|
|
|
from vllm.config import CUDAGraphMode
|
2025-02-22 17:43:42 +08:00
|
|
|
from vllm.model_executor.layers.rotary_embedding import (
|
2025-10-25 09:12:18 +08:00
|
|
|
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding,
|
2025-10-11 08:36:20 +08:00
|
|
|
YaRNScalingRotaryEmbedding)
|
[CustomOp] Register AscendApplyRotaryEmb CustomOp and remove related patch (#4667)
### What this PR does / why we need it?
Following https://github.com/vllm-project/vllm/pull/29873, register
`AscendApplyRotaryEmb` CustomOp and remove related patch.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
#### ✅ Test Qwen2.5-VL
Run:
```bash
vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct \
--max_model_len 16384
```
Output:
```
{"id":"chatcmpl-b02c1ff3415d2462","object":"chat.completion","created":1766129265,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-In struct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the illustration is \"TONGYI Qwen.\" The word \"TONGYI\" is writ ten in blue, and \"Qwen\" is written in gray. The text appears to be part of a logo or branding design.","refusal":null,"annotations":null,"audio": null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"stop","stop_reason":null,"tok en_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":78,"total_tokens":129,"completion_tokens":51,"prompt_tokens_d
```
#### ✅ Test Qwen3-VL
Run:
```bash
vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \
--max_model_len 16384
```
Output:
```
{"id":"chatcmpl-a3a7de5a900a9321","object":"chat.completion","created":1766129586,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the illustration is **“TONGYI Qwen”**.\n\n### How it looks:\n- **“TONGYI”** is written in **uppercase letters** in a **bold, modern sans-serif font**, colored **blue**.\n- **“Qwen”** is written in **lowercase letters** in a **slightly thinner, elegant sans-serif font**, colored **dark gray**.\n- The two lines of text are stacked vertically, with “TONG","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":112,"total_tokens":212,"completion_tokens":100,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null}
```
- vLLM version: v0.12.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9
---------
Signed-off-by: shen-shanshan <467638484@qq.com>
2025-12-23 10:04:37 +08:00
|
|
|
from vllm.model_executor.layers.rotary_embedding.common import ApplyRotaryEmb
|
2025-02-21 17:07:37 +08:00
|
|
|
|
2025-08-25 09:32:35 +08:00
|
|
|
from vllm_ascend.platform import NPUPlatform
|
[refact] unified soc_version code (#4359)
### What this PR does / why we need it?
Currently, there are two paths to judge the chip type in code,
`get_ascend_soc_version` use `get_soc_version` api in torch_npu, and
`is_310p` `use _build_info.__soc_version__`, which generate when
install. We need to unify the two paths.
We need to unify these codes based on the following points:
1. We need to ensure consistency in chip type judgment between compiling
and running states;
2. In compiling state, we need chip type to complete op's compilation,
but in running state, we only need device
type(910B/910_93/310P/910_95/etc) to make code branch judgement;
3. In compiling state, torch_npu may not have been installed yet, so we
can't use torch_npu's api.
Based on the above points, we have made the following changes:
1. When user set env `SOC_VERSION`, use it; when not set, query
soc_version by `npu-smi`;
2. generate device_type based on soc_version when compiling, and write
`__device_type__` instead of `__soc_version__` in `_build_info.py`;
3. In running state, use `__device_type__` to judge code branch.
### Does this PR introduce _any_ user-facing change?
When not set env `SOC_VERSION`, it will not be `ASCEND910B1` by default,
we will query soc_version by `npu-smi`. And env `SOC_VERSION` must be in
the list `soc_to_device` in `setup.py`.
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
Signed-off-by: zzzzwwjj <1183291235@qq.com>
2025-11-26 14:28:55 +08:00
|
|
|
from vllm_ascend.utils import (AscendDeviceType, enable_custom_op,
|
2025-12-20 22:52:50 +08:00
|
|
|
get_ascend_device_type, has_rope, is_vl_model)
|
2025-12-17 08:53:44 +08:00
|
|
|
|
|
|
|
|
# Currently, rope ops used on npu requires detached cos && sin as inputs.
|
|
|
|
|
# However, RotaryEmbedding in vllm use cos_sin_cache as a whole variable.
|
|
|
|
|
# So we have to preprocess cos_sin_cache int cos && sin. In the future,
|
|
|
|
|
# we shall implement a new rope ops which accept cos_sin_cache as inputs.
|
|
|
|
|
# NOTE(Angazenn): MLA && SFA models uses attn_metadata to pass cos && sin
|
|
|
|
|
# to rope in AscendMLA(SFA)Impl. However, since rope is isolated from
|
|
|
|
|
# AscendAttentionBackendImpl for GQA models, we cannot pass cos && sin by
|
|
|
|
|
# attn_metadata. This causes that rope in GQA models must pass cos && sin
|
|
|
|
|
# by different approaches.
|
|
|
|
|
_cos_mla: Optional[torch.Tensor] = None
|
|
|
|
|
_sin_mla: Optional[torch.Tensor] = None
|
|
|
|
|
_cos_sin_cache: Optional[torch.Tensor] = None
|
|
|
|
|
_cos: Optional[torch.Tensor] = None
|
|
|
|
|
_sin: Optional[torch.Tensor] = None
|
|
|
|
|
_cos_slice: Optional[torch.Tensor] = None
|
|
|
|
|
_sin_slice: Optional[torch.Tensor] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype,
|
|
|
|
|
device):
|
|
|
|
|
global _cos_mla
|
|
|
|
|
global _sin_mla
|
|
|
|
|
global _cos
|
|
|
|
|
global _sin
|
|
|
|
|
|
|
|
|
|
if _cos_mla is not None or \
|
|
|
|
|
_sin_mla is not None or \
|
|
|
|
|
_cos is not None or \
|
|
|
|
|
_sin is not None:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
compilation_config = vllm_config.compilation_config
|
|
|
|
|
model_config = vllm_config.model_config
|
|
|
|
|
max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
|
|
|
|
|
2025-12-20 22:52:50 +08:00
|
|
|
if model_config.use_mla:
|
|
|
|
|
if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
|
|
|
|
|
rope_dim = model_config.hf_text_config.qk_rope_head_dim
|
|
|
|
|
_cos_mla = torch.ones(max_num_reqs * decode_token_per_req,
|
|
|
|
|
1,
|
|
|
|
|
1,
|
|
|
|
|
rope_dim,
|
|
|
|
|
dtype=dtype,
|
|
|
|
|
device=device)
|
|
|
|
|
_sin_mla = torch.zeros(max_num_reqs * decode_token_per_req,
|
|
|
|
|
1,
|
|
|
|
|
1,
|
|
|
|
|
rope_dim,
|
|
|
|
|
dtype=dtype,
|
|
|
|
|
device=device)
|
|
|
|
|
elif not is_vl_model(vllm_config) and has_rope(vllm_config):
|
2025-12-17 08:53:44 +08:00
|
|
|
rope_dim = model_config.get_head_size()
|
|
|
|
|
# For models using partial rope like Qwen3-Next.
|
|
|
|
|
if hasattr(model_config.hf_text_config, "partial_rotary_factor"):
|
|
|
|
|
rope_dim = int(rope_dim *
|
|
|
|
|
model_config.hf_text_config.partial_rotary_factor)
|
|
|
|
|
_cos = torch.ones(1,
|
|
|
|
|
max_num_batched_tokens,
|
|
|
|
|
1,
|
|
|
|
|
rope_dim,
|
|
|
|
|
dtype=dtype,
|
|
|
|
|
device=device)
|
|
|
|
|
_sin = torch.zeros(1,
|
|
|
|
|
max_num_batched_tokens,
|
|
|
|
|
1,
|
|
|
|
|
rope_dim,
|
|
|
|
|
dtype=dtype,
|
|
|
|
|
device=device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_cos_and_sin_mla():
|
|
|
|
|
return _cos_mla, _sin_mla
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _record_cos_sin_cache(cos_sin_cache):
|
|
|
|
|
global _cos_sin_cache
|
|
|
|
|
if _cos_sin_cache is not None:
|
|
|
|
|
return
|
|
|
|
|
_cos_sin_cache = cos_sin_cache
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def update_cos_sin(positions):
|
|
|
|
|
global _cos
|
|
|
|
|
global _sin
|
|
|
|
|
global _cos_slice
|
|
|
|
|
global _sin_slice
|
|
|
|
|
|
|
|
|
|
if _cos_sin_cache is None or \
|
|
|
|
|
_cos is None or \
|
|
|
|
|
_sin is None:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
num_tokens = positions.size(0)
|
|
|
|
|
_cos[:, :num_tokens] = _cos_sin_cache.index_select(0, positions).view(
|
|
|
|
|
num_tokens, 2, -1).repeat(1, 1, 2).chunk(2, dim=-2)[0]
|
|
|
|
|
_sin[:, :num_tokens] = _cos_sin_cache.index_select(0, positions).view(
|
|
|
|
|
num_tokens, 2, -1).repeat(1, 1, 2).chunk(2, dim=-2)[1]
|
|
|
|
|
_cos_slice = _cos[:, :num_tokens]
|
|
|
|
|
_sin_slice = _sin[:, :num_tokens]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_cos_and_sin_slice():
|
|
|
|
|
return _cos_slice, _sin_slice
|
2025-04-18 08:56:05 +08:00
|
|
|
|
2025-02-21 17:07:37 +08:00
|
|
|
|
2025-09-02 17:25:33 +08:00
|
|
|
def _custom_rotary_embedding_enabled(query, neox_style, head_size):
|
2025-06-16 21:03:16 +08:00
|
|
|
return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and enable_custom_op(
|
|
|
|
|
)
|
2025-04-29 17:12:03 +08:00
|
|
|
|
|
|
|
|
|
2025-09-02 17:25:33 +08:00
|
|
|
def _rope_forward_oot(
|
2025-02-21 17:07:37 +08:00
|
|
|
self,
|
|
|
|
|
positions: torch.Tensor,
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
2025-09-09 14:28:14 +08:00
|
|
|
is_neox_style: bool,
|
|
|
|
|
offsets: Optional[torch.Tensor] = None
|
2025-02-21 17:07:37 +08:00
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
2025-04-29 17:12:03 +08:00
|
|
|
query_shape, key_shape = query.shape, key.shape
|
2025-02-21 17:07:37 +08:00
|
|
|
if self.cos_sin_cache.device != query.device:
|
|
|
|
|
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
|
|
|
|
|
if self.cos_sin_cache.dtype != query.dtype:
|
|
|
|
|
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
|
2025-04-18 08:56:05 +08:00
|
|
|
# adopt custom kernel path for rotary_embedding
|
[refact] unified soc_version code (#4359)
### What this PR does / why we need it?
Currently, there are two paths to judge the chip type in code,
`get_ascend_soc_version` use `get_soc_version` api in torch_npu, and
`is_310p` `use _build_info.__soc_version__`, which generate when
install. We need to unify the two paths.
We need to unify these codes based on the following points:
1. We need to ensure consistency in chip type judgment between compiling
and running states;
2. In compiling state, we need chip type to complete op's compilation,
but in running state, we only need device
type(910B/910_93/310P/910_95/etc) to make code branch judgement;
3. In compiling state, torch_npu may not have been installed yet, so we
can't use torch_npu's api.
Based on the above points, we have made the following changes:
1. When user set env `SOC_VERSION`, use it; when not set, query
soc_version by `npu-smi`;
2. generate device_type based on soc_version when compiling, and write
`__device_type__` instead of `__soc_version__` in `_build_info.py`;
3. In running state, use `__device_type__` to judge code branch.
### Does this PR introduce _any_ user-facing change?
When not set env `SOC_VERSION`, it will not be `ASCEND910B1` by default,
we will query soc_version by `npu-smi`. And env `SOC_VERSION` must be in
the list `soc_to_device` in `setup.py`.
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
Signed-off-by: zzzzwwjj <1183291235@qq.com>
2025-11-26 14:28:55 +08:00
|
|
|
if _custom_rotary_embedding_enabled(
|
|
|
|
|
query, is_neox_style, self.head_size) and get_ascend_device_type(
|
|
|
|
|
) != AscendDeviceType._310P:
|
2025-09-13 11:58:52 +08:00
|
|
|
query, key = torch.ops._C_ascend.rotary_embedding(
|
2025-02-21 17:07:37 +08:00
|
|
|
positions,
|
|
|
|
|
query,
|
|
|
|
|
key,
|
|
|
|
|
self.head_size,
|
|
|
|
|
self.cos_sin_cache,
|
2025-09-09 14:28:14 +08:00
|
|
|
is_neox_style,
|
2025-02-21 17:07:37 +08:00
|
|
|
)
|
2025-04-29 17:12:03 +08:00
|
|
|
return query.view(query_shape), key.view(key_shape)
|
2025-02-22 17:43:42 +08:00
|
|
|
if offsets is not None:
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
"Batched rotary embedding is currently not supported on NPU.")
|
|
|
|
|
else:
|
2025-12-17 08:53:44 +08:00
|
|
|
cos, sin = get_cos_and_sin_slice()
|
|
|
|
|
if is_neox_style and self.head_size == 128 and self.cos_sin_cache.shape[
|
|
|
|
|
-1] == 128 and cos is not None and sin is not None:
|
2025-09-09 14:28:14 +08:00
|
|
|
# If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation.
|
|
|
|
|
# This method requires head_size and rotary_dim equal 128 and neox_style is True
|
|
|
|
|
query = query.contiguous().view(1, query.shape[0], -1,
|
|
|
|
|
self.head_size)
|
|
|
|
|
key = key.contiguous().view(1, key.shape[0], -1, self.head_size)
|
Adopt inductor fusion and define quantization fusion pass (#4168)
### What this PR does / why we need it?
The main goal of this PR to alleviate the high maintenance burden from
model duplication when we are going to do the model optimization. Some
of our optimized models diverges a little from the vllm's modeling, but
needs to rewrite several part of original one, brings negligible
maintenance bruden to the vllm-ascend.In order to solve that, we propose
to leverage `torch.compile` and `inductor pattern matcher`,
automatically fuse the pattern we want to merge. For more details can
refer to the RFC https://github.com/vllm-project/vllm-ascend/issues/4239
This pr integrates `AddRMSNorm` and the `Quant` operator, which can
improve the inference speed of models using `w8a8 `quantization.
### Does this PR introduce _any_ user-facing change?
Yes, add new additional_config
### How was this patch tested?
```python
def main():
prompts = [
"The president of the United States is Mr.",
]
# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95)
# Create an LLM.
llm = LLM(
model="/root/.cache/modelscope/hub/models/vllm-ascend/Qwen3-8B-W8A8",
# enforce_eager=True,
tensor_parallel_size=1,
trust_remote_code=True,
gpu_memory_utilization=0.7,
quantization="ascend",
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
```text
Prompt: 'The president of the United States is Mr.', Generated text: ' Trump. The president of the United States is Mr. Biden. Which of the following statements is correct? \n\nA. Mr. Trump is Mr. Biden. \nB. Mr. Trump is not Mr. Biden. \nC. The president of the United States is not Mr. Trump. \nD. The president of the United States is not Mr. Biden.\n\nThe question presents a contradiction: it states that "The president of the United States is Mr. Trump" and "The president of'
```
- vLLM version: 86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24
- vLLM main:
https://github.com/vllm-project/vllm/commit/86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24
---------
Signed-off-by: Icey <1790571317@qq.com>
Signed-off-by: wxsIcey <1790571317@qq.com>
2025-12-04 10:29:48 +08:00
|
|
|
# Although this function modifies in-place, please retain the function's return value.
|
|
|
|
|
# Otherwise, the graph fusion operation may fail.
|
|
|
|
|
query, key = torch_npu.npu_apply_rotary_pos_emb(
|
2025-12-17 08:53:44 +08:00
|
|
|
query, key, cos, sin)
|
2025-09-09 14:28:14 +08:00
|
|
|
elif self.rotary_dim < self.head_size:
|
[Bugfix] Fix long context seq accuracy problem for `GLM4.5` (#2601)
### What this PR does / why we need it?
Fix long context seq accuracy problem for `GLM4.5`.
When `max_tokens=1000`, there is cyclic output problem like:
```bash
00 00 00 00 00 00 00 00 00 00 00 00 00 00
```
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
```python
import os
os.environ["VLLM_USE_MODELSCOPE"] = "True"
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
from vllm import LLM, SamplingParams
def main():
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=1000, temperature=0.0)
# Create an LLM.
llm = LLM(model="/root/.cache/modelscope/hub/models/ZhipuAI/GLM-4___5",
tensor_parallel_size=8,
enforce_eager=True,
trust_remote_code=True,
max_model_len=1024)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
if __name__ == "__main__":
main()
```
- vLLM version: v0.10.1.1
- vLLM main:
https://github.com/vllm-project/vllm/commit/0235103cbbdb511e6708aae600f759060a797c16
---------
Signed-off-by: Shanshan Shen <87969357+shen-shanshan@users.noreply.github.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
2025-09-03 09:18:44 +08:00
|
|
|
num_tokens = query.shape[0]
|
|
|
|
|
query = query.view(num_tokens, -1, self.head_size)
|
|
|
|
|
key = key.view(num_tokens, -1, self.head_size)
|
|
|
|
|
q_rot = query[..., :self.rotary_dim]
|
|
|
|
|
q_pass = query[..., self.rotary_dim:]
|
|
|
|
|
k_rot = key[..., :self.rotary_dim]
|
|
|
|
|
k_pass = key[..., self.rotary_dim:]
|
|
|
|
|
q_rot = q_rot.contiguous().view(num_tokens, -1)
|
|
|
|
|
k_rot = k_rot.contiguous().view(num_tokens, -1)
|
|
|
|
|
torch_npu._npu_rotary_embedding(
|
|
|
|
|
positions,
|
|
|
|
|
q_rot,
|
|
|
|
|
k_rot,
|
|
|
|
|
self.head_size,
|
|
|
|
|
self.cos_sin_cache,
|
2025-09-09 14:28:14 +08:00
|
|
|
is_neox_style,
|
[Bugfix] Fix long context seq accuracy problem for `GLM4.5` (#2601)
### What this PR does / why we need it?
Fix long context seq accuracy problem for `GLM4.5`.
When `max_tokens=1000`, there is cyclic output problem like:
```bash
00 00 00 00 00 00 00 00 00 00 00 00 00 00
```
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
```python
import os
os.environ["VLLM_USE_MODELSCOPE"] = "True"
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
from vllm import LLM, SamplingParams
def main():
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=1000, temperature=0.0)
# Create an LLM.
llm = LLM(model="/root/.cache/modelscope/hub/models/ZhipuAI/GLM-4___5",
tensor_parallel_size=8,
enforce_eager=True,
trust_remote_code=True,
max_model_len=1024)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
if __name__ == "__main__":
main()
```
- vLLM version: v0.10.1.1
- vLLM main:
https://github.com/vllm-project/vllm/commit/0235103cbbdb511e6708aae600f759060a797c16
---------
Signed-off-by: Shanshan Shen <87969357+shen-shanshan@users.noreply.github.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
2025-09-03 09:18:44 +08:00
|
|
|
)
|
|
|
|
|
q_rot = q_rot.view(num_tokens, -1, self.rotary_dim)
|
|
|
|
|
k_rot = k_rot.view(num_tokens, -1, self.rotary_dim)
|
|
|
|
|
q = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape)
|
|
|
|
|
k = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape)
|
|
|
|
|
return q, k
|
2025-09-09 14:28:14 +08:00
|
|
|
else:
|
|
|
|
|
# TODO: Remove the contiguous in the future.
|
|
|
|
|
query = query.contiguous().view(query.shape[0], -1)
|
|
|
|
|
key = key.contiguous().view(key.shape[0], -1)
|
|
|
|
|
torch_npu._npu_rotary_embedding(
|
|
|
|
|
positions,
|
|
|
|
|
query,
|
|
|
|
|
key,
|
|
|
|
|
self.head_size,
|
|
|
|
|
self.cos_sin_cache,
|
|
|
|
|
is_neox_style,
|
|
|
|
|
)
|
|
|
|
|
return query.view(query_shape), key.view(key_shape)
|
2025-02-21 17:07:37 +08:00
|
|
|
|
|
|
|
|
|
2025-08-25 09:32:35 +08:00
|
|
|
class AscendRotaryEmbedding(RotaryEmbedding):
|
qwen3_moe/qwen25 support torchair graph (#2403)
### What this PR does / why we need it?
Added support for the TorchAir graph mode in qwen3_moe and qwen2.5
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
```bash
llm = LLM(
model=model,
tensor_parallel_size=GPUs_per_dp_rank,
enforce_eager=False,
enable_expert_parallel=True,
max_model_len=4096,
max_num_seqs=16,
trust_remote_code=trust_remote_code,
gpu_memory_utilization=0.4,
additional_config={
"torchair_graph_config": {
"enabled": True,
"use_cached_graph": False,
"graph_batch_sizes_init": False,
"graph_batch_sizes": [16]
},
"ascend_scheduler_config": {
"enabled": True,
"chunked_prefill_enabled":True,
},
"refresh": True,
},
)
```
- vLLM version: v0.10.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/b87cb97a53bcff92a90308528b3f313e43aff102
Signed-off-by: taoyuxiang <oui.nicholas.tao@gmail.com>
2025-08-20 11:23:50 +08:00
|
|
|
|
2025-08-25 09:32:35 +08:00
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
head_size: int,
|
|
|
|
|
rotary_dim: int,
|
|
|
|
|
max_position_embeddings: int,
|
|
|
|
|
base: float,
|
|
|
|
|
is_neox_style: bool,
|
|
|
|
|
dtype: torch.dtype,
|
|
|
|
|
) -> None:
|
|
|
|
|
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
2025-04-29 17:12:03 +08:00
|
|
|
is_neox_style, dtype)
|
2025-12-17 08:53:44 +08:00
|
|
|
_record_cos_sin_cache(self.cos_sin_cache)
|
2025-08-25 09:32:35 +08:00
|
|
|
|
|
|
|
|
def forward_oot(
|
|
|
|
|
self,
|
|
|
|
|
positions: torch.Tensor,
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
offsets: Optional[torch.Tensor] = None,
|
|
|
|
|
is_neox_style_override: Optional[bool] = None,
|
|
|
|
|
):
|
2025-09-09 14:28:14 +08:00
|
|
|
is_neox_style = self.is_neox_style
|
|
|
|
|
if is_neox_style_override is not None:
|
|
|
|
|
is_neox_style = is_neox_style_override
|
|
|
|
|
return _rope_forward_oot(self, positions, query, key, is_neox_style,
|
|
|
|
|
offsets)
|
2025-08-25 09:32:35 +08:00
|
|
|
|
|
|
|
|
|
2025-10-11 08:36:20 +08:00
|
|
|
class AscendYaRNRotaryEmbedding(YaRNScalingRotaryEmbedding):
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
head_size: int,
|
|
|
|
|
rotary_dim: int,
|
|
|
|
|
max_position_embeddings: int,
|
|
|
|
|
base: float,
|
|
|
|
|
is_neox_style: bool,
|
|
|
|
|
scaling_factor: float,
|
|
|
|
|
dtype: torch.dtype,
|
|
|
|
|
*,
|
|
|
|
|
extrapolation_factor: float = 1,
|
|
|
|
|
attn_factor: float = 1,
|
|
|
|
|
beta_fast: int = 32,
|
|
|
|
|
beta_slow: int = 1,
|
|
|
|
|
) -> None:
|
|
|
|
|
extra_kwargs = {
|
|
|
|
|
"extrapolation_factor": extrapolation_factor,
|
|
|
|
|
"attn_factor": attn_factor,
|
|
|
|
|
"beta_fast": beta_fast,
|
|
|
|
|
"beta_slow": beta_slow
|
|
|
|
|
}
|
|
|
|
|
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
|
|
|
|
is_neox_style, scaling_factor, dtype, **extra_kwargs)
|
2025-12-17 08:53:44 +08:00
|
|
|
_record_cos_sin_cache(self.cos_sin_cache)
|
2025-10-11 08:36:20 +08:00
|
|
|
|
|
|
|
|
def forward_oot(
|
|
|
|
|
self,
|
|
|
|
|
positions: torch.Tensor,
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
offsets: Optional[torch.Tensor] = None,
|
|
|
|
|
is_neox_style_override: Optional[bool] = None,
|
|
|
|
|
):
|
|
|
|
|
return AscendRotaryEmbedding.forward_oot(self, positions, query, key,
|
|
|
|
|
offsets,
|
|
|
|
|
is_neox_style_override)
|
|
|
|
|
|
|
|
|
|
|
2025-08-25 09:32:35 +08:00
|
|
|
class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
head_size: int,
|
|
|
|
|
rotary_dim: int,
|
|
|
|
|
max_position_embeddings: int,
|
|
|
|
|
base: int,
|
|
|
|
|
is_neox_style: bool,
|
|
|
|
|
scaling_factor: float,
|
|
|
|
|
dtype: torch.dtype,
|
|
|
|
|
*,
|
|
|
|
|
extrapolation_factor: float = 1,
|
|
|
|
|
attn_factor: float = 1,
|
|
|
|
|
beta_fast: int = 32,
|
|
|
|
|
beta_slow: int = 1,
|
|
|
|
|
mscale: float = 1,
|
|
|
|
|
mscale_all_dim: float = 0,
|
|
|
|
|
) -> None:
|
|
|
|
|
# Note: we adopt the native huggingface deepseek rope initialization code from
|
|
|
|
|
# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py for
|
|
|
|
|
# its more ascend compute friendly
|
|
|
|
|
self.scaling_factor = scaling_factor
|
|
|
|
|
self.extrapolation_factor = extrapolation_factor
|
|
|
|
|
self.attn_factor = attn_factor
|
|
|
|
|
self.beta_fast = beta_fast
|
|
|
|
|
self.beta_slow = beta_slow
|
|
|
|
|
# Get n-d magnitude scaling corrected for interpolation.
|
|
|
|
|
self.mscale = float(
|
|
|
|
|
self._yarn_get_mscale(self.scaling_factor, float(mscale)) /
|
|
|
|
|
self._yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
|
|
|
|
|
attn_factor)
|
|
|
|
|
super(DeepseekScalingRotaryEmbedding,
|
|
|
|
|
self).__init__(head_size, rotary_dim, max_position_embeddings,
|
|
|
|
|
base, is_neox_style, dtype)
|
2025-09-08 22:03:34 +08:00
|
|
|
|
|
|
|
|
# NOTE: For ascend friendly computing, reorder sin and cos cache
|
|
|
|
|
self.max_seq_len = math.ceil(max_position_embeddings * scaling_factor)
|
|
|
|
|
self._set_cos_sin_cache(self.max_seq_len,
|
2025-08-25 09:32:35 +08:00
|
|
|
device=NPUPlatform.device_type,
|
|
|
|
|
dtype=dtype)
|
|
|
|
|
|
|
|
|
|
def _yarn_get_mscale(self, scale: float = 1, mscale: float = 1) -> float:
|
|
|
|
|
if scale <= 1:
|
|
|
|
|
return 1.0
|
|
|
|
|
return 0.1 * mscale * math.log(scale) + 1.0
|
|
|
|
|
|
|
|
|
|
def _rotate_half(self, x):
|
|
|
|
|
"""Rotates half the hidden dims of the input."""
|
|
|
|
|
x1 = x[..., :x.shape[-1] // 2]
|
|
|
|
|
x2 = x[..., x.shape[-1] // 2:]
|
|
|
|
|
return torch.cat((-x2, x1), dim=-1)
|
|
|
|
|
|
|
|
|
|
def _yarn_linear_ramp_mask(self, min_value, max_value, dim):
|
|
|
|
|
# Note: The if conditional branch is not used here
|
|
|
|
|
# to solve MTP compilation error.
|
|
|
|
|
max_value += (min_value == max_value).float() * 0.001
|
|
|
|
|
linear_func = (torch.arange(dim, dtype=torch.float32) -
|
|
|
|
|
min_value) / (max_value - min_value)
|
|
|
|
|
ramp_func = torch.clamp(linear_func, 0, 1)
|
|
|
|
|
return ramp_func
|
|
|
|
|
|
|
|
|
|
# Inverse dim formula to find dim based on number of rotations
|
|
|
|
|
def _yarn_find_correction_dim(self,
|
|
|
|
|
num_rotations,
|
|
|
|
|
dim,
|
|
|
|
|
base=10000,
|
|
|
|
|
max_position_embeddings=2048):
|
|
|
|
|
# Note: use torch instead of math to solve MTP compilation error.
|
|
|
|
|
return (dim * torch.log(
|
|
|
|
|
torch.tensor(max_position_embeddings) /
|
|
|
|
|
(num_rotations * 2 * torch.pi))) / (2 *
|
|
|
|
|
torch.log(torch.tensor(base)))
|
|
|
|
|
|
|
|
|
|
# Find dim range bounds based on rotations
|
|
|
|
|
def _yarn_find_correction_range(self,
|
|
|
|
|
low_rot,
|
|
|
|
|
high_rot,
|
|
|
|
|
dim,
|
|
|
|
|
base=10000,
|
|
|
|
|
max_position_embeddings=2048):
|
|
|
|
|
# Note: use torch instead of math to solve MTP compilation error.
|
|
|
|
|
low = torch.floor(
|
|
|
|
|
self._yarn_find_correction_dim(low_rot, dim, base,
|
|
|
|
|
max_position_embeddings))
|
|
|
|
|
high = torch.ceil(
|
|
|
|
|
self._yarn_find_correction_dim(high_rot, dim, base,
|
|
|
|
|
max_position_embeddings))
|
|
|
|
|
# Note: use torch instead of max/min to solve MTP compilation error.
|
|
|
|
|
return torch.clamp(low, min=0), torch.clamp(high, max=dim - 1)
|
|
|
|
|
|
|
|
|
|
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
|
|
|
|
def _apply_rotary_pos_emb(self,
|
|
|
|
|
q,
|
|
|
|
|
k,
|
|
|
|
|
cos,
|
|
|
|
|
sin,
|
|
|
|
|
position_ids,
|
|
|
|
|
unsqueeze_dim=1):
|
|
|
|
|
"""Applies Rotary Position Embedding to the query and key tensors.
|
|
|
|
|
Args:
|
|
|
|
|
q (`torch.Tensor`): The query tensor.
|
|
|
|
|
k (`torch.Tensor`): The key tensor.
|
|
|
|
|
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
|
|
|
|
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
|
|
|
|
position_ids (`torch.Tensor`):
|
|
|
|
|
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
|
|
|
|
used to pass offsetted position ids when working with a KV-cache.
|
|
|
|
|
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
|
|
|
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
|
|
|
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
|
|
|
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
|
|
|
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
|
|
|
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
|
|
|
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
|
|
|
|
Returns:
|
|
|
|
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
|
|
|
|
"""
|
|
|
|
|
cos = cos[position_ids]
|
|
|
|
|
sin = sin[position_ids]
|
|
|
|
|
cos = cos[:, None, None, :]
|
|
|
|
|
sin = sin[:, None, None, :]
|
|
|
|
|
|
|
|
|
|
if len(q.shape) == 3:
|
|
|
|
|
q = q[:, :, None, :]
|
|
|
|
|
if len(k.shape) == 2:
|
|
|
|
|
k = k[:, None, None, :]
|
|
|
|
|
elif len(k.shape) == 3:
|
|
|
|
|
k = k[:, :, None, :]
|
|
|
|
|
|
|
|
|
|
b, h_q, s, d = q.shape
|
|
|
|
|
q = q.view(b, h_q, s, d // 2, 2).transpose(4, 3).reshape(b, h_q, s, d)
|
|
|
|
|
|
|
|
|
|
b, h_k, s, d = k.shape
|
|
|
|
|
k = k.view(b, h_k, s, d // 2, 2).transpose(4, 3).reshape(b, h_k, s, d)
|
|
|
|
|
|
|
|
|
|
q_embed = (q * cos) + (self._rotate_half(q) * sin)
|
|
|
|
|
k_embed = (k * cos) + (self._rotate_half(k) * sin)
|
|
|
|
|
|
|
|
|
|
q_embed = q_embed.view(b, h_q, d)
|
|
|
|
|
k_embed = k_embed.view(b, h_k, d)
|
|
|
|
|
|
|
|
|
|
return q_embed, k_embed
|
|
|
|
|
|
2025-09-08 22:03:34 +08:00
|
|
|
def _set_cos_sin_cache(self, max_seq_len, device, dtype):
|
2025-08-25 09:32:35 +08:00
|
|
|
dim = self.rotary_dim
|
|
|
|
|
|
|
|
|
|
freq_extra = 1.0 / (self.base**(
|
|
|
|
|
torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
|
|
|
|
|
freq_inter = 1.0 / (self.scaling_factor * self.base**(
|
|
|
|
|
torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
|
|
|
|
|
|
|
|
|
|
low, high = self._yarn_find_correction_range(
|
|
|
|
|
self.beta_fast,
|
|
|
|
|
self.beta_slow,
|
|
|
|
|
dim,
|
|
|
|
|
self.base,
|
|
|
|
|
self.max_position_embeddings,
|
|
|
|
|
)
|
|
|
|
|
inv_freq_mask = 1.0 - self._yarn_linear_ramp_mask(
|
|
|
|
|
low, high, dim // 2).to(device=device, dtype=torch.float32)
|
|
|
|
|
inv_freq = freq_inter * (1 -
|
|
|
|
|
inv_freq_mask) + freq_extra * inv_freq_mask
|
|
|
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
|
|
|
|
2025-09-08 22:03:34 +08:00
|
|
|
t = torch.arange(max_seq_len, device=device, dtype=torch.float32)
|
2025-08-25 09:32:35 +08:00
|
|
|
|
|
|
|
|
freqs = torch.outer(t, inv_freq)
|
|
|
|
|
cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale
|
|
|
|
|
sin_cached = torch.cat([freqs, freqs], dim=-1).sin() * self.mscale
|
|
|
|
|
cos_cached = cos_cached.to(dtype)
|
|
|
|
|
sin_cached = sin_cached.to(dtype)
|
|
|
|
|
cache = torch.cat(
|
|
|
|
|
[freqs.cos() * self.mscale,
|
|
|
|
|
freqs.sin() * self.mscale], dim=-1).to(dtype)
|
|
|
|
|
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
|
|
|
|
self.register_buffer("cos_cached", cos_cached, persistent=False)
|
|
|
|
|
self.register_buffer("sin_cached", sin_cached, persistent=False)
|
|
|
|
|
|
|
|
|
|
def forward(self,
|
|
|
|
|
positions: torch.Tensor,
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
2025-09-08 22:03:34 +08:00
|
|
|
offsets: Optional[torch.Tensor] = None):
|
2025-08-25 09:32:35 +08:00
|
|
|
if len(key.shape) == 2:
|
|
|
|
|
key = key[:, None, :]
|
|
|
|
|
# Note: we implement the non neox_style method with shuffle the last dim and neox style
|
|
|
|
|
# calculation method which is also more compute friendly to the ascend machine
|
|
|
|
|
# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py
|
2025-09-09 14:28:14 +08:00
|
|
|
is_neox_style = True
|
2025-08-25 09:32:35 +08:00
|
|
|
if self.is_neox_style is False:
|
|
|
|
|
b, h_q, d = query.shape
|
|
|
|
|
query = query.view(b, h_q, d // 2,
|
|
|
|
|
2).transpose(3, 2).reshape(b, h_q, d)
|
|
|
|
|
b, h_k, d = key.shape
|
|
|
|
|
key = key.view(b, h_k, d // 2, 2).transpose(3,
|
|
|
|
|
2).reshape(b, h_k, d)
|
2025-09-09 14:28:14 +08:00
|
|
|
q_pe, k_pe = _rope_forward_oot(self, positions, query, key,
|
|
|
|
|
is_neox_style, offsets)
|
2025-08-25 09:32:35 +08:00
|
|
|
return q_pe, k_pe
|
2025-10-25 09:12:18 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class AscendMRotaryEmbedding(MRotaryEmbedding):
|
|
|
|
|
|
|
|
|
|
def forward_oot(
|
|
|
|
|
self,
|
|
|
|
|
positions: torch.Tensor,
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
):
|
2025-12-12 15:50:57 +08:00
|
|
|
if self.mrope_section != [16, 24, 24] or \
|
2025-12-17 14:08:19 +08:00
|
|
|
get_ascend_device_type() == AscendDeviceType.A5:
|
2025-10-25 09:12:18 +08:00
|
|
|
return super().forward_oot(positions, query, key)
|
|
|
|
|
|
|
|
|
|
import torch_npu
|
|
|
|
|
mrope_section = [0, 0, 0
|
|
|
|
|
] if positions.ndim == 1 else self.mrope_section
|
|
|
|
|
|
|
|
|
|
if self.cos_sin_cache.device != query.device: # type: ignore
|
|
|
|
|
self.cos_sin_cache = self.cos_sin_cache.to( # type: ignore
|
|
|
|
|
query.device) # type: ignore
|
|
|
|
|
|
|
|
|
|
if self.cos_sin_cache.dtype != query.dtype: # type: ignore
|
|
|
|
|
self.cos_sin_cache = self.cos_sin_cache.to( # type: ignore
|
|
|
|
|
query.dtype) # type: ignore
|
|
|
|
|
|
2025-12-08 19:19:17 +08:00
|
|
|
query, key = torch_npu.npu_mrope(positions.contiguous(),
|
2025-10-25 09:12:18 +08:00
|
|
|
query.contiguous(),
|
|
|
|
|
key.contiguous(),
|
|
|
|
|
self.cos_sin_cache.contiguous(),
|
|
|
|
|
self.head_size,
|
|
|
|
|
mrope_section=mrope_section,
|
|
|
|
|
rotary_mode='half')
|
|
|
|
|
|
2025-11-19 22:31:14 +08:00
|
|
|
return query, key
|
[CustomOp] Register AscendApplyRotaryEmb CustomOp and remove related patch (#4667)
### What this PR does / why we need it?
Following https://github.com/vllm-project/vllm/pull/29873, register
`AscendApplyRotaryEmb` CustomOp and remove related patch.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
#### ✅ Test Qwen2.5-VL
Run:
```bash
vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct \
--max_model_len 16384
```
Output:
```
{"id":"chatcmpl-b02c1ff3415d2462","object":"chat.completion","created":1766129265,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-In struct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the illustration is \"TONGYI Qwen.\" The word \"TONGYI\" is writ ten in blue, and \"Qwen\" is written in gray. The text appears to be part of a logo or branding design.","refusal":null,"annotations":null,"audio": null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"stop","stop_reason":null,"tok en_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":78,"total_tokens":129,"completion_tokens":51,"prompt_tokens_d
```
#### ✅ Test Qwen3-VL
Run:
```bash
vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \
--max_model_len 16384
```
Output:
```
{"id":"chatcmpl-a3a7de5a900a9321","object":"chat.completion","created":1766129586,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the illustration is **“TONGYI Qwen”**.\n\n### How it looks:\n- **“TONGYI”** is written in **uppercase letters** in a **bold, modern sans-serif font**, colored **blue**.\n- **“Qwen”** is written in **lowercase letters** in a **slightly thinner, elegant sans-serif font**, colored **dark gray**.\n- The two lines of text are stacked vertically, with “TONG","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":112,"total_tokens":212,"completion_tokens":100,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null}
```
- vLLM version: v0.12.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9
---------
Signed-off-by: shen-shanshan <467638484@qq.com>
2025-12-23 10:04:37 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class AscendApplyRotaryEmb(ApplyRotaryEmb):
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
enforce_enable: bool = False,
|
|
|
|
|
is_neox_style: bool = True,
|
|
|
|
|
enable_fp32_compute: bool = False,
|
|
|
|
|
) -> None:
|
|
|
|
|
super().__init__(
|
|
|
|
|
enforce_enable=enforce_enable,
|
|
|
|
|
is_neox_style=is_neox_style,
|
|
|
|
|
enable_fp32_compute=enable_fp32_compute,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def forward_oot(
|
|
|
|
|
self,
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
cos: torch.Tensor,
|
|
|
|
|
sin: torch.Tensor,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
head_dim = x.shape[-1]
|
|
|
|
|
|
|
|
|
|
origin_dtype = x.dtype
|
|
|
|
|
if self.enable_fp32_compute:
|
|
|
|
|
x = x.float()
|
|
|
|
|
cos = cos.float()
|
|
|
|
|
sin = sin.float()
|
|
|
|
|
|
|
|
|
|
# cos, sin: [seq_len, head_dim // 2]
|
|
|
|
|
cos = torch.cat((cos, cos), dim=-1)
|
|
|
|
|
sin = torch.cat((sin, sin), dim=-1)
|
|
|
|
|
# cos, sin: [1, seq_len, 1, head_dim]
|
|
|
|
|
cos = cos.reshape(1, -1, 1, head_dim)
|
|
|
|
|
sin = sin.reshape(1, -1, 1, head_dim)
|
|
|
|
|
|
|
|
|
|
if len(x.shape) == 3:
|
|
|
|
|
# x: [seq_len, num_heads, head_size]
|
|
|
|
|
x = x.unsqueeze(0)
|
|
|
|
|
# x: [1, seq_len, num_heads, head_size]
|
|
|
|
|
output = torch_npu.npu_rotary_mul(x, cos, sin).squeeze(0)
|
|
|
|
|
else:
|
|
|
|
|
assert len(x.shape) == 4
|
|
|
|
|
# x: [2 * b, s, head, head_dim]
|
|
|
|
|
qk = einops.rearrange(
|
|
|
|
|
x, "(two b) s head head_dim -> b s two head head_dim", two=2)
|
|
|
|
|
# q, k: [b, s, head, head_dim]
|
|
|
|
|
q, k = qk[:, :, 0], qk[:, :, 1]
|
|
|
|
|
q = torch_npu.npu_rotary_mul(q, cos, sin)
|
|
|
|
|
k = torch_npu.npu_rotary_mul(k, cos, sin)
|
|
|
|
|
output = torch.cat([q, k], dim=0)
|
|
|
|
|
|
|
|
|
|
if self.enable_fp32_compute:
|
|
|
|
|
output = output.to(origin_dtype)
|
|
|
|
|
return output
|