Drop torchair (#4814)

aclgraph is stable and fast now. Let's drop torchair graph mode now.

TODO: some logic to adapt torchair should be cleaned up as well. We'll
do it in the following PR.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
wangxiyuan
2025-12-10 09:20:40 +08:00
committed by GitHub
parent ba9cda9dfd
commit 835b4c8f1d
84 changed files with 77 additions and 16881 deletions

View File

@@ -107,7 +107,6 @@ jobs:
# ------------------------------------ v1 spec decode test ------------------------------------ #
pytest -sv tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py
pytest -sv tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_torchair_correctness.py
pytest -sv tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py
e2e-2-cards:
@@ -170,10 +169,6 @@ jobs:
if: ${{ inputs.type == 'light' }}
run: |
pytest -sv tests/e2e/multicard/test_qwen3_moe.py::test_models_distributed_Qwen3_MOE_TP2_WITH_EP
pytest -sv tests/e2e/multicard/test_torchair_graph_mode.py::test_e2e_qwen3_moe_with_torchair
pytest -sv tests/e2e/multicard/test_torchair_graph_mode.py::test_e2e_deepseekv2lite_with_torchair
pytest -sv tests/e2e/multicard/test_torchair_graph_mode.py::test_e2e_deepseekv2lite_with_torchair_v1scheduler
pytest -sv tests/e2e/multicard/test_torchair_graph_mode.py::test_e2e_deepseekv2lite_with_nz
- name: Run vllm-project/vllm-ascend test (full)
env:
@@ -183,7 +178,6 @@ jobs:
run: |
pytest -sv tests/e2e/multicard/test_quantization.py
pytest -sv tests/e2e/multicard/test_aclgraph_capture_replay.py
pytest -sv tests/e2e/multicard/test_torchair_graph_mode.py
pytest -sv tests/e2e/multicard/test_full_graph_mode.py
pytest -sv tests/e2e/multicard/test_data_parallel.py
pytest -sv tests/e2e/multicard/test_expert_parallel.py

View File

@@ -127,9 +127,6 @@ jobs:
- name: multi-node-deepseek-dp
config_file_path: DeepSeek-R1-W8A8-A2.yaml
size: 2
- name: multi-node-deepseek-dp-torchair
config_file_path: DeepSeek-R1-W8A8-A2-torchair.yaml
size: 2
uses: ./.github/workflows/_e2e_nightly_multi_node.yaml
with:
soc_version: a2

View File

@@ -134,8 +134,6 @@ jobs:
run: |
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/arm64-linux/devlib
pytest -sv --cov --cov-report=xml:unittests-coverage.xml tests/ut \
--ignore tests/ut/torchair/models/test_torchair_deepseek_mtp.py \
--ignore tests/ut/torchair/models/test_torchair_deepseek_v2.py \
--ignore tests/ut/model_loader/netloader/test_netloader_elastic.py \
--ignore tests/ut/kv_connector/test_remote_prefill_lifecycle.py \
--ignore tests/ut/kv_connector/test_remote_decode_lifecycle.py \

View File

@@ -12,7 +12,7 @@ repos:
- id: codespell
args: [
--toml, pyproject.toml,
'--skip', 'tests/e2e/multicard/test_torchair_graph_mode.py,csrc/**,tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**,.github/**,typos.toml',
'--skip', 'csrc/**,tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**,.github/**,typos.toml',
'-L', 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,ArchType,AND,ND'
]
additional_dependencies:

View File

@@ -251,7 +251,6 @@ This will reproduce the E2E test. See [vllm_ascend_test.yaml](https://github.com
- Offline test example: [`tests/e2e/singlecard/test_offline_inference.py`](https://github.com/vllm-project/vllm-ascend/blob/main/tests/e2e/singlecard/test_offline_inference.py)
- Online test examples: [`tests/e2e/singlecard/test_prompt_embedding.py`](https://github.com/vllm-project/vllm-ascend/blob/main/tests/e2e/singlecard/test_prompt_embedding.py)
- Correctness test example: [`tests/e2e/singlecard/test_aclgraph.py`](https://github.com/vllm-project/vllm-ascend/blob/main/tests/e2e/singlecard/test_aclgraph.py)
- Reduced Layer model test example: [test_torchair_graph_mode.py - DeepSeek-V3-Pruning](https://github.com/vllm-project/vllm-ascend/blob/20767a043cccb3764214930d4695e53941de87ec/tests/e2e/multicard/test_torchair_graph_mode.py#L48)
The CI resource is limited, and you might need to reduce the number of layers of a model. Below is an example of how to generate a reduced layer model:
1. Fork the original model repo in modelscope. All the files in the repo except for weights are required.

View File

@@ -6,7 +6,7 @@ MTP boosts inference performance by parallelizing the prediction of multiple tok
## How to Use MTP
To enable MTP for DeepSeek-V3 models, add the following parameter when starting the service:
--speculative_config ' {"method": "deepseek_mtp", "num_speculative_tokens": 1, "disable_padded_drafter_batch": False} '
--speculative_config ' {"method": "mtp", "num_speculative_tokens": 1, "disable_padded_drafter_batch": False} '
- `num_speculative_tokens`: The number of speculative tokens which enable model to predict multiple tokens at once, if provided. It will default to the number in the draft model config if present, otherwise, it is required.
- `disable_padded_drafter_batch`: Disable input padding for speculative decoding. If set to True, speculative input batches can contain sequences of different lengths, which may only be supported by certain attention backends. This currently only affects the MTP method of speculation, default is False.
@@ -74,21 +74,18 @@ If the bonus token is accepted, the MTP model performs inference for (num_specul
### Method Validation
- Currently, the spec_decode scenario only supports methods such as ngram, eagle, eagle3, and deepseek_mtp. If an incorrect parameter is passed for the method, the code will raise an error to alert the user that an incorrect method was provided.
- Currently, the spec_decode scenario only supports methods such as ngram, eagle, eagle3, and mtp. If an incorrect parameter is passed for the method, the code will raise an error to alert the user that an incorrect method was provided.
```
def get_spec_decode_method(method,
vllm_config,
device,
runner,
is_torchair_graph=False):
runner):
if method == "ngram":
return NgramProposer(vllm_config, device, runner)
elif method in ["eagle", "eagle3"]:
return EagleProposer(vllm_config, device, runner)
elif method == 'deepseek_mtp':
if is_torchair_graph:
return TorchairMtpProposer(vllm_config, device, runner)
elif method == 'mtp':
return MtpProposer(vllm_config, device, runner)
else:
raise ValueError("Unknown speculative decoding method: "

View File

@@ -128,9 +128,8 @@ vllm serve /weights/DeepSeek-V3.1_w8a8mix_mtp \
--trust-remote-code \
--no-enable-prefix-caching \
--gpu-memory-utilization 0.92 \
--speculative-config '{"num_speculative_tokens": 1, "method": "deepseek_mtp"}' \
--speculative-config '{"num_speculative_tokens": 1, "method": "mtp"}' \
--compilation-config '{"cudagraph_mode": "FULL_DECODE_ONLY"}' \
--additional-config '{"torchair_graph_config":{"enabled":false}}'
```
### Multi-node Deployment
@@ -190,9 +189,8 @@ vllm serve /weights/DeepSeek-V3.1_w8a8mix_mtp \
--trust-remote-code \
--no-enable-prefix-caching \
--gpu-memory-utilization 0.94 \
--speculative-config '{"num_speculative_tokens": 1, "method": "deepseek_mtp"}' \
--speculative-config '{"num_speculative_tokens": 1, "method": "mtp"}' \
--compilation-config '{"cudagraph_mode": "FULL_DECODE_ONLY"}' \
--additional-config '{"torchair_graph_config":{"enabled":false}}'
```
**Node 1**
@@ -247,9 +245,8 @@ vllm serve /weights/DeepSeek-V3.1_w8a8mix_mtp \
--trust-remote-code \
--no-enable-prefix-caching \
--gpu-memory-utilization 0.94 \
--speculative-config '{"num_speculative_tokens": 1, "method": "deepseek_mtp"}' \
--speculative-config '{"num_speculative_tokens": 1, "method": "mtp"}' \
--compilation-config '{"cudagraph_mode": "FULL_DECODE_ONLY"}' \
--additional-config '{"torchair_graph_config":{"enabled":false}}'
```
### Prefill-Decode Disaggregation
@@ -421,7 +418,7 @@ vllm serve /weights/DeepSeek-V3.1_w8a8mix_mtp \
--gpu-memory-utilization 0.9 \
--quantization ascend \
--no-enable-prefix-caching \
--speculative-config '{"num_speculative_tokens": 1, "method": "deepseek_mtp"}' \
--speculative-config '{"num_speculative_tokens": 1, "method": "mtp"}' \
--additional-config '{"recompute_scheduler_enable":true,"enable_shared_expert_dp": true}' \
--kv-transfer-config \
'{"kv_connector": "MooncakeConnector",

View File

@@ -173,8 +173,7 @@ vllm serve vllm-ascend/DeepSeek-V3.2-Exp-W8A8 \
--enable-expert-parallel \
--trust-remote-code \
--no-enable-prefix-caching \
--gpu-memory-utilization 0.92 \
--additional-config '{"torchair_graph_config":{"enabled":true,"graph_batch_sizes":[16]}}'
--gpu-memory-utilization 0.92
```
### Multi-node Deployment
@@ -225,8 +224,7 @@ vllm serve /root/.cache/Modelers_Park/DeepSeek-V3.2-Exp \
--max-num-batched-tokens 17450 \
--trust-remote-code \
--no-enable-prefix-caching \
--gpu-memory-utilization 0.9 \
--additional-config '{"torchair_graph_config":{"enabled":true,"graph_batch_sizes":[16]}}'
--gpu-memory-utilization 0.9
```
**Node 1**
@@ -269,8 +267,7 @@ vllm serve /root/.cache/Modelers_Park/DeepSeek-V3.2-Exp \
--enable-expert-parallel \
--trust-remote-code \
--no-enable-prefix-caching \
--gpu-memory-utilization 0.92 \
--additional-config '{"torchair_graph_config":{"enabled":true,"graph_batch_sizes":[16]}}'
--gpu-memory-utilization 0.92
```
::::
@@ -316,8 +313,7 @@ vllm serve vllm-ascend/DeepSeek-V3.2-Exp-W8A8 \
--trust-remote-code \
--quantization ascend \
--no-enable-prefix-caching \
--gpu-memory-utilization 0.9 \
--additional-config '{"torchair_graph_config":{"enabled":true,"graph_batch_sizes":[16]}}'
--gpu-memory-utilization 0.9
```
**Node 1**
@@ -362,8 +358,7 @@ vllm serve vllm-ascend/DeepSeek-V3.2-Exp-W8A8 \
--trust-remote-code \
--quantization ascend \
--no-enable-prefix-caching \
--gpu-memory-utilization 0.92 \
--additional-config '{"torchair_graph_config":{"enabled":true,"graph_batch_sizes":[16]}}'
--gpu-memory-utilization 0.92
```
::::

View File

@@ -136,8 +136,7 @@ vllm serve vllm-ascend/DeepSeek-V3.1-W8A8 \
--max-num-batched-tokens 8192 \
--trust-remote-code \
--no-enable-prefix-caching \
--gpu-memory-utilization 0.9 \
--additional-config '{"torchair_graph_config":{"enabled":true}}'
--gpu-memory-utilization 0.9
```
**Node 1**
@@ -181,8 +180,7 @@ vllm serve vllm-ascend/DeepSeek-V3.1-W8A8 \
--enable-expert-parallel \
--trust-remote-code \
--no-enable-prefix-caching \
--gpu-memory-utilization 0.92 \
--additional-config '{"torchair_graph_config":{"enabled":true}}'
--gpu-memory-utilization 0.92
```
The deployment view looks like:

View File

@@ -92,8 +92,7 @@ vllm serve /home/cache/weights/Kimi-K2-Instruct-W8A8 \
--max-num-batched-tokens 8192 \
--trust-remote-code \
--no-enable-prefix-caching \
--gpu-memory-utilization 0.9 \
--additional-config '{"torchair_graph_config":{"enabled":true}}'
--gpu-memory-utilization 0.9
```
**Node 1**
@@ -136,8 +135,7 @@ vllm serve /home/cache/weights/Kimi-K2-Instruct-W8A8 \
--enable-expert-parallel \
--trust-remote-code \
--no-enable-prefix-caching \
--gpu-memory-utilization 0.92 \
--additional-config '{"torchair_graph_config":{"enabled":true}}'
--gpu-memory-utilization 0.92
```
The deployment view looks like:

View File

@@ -153,12 +153,7 @@ if __name__ == "__main__":
enable_expert_parallel=True,
distributed_executor_backend="mp",
max_model_len=1024,
trust_remote_code=True,
additional_config={
'torchair_graph_config': {
'enabled': True,
}
})
trust_remote_code=True)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:

View File

@@ -27,7 +27,6 @@ The following table lists additional configuration options available in vLLM Asc
| Name | Type | Default | Description |
|-------------------------------------|------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------|
| `xlite_graph_config` | dict | `{}` | Configuration options for xlite graph mode |
| `torchair_graph_config` | dict | `{}` | Configuration options for torchair graph mode |
| `weight_prefetch_config` | dict | `{}` | Configuration options for weight prefetch |
| `refresh` | bool | `false` | Whether to refresh global Ascend configuration content. This is usually used by rlhf or ut/e2e test case. |
| `expert_map_path` | str | `None` | When using expert load balancing for an MoE model, an expert map path needs to be passed in. |
@@ -52,21 +51,6 @@ The details of each configuration option are as follows:
| `enabled` | bool | `False` | Whether to enable xlite graph mode. Currently only Llama or Qwen dense series models are supported. |
| `full_mode` | bool | `False` | Whether to enable xlite for both the prefill and decode stages. By default, xlite is only enabled for the decode stage. |
**torchair_graph_config**
| Name | Type | Default | Description |
| ---- | ---- | ------- | ----------- |
| `enabled` | bool | `False` | Whether to enable torchair graph mode. Currently only DeepSeek series models and PanguProMoE are supported. |
| `mode` | str | `None` | When using reduce-overhead mode for torchair, it needs to be set. |
| `enable_multistream_mla`| bool | `False` | Whether to put vector operators of MLA to another stream. This option only takes effect on models using MLA (for example, DeepSeek). |
| `enable_view_optimize` | bool | `True` | Whether to enable torchair view optimization. |
| `enable_frozen_parameter` | bool | `True` | Whether to fix the memory address of weights during inference to reduce the input address refresh time during graph execution. |
| `use_cached_graph` | bool | `False` | Whether to use cached graph. |
| `graph_batch_sizes` | list[int] | `[]` | The batch size for torchair graph cache. |
| `graph_batch_sizes_init` | bool | `False` | Init graph batch size dynamically if `graph_batch_sizes` is empty. |
| `enable_kv_nz`| bool | `False` | Whether to enable KV Cache NZ layout. This option only takes effect on models using MLA (for example, DeepSeek). |
| `enable_super_kernel` | bool | `False` | Whether to enable super kernel to fuse operators in deepseek moe layers. This option only takes effects on moe models using dynamic w8a8 quantization.|
**weight_prefetch_config**
| Name | Type | Default | Description |
@@ -80,13 +64,6 @@ An example of additional configuration is as follows:
```
{
"torchair_graph_config": {
"enabled": True,
"use_cached_graph": True,
"graph_batch_sizes": [1, 2, 4, 8],
"graph_batch_sizes_init": False,
"enable_kv_nz": False
},
"weight_prefetch_config": {
"enabled": True,
"prefetch_ratio": {

View File

@@ -10,9 +10,8 @@ This guide provides instructions for using Ascend Graph Mode with vLLM Ascend. P
From v0.9.1rc1 with V1 Engine, vLLM Ascend will run models in graph mode by default to keep the same behavior with vLLM. If you hit any issues, please feel free to open an issue on GitHub and fallback to the eager mode temporarily by setting `enforce_eager=True` when initializing the model.
There are three kinds for graph mode supported by vLLM Ascend:
There are two kinds for graph mode supported by vLLM Ascend:
- **ACLGraph**: This is the default graph mode supported by vLLM Ascend. In v0.9.1rc1, Qwen and Deepseek series models are well tested.
- **TorchAirGraph**: This is the GE graph mode. In v0.9.1rc1, only DeepSeek series models are supported.
- **XliteGraph**: This is the euler xlite graph mode. In v0.11.0, only Llama and Qwen dense serise models are supported.
## Using ACLGraph
@@ -35,29 +34,6 @@ Online example:
vllm serve Qwen/Qwen2-7B-Instruct
```
## Using TorchAirGraph
If you want to run DeepSeek series models with the graph mode, you should use [TorchAirGraph](https://www.hiascend.com/document/detail/zh/Pytorch/700/modthirdparty/torchairuseguide/torchair_0002.html). In this case, additional configuration is required.
Offline example:
```python
import os
from vllm import LLM
# TorchAirGraph only works without chunked-prefill now
model = LLM(model="path/to/DeepSeek-R1-0528", additional_config={"torchair_graph_config": {"enabled": True}})
outputs = model.generate("Hello, how are you?")
```
Online example:
```shell
vllm serve path/to/DeepSeek-R1-0528 --additional-config='{"torchair_graph_config": {"enabled": true}}'
```
You can find more details about additional configuration [here](../configuration/additional_config.md).
## Using XliteGraph
If you want to run Llama or Qwen dense series models with xlite graph mode, please install xlite, and set xlite_graph_config.
@@ -87,7 +63,7 @@ You can find more details abort xlite [here](https://gitee.com/openeuler/GVirt/b
## Fallback to the Eager Mode
If `ACLGraph`, `TorchAirGraph` and `XliteGraph` all fail to run, you should fallback to the eager mode.
If `ACLGraph` and `XliteGraph` all fail to run, you should fallback to the eager mode.
Offline example:

View File

@@ -104,22 +104,3 @@ First, make sure you specify `ascend` as the quantization method. Second, check
### 2. How to solve the error "Could not locate the configuration_deepseek.py"?
Please convert DeepSeek series models using `br_release_MindStudio_8.1.RC2_TR5_20260624` ModelSlim, where the missing configuration_deepseek.py error has been fixed.
### 3. What should be considered when converting DeepSeek series models with ModelSlim?
When the MLA portion of the weights used the `W8A8_DYNAMIC` quantization with the torchair graph mode enabled, modify the configuration file in the CANN package to prevent incorrect inference results.
The operation steps are as follows:
1. Search in the CANN package directory, for example:
find /usr/local/Ascend/ -name fusion_config.json
2. Add `"AddRmsNormDynamicQuantFusionPass":"off",` and `"MultiAddRmsNormDynamicQuantFusionPass":"off",` to the fusion_config.json you find, the location is as follows:
```bash
{
"Switch":{
"GraphFusion":{
"AddRmsNormDynamicQuantFusionPass":"off",
"MultiAddRmsNormDynamicQuantFusionPass":"off",
```

View File

@@ -27,5 +27,4 @@ vllm serve Qwen/Qwen1.5-MoE-A2.7B \
--max-num-batched-tokens 4096 \
--gpu-memory-utilization 0.9 \
--trust-remote-code \
--enforce-eager \
--additional-config '{"torchair_graph_config":{"enabled":false, "use_cached_graph":false}}'
--enforce-eager

View File

@@ -1,59 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
import pytest
import vllm # noqa: F401
import vllm_ascend # noqa: F401
from tests.e2e.conftest import VllmRunner
# Pangu local model path
MODELS = [
"IntervitensInc/pangu-pro-moe-model",
]
# set additional config for ascend scheduler and torchair graph
ADDITIONAL_CONFIG = [{
"additional_config": {
"torchair_graph_config": {
"enabled": True
}
}
}]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float16"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enfore_eager", [True, False])
@pytest.mark.parametrize("additional_config", ADDITIONAL_CONFIG)
def test_pangu_model(model: str, dtype: str, max_tokens: int,
enfore_eager: bool, additional_config: dict) -> None:
if enfore_eager:
additional_config = {}
example_prompts = [
"Hello, my name is",
"The future of AI is",
]
with VllmRunner(model,
tensor_parallel_size=4,
dtype=dtype,
max_model_len=1024,
enforce_eager=True,
enable_expert_parallel=True,
additional_config=additional_config,
distributed_executor_backend="mp") as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)

View File

@@ -78,9 +78,6 @@ def test_models_distributed_DeepSeek_multistream_moe():
tensor_parallel_size=2,
distributed_executor_backend="mp",
additional_config={
"torchair_graph_config": {
"enabled": True,
},
"enable_multistream_moe": True,
"refresh": True,
},
@@ -144,17 +141,12 @@ def test_models_distributed_DeepSeek_W4A8DYNAMIC(model):
"Hello, my name is",
]
max_tokens = 5
with VllmRunner(
snapshot_download(model),
dtype="auto",
tensor_parallel_size=2,
quantization="ascend",
enforce_eager=True,
enable_expert_parallel=True,
additional_config={"torchair_graph_config": {
"enabled": False,
}},
) as vllm_model:
with VllmRunner(snapshot_download(model),
dtype="auto",
tensor_parallel_size=2,
quantization="ascend",
enforce_eager=True,
enable_expert_parallel=True) as vllm_model:
vllm_model.generate_greedy(prompts, max_tokens)

View File

@@ -1,290 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
"""Compare the short outputs of HF and vLLM when using greedy sampling.
Run `pytest tests/multicard/test_torchair_graph_mode.py`.
"""
import os
from typing import Dict
import pytest
from tests.e2e.conftest import VllmRunner
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
def _deepseek_torchair_test_fixture(
additional_config: Dict,
*,
tensor_parallel_size=2,
use_v1_schduler=False,
):
example_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
kwargs = {}
if not use_v1_schduler:
kwargs = {
"refresh": True,
}
additional_config.update(**kwargs)
with VllmRunner(
"vllm-ascend/DeepSeek-V3-Pruning",
dtype="half",
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend="mp",
additional_config=additional_config,
) as vllm_model:
# use greedy sampler to make sure the generated results are fix
vllm_output = vllm_model.generate_greedy(example_prompts, 5)
# NOTE: vllm-ascend/DeepSeek-V3-Pruning is a random weight of
# DeepSeek-V3 with 2 hidden layers, thus the golden results seems
# inaccurate. This will only change if accuracy improves with the
# official weights of DeepSeek-V3.
golden_results = [
'Hello, my name is下载早点向前很有่อง',
'The president of the United States isSender)## physiological Albany',
'The capital of France is Rocky转角 hospitalizedinterval sparked',
'The future of AI is её asegο BIOS一扫',
]
assert len(golden_results) == len(vllm_output)
for i in range(len(vllm_output)):
assert golden_results[i] == vllm_output[i][1]
print(f"Generated text: {vllm_output[i][1]!r}")
def test_e2e_deepseekv3_with_torchair():
additional_config = {
"torchair_graph_config": {
"enabled": True,
},
}
_deepseek_torchair_test_fixture(additional_config)
def test_e2e_deepseekv3_with_torchair_ms_mla():
additional_config = {
"torchair_graph_config": {
"enabled": True,
"enable_multistream_mla": True,
},
}
_deepseek_torchair_test_fixture(additional_config)
@pytest.mark.skip("accuracy test failed. Fix me")
def test_e2e_deepseekv3_with_torchair_v1scheduler():
additional_config = {
"torchair_graph_config": {
"enabled": True,
},
}
_deepseek_torchair_test_fixture(additional_config, use_v1_schduler=True)
def _pangu_torchair_test_fixture(
additional_config: Dict,
*,
tensor_parallel_size=2,
):
example_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# torchair is only work without chunked-prefill now
kwargs = {
"refresh": True,
}
additional_config.update(**kwargs)
with VllmRunner(
"vllm-ascend/pangu-pro-moe-pruing",
dtype="half",
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend="mp",
additional_config=additional_config,
enable_expert_parallel=True,
) as vllm_model:
# use greedy sampler to make sure the generated results are fix
vllm_output = vllm_model.generate_greedy(example_prompts, 5)
# NOTE: vllm-ascend/pangu-pro-moe-pruing is only part of PanguProMoE
# with 2 hidden layers, thus the golden results seems inaccurate.
# This will only change if accuracy changes with the official weights
# of PanguProMoE.
golden_results = [
'Hello, my name is Remempondeprecatedmiot忱',
'The president of the United States is Remem下的一个 rever ceremoni Segnali',
'The capital of France is Rememvoud administrativ Remem投',
'The future of AI isotope Segnali Zoeken精细化 supus',
]
assert len(golden_results) == len(vllm_output)
for i in range(len(vllm_output)):
assert golden_results[i] == vllm_output[i][1]
print(f"Generated text: {vllm_output[i][1]!r}")
@pytest.mark.skip("skipping test_e2e_pangu_with_torchair")
def test_e2e_pangu_with_torchair():
additional_config = {
"torchair_graph_config": {
"enabled": True,
},
}
_pangu_torchair_test_fixture(additional_config)
def _qwen_torchair_test_fixture(
model,
tp,
enable_expert_parallel,
):
# The current access control does not support 16 cards,
# so the MC2 operator in Qwen's graph mode cannot run.
# Once 16-card support is available,
# this e2e can be switched to graph mode.
example_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
additional_config = {
"torchair_graph_config": {
"enabled": False,
},
"refresh": True,
}
with VllmRunner(
model,
dtype="half",
tensor_parallel_size=tp,
distributed_executor_backend="mp",
enforce_eager=True,
additional_config=additional_config,
enable_expert_parallel=enable_expert_parallel,
) as vllm_model:
# use greedy sampler to make sure the generated results are fix
vllm_output = vllm_model.generate_greedy(example_prompts, 5)
# NOTE: vllm-ascend/pangu-pro-moe-pruing is only part of PanguProMoE
# with 2 hidden layers, thus the golden results seems inaccurate.
# This will only change if accuracy changes with the official weights
# of PanguProMoE.
golden_results = [
'Hello, my name is Remempondeprecatedmiot忱',
'The president of the United States is Remem下的一个 rever ceremoni Segnali',
'The capital of France is Rememvoud administrativ Remem投',
'The future of AI isotope Segnali Zoeken精细化 supus',
]
assert len(golden_results) == len(vllm_output)
for i in range(len(vllm_output)):
print(f"Generated text: {vllm_output[i][1]!r}")
def test_e2e_qwen2_with_torchair():
_qwen_torchair_test_fixture("Qwen/Qwen2.5-0.5B-Instruct", 2, False)
def test_e2e_qwen3_moe_with_torchair():
_qwen_torchair_test_fixture("Qwen/Qwen3-30B-A3B", 2, True)
# test deepseek-v2-lite
def _deepseek_v2_lite_torchair_test_fixure(
additional_config: Dict,
*,
tensor_parallel_size=2,
use_v1_schduler=False,
):
example_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
kwargs = {}
if not use_v1_schduler:
kwargs = {
"refresh": True,
}
additional_config.update(**kwargs)
with VllmRunner(
"deepseek-ai/DeepSeek-V2-Lite",
dtype="half",
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend="mp",
additional_config=additional_config,
) as vllm_model:
vllm_output = vllm_model.generate_greedy(example_prompts, 5)
# NOTE: deepseek-ai/DeepSeek-V2-Lite is a random weight of
# DeepSeek-V2-Lite with 2 hidden layers, thus the golden results seems
# inaccurate. This will only change if accuracy improves with the
# official weights of DeepSeek-V2-Lite.
for i in range(len(vllm_output)):
generated_text = vllm_output[i][1]
assert len(
generated_text.strip()) > 0, f"The {i}-th output is null, failed"
def test_e2e_deepseekv2lite_with_torchair():
additional_config = {
"torchair_graph_config": {
"enabled": True,
},
}
_deepseek_v2_lite_torchair_test_fixure(additional_config)
def test_e2e_deepseekv2lite_with_torchair_v1scheduler():
additional_config = {
"torchair_graph_config": {
"enabled": True,
},
}
_deepseek_v2_lite_torchair_test_fixure(additional_config,
use_v1_schduler=True)
# kv_cache enable e2e test
def test_e2e_deepseekv2lite_with_nz():
additional_config = {
"torchair_graph_config": {
"enabled": True,
"enable_kv_nz": True,
},
}
_deepseek_v2_lite_torchair_test_fixure(additional_config)

View File

@@ -73,7 +73,6 @@ async def test_models(model: str, mode: str) -> None:
"VLLM_RPC_TIMEOUT": "3600000",
"VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS": "3600000"
}
additional_config: dict[str, Any] = {}
speculative_config = {"num_speculative_tokens": 2, "method": "mtp"}
compilation_config = {
"cudagraph_capture_sizes": [56],
@@ -104,7 +103,6 @@ async def test_models(model: str, mode: str) -> None:
["--speculative-config",
json.dumps(speculative_config)])
server_args.extend(["--gpu-memory-utilization", "0.92"])
additional_config["torchair_graph_config"] = {"enabled": True}
aisbench_cases = aisbench_gsm8k
if mode == "mtp3":
env_dict["HCCL_OP_EXPANSION_MODE"] = "AIV"
@@ -117,9 +115,7 @@ async def test_models(model: str, mode: str) -> None:
server_args.extend(
["--compilation-config",
json.dumps(compilation_config)])
additional_config["torchair_graph_config"] = {"enabled": False}
aisbench_cases = aisbench_aime
server_args.extend(["--additional-config", json.dumps(additional_config)])
request_keyword_args: dict[str, Any] = {
**api_keyword_args,
}

View File

@@ -74,13 +74,6 @@ async def test_models(model: str) -> None:
"PYTORCH_NPU_ALLOC_CONF": "expandable_segments:True",
}
additional_config = {
"torchair_graph_config": {
"enabled": True,
"enable_multistream_moe": False,
"enable_multistream_mla": True,
"graph_batch_size": [16],
"use_cached_graph": True
},
"chunked_prefill_for_mla": True,
"enable_weight_nz_layout": True
}

View File

@@ -29,7 +29,6 @@ MODELS = [
]
MODES = [
"torchair",
"single",
"aclgraph",
"aclgraph_mlapo",
@@ -78,13 +77,6 @@ async def test_models(model: str, mode: str) -> None:
}
speculative_config = {"num_speculative_tokens": 1, "method": "mtp"}
additional_config = {
"torchair_graph_config": {
"enabled": True,
"enable_multistream_moe": False,
"enable_multistream_mla": True,
"graph_batch_sizes": [16],
"use_cached_graph": True
},
"chunked_prefill_for_mla": True,
"enable_weight_nz_layout": True
}
@@ -99,12 +91,8 @@ async def test_models(model: str, mode: str) -> None:
]
if mode == "single":
server_args.append("--enforce-eager")
additional_config["torchair_graph_config"] = {"enabled": False}
if mode == "aclgraph":
additional_config["torchair_graph_config"] = {"enabled": False}
if mode == "aclgraph_mlapo":
env_dict["VLLM_ASCEND_ENABLE_MLAPO"] = "1"
additional_config["torchair_graph_config"] = {"enabled": False}
server_args.extend(["--additional-config", json.dumps(additional_config)])
request_keyword_args: dict[str, Any] = {
**api_keyword_args,

View File

@@ -68,9 +68,6 @@ async def test_models(model: str) -> None:
"cudagraph_mode": "FULL_DECODE_ONLY"
}
additional_config: dict[str, Any] = {
"torchair_graph_config": {
"enabled": True
},
"enable_shared_expert_dp": False,
"multistream_overlap_shared_expert": False,
"dynamic_eplb": True,

View File

@@ -72,27 +72,13 @@ async def test_models(model: str, tp_size: int, dp_size: int,
port = get_open_port()
env_dict = {"HCCL_BUFFSIZE": "1024", "VLLM_ASCEND_ENABLE_MLAPO": "0"}
server_args = [
"--no-enable-prefix-caching",
"--enable-expert-parallel",
"--no-enable-prefix-caching", "--enable-expert-parallel",
"--tensor-parallel-size",
str(tp_size),
"--data-parallel-size",
str(dp_size),
"--port",
str(port),
"--max-model-len",
"16384",
"--max-num-batched-tokens",
"16384",
"--block-size",
"16",
"--trust-remote-code",
"--quantization",
"ascend",
"--gpu-memory-utilization",
"0.9",
"--additional-config",
'{"torchair_graph_config":{"enabled":true,"graph_batch_sizes":[16]}}',
str(tp_size), "--data-parallel-size",
str(dp_size), "--port",
str(port), "--max-model-len", "16384", "--max-num-batched-tokens",
"16384", "--block-size", "16", "--trust-remote-code", "--quantization",
"ascend", "--gpu-memory-utilization", "0.9"
]
if full_graph:
server_args += [

View File

@@ -1,64 +0,0 @@
test_name: "test DeepSeek-R1-W8A8 torchair on A2"
model: "vllm-ascend/DeepSeek-R1-0528-W8A8"
num_nodes: 2
npu_per_node: 8
env_common:
VLLM_USE_MODELSCOPE: true
HCCL_BUFFSIZE: 1024
SERVER_PORT: 8080
OMP_PROC_BIND: false
OMP_NUM_THREADS: 10
deployment:
-
server_cmd: >
vllm serve vllm-ascend/DeepSeek-R1-0528-W8A8
--host 0.0.0.0
--port $SERVER_PORT
--data-parallel-size 4
--data-parallel-size-local 2
--data-parallel-address $LOCAL_IP
--data-parallel-rpc-port 13399
--no-enable-prefix-caching
--max-num-seqs 16
--tensor-parallel-size 4
--max-model-len 36864
--max-num-batched-tokens 6000
--enable-expert-parallel
--trust-remote-code
--quantization ascend
--gpu-memory-utilization 0.9
--speculative-config '{"num_speculative_tokens": 1, "method":"mtp"}'
--additional-config '{"torchair_graph_config":{"enabled":true,"enable_multistream_moe":true},"chunked_prefill_for_mla":true,"enable_weight_nz_layout":true}'
-
server_cmd: >
vllm serve vllm-ascend/DeepSeek-R1-0528-W8A8
--headless
--data-parallel-size 4
--data-parallel-rpc-port 13399
--data-parallel-size-local 2
--data-parallel-start-rank 2
--data-parallel-address $MASTER_IP
--no-enable-prefix-caching
--max-num-seqs 16
--tensor-parallel-size 4
--max-model-len 36864
--max-num-batched-tokens 6000
--enable-expert-parallel
--trust-remote-code
--quantization ascend
--gpu-memory-utilization 0.9
--speculative-config '{"num_speculative_tokens": 1, "method":"mtp"}'
--additional-config '{"torchair_graph_config":{"enabled":true,"enable_multistream_moe":true},"chunked_prefill_for_mla":true,"enable_weight_nz_layout":true}'
benchmarks:
acc:
case_type: accuracy
dataset_path: vllm-ascend/gsm8k
request_conf: vllm_api_general_chat
dataset_conf: gsm8k/gsm8k_gen_0_shot_cot_chat_prompt
max_out_len: 32768
batch_size: 512
baseline: 95
threshold: 5

View File

@@ -58,7 +58,7 @@ deployment:
}
}'
--additional-config
'{"torchair_graph_config":{"enabled":false,"enable_multistream_shared_expert":false},"enable_prefill_optimizations":true,"enable_weight_nz_layout":true,"dynamic_eplb":true,"num_iterations_eplb_update":2048,"num_wait_worker_iterations":200}'
'{"enable_prefill_optimizations":true,"enable_weight_nz_layout":true,"dynamic_eplb":true,"num_iterations_eplb_update":2048,"num_wait_worker_iterations":200}'
-
server_cmd: >
@@ -96,7 +96,7 @@ deployment:
}
}'
--additional-config
'{"torchair_graph_config":{"enabled":false,"enable_multistream_shared_expert":false},"enable_prefill_optimizations":true,"enable_weight_nz_layout":true,"dynamic_eplb":true,"num_iterations_eplb_update":2048,"num_wait_worker_iterations":200}'
'{"enable_prefill_optimizations":true,"enable_weight_nz_layout":true,"dynamic_eplb":true,"num_iterations_eplb_update":2048,"num_wait_worker_iterations":200}'
-
server_cmd: >
vllm serve vllm-ascend/DeepSeek-R1-0528-W8A8
@@ -135,7 +135,7 @@ deployment:
}
}'
--additional-config
'{"torchair_graph_config":{"enabled":true,"enable_multistream_mla":true,"graph_batch_sizes":[28],"use_cached_graph":true,"enable_super_kernel":false},"multistream_overlap_shared_expert":true,"dynamic_eplb":true,"num_iterations_eplb_update":2048,"num_wait_worker_iterations":200}'
'{"multistream_overlap_shared_expert":true,"dynamic_eplb":true,"num_iterations_eplb_update":2048,"num_wait_worker_iterations":200}'
-
server_cmd: >
vllm serve vllm-ascend/DeepSeek-R1-0528-W8A8
@@ -173,7 +173,7 @@ deployment:
}
}'
--additional-config
'{"torchair_graph_config":{"enabled":true,"enable_multistream_mla":true,"graph_batch_sizes":[28],"use_cached_graph":true,"enable_super_kernel":false},"multistream_overlap_shared_expert":true,"dynamic_eplb":true,"num_iterations_eplb_update":2048,"num_wait_worker_iterations":200}'
'{"multistream_overlap_shared_expert":true,"dynamic_eplb":true,"num_iterations_eplb_update":2048,"num_wait_worker_iterations":200}'
benchmarks:
perf:
case_type: performance

View File

@@ -57,7 +57,7 @@ deployment:
}
}'
--additional-config
'{"torchair_graph_config":{"enabled":false,"enable_multistream_shared_expert":false},"enable_prefill_optimizations":true,"enable_weight_nz_layout":true}'
'{"enable_prefill_optimizations":true,"enable_weight_nz_layout":true}'
-
server_cmd: >
@@ -95,7 +95,7 @@ deployment:
}
}'
--additional-config
'{"torchair_graph_config":{"enabled":false,"enable_multistream_shared_expert":false},"enable_prefill_optimizations":true,"enable_weight_nz_layout":true}'
'{"enable_prefill_optimizations":true,"enable_weight_nz_layout":true}'
-
server_cmd: >
vllm serve vllm-ascend/DeepSeek-R1-0528-W8A8
@@ -134,7 +134,7 @@ deployment:
}
}'
--additional-config
'{"torchair_graph_config":{"enabled":true,"enable_multistream_mla":true,"graph_batch_sizes":[28],"use_cached_graph":true,"enable_super_kernel":false},"multistream_overlap_shared_expert":true}'
'{"multistream_overlap_shared_expert":true}'
-
server_cmd: >
vllm serve vllm-ascend/DeepSeek-R1-0528-W8A8
@@ -172,7 +172,7 @@ deployment:
}
}'
--additional-config
'{"torchair_graph_config":{"enabled":true,"enable_multistream_mla":true,"graph_batch_sizes":[28],"use_cached_graph":true,"enable_super_kernel":false},"multistream_overlap_shared_expert":true}'
'{"multistream_overlap_shared_expert":true}'
benchmarks:
perf:
case_type: performance

View File

@@ -82,7 +82,6 @@ deployment:
--trust-remote-code
--no-enable-prefix-caching
--gpu-memory-utilization 0.9
--additional-config '{"torchair_graph_config":{"enabled":true}}'
--kv-transfer-config
'{"kv_connector": "MooncakeConnector",
"kv_role": "kv_consumer",

View File

@@ -29,7 +29,6 @@ deployment:
--trust-remote-code
--no-enable-prefix-caching
--gpu-memory-utilization 0.9
--additional-config '{"torchair_graph_config":{"enabled":true,"graph_batch_sizes":[16]}}'
-
server_cmd: >
@@ -49,5 +48,4 @@ deployment:
--trust-remote-code
--no-enable-prefix-caching
--gpu-memory-utilization 0.92
--additional-config '{"torchair_graph_config":{"enabled":true,"graph_batch_sizes":[16]}}'
benchmarks:

View File

@@ -1,106 +0,0 @@
from __future__ import annotations
import pytest
from vllm import SamplingParams
from vllm.config import CompilationConfig, CUDAGraphMode
from tests.e2e.conftest import VllmRunner
@pytest.fixture
def sampling_config():
return SamplingParams(temperature=0, max_tokens=256, ignore_eos=False)
@pytest.fixture
def model_name():
return "wemaster/deepseek_mtp_main_random_bf16"
def mtp_torchair_correctness(
sampling_config: SamplingParams,
model_name: str,
graph_mode: CUDAGraphMode = CUDAGraphMode.PIECEWISE,
):
example_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using mtp speculative decoding.
'''
with VllmRunner(model_name,
tensor_parallel_size=1,
gpu_memory_utilization=0.7,
max_model_len=256,
enforce_eager=False,
additional_config={
"torchair_graph_config": {
"enabled": True,
"use_cached_graph": False,
"graph_batch_sizes": [1, 2, 4],
},
"multistream_overlap_shared_expert": "True"
}) as ref_llm:
ref_outputs = ref_llm.generate(example_prompts, sampling_config)
graph_mode_str = "PIECEWISE"
if graph_mode == CUDAGraphMode.FULL:
graph_mode_str = "FULL"
with VllmRunner(model_name,
tensor_parallel_size=1,
max_num_seqs=256,
gpu_memory_utilization=0.7,
distributed_executor_backend="mp",
enable_expert_parallel=True,
speculative_config={
"method": "mtp",
"num_speculative_tokens": 1,
},
enforce_eager=False,
max_model_len=2000,
compilation_config=CompilationConfig(
cudagraph_mode=graph_mode_str),
additional_config={
"torchair_graph_config": {
"enabled": True,
"use_cached_graph": False,
"graph_batch_sizes": [1, 2, 4],
},
"multistream_overlap_shared_expert": "True"
}) as spec_llm:
spec_outputs = spec_llm.generate(example_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
ref_token_ids = ref_output[0][0]
spec_token_ids = spec_output[0][0]
if ref_token_ids == spec_token_ids[:len(ref_token_ids)]:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output[1][0]}")
print(f"spec_output: {spec_output[1][0]}")
# Heuristic: expect at least 66% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs))
def test_mtp_torchair_correctness_piecewise(
sampling_config: SamplingParams,
model_name: str,
):
mtp_torchair_correctness(sampling_config, model_name)
def test_mtp_torchair_correctness_full(
sampling_config: SamplingParams,
model_name: str,
):
mtp_torchair_correctness(sampling_config, model_name, CUDAGraphMode.FULL)

View File

@@ -118,7 +118,6 @@ def mock_dist_env(mocker: MockerFixture):
return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.fused_moe.get_ascend_config',
return_value=MagicMock(
torchair_graph_config=MagicMock(enabled=False),
enable_multistream_moe=False,
expert_map_path=None
)), \

View File

@@ -110,11 +110,6 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding,
mock_custom_enabled,
mock_soc_version, mock__c):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
# Setup mock for custom kernel path
mock__c.rotary_embedding.return_value = self.query, self.key
vllm_config = VllmConfig()
model_config = ModelConfig(MODEL,
@@ -139,9 +134,6 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
def test_rope_forward_oot_contiguous(self, mock_npu_rotary,
mock_custom_enabled):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
# Test contiguous path when custom is disabled
non_contig_query = self.query.transpose(0, 1)
non_contig_key = self.key.transpose(0, 1)
@@ -165,9 +157,6 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
def test_rope_forward_oot_with_offsets(self):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
# Test that NotImplementedError is raised when offsets is provided
offsets = torch.tensor([1, 2, 3])
with self.assertRaises(NotImplementedError):
@@ -190,9 +179,6 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary,
mock_custom_enabled):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
# Test neox_style override
vllm_config = VllmConfig()
model_config = ModelConfig(MODEL,
@@ -219,9 +205,6 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
def test_rope_forward_oot_rotary_dim_less_than_head_size(
self, mock_npu_rotary, mock_custom_enabled):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
# test case when rotary_dim < head_size
org_rotary_dim = self.layer.rotary_dim
self.layer.rotary_dim = self.layer.head_size // 2
@@ -415,7 +398,6 @@ class TestAscendMRotaryEmbedding(unittest.TestCase):
mrope_section=self.mrope_section)
self.mock_config = MagicMock()
self.mock_config.torchair_graph_config.enabled = False
def _create_vllm_config(self):
vllm_config = VllmConfig()

View File

@@ -33,7 +33,6 @@ class TestAscendW8A8FusedMoEMethod(TestBase):
mock_get_ep_group.return_value = mock_ep_group
mock_ascend_config = Mock()
mock_ascend_config.torchair_graph_config = Mock(enabled=False)
mock_ascend_config.enable_chunked_prefill = False
mock_get_ascend_config.return_value = mock_ascend_config
mock_mc2_group = Mock(device_group=0)

View File

@@ -13,15 +13,10 @@
# This file is a part of the vllm-ascend project.
#
import os
from transformers import PretrainedConfig
from vllm.config import ModelConfig, ParallelConfig, VllmConfig
from vllm.config import VllmConfig
from tests.ut.base import TestBase
from vllm_ascend.ascend_config import (_check_torchair_supported,
check_ascend_config,
clear_ascend_config, get_ascend_config,
from vllm_ascend.ascend_config import (clear_ascend_config, get_ascend_config,
init_ascend_config)
@@ -45,17 +40,6 @@ class TestAscendConfig(TestBase):
self.assertIsNone(ascend_config.expert_map_path)
self.assertFalse(ascend_config.multistream_overlap_shared_expert)
torchair_graph_config = ascend_config.torchair_graph_config
self.assertFalse(torchair_graph_config.enabled)
self.assertEqual(torchair_graph_config.mode, '')
self.assertFalse(torchair_graph_config.use_cached_graph)
self.assertEqual(torchair_graph_config.graph_batch_sizes, [])
self.assertFalse(torchair_graph_config.graph_batch_sizes_init)
self.assertFalse(torchair_graph_config.enable_multistream_mla)
self.assertTrue(torchair_graph_config.enable_view_optimize)
self.assertTrue(torchair_graph_config.enable_frozen_parameter)
self.assertFalse(torchair_graph_config.enable_kv_nz)
ascend_compilation_config = ascend_config.ascend_compilation_config
self.assertTrue(ascend_compilation_config.enable_quantization_fusion)
@@ -63,16 +47,6 @@ class TestAscendConfig(TestBase):
def test_init_ascend_config_with_additional_config(self):
test_vllm_config = VllmConfig()
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": True,
"use_cached_graph": True,
"graph_batch_sizes": [1, 2, 4],
"graph_batch_sizes_init": False,
"enable_multistream_mla": True,
"enable_view_optimize": True,
"enable_frozen_parameter": True,
"enable_kv_nz": True
},
"ascend_compilation_config": {
"enable_quantization_fusion": False,
},
@@ -84,65 +58,9 @@ class TestAscendConfig(TestBase):
self.assertEqual(ascend_config.expert_map_path, "test_expert_map_path")
self.assertTrue(ascend_config.multistream_overlap_shared_expert)
torchair_graph_config = ascend_config.torchair_graph_config
self.assertTrue(torchair_graph_config.enabled)
self.assertTrue(torchair_graph_config.use_cached_graph)
self.assertEqual(torchair_graph_config.graph_batch_sizes, [1, 2, 4])
self.assertFalse(torchair_graph_config.graph_batch_sizes_init)
self.assertTrue(torchair_graph_config.enable_multistream_mla)
self.assertTrue(torchair_graph_config.enable_view_optimize)
self.assertTrue(torchair_graph_config.enable_frozen_parameter)
self.assertTrue(torchair_graph_config.enable_kv_nz)
ascend_compilation_config = ascend_config.ascend_compilation_config
self.assertFalse(ascend_compilation_config.enable_quantization_fusion)
@_clean_up_ascend_config
def test_init_ascend_config_with_refresh(self):
test_vllm_config = VllmConfig()
ascend_config = init_ascend_config(test_vllm_config)
self.assertFalse(ascend_config.torchair_graph_config.enabled)
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": True,
},
}
ascend_config = init_ascend_config(test_vllm_config)
self.assertFalse(ascend_config.torchair_graph_config.enabled)
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": True,
},
"refresh": True,
}
ascend_config = init_ascend_config(test_vllm_config)
self.assertTrue(ascend_config.torchair_graph_config.enabled)
@_clean_up_ascend_config
def test_init_ascend_config_with_wrong_input(self):
test_vllm_config = VllmConfig()
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": True,
"graph_batch_sizes": "fake_size",
},
"refresh": True,
}
with self.assertRaises(TypeError):
init_ascend_config(test_vllm_config)
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": False,
"graph_batch_sizes": [1, 2, 4, 8],
"graph_batch_sizes_init": True,
},
"refresh": True,
}
with self.assertRaises(ValueError):
init_ascend_config(test_vllm_config)
@_clean_up_ascend_config
def test_get_ascend_config(self):
test_vllm_config = VllmConfig()
@@ -162,203 +80,3 @@ class TestAscendConfig(TestBase):
clear_ascend_config()
with self.assertRaises(RuntimeError):
get_ascend_config()
@_clean_up_ascend_config
def test_check_ascend_config_pass(self):
test_vllm_config = VllmConfig()
init_ascend_config(test_vllm_config)
check_ascend_config(test_vllm_config, False)
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": True,
},
"refresh": True
}
init_ascend_config(test_vllm_config)
check_ascend_config(test_vllm_config, False)
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": False,
},
"refresh": True
}
init_ascend_config(test_vllm_config)
check_ascend_config(test_vllm_config, False)
@_clean_up_ascend_config
def test_check_ascend_config_wrong_case(self):
test_vllm_config = VllmConfig()
# torchair + eager mode
with self.assertRaises(RuntimeError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": True,
},
"refresh": True
}
init_ascend_config(test_vllm_config)
enforce_eager = True
check_ascend_config(test_vllm_config, enforce_eager)
# torchair + non deepseek model
with self.assertRaises(NotImplementedError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": True,
},
"refresh": True
}
model_path = os.path.join(os.path.dirname(__file__), "fake_weight")
fake_model_config = ModelConfig(model=model_path)
fake_model_config.hf_config = PretrainedConfig()
fake_model_config.hf_config.model_type = "llama"
test_vllm_config.model_config = fake_model_config
init_ascend_config(test_vllm_config)
check_ascend_config(test_vllm_config, False)
def test_check_torchair_supported(self):
test_cases = [('deepseek_v3', True), ('PanguProMoE', True),
('qwen', True), ('llama', False)]
for model_type, expected_output in test_cases:
self.assertEqual(_check_torchair_supported(model_type),
expected_output)
@_clean_up_ascend_config
def test_ascend_config_load_error(self):
test_vllm_config = VllmConfig()
# graph_batch_sizes should be list.
with self.assertRaises(TypeError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"graph_batch_sizes": "fake_size",
},
"refresh": True
}
init_ascend_config(test_vllm_config)
# use_cached_graph should not be enabled without torchair graph mode
with self.assertRaises(RuntimeError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": False,
"use_cached_graph": True,
},
"refresh": True
}
init_ascend_config(test_vllm_config)
# use_cached_kv_cache_bytes should not be enabled without torchair graph mode
with self.assertRaises(RuntimeError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": False,
"use_cached_kv_cache_bytes": True,
},
"refresh": True
}
init_ascend_config(test_vllm_config)
# graph_batch_sizes should not be set without torchair graph mode
with self.assertRaises(RuntimeError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": False,
"graph_batch_sizes": [1, 2, 4],
},
"refresh": True
}
init_ascend_config(test_vllm_config)
# use_cached_kv_cache_bytes is valid only when torchair graph mode and use_cached_graph are enabled
with self.assertRaises(RuntimeError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": True,
"use_cached_graph": False,
"use_cached_kv_cache_bytes": True,
},
"refresh": True
}
init_ascend_config(test_vllm_config)
# graph_batch_sizes_init should not be enabled without torchair graph mode
with self.assertRaises(RuntimeError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": False,
"graph_batch_sizes_init": True,
},
"refresh": True
}
init_ascend_config(test_vllm_config)
# enable_multistream_mla should not be enabled without torchair graph mode
with self.assertRaises(RuntimeError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": False,
"enable_multistream_mla": True,
},
"refresh": True
}
init_ascend_config(test_vllm_config)
# mode should not be configured without torchair graph mode
with self.assertRaises(RuntimeError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": False,
"mode": 'max-autotune',
},
"refresh": True
}
init_ascend_config(test_vllm_config)
# enable_kv_nz should not be enabled without torchair graph mode
with self.assertRaises(RuntimeError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": False,
"enable_kv_nz": True,
},
"refresh": True
}
init_ascend_config(test_vllm_config)
with self.assertRaises(AssertionError):
test_vllm_config.additional_config = {
"lmhead_tensor_parallel_size": 2,
"refresh": True
}
test_vllm_config.parallel_config = ParallelConfig(
data_parallel_size=4, tensor_parallel_size=2)
init_ascend_config(test_vllm_config)
with self.assertRaises(AssertionError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": True,
},
"oproj_tensor_parallel_size": 2,
"refresh": True
}
test_vllm_config.parallel_config = ParallelConfig(
data_parallel_size=4, tensor_parallel_size=2)
init_ascend_config(test_vllm_config)
with self.assertRaises(AssertionError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": False,
},
"oproj_tensor_parallel_size": 2,
"refresh": True
}
test_vllm_config.parallel_config = ParallelConfig(
data_parallel_size=4, tensor_parallel_size=1)
model_path = os.path.join(os.path.dirname(__file__), "fake_weight")
test_vllm_config.model_config = ModelConfig(model=model_path,
enforce_eager=True)
init_ascend_config(test_vllm_config)

View File

@@ -31,7 +31,6 @@ class TestNPUPlatform(TestBase):
@staticmethod
def mock_vllm_ascend_config():
mock_ascend_config = MagicMock()
mock_ascend_config.torchair_graph_config.enabled = False
mock_ascend_config.xlite_graph_config.enabled = False
mock_ascend_config.enable_shared_expert_dp = False
return mock_ascend_config
@@ -403,47 +402,6 @@ class TestNPUPlatform(TestBase):
CUDAGraphMode.NONE,
)
@patch('vllm_ascend.utils.get_ascend_device_type',
return_value=AscendDeviceType._910_93)
@patch("vllm_ascend.utils.update_default_aclgraph_sizes")
@patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config")
@patch(
"vllm_ascend.core.recompute_schedule_config.RecomputeSchedulerConfig.initialize_from_config"
)
def test_check_and_update_config_torchair_enabled_compilation(
self, mock_init_recompute, mock_init_ascend, mock_check_ascend,
mock_update_default, mock_soc_version):
mock_update_default.return_value = MagicMock()
mock_ascend_config = TestNPUPlatform.mock_vllm_ascend_config()
mock_ascend_config.torchair_graph_config.enabled = True
mock_init_ascend.return_value = mock_ascend_config
vllm_config = TestNPUPlatform.mock_vllm_config()
vllm_config.model_config.enforce_eager = False
vllm_config.parallel_config.decode_context_parallel_size = 1
vllm_config.parallel_config.prefill_context_parallel_size = 1
vllm_config.parallel_config.tensor_parallel_size = 1
mock_init_recompute.return_value = MagicMock()
vllm_config.scheduler_config = MagicMock()
vllm_config.compilation_config.mode = CompilationMode.VLLM_COMPILE
with self.assertLogs(logger="vllm", level="INFO") as cm:
from vllm_ascend import platform
importlib.reload(platform)
self.platform.check_and_update_config(vllm_config)
self.assertTrue("Torchair compilation enabled" in cm.output[0])
self.assertEqual(
vllm_config.compilation_config.mode,
CompilationMode.NONE,
)
self.assertEqual(
vllm_config.compilation_config.cudagraph_mode,
CUDAGraphMode.NONE,
)
@patch('vllm_ascend.utils.get_ascend_device_type',
return_value=AscendDeviceType._910_93)
@patch("vllm_ascend.ascend_config.check_ascend_config")
@@ -503,16 +461,6 @@ class TestNPUPlatform(TestBase):
"vllm_ascend.worker.worker_v1.NPUWorker",
)
test_ascend_config = TestNPUPlatform.mock_vllm_ascend_config()
test_ascend_config.torchair_graph_config.enabled = True
mock_init_ascend.return_value = test_ascend_config
vllm_config.parallel_config.worker_cls = "auto"
self.platform.check_and_update_config(vllm_config)
self.assertEqual(
vllm_config.parallel_config.worker_cls,
"vllm_ascend.torchair.torchair_worker.NPUTorchairWorker",
)
test_ascend_config = TestNPUPlatform.mock_vllm_ascend_config()
test_ascend_config.xlite_graph_config.enabled = True
mock_init_ascend.return_value = test_ascend_config
@@ -550,14 +498,7 @@ class TestNPUPlatform(TestBase):
self.platform.check_and_update_config(vllm_config)
self.assertEqual(vllm_config.compilation_config.custom_ops, [])
@patch('vllm_ascend.platform.get_ascend_config')
def test_get_attn_backend_cls_use_v1_and_mla(self, mock_get_ascend_config):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
mock_config.enable_shared_expert_dp = False
mock_get_ascend_config.return_value = mock_config
def test_get_attn_backend_cls_use_v1_and_mla(self):
result = self.platform.get_attn_backend_cls(
selected_backend="ascend",
head_size=64,
@@ -570,56 +511,7 @@ class TestNPUPlatform(TestBase):
self.assertEqual(result,
"vllm_ascend.attention.mla_v1.AscendMLABackend")
@patch('vllm_ascend.platform.get_ascend_config')
def test_get_attn_backend_cls_use_v1_mla_and_torchair(
self, mock_get_ascend_config):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = True
mock_get_ascend_config.return_value = mock_config
result = self.platform.get_attn_backend_cls(
selected_backend="ascend",
head_size=64,
dtype="float16",
kv_cache_dtype="float16",
block_size=64,
#use_sfa=False,
use_mla=True,
)
self.assertEqual(
result,
"vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend")
@patch('vllm_ascend.platform.get_ascend_config')
def test_get_attn_backend_cls_use_v1_and_torchair(self,
mock_get_ascend_config):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = True
mock_get_ascend_config.return_value = mock_config
result = self.platform.get_attn_backend_cls(
selected_backend="ascend",
head_size=64,
dtype="float16",
kv_cache_dtype="float16",
block_size=64,
#use_sfa=False,
use_mla=False,
)
self.assertEqual(
result,
"vllm_ascend.torchair.torchair_attention.AscendAttentionTorchairBackend"
)
@patch('vllm_ascend.platform.get_ascend_config')
def test_get_attn_backend_cls_use_v1_only(self, mock_get_ascend_config):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
mock_get_ascend_config.return_value = mock_config
def test_get_attn_backend_cls_use_v1_only(self):
result = self.platform.get_attn_backend_cls(
selected_backend="ascend",
head_size=64,

View File

@@ -1,61 +0,0 @@
from unittest.mock import Mock
import pytest
from pytest_mock import MockerFixture
from transformers import PretrainedConfig
from vllm.distributed.parallel_state import GroupCoordinator
from tests.ut.base import PytestBase
from vllm_ascend.torchair.models.qwen3_moe import CustomSparseMoeBlock
class TestCustomSparseMoeBlock(PytestBase):
@pytest.fixture
def setup_csmb(self, mocker: MockerFixture):
config = PretrainedConfig(num_experts=64,
hidden_size=2048,
num_experts_per_tok=2,
moe_intermediate_size=1408,
norm_topk_prob=True)
mocker.patch(
'vllm_ascend.torchair.models.qwen3_moe.get_tensor_model_parallel_world_size',
return_value=10)
mocker.patch(
'vllm.model_executor.layers.linear.ReplicatedLinear.__init__',
return_value=None)
mocker.patch(
'vllm_ascend.torchair.ops.torchair_fused_moe.TorchairAscendFusedMoE.__init__',
return_value=None)
tp_group = Mock(spec=GroupCoordinator)
tp_group.rank_in_group = 0
tp_group.world_size = 1
tp_group.device_group = Mock()
dp_group = Mock(spec=GroupCoordinator)
dp_group.rank_in_group = 0
dp_group.world_size = 1
ep_group = Mock(spec=GroupCoordinator)
ep_group.rank_in_group = 0
ep_group.world_size = 1
mocker.patch('vllm_ascend.torchair.models.qwen3_moe.get_tp_group',
return_value=tp_group)
mocker.patch('vllm_ascend.torchair.models.qwen3_moe.get_dp_group',
return_value=dp_group)
mocker.patch('vllm_ascend.torchair.models.qwen3_moe.get_ep_group',
return_value=ep_group)
ascend_config = mocker.MagicMock()
ascend_config.max_num_batched_tokens = 2048
ascend_config.max_model_len = 1024
mocker.patch("vllm_ascend.utils.get_ascend_config",
return_value=ascend_config)
custom_moe_block = CustomSparseMoeBlock(config, None, "")
return custom_moe_block
def test_init(self, mocker: MockerFixture, setup_csmb):
custom_moe_block = setup_csmb
assert isinstance(custom_moe_block, CustomSparseMoeBlock)

View File

@@ -1,206 +0,0 @@
import pytest
import torch
from pytest_mock import MockerFixture
from transformers import PretrainedConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from tests.ut.base import PytestBase
from vllm_ascend.torchair.models.torchair_deepseek_mtp import (
TorchairDeepSeekMTP, TorchairDeepSeekMultiTokenPredictor,
TorchairDeepSeekMultiTokenPredictorLayer)
class TestTorchairDeepSeekMultiTokenPredictorLayer(PytestBase):
@pytest.fixture
def setup_mtp_layer(self, mocker: MockerFixture):
config = PretrainedConfig(vocab_size=1000,
hidden_size=768,
rms_norm_eps=1e-5)
mocker.patch(
'vllm_ascend.torchair.models.torchair_deepseek_mtp.get_tensor_model_parallel_world_size',
return_value=1)
mocker.patch(
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
return_value=None)
mocker.patch("vllm.model_executor.layers.layernorm.RMSNorm.__init__",
return_value=None)
mocker.patch(
"vllm.model_executor.models.deepseek_mtp.SharedHead.__init__",
return_value=None)
mocker.patch(
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekShareHead.__init__",
return_value=None)
mocker_deepseek_v2_decode_layer = mocker.patch(
"vllm_ascend.torchair.models.torchair_deepseek_v2.TorchairDeepseekV2DecoderLayer.__init__",
return_value=None)
mocker.patch(
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
return_value=None)
ascend_config = mocker.MagicMock()
ascend_config.max_num_batched_tokens = 2048
ascend_config.max_model_len = 1024
mocker.patch("vllm_ascend.utils.get_ascend_config",
return_value=ascend_config)
mtp_layer = TorchairDeepSeekMultiTokenPredictorLayer(config, "", None)
mocker_deepseek_v2_decode_layer.assert_called_once()
return mtp_layer
def test_init(self, mocker: MockerFixture, setup_mtp_layer):
mtp_layer = setup_mtp_layer
assert isinstance(mtp_layer, TorchairDeepSeekMultiTokenPredictorLayer)
def test_forward(self, mocker: MockerFixture, setup_mtp_layer):
mtp_layer = setup_mtp_layer
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
mocker.patch.object(mtp_layer,
'eh_proj',
return_value=torch.randn(2, 3, 768))
mocker.patch("torch.cat", return_value=torch.randn(2, 3, 768))
mtp_layer.mtp_block.return_value = (torch.randn(2, 3, 768),
torch.randn(2, 3, 768))
mtp_layer.enorm.return_value = torch.randn(2, 3, 768)
mtp_layer.hnorm.return_value = torch.randn(2, 3, 768)
input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]])
positions = torch.tensor([[0, 1, 2], [0, 1, 2]])
kv_cache = torch.randn(2, 3, 768)
previous_hidden_states = torch.randn(2, 3, 768)
inputs_embeds = torch.tensor([[1.0, 2.0, 3.0]])
output = mtp_layer(input_ids, positions, kv_cache, None,
previous_hidden_states, inputs_embeds, 0)
assert output.shape == (3, 768)
class TestTorchairDeepSeekMultiTokenPredictor(PytestBase):
@pytest.fixture
def setup_predictor(self, mocker: MockerFixture):
mock_vllm_config = mocker.MagicMock(spec=VllmConfig)
mock_model_config = mocker.MagicMock(spec=ModelConfig)
mock_hf_config = mocker.MagicMock()
mock_hf_config.num_hidden_layers = 12
mock_hf_config.num_nextn_predict_layers = 3
mock_hf_config.vocab_size = 30000
mock_model_config.hf_config = mock_hf_config
mock_vllm_config.model_config = mock_model_config
mock_vllm_config.cache_config = CacheConfig()
mock_vllm_config.quant_config = mocker.MagicMock()
mocker.patch(
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
return_value=None)
mocker.patch(
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMultiTokenPredictorLayer.__init__",
return_value=None)
mocker.patch(
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
return_value=None)
ascend_config = mocker.MagicMock()
ascend_config.max_num_batched_tokens = 2048
ascend_config.max_model_len = 1024
mocker.patch("vllm_ascend.utils.get_ascend_config",
return_value=ascend_config)
predictor = TorchairDeepSeekMultiTokenPredictor(
vllm_config=mock_vllm_config)
return predictor
def test_init(self, mocker: MockerFixture, setup_predictor):
predictor = setup_predictor
assert predictor.num_mtp_layers == 3
assert isinstance(predictor, TorchairDeepSeekMultiTokenPredictor)
@pytest.mark.parametrize(
'kv_caches, inputs_embeds',
[(torch.tensor([[[0.1, 0.2, 0.3]]]), torch.tensor([[0.1, 0.2, 0.3]]))])
def test_forward(self, mocker: MockerFixture, setup_predictor, kv_caches,
inputs_embeds):
predictor = setup_predictor
mock_layer = mocker.MagicMock()
mock_layer.return_value = torch.tensor([1.0, 2.0, 3.0])
predictor.layers_list = [mock_layer]
# todo: need or not?
# predictor.num_mtp_layers = 1
input_ids = torch.tensor([[1, 2, 3]])
positions = torch.tensor([[0, 1, 2]])
mocker.patch(
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMultiTokenPredictorLayer.__call__",
return_value=torch.tensor([[1.0, 2.0, 3.0]]))
output = predictor.forward(input_ids, positions, kv_caches, None, None,
inputs_embeds, 0)
mock_layer.assert_called_once()
assert torch.allclose(output, torch.tensor([1.0, 2.0, 3.0]))
def test_compute_logits(self, mocker: MockerFixture, setup_predictor):
hidden_states = torch.tensor([[1, 2, 3], [4, 5, 6]])
predictor = setup_predictor
mock_layer = mocker.MagicMock()
mock_layer.return_value = torch.tensor([1.0, 2.0, 3.0])
predictor.layers_list = [mock_layer]
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
mocker.patch(
"vllm.model_executor.layers.logits_processor.LogitsProcessor.__init__",
return_value=None)
predictor.logits_processor.return_value = torch.tensor([1.0, 2.0, 3.0])
result_logits = predictor.compute_logits(hidden_states=hidden_states)
predictor.logits_processor.assert_called_once()
assert torch.allclose(result_logits, torch.tensor([1.0, 2.0, 3.0]))
class TestTorchairDeepSeekMTP(PytestBase):
@pytest.fixture
def setup_mtp(self, mocker: MockerFixture):
vllm_config = mocker.MagicMock()
vllm_config.model_config.hf_config.num_hidden_layers = 12
vllm_config.model_config.hf_config.num_nextn_predict_layers = 3
vllm_config.cache_config = mocker.MagicMock()
vllm_config.quant_config = mocker.MagicMock()
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
mocker.patch(
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
return_value=None)
mocker.patch(
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMultiTokenPredictorLayer.__call__",
return_value=None)
mocker.patch(
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
return_value=None)
ascend_config = mocker.MagicMock()
ascend_config.max_num_batched_tokens = 2048
ascend_config.max_model_len = 1024
mocker.patch("vllm_ascend.utils.get_ascend_config",
return_value=ascend_config)
mtp = TorchairDeepSeekMTP(vllm_config=vllm_config)
return mtp
def test_init(self, mocker: MockerFixture, setup_mtp):
mtp = setup_mtp
assert isinstance(mtp, TorchairDeepSeekMTP)
def test_forward(self, mocker: MockerFixture, setup_mtp):
input_ids = torch.tensor([[1, 2, 3]])
positions = torch.tensor([[0, 1, 2]])
kv_caches = [torch.tensor([[0.1, 0.2, 0.3]])]
previous_hidden_states = torch.tensor([[0.1, 0.2, 0.3]])
inputs_embeds = torch.tensor([[0.1, 0.2, 0.3]])
spec_step_idx = 0
setup_mtp.model.return_value = torch.tensor([[1.0, 2.0, 3.0]])
output = setup_mtp.forward(input_ids, positions, kv_caches, None,
previous_hidden_states, inputs_embeds,
spec_step_idx)
assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]]))

View File

@@ -1,366 +0,0 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from types import SimpleNamespace
from unittest.mock import MagicMock, Mock, patch
import pytest
import torch
from transformers import PretrainedConfig
from vllm.config import CacheConfig
from vllm.distributed.parallel_state import GroupCoordinator
from vllm.transformers_utils.config import patch_rope_parameters
from vllm_ascend.torchair.models.torchair_deepseek_v2 import (
TorchairDeepseekV2DecoderLayer, TorchairDeepseekV2ForCausalLM,
TorchairDeepseekV2MergedReplicatedLinear, TorchairDeepseekV2MLAAttention,
TorchairDeepseekV2MLP, TorchairDeepseekV2MoE,
TorchairDeepseekV2RowParallelLinear,
TorchairDeepseekV2RowParallelLinearReplaceAllreduce,
TorchairDeepseekV2SiluAndMul)
@pytest.fixture
def base_config():
config = PretrainedConfig(
hidden_size=128,
num_attention_heads=8,
num_hidden_layers=2,
intermediate_size=256,
hidden_act="silu",
rms_norm_eps=1e-6,
rope_theta=10000.0,
max_position_embeddings=2048,
n_routed_experts=4,
n_shared_experts=1,
moe_intermediate_size=256,
num_experts_per_tok=2,
routed_scaling_factor=1.0,
first_k_dense_replace=0,
moe_layer_freq=1,
kv_lora_rank=16,
qk_nope_head_dim=16,
qk_rope_head_dim=16,
v_head_dim=32,
topk_method="noaux_tc",
scoring_func="softmax",
norm_topk_prob=True,
n_group=1,
topk_group=1,
vocab_size=10000,
)
patch_rope_parameters(config)
return config
@pytest.fixture
def vllm_config(base_config):
model_config = SimpleNamespace(
hf_config=base_config,
tensor_parallel_size=1,
dtype=torch.float32,
use_mla=False,
quant_config=None,
max_model_len=2048,
)
cache_config = CacheConfig()
vllm_config = Mock()
vllm_config.model_config = model_config
vllm_config.cache_config = cache_config
vllm_config.quant_config = None
return vllm_config
@pytest.fixture
def mock_distributed():
tp_group = Mock(spec=GroupCoordinator)
tp_group.rank_in_group = 0
tp_group.world_size = 1
tp_group.device_group = Mock()
dp_group = Mock(spec=GroupCoordinator)
dp_group.rank_in_group = 0
dp_group.world_size = 1
ep_group = Mock(spec=GroupCoordinator)
ep_group.rank_in_group = 0
ep_group.world_size = 1
pp_group = Mock(spec=GroupCoordinator)
pp_group.rank_in_group = 0
pp_group.world_size = 1
dcp_group = MagicMock(spec=GroupCoordinator)
dcp_group.rank_in_group = 0
dcp_group.world_size = 1
dcp_group.device_group = MagicMock()
mlp_tp_group = Mock(spec=GroupCoordinator)
mlp_tp_group.rank_in_group = 0
mlp_tp_group.world_size = 1
mlp_tp_group.all_gather = Mock(return_value=torch.randn(2, 4, 128))
mock_vllm_config = Mock()
mock_vllm_config.scheduler_config = Mock(max_num_seqs=256)
mock_vllm_config.model_config = Mock(max_model_len=2048, quant_config=None)
with patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_tensor_model_parallel_rank", return_value=0), \
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_tensor_model_parallel_world_size", return_value=1), \
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_tp_group", return_value=tp_group), \
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_ep_group", return_value=ep_group), \
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_dp_group", return_value=dp_group), \
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_pp_group", return_value=pp_group), \
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_pp_group",
return_value=Mock(is_first_rank=False, is_last_rank=False)), \
patch('vllm.distributed.parallel_state.get_dcp_group', return_value=dcp_group), \
patch('vllm.distributed.parallel_state._DCP', new_callable=lambda: MagicMock(spec=GroupCoordinator)), \
patch("vllm.distributed.get_decode_context_model_parallel_world_size", return_value=1),\
patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,
_PP=pp_group), \
patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group):
yield
@pytest.fixture
def mock_forward_context():
forward_context = Mock(in_profile_run=False, with_prefill=False)
with patch(
"vllm_ascend.torchair.models.torchair_deepseek_v2.get_forward_context",
return_value=forward_context):
yield
@pytest.fixture
def patch_attention_init():
try:
from vllm_ascend.torchair.models.torchair_deepseek_v2 import \
DeepseekV2Attention
original_init = DeepseekV2Attention.__init__
def patched_init(self, *args, **kwargs):
kwargs.pop("decoder_layer", None)
if 'vllm_config' not in kwargs:
mock_vllm_config = Mock()
mock_vllm_config.model_config = Mock()
mock_vllm_config.model_config.hf_config = Mock()
mock_vllm_config.model_config.hf_config.hidden_size = 128
mock_vllm_config.model_config.dtype = torch.float32
mock_vllm_config.model_config.quant_config = None
mock_vllm_config.cache_config = CacheConfig()
kwargs['vllm_config'] = mock_vllm_config
return original_init(self, *args, **kwargs)
DeepseekV2Attention.__init__ = patched_init
yield
DeepseekV2Attention.__init__ = original_init
except ImportError:
yield
def test_torchair_deepseek_v2_silu_and_mul():
torch.set_default_device("cpu")
silu = TorchairDeepseekV2SiluAndMul()
assert silu.weight_scale is None
x = torch.randn(2, 4)
output = silu.forward_oot(x)
assert output.shape == (2, 2)
weight_scale = Mock(return_value=torch.tensor(0.1))
silu = TorchairDeepseekV2SiluAndMul(weight_scale=weight_scale)
quant_x = torch.randint(-128, 127, (2, 4), dtype=torch.int32)
dynamic_scale = torch.randn(2, 1)
with patch("torch_npu.npu_dequant_swiglu_quant",
return_value=torch.randn(2, 4)):
output = silu.forward_oot((quant_x, dynamic_scale))
assert output.shape == (2, 4)
def test_torchair_deepseek_v2_merged_replicated_linear(mock_distributed):
linear = TorchairDeepseekV2MergedReplicatedLinear(input_size=128,
output_sizes=[64, 64],
bias=False,
quant_config=None)
assert linear.output_sizes == [64, 64]
param = Mock()
param.data = torch.zeros(128, 128)
param.output_dim = 1
param.is_gguf_weight = False
param.is_gguf_weight_type = False
loaded_weight = torch.randn(128, 64)
linear.weight_loader(param, loaded_weight, loaded_shard_id=0)
with pytest.raises(AssertionError):
linear.weight_loader(param, torch.randn(128, 32), loaded_shard_id=0)
@pytest.mark.parametrize("cls", [
TorchairDeepseekV2RowParallelLinearReplaceAllreduce,
TorchairDeepseekV2RowParallelLinear
])
def test_row_parallel_linear(cls, mock_distributed, mock_forward_context):
linear = cls(input_size=128, output_size=64, bias=False, quant_config=None)
linear.quant_method = Mock()
linear.quant_method.apply.return_value = torch.randn(2, 4, 64)
input_ = torch.randn(2, 4, 128)
with patch(
"vllm_ascend.torchair.models.torchair_deepseek_v2.split_tensor_along_last_dim",
return_value=[torch.randn(2, 4, 64)]):
linear.input_is_parallel = False
output = linear(input_, is_prefill=True)
assert output[0].shape == (2, 4, 64)
linear.input_is_parallel = True
output = linear(input_, is_prefill=False)
assert output[0].shape == (2, 4, 64)
def test_torchair_deepseek_v2_mlp(mock_distributed, base_config):
mlp = TorchairDeepseekV2MLP(hidden_size=128,
intermediate_size=256,
hidden_act="silu",
quant_config=None)
assert isinstance(mlp.act_fn, TorchairDeepseekV2SiluAndMul)
with patch(
"vllm_ascend.torchair.models.torchair_deepseek_v2.QuantizationConfig"
) as mock_quant_config:
mock_quant_config.name = "w8a8dynamic"
with pytest.raises(NotImplementedError):
TorchairDeepseekV2MLP(hidden_size=128,
intermediate_size=256,
hidden_act="silu",
quant_config=mock_quant_config,
force_replicate=False)
with pytest.raises(ValueError):
TorchairDeepseekV2MLP(hidden_size=128,
intermediate_size=256,
hidden_act="relu",
quant_config=None)
def test_torchair_deepseek_v2_moe(mock_distributed, base_config,
mock_forward_context):
base_config.n_shared_experts = 1
moe = TorchairDeepseekV2MoE(config=base_config,
quant_config=None,
prefix="mlp")
assert moe.top_k == 2
x = torch.randn(2, 4, 128)
attn_metadata = Mock(num_prefills=1)
with patch(
"vllm_ascend.torchair.ops.torchair_fused_moe.TorchairAscendFusedMoE.__call__",
return_value=(torch.randn(2, 4, 128), torch.randn(2, 4, 128))):
output = moe(x, attn_metadata)
assert output.shape == (2, 4, 128)
@patch("torch_npu.npu_rms_norm")
def test_torchair_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
base_config):
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
attn = TorchairDeepseekV2MLAAttention(config=base_config,
hidden_size=128,
num_heads=8,
qk_nope_head_dim=16,
qk_rope_head_dim=16,
v_head_dim=32,
q_lora_rank=16,
kv_lora_rank=16,
cache_config=CacheConfig(),
quant_config=None,
prefix="layers.0.self_attn")
assert attn.debug_layer_idx == 0
x = torch.randn(2, 4, 128)
positions = torch.arange(4).repeat(2, 1)
with patch.object(attn.mla_attn,
"__call__",
return_value=torch.randn(2, 4, 128)):
with pytest.raises(AssertionError):
attn(positions, x)
attn = TorchairDeepseekV2MLAAttention(config=base_config,
hidden_size=128,
num_heads=8,
qk_nope_head_dim=16,
qk_rope_head_dim=16,
v_head_dim=32,
q_lora_rank=None,
kv_lora_rank=16,
prefix="layers.1.self_attn")
assert hasattr(attn, "q_proj")
@patch("torch_npu.npu_add_rms_norm")
@patch("torch_npu.npu_rms_norm")
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
def test_torchair_deepseek_v2_decoder_layer(mock_maybe_wait_prefetch_done,
mock_rms_norm, mock_add_norm,
mock_distributed, base_config,
vllm_config, mock_forward_context,
patch_attention_init):
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
mock_add_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128),
torch.randn(2, 128))
base_config.n_routed_experts = 4
layer = TorchairDeepseekV2DecoderLayer(
config=base_config,
prefix="layers.0",
model_config=vllm_config.model_config,
cache_config=CacheConfig(),
quant_config=None)
assert isinstance(layer.mlp, TorchairDeepseekV2MoE)
x = torch.randn(2, 4, 128)
positions = torch.arange(4).repeat(2, 1)
with patch.object(layer.self_attn, "forward", Mock(return_value=torch.randn(2, 4, 128))), \
patch.object(layer.mlp, "forward", Mock(return_value=torch.randn(2, 4, 128))):
hidden_states, residual = layer(positions, x, None)
assert hidden_states.shape == (2, 4, 128)
base_config.n_routed_experts = None
layer = TorchairDeepseekV2DecoderLayer(
config=base_config,
prefix="layers.0",
model_config=vllm_config.model_config,
quant_config=None)
assert isinstance(layer.mlp, TorchairDeepseekV2MLP)
def test_torchair_deepseek_v2_for_causal_lm(mock_distributed, vllm_config,
patch_attention_init):
model = TorchairDeepseekV2ForCausalLM(vllm_config=vllm_config)
input_ids = torch.randint(0, 10000, (2, 4))
positions = torch.arange(4).repeat(2, 1)
with patch.object(model.model,
"forward",
return_value=torch.randn(2, 4, 128)):
output = model(input_ids, positions)
assert output.shape == (2, 4, 128)
weights = [("model.embed_tokens.weight", torch.randn(10000, 128))]
with patch(
"vllm.model_executor.model_loader.weight_utils.default_weight_loader"
):
loaded = model.load_weights(weights)
assert loaded is not None

View File

@@ -1,423 +0,0 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from typing import List, TypedDict
from unittest.mock import MagicMock, patch
import pytest
import torch
import torch.nn as nn
import torch_npu
from pytest_mock import MockerFixture
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
import vllm_ascend
from vllm_ascend.ascend_forward_context import get_fused_moe_state
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
from vllm_ascend.torchair.ops.torchair_fused_moe import (
TorchairAscendFusedMoE, TorchairAscendUnquantizedFusedMoEMethod)
from vllm_ascend.utils import adapt_patch # noqa E402
from vllm_ascend.utils import AscendDeviceType
adapt_patch(True)
def mock_ep_and_mc2_group(mocker):
mock_group = mocker.MagicMock()
mock_group.rank_in_group = 0
mock_group.rank = 0
mock_group.world_size = 4
mock_group.device_group = "mock_group_ep"
mock_group.all_to_all = MagicMock(return_value=torch.randn(8, 8))
return mock_group
def mock_dp_and_tp_group(mocker):
mock_group = mocker.MagicMock()
mock_group.rank_in_group = 0
mock_group.world_size = 2
mock_group.device_group = "mock_group"
mock_group.all_gather = MagicMock(return_value=torch.randn(10, 32))
return mock_group
@pytest.fixture
def mock_dist_env(mocker: MockerFixture):
# init dist env patch
dp_metadata = MagicMock(num_tokens_across_dp_cpu=[5, 5])
with patch('torch.npu.is_available', return_value=True), \
patch('torch.distributed.get_rank', return_value=0), \
patch('torch.distributed.get_world_size', return_value=4), \
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('torch.distributed.all_gather', return_value=MagicMock(return_value=torch.randn(10,32))), \
patch('torch.distributed.all_to_all_single', return_value=torch.randn(8, 32)), \
patch('vllm_ascend.torchair.ops.torchair_fused_moe.tensor_model_parallel_all_reduce',
return_value=torch.randn(5, 32)), \
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_ascend_config',
return_value=MagicMock(
torchair_graph_config=MagicMock(enabled=False),
enable_multistream_moe=False,
enable_shared_expert_dp=False,
expert_map_path=None,
init_redundancy_expert=2,
)), \
patch('vllm_ascend.torchair.ops.torchair_fused_moe.determine_expert_map',
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context',
return_value=MagicMock(
max_tokens_across_dp=10,
dp_metadata=dp_metadata,
)), \
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_current_vllm_config',
return_value=MagicMock(
parallel_config=MagicMock(tensor_parallel_size=2),
scheduler_config=MagicMock(max_num_seqs=4),
model_config=MagicMock(max_model_len=2048)
)):
yield
@pytest.fixture
def mock_moe_env(mocker: MockerFixture):
# init moe env patch
with patch('torch_npu.npu_moe_gating_top_k', return_value=(
torch.randn(8, 2),
torch.randint(0, 8, (8, 2)),
None
)), \
patch('torch_npu.npu_moe_init_routing', return_value=(
torch.randn(8, 2),
torch.randint(0, 8, (8, 2)),
torch.tensor([0, 1, 2, 4, 6, 2, 7, 1])
)), \
patch("torch_npu.npu_moe_compute_expert_tokens", return_value=(
torch.randn(8, 2)
)), \
patch("torch_npu.npu_moe_distribute_dispatch", return_value=(
torch.randn(16, 2)
)), \
patch("torch_npu.npu_moe_distribute_combine", return_value=(
torch.randn(16, 2)
)), \
patch("torch_npu.npu_grouped_matmul", return_value=(
[torch.randn(16, 2)]
)), \
patch("torch_npu.npu_swiglu", return_value=(
torch.randn(16, 2)
)), \
patch("torch_npu.npu_moe_gating_top_k_softmax", return_value=(
torch.randn(8, 2),
torch.randint(0, 8, (8, 2)),
torch.tensor([0, 1, 2, 4, 6, 2, 7, 1])
)), \
patch("torch_npu.npu_moe_finalize_routing", return_value=(
torch.randn(16, 2)
)):
if hasattr(torch_npu, 'npu_moe_distribute_dispatch_v2'):
with patch("torch_npu.npu_moe_distribute_dispatch_v2", return_value=(
torch.randn(16, 2))), \
patch("torch_npu.npu_moe_distribute_combine_v2", return_value=(
torch.randn(16, 2))):
yield
else:
yield
@pytest.fixture
def default_moe_config():
"""default moe config"""
return {
'num_experts': 8,
'top_k': 2,
'hidden_size': 512,
'intermediate_size': 1024
}
@pytest.fixture
def moe_method(mock_dist_env):
moe = MagicMock()
moe.moe_parallel_config.return_value = MagicMock(ep_size=4)
moe.moe_parallel_config.use_ep = False
moe.moe_parallel_config.dp_size = 1
return TorchairAscendUnquantizedFusedMoEMethod(moe)
class Device(TypedDict):
device_id: int
device_expert: List[int]
class Layer(TypedDict):
layer_id: int
device_count: int
device_list: List[Device]
class MockData(TypedDict):
moe_layer_count: int
layer_list: List[Layer]
class MockQuantMethod(nn.Module):
def __init__(self, shared_experts, num_tokens):
super().__init__()
if shared_experts:
self.apply = MagicMock(return_value=(torch.randn(num_tokens, 32),
torch.randn(num_tokens, 10)))
else:
self.apply = MagicMock(return_value=(torch.randn(num_tokens, 32)))
class MockFusedMoEMethod(FusedMoEMethodBase):
moe = MagicMock()
def __init__(self):
super().__init__(self.moe)
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
pass
def apply(self, hidden_states: torch.Tensor,
expert_weights: torch.Tensor) -> torch.Tensor:
pass
def get_fused_moe_quant_config(self, layer: torch.nn.Module):
pass
class TestTorchairAscendFusedMoe:
@pytest.fixture
def test_init_no_quant(self, mock_dist_env, default_moe_config):
layer = TorchairAscendFusedMoE(**default_moe_config)
layer.w13_weight = nn.Parameter(
torch.randn(default_moe_config['num_experts'],
default_moe_config['intermediate_size'] * 2,
default_moe_config['hidden_size']))
layer.w2_weight = nn.Parameter(
torch.randn(default_moe_config['num_experts'],
default_moe_config['hidden_size'],
default_moe_config['intermediate_size']))
assert layer.num_experts == default_moe_config['num_experts']
assert layer.top_k == default_moe_config['top_k']
assert hasattr(layer, 'w13_weight')
assert hasattr(layer, 'w2_weight')
# check group_topk
with pytest.raises(AssertionError):
error_config = default_moe_config.copy()
error_config['use_grouped_topk'] = True
layer = TorchairAscendFusedMoE(**error_config)
# check scoring_func
with pytest.raises(ValueError):
error_config = default_moe_config.copy()
error_config['scoring_func'] = "random"
layer = TorchairAscendFusedMoE(**error_config)
@pytest.fixture
def test_init_with_quant(self, mock_dist_env, default_moe_config):
mock_quant_config = MagicMock()
mock_quant_method = MockFusedMoEMethod()
mock_quant_config.get_quant_method.return_value = mock_quant_method
mock_quant_config.is_layer_skipped_ascend.return_value = False
with patch("vllm_ascend.quantization.quant_config.get_quant_method"):
moe = TorchairAscendFusedMoE(**default_moe_config,
quant_config=mock_quant_config)
assert moe.quant_method is not None
assert isinstance(moe.quant_method, AscendFusedMoEMethod)
@pytest.fixture
def test_init_with_mixed_quant(self, mock_dist_env, default_moe_config):
mock_quant_config = MagicMock()
mock_quant_method = MockFusedMoEMethod()
mock_quant_config.get_quant_method.return_value = mock_quant_method
mock_quant_config.is_layer_skipped_ascend.return_value = True
moe = TorchairAscendFusedMoE(**default_moe_config,
quant_config=mock_quant_config)
assert moe.quant_method is not None
assert isinstance(moe.quant_method,
TorchairAscendUnquantizedFusedMoEMethod)
@pytest.fixture
@pytest.mark.parametrize(
"others_param",
[[None,
MagicMock(return_value=torch.randn(5, 32)), False, 5, None],
[2, None, False, 5, None], [None, None, True, 5, None],
[None, None, False, 1, None], [None, None, True, 5, 1],
[None, None, False, 5, 1]])
def test_forward(self, mock_dist_env, default_moe_config, others_param):
"""
1 test has shared_experts
2 test has top_k
3 test is_prefill is true
4 test single num_tokens(decode)
5 test ep_size is 1 and is_prefill is true
6 test ep_size is 1 and is_prefill is False
"""
top_k, shared_experts, is_prefill, num_tokens, ep_size = others_param
inputs = torch.randn(num_tokens, 32)
router_logits = torch.randn(num_tokens, 8)
moe = TorchairAscendFusedMoE(**default_moe_config)
if ep_size == 1:
moe.moe_parallel_config.ep_size = 1
moe.quant_method = MockQuantMethod(shared_experts, num_tokens)
forward_context = MagicMock(mc2_mask=torch.zeros(num_tokens,
dtype=torch.bool),
padded_num_tokens=num_tokens)
with patch(
"vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context",
return_value=forward_context):
output = moe.forward(inputs,
router_logits,
is_prefill=is_prefill,
top_k=top_k,
shared_experts=shared_experts)
moe.quant_method.apply.assert_called_once()
if shared_experts:
assert output[0].shape == (num_tokens, 32)
assert output[1].shape == (num_tokens, 10)
else:
assert output.shape == (num_tokens, 32)
@pytest.fixture
def test_forward_ms_fused_moe_comp(self, mock_dist_env,
default_moe_config):
inputs = torch.randn(5, 32)
router_logits = torch.randn(5, 8)
moe = TorchairAscendFusedMoE(**default_moe_config)
moe.quant_method = MockQuantMethod(None, 5)
output = moe._forward_ms_fused_moe_comp(inputs,
router_logits,
is_prefill=False,
real_top_k=1)
moe.quant_method.apply.assert_called_once()
assert output.shape == (5, 32)
class TestTorchairAscendUnquantizedFusedMoEMethod:
def test_process_weights_after_loading(self, moe_method, mock_dist_env):
layer = MagicMock()
layer.w13_weight.data = torch.randn(16, 32)
layer.w2_weight.data = torch.randn(16, 32)
moe_method.process_weights_after_loading(layer)
assert isinstance(layer.w13_weight, torch.nn.Parameter)
assert isinstance(layer.w2_weight, torch.nn.Parameter)
assert not layer.w13_weight.requires_grad
assert not layer.w2_weight.requires_grad
@pytest.mark.parametrize("others_param",
[[256, 4], [128, 1], [128, 1], [128, 4]])
def test_apply_without_expert_map(self, moe_method, mock_dist_env,
mock_moe_env, others_param):
"""
1 test is_deepseek_v3_r1=true and use fused_experts_with_all2all
2 test use_select_experts and fused_experts
3 test use select_gating_topk_softmax_experts and fused_experts
4 test use select_experts and fused_experts_with_all2all_buffer
"""
global_num_experts, ep_size = others_param
is_prefill = False
global_redundant_expert_num = vllm_ascend.torchair.ops.torchair_fused_moe.get_ascend_config(
).init_redundancy_expert
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
forward_context = MagicMock(fused_moe_state=get_fused_moe_state(
ep_size, is_prefill, is_deepseek_v3_r1))
with patch(
"vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context",
return_value=forward_context):
moe_method.ep_size = ep_size
x = torch.randn(8, 2, 2)
router_logits = torch.randn(8, 8)
layer = MagicMock()
layer.w13_weight = torch.randn(8, 16, 1)
layer.w2_weight = torch.randn(16, 8, 1)
result = moe_method.apply(layer=layer,
x=x,
router_logits=router_logits,
top_k=2,
renormalize=True,
global_num_experts=global_num_experts,
is_prefill=is_prefill)
if ep_size == 1:
assert result.shape == (16, 2)
else:
assert result.shape == x.shape
@pytest.mark.parametrize("others_param", [16, 1, 4])
def test_apply_with_expert_map(self, moe_method, mock_dist_env,
mock_moe_env, others_param):
"""
1 test use_select_experts and use fused_expters_with_mc2
2 test use_select_experts and fused_experts_with_all2all_buffer
3 test use_select_experts and fused_experts_with_all2all
4 test use_select_experts and fused_experts
"""
ep_size = others_param
is_prefill = False
forward_context = MagicMock(
fused_moe_state=get_fused_moe_state(ep_size, is_prefill, True))
with patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context", return_value=forward_context), \
patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_ascend_device_type", return_value=AscendDeviceType._910_93):
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
moe_method.ep_size = ep_size
x = torch.randn(8, 2, 2)
if ep_size == 1:
x = x.view(-1, 2)
router_logits = torch.randn(8, 8)
layer = MagicMock()
layer.w13_weight = torch.randn(8, 16, 1)
layer.w2_weight = torch.randn(16, 8, 1)
result = moe_method.apply(layer=layer,
x=x,
router_logits=router_logits,
top_k=2,
renormalize=True,
global_num_experts=128,
expert_map=expert_map,
is_prefill=is_prefill)
if ep_size == 16 or ep_size == 1:
assert result.shape == (16, 2)
else:
assert result.shape == x.shape

View File

@@ -1,333 +0,0 @@
import math
from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.torchair.ops.torchair_rotary_embedding import (
_set_cos_sin_cache, custom_rotary_embedding_enabled,
native_rope_deepseek_forward, rope_forward_oot, rotate_half,
yarn_find_correction_dim, yarn_get_mscale)
from vllm_ascend.utils import AscendDeviceType
class TestCustomRotaryEmbeddingEnabled(TestBase):
def setUp(self):
# Common setup for tests
self.positions = torch.tensor([1, 2, 3])
self.query = torch.randn(3, 4, dtype=torch.float16)
self.key = torch.randn(3, 4, dtype=torch.float16)
self.head_size = 32
self.cos_sin_cache = torch.randn(3, 4)
# Mock self object for rope_forward_oot
self.mock_self = MagicMock()
self.mock_self.head_size = self.head_size
self.mock_self.cos_sin_cache = self.cos_sin_cache
self.mock_self.is_neox_style = True
self.mock_self.forward_native.return_value = (self.query, self.key)
def test_custom_rotary_embedding_enabled(self):
# Test when all conditions are True
with patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op',
return_value=True):
result = custom_rotary_embedding_enabled(self.query, True,
self.head_size)
self.assertTrue(result)
# Test when dtype is not float16
with patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op',
return_value=True):
query = self.query.to(torch.float32)
result = custom_rotary_embedding_enabled(query, True,
self.head_size)
self.assertFalse(result)
# Test when neox_style is False
with patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op',
return_value=True):
result = custom_rotary_embedding_enabled(self.query, False,
self.head_size)
self.assertFalse(result)
# Test when head_size is not divisible by 32
with patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op',
return_value=True):
result = custom_rotary_embedding_enabled(self.query, True,
self.head_size + 1)
self.assertFalse(result)
# Test when custom op is disabled
with patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op',
return_value=False):
result = custom_rotary_embedding_enabled(self.query, True,
self.head_size)
self.assertFalse(result)
class TestRopeForwardOot(TestBase):
def setUp(self):
# Common setup for tests
self.positions = torch.tensor([1, 2, 3])
self.query = torch.randn(3, 4, dtype=torch.float16)
self.key = torch.randn(3, 4, dtype=torch.float16)
self.head_size = 32
self.cos_sin_cache = torch.randn(3, 4)
# Mock self object for rope_forward_oot
self.mock_self = MagicMock()
self.mock_self.head_size = self.head_size
self.mock_self.cos_sin_cache = self.cos_sin_cache
self.mock_self.is_neox_style = True
self.mock_self.forward_native.return_value = (self.query, self.key)
@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
def test_rope_forward_oot_torchair_enabled_base(self,
mock_get_ascend_config):
# Setup mock for torchair enabled
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = True
mock_get_ascend_config.return_value = mock_config
result_q, result_k = rope_forward_oot(self.mock_self, self.positions,
self.query, self.key)
self.mock_self.forward_native.assert_called_once_with(
self.positions, self.query, self.key, None)
self.assertTrue(torch.equal(result_q, self.query))
self.assertTrue(torch.equal(result_k, self.key))
@patch('torch.ops._C_ascend')
@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
@patch('vllm_ascend.utils.get_ascend_device_type',
return_value=AscendDeviceType._910_93)
@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.custom_rotary_embedding_enabled',
return_value=True)
@patch('torch.ops._npu_rotary_embedding')
def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding,
mock_custom_enabled,
mock_soc_version,
mock_get_ascend_config, mock__c):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
mock_get_ascend_config.return_value = mock_config
# Setup mock for custom kernel path
mock__c.rotary_embedding.return_value = self.query, self.key
result_q, result_k = rope_forward_oot(self.mock_self, self.positions,
self.query, self.key)
self.assertEqual(result_q.shape, self.query.shape)
self.assertEqual(result_k.shape, self.key.shape)
@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.custom_rotary_embedding_enabled',
return_value=False)
@patch('torch_npu._npu_rotary_embedding')
def test_rope_forward_oot_contiguous(self, mock_npu_rotary,
mock_custom_enabled,
mock_get_ascend_config):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
mock_get_ascend_config.return_value = mock_config
# Test contiguous path when custom is disabled
non_contig_query = self.query.transpose(0, 1)
non_contig_key = self.key.transpose(0, 1)
result_q, result_k = rope_forward_oot(self.mock_self, self.positions,
non_contig_query, non_contig_key)
mock_npu_rotary.assert_called_once()
self.assertEqual(result_q.shape, non_contig_query.shape)
self.assertEqual(result_k.shape, non_contig_key.shape)
@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
def test_rope_forward_oot_with_offsets(self, mock_get_ascend_config):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
mock_get_ascend_config.return_value = mock_config
# Test that NotImplementedError is raised when offsets is provided
offsets = torch.tensor([1, 2, 3])
with self.assertRaises(NotImplementedError):
rope_forward_oot(self.mock_self, self.positions, self.query,
self.key, offsets)
@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.custom_rotary_embedding_enabled',
return_value=False)
@patch('torch_npu._npu_rotary_embedding')
def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary,
mock_custom_enabled,
mock_get_ascend_config):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
mock_get_ascend_config.return_value = mock_config
# Test neox_style override
result_q, result_k = rope_forward_oot(self.mock_self,
self.positions,
self.query,
self.key,
is_neox_style_override=False)
# Check that neox_style=False was passed to the NPU function
args, kwargs = mock_npu_rotary.call_args
self.assertFalse(args[-1])
class MockRopeModule:
def __init__(self, max_seq_len=2048, is_neox_style=True):
self.max_seq_len = max_seq_len
self.is_neox_style = is_neox_style
self.cos_cached = None
self.sin_cached = None
self.rotary_dim = 1
self.base = 1
self.beta_fast = 32
self.beta_slow = 1
self.max_position_embeddings = 4096
self.mscale = 1.0
self.scaling_factor = 40
def register_buffer(self):
pass
class TestSetSinCosCache(TestBase):
def test_set_cos_sin_cache(self):
module = MockRopeModule()
with patch.object(module, "register_buffer") as mock_register_buffer:
_set_cos_sin_cache(module,
1024,
device="cpu",
dtype=torch.bfloat16)
mock_register_buffer.assert_called()
class TestNativeRopeDeepseekForward(TestBase):
@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot')
def test_native_rope_deepseek_forward_base(self, mock_rope_forward_oot):
module = MockRopeModule()
positions = torch.tensor([1, 2, 3])
query = torch.randn(1, 8, 128)
key = torch.randn(1, 8, 128)
mock_rope_forward_oot.return_value = (query, key)
q_pe, k_pe = native_rope_deepseek_forward(module, positions, query,
key)
assert q_pe.shape == query.shape
assert k_pe.shape == key.shape
@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot')
def test_native_rope_deepseek_forward_key_reshaping(
self, mock_rope_forward_oot):
module = MockRopeModule()
positions = torch.tensor([1, 2, 3])
query = torch.randn(1, 8, 128)
key = torch.randn(1, 128)
mock_rope_forward_oot.return_value = (query, key)
q_pe, k_pe = native_rope_deepseek_forward(module, positions, query,
key)
assert q_pe.shape == query.shape
assert k_pe.shape == (1, 128)
@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot')
def test_native_rope_deepseek_forward_non_neox_style(
self, mock_rope_forward_oot):
module = MockRopeModule(is_neox_style=False)
positions = torch.tensor([1, 2, 3])
query = torch.randn(1, 8, 128)
key = torch.randn(1, 8, 128)
mock_rope_forward_oot.return_value = (query, key)
q_pe, k_pe = native_rope_deepseek_forward(module, positions, query,
key)
assert q_pe.shape == query.shape
assert k_pe.shape == key.shape
class TestRotateHalf(TestBase):
def test_rotate_half_even_dim(self):
# Test with even dimension
x = torch.tensor([1.0, 2.0, 3.0, 4.0])
expected = torch.tensor([-3.0, -4.0, 1.0, 2.0])
result = rotate_half(x)
self.assertTrue(torch.allclose(result, expected))
class TestYarnFindCorrectionDim(TestBase):
def test_basic_case(self):
# Test with standard values
num_rotations = 100
dim = 512
base = 10000
max_position_embeddings = 2048
result = yarn_find_correction_dim(num_rotations, dim, base,
max_position_embeddings)
# Calculate expected value manually
expected = (dim * torch.log(
torch.tensor(max_position_embeddings) /
(num_rotations * 2 * torch.pi))) / (2 *
torch.log(torch.tensor(base)))
self.assertTrue(torch.allclose(result, expected))
class TestYarnGetMscale(TestBase):
def test_scale_less_than_or_equal_1(self):
self.assertEqual(yarn_get_mscale(scale=0.5), 1.0)
self.assertEqual(yarn_get_mscale(scale=1.0), 1.0)
self.assertEqual(yarn_get_mscale(scale=0.999), 1.0)
def test_scale_greater_than_1(self):
test_cases = [(2.0, 1.0, 1.0 + 0.1 * math.log(2.0)),
(10.0, 1.0, 1.0 + 0.1 * math.log(10.0)),
(5.0, 2.0, 1.0 + 0.2 * math.log(5.0)),
(math.e, 1.0, 1.0 + 0.1)]
for scale, mscale, expected in test_cases:
result = yarn_get_mscale(scale, mscale)
self.assertAlmostEqual(
result,
expected,
places=6,
msg=f"Failed for scale={scale}, mscale={mscale}")

View File

@@ -1,296 +0,0 @@
from unittest.mock import Mock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.torchair.quantization.torchair_w4a8_dynamic import (
TorchairAscendW4A8DynamicFusedMoEMethod,
TorchairAscendW4A8DynamicLinearMethod)
class TestAscendW4A8DynamicLinearMethod(TestBase):
@patch('vllm.distributed.get_tensor_model_parallel_world_size')
@patch(
'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_current_vllm_config'
)
def setUp(self, mock_get_current_vllm_config, mock_get_tp_world_size):
mock_get_tp_world_size.return_value = 1
mock_vllm_config = Mock()
mock_vllm_config.quant_config = Mock(
quant_description={"group_size": 256})
mock_get_current_vllm_config.return_value = mock_vllm_config
self.method = TorchairAscendW4A8DynamicLinearMethod()
self.method.group_size = 8
def test_get_weight(self):
weight = self.method.get_weight(8, 32, torch.bfloat16)
self.assertEqual(weight["weight"].dtype, torch.int8)
self.assertEqual(weight["weight"].shape, (32, 8))
# new quant version weight
self.method.new_quant_version = True
weight = self.method.get_weight(8, 32, torch.bfloat16)
self.assertEqual(weight["weight"].dtype, torch.int8)
self.assertEqual(weight["weight"].shape, (16, 8))
self.assertEqual(weight["_packed_dim"], 0)
self.assertEqual(weight["_packed_factor"], 2)
def test_get_pergroup_param(self):
params = self.method.get_pergroup_param(8, 32, torch.bfloat16)
self.assertEqual(params["weight_scale"].dtype, torch.bfloat16)
self.assertEqual(params["weight_scale"].shape, (32, 1))
self.assertEqual(params["weight_offset"].dtype, torch.bfloat16)
self.assertEqual(params["weight_offset"].shape, (32, 1))
self.assertEqual(params["weight_scale_second"].dtype, torch.bfloat16)
self.assertEqual(params["weight_scale_second"].shape, (32, 1))
self.assertEqual(params["weight_offset_second"].dtype, torch.bfloat16)
self.assertEqual(params["weight_offset_second"].shape, (32, 1))
# new quant version weight
self.method.new_quant_version = True
params = self.method.get_pergroup_param(8,
32,
torch.bfloat16,
layer_type="column")
self.assertEqual(params["scale_bias"].dtype, torch.float32)
self.assertEqual(params["scale_bias"].shape, (32, 1))
params = self.method.get_pergroup_param(8,
32,
torch.bfloat16,
layer_type="row")
self.assertEqual(params["scale_bias"].dtype, torch.float32)
self.assertEqual(params["scale_bias"].shape, (32, 16))
@patch('torch_npu.npu_convert_weight_to_int4pack')
@patch('torch.Tensor.npu')
def test_process_weights_after_loading(self, mock_npu,
mock_npu_convert_weight):
mock_npu.side_effect = lambda: torch.zeros(
(1, 32), dtype=torch.float32)
mock_npu_convert_weight.return_value = torch.zeros((32, 4),
dtype=torch.int32)
# old quant version weight
layer = torch.nn.Module()
layer.weight = torch.nn.Parameter(torch.zeros((32, 8),
dtype=torch.int8),
requires_grad=False)
layer.weight_scale = torch.nn.Parameter(torch.ones(
(32, 1), dtype=torch.float32),
requires_grad=False)
layer.weight_offset = torch.nn.Parameter(torch.empty_like(
layer.weight_scale.data),
requires_grad=False)
layer.weight_scale_second = torch.nn.Parameter(torch.ones(
(32, 1), dtype=torch.float32),
requires_grad=False)
layer.weight_offset_second = torch.nn.Parameter(torch.empty_like(
layer.weight_scale_second.data),
requires_grad=False)
self.method.process_weights_after_loading(layer)
self.assertTrue(hasattr(layer, "weight_scale_bias"))
self.assertEqual(layer.weight_scale_bias.data.shape, (32, ))
self.assertEqual(layer.weight_scale_bias.data.dtype, torch.float32)
# new quant version weight
self.method.new_quant_version = True
new_layer = torch.nn.Module()
new_layer.weight = torch.nn.Parameter(torch.zeros((16, 8),
dtype=torch.int8),
requires_grad=False)
new_layer.weight_scale = torch.nn.Parameter(torch.ones(
(32, 1), dtype=torch.float32),
requires_grad=False)
new_layer.weight_offset = torch.nn.Parameter(torch.empty_like(
new_layer.weight_scale.data),
requires_grad=False)
new_layer.weight_scale_second = torch.nn.Parameter(torch.ones(
(32, 1), dtype=torch.float32),
requires_grad=False)
new_layer.weight_offset_second = torch.nn.Parameter(
torch.empty_like(new_layer.weight_scale_second.data),
requires_grad=False)
new_layer.scale_bias = torch.nn.Parameter(torch.zeros(
(32, 1), dtype=torch.float32),
requires_grad=False)
self.method.process_weights_after_loading(new_layer)
self.assertEqual(new_layer.scale_bias.data.shape, (32, ))
self.assertTrue(hasattr(new_layer, "weight_scale_second"))
self.assertEqual(new_layer.weight_scale_second.data.shape, (1, 32))
class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
experts = 8
input_size = 16
output_size = 56
group_size = 2
@patch(
'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_current_vllm_config'
)
@patch(
'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_ep_group')
@patch(
'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_ascend_config'
)
@patch(
'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_mc2_group'
)
@patch('torch.distributed.get_rank', return_value=0)
def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ascend_config,
mock_get_ep_group, get_current_vllm_config):
mock_ascend_config = Mock()
mock_ascend_config.torchair_graph_config = Mock(enabled=False)
mock_get_ascend_config.return_value = mock_ascend_config
mock_vllm_config = Mock()
mock_vllm_config.quant_config = Mock(quant_description={
"group_size": self.group_size,
"version": "0.0.0"
})
mock_vllm_config.parallel_config = Mock(enable_expert_parallel=True)
get_current_vllm_config.return_value = mock_vllm_config
self.quant_method = TorchairAscendW4A8DynamicFusedMoEMethod()
def test_get_weight(self):
# old quant version w4a8 weight
param_dict = self.quant_method.get_weight(self.experts,
self.input_size,
self.output_size,
torch.bfloat16)
self.assertEqual(param_dict["w13_weight"].dtype, torch.int8)
self.assertEqual(param_dict["w13_weight"].shape,
(self.experts, 2 * self.input_size, self.output_size))
# new quant version weight
self.quant_method.new_quant_version = True
param_dict = self.quant_method.get_weight(self.experts,
self.input_size,
self.output_size,
torch.bfloat16)
self.assertEqual(param_dict["w13_weight"].dtype, torch.int8)
self.assertEqual(param_dict["w13_weight"].shape,
(self.experts, self.input_size, self.output_size))
def test_get_dynamic_quant_param(self):
# old quant version weight
param_dict = self.quant_method.get_dynamic_quant_param(
self.experts, self.input_size, self.output_size, torch.bfloat16)
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.float32)
self.assertEqual(param_dict["w13_weight_scale"].shape,
(self.experts, 2 * self.input_size, 1))
self.assertEqual(param_dict["w13_weight_scale_second"].dtype,
torch.float32)
self.assertEqual(param_dict["w13_weight_scale_second"].shape,
(self.experts, 2 * self.input_size,
self.output_size // self.group_size))
self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.float32)
self.assertEqual(param_dict["w2_weight_scale"].shape,
(self.experts, self.output_size, 1))
self.assertEqual(param_dict["w2_weight_scale_second"].dtype,
torch.float32)
self.assertEqual(param_dict["w2_weight_scale_second"].shape,
(self.experts, self.output_size,
self.input_size // self.group_size))
# new quant version weight
self.quant_method.new_quant_version = True
param_dict = self.quant_method.get_dynamic_quant_param(
self.experts, self.input_size, self.output_size, torch.bfloat16)
self.assertEqual(param_dict["w2_scale_bias"].dtype, torch.float32)
self.assertEqual(
param_dict["w2_scale_bias"].shape,
(self.experts, self.output_size, 16 // self.quant_method.tp_size))
# per-channel weight
self.quant_method.is_per_channel_weight = True
param_dict = self.quant_method.get_dynamic_quant_param(
self.experts, self.input_size, self.output_size, torch.bfloat16)
pergroup_param = [
"w13_weight_scale_second", "w13_weight_offset_second",
"w2_weight_scale_second", "w2_weight_offset_second"
]
is_contains = any(key in param_dict for key in pergroup_param)
self.assertFalse(is_contains)
def build_layer(self,
is_new_quant_version=True,
is_per_channel_weight=False):
layer = torch.nn.Module()
if is_new_quant_version:
layer.w13_weight = torch.nn.Parameter(torch.zeros(
(self.experts, self.input_size, self.output_size),
dtype=torch.int8),
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(torch.zeros(
(self.experts, self.output_size // 2, self.input_size),
dtype=torch.int8),
requires_grad=False)
w13_scale_bias = torch.zeros(
(self.experts, 2 * self.input_size, 1), dtype=torch.float32)
layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias,
requires_grad=False)
w2_scale_bias = torch.zeros((self.experts, self.output_size,
16 // self.quant_method.tp_size),
dtype=torch.float32)
layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias,
requires_grad=False)
else:
layer.w13_weight = torch.nn.Parameter(torch.zeros(
(self.experts, 2 * self.input_size, self.output_size),
dtype=torch.int8),
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(torch.zeros(
(self.experts, self.output_size, self.input_size),
dtype=torch.int8),
requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
(self.experts, 2 * self.input_size, 1), dtype=torch.float32),
requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(torch.ones(
(self.experts, self.output_size, 1), dtype=torch.float32),
requires_grad=False)
if not is_per_channel_weight:
layer.w13_weight_scale_second = torch.nn.Parameter(
torch.ones((self.experts, 2 * self.input_size,
self.output_size // self.group_size),
dtype=torch.float32),
requires_grad=False)
layer.w13_weight_offset_second = torch.nn.Parameter(
torch.empty_like(layer.w13_weight_scale_second.data),
requires_grad=False)
layer.w2_weight_scale_second = torch.nn.Parameter(
torch.ones((self.experts, self.output_size,
self.input_size // self.group_size),
dtype=torch.float32),
requires_grad=False)
layer.w2_weight_offset_second = torch.nn.Parameter(
torch.empty_like(layer.w2_weight_scale_second.data),
requires_grad=False)
return layer
@patch('torch_npu.npu_quantize')
@patch('torch.Tensor.npu')
def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize):
mock_npu.return_value = torch.Tensor()
mock_npu_quantize.return_value = torch.Tensor()
# old quant version weight
layer = self.build_layer(is_new_quant_version=False)
self.quant_method.process_weights_after_loading(layer)
self.assertTrue(hasattr(layer, "w13_scale_bias"))
self.assertEqual(layer.w13_scale_bias.data.shape,
(self.experts, 2 * self.input_size))
self.assertEqual(layer.w13_scale_bias.data.dtype, torch.float32)
self.assertTrue(hasattr(layer, "w2_scale_bias"))
self.assertEqual(layer.w2_scale_bias.data.shape,
(self.experts, self.output_size))
self.assertEqual(layer.w2_scale_bias.data.dtype, torch.float32)
# new quant version weight
self.quant_method.new_quant_version = True
new_layer = self.build_layer(is_new_quant_version=True)
self.quant_method.process_weights_after_loading(new_layer)
self.assertEqual(new_layer.w13_scale_bias.data.shape,
(self.experts, 2 * self.input_size))
self.assertEqual(new_layer.w2_scale_bias.data.shape,
(self.experts, self.output_size))
self.assertFalse(hasattr(new_layer, "w13_weight_scale_second"))
# per-channel weight
self.quant_method.is_per_channel_weight = True
per_channel_layer = self.build_layer(is_new_quant_version=True,
is_per_channel_weight=True)
self.quant_method.process_weights_after_loading(per_channel_layer)
self.assertEqual(new_layer.w13_scale_bias.data.shape,
(self.experts, 2 * self.input_size))

View File

@@ -1,129 +0,0 @@
from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import (
torchair_fused_experts_with_all2all, torchair_fused_experts_with_mc2)
from vllm_ascend.utils import AscendDeviceType
class TestAscendW8A8FusedMoEMethod(TestBase):
def setUp(self):
self.hidden_size = 128
self.num_tokens = 128
self.placeholder = torch.randn(self.num_tokens,
self.hidden_size,
dtype=torch.bfloat16)
@patch("torch.distributed.all_to_all_single")
@patch("torch_npu.npu_moe_re_routing")
@patch("torch_npu.npu_grouped_matmul")
@patch("torch_npu.npu_swiglu")
@patch("torch_npu.npu_dynamic_quant")
@patch("torch_npu.npu_moe_finalize_routing")
@patch("torch_npu.npu_moe_init_routing_quant")
def test_torchair_fused_experts_with_all2all(
self, mock_npu_moe_init_routing_quant, mock_moe_finalize_routing,
mock_dynamic_quant, mock_swiglu, mock_grouped_matmul,
mock_moe_re_routing, mock_all_to_all_single):
expert_map = MagicMock()
ep_group = MagicMock()
placeholder_int8 = torch.randint(0,
100,
(self.num_tokens, self.hidden_size),
dtype=torch.int8)
placeholder_ones = torch.ones(self.num_tokens, dtype=torch.int32)
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
input)
mock_npu_moe_init_routing_quant.return_value = (
placeholder_int8, placeholder_ones, placeholder_ones,
torch.bincount(placeholder_ones, minlength=len(expert_map)),
torch.randn(self.num_tokens))
mock_moe_re_routing.return_value = (placeholder_int8, self.placeholder,
torch.randint(0,
100,
(self.num_tokens, ),
dtype=torch.int32),
self.placeholder)
mock_grouped_matmul.return_value = self.placeholder
mock_swiglu.return_value = self.placeholder
mock_dynamic_quant.return_value = (
placeholder_int8,
torch.randn(self.num_tokens),
)
mock_moe_finalize_routing.return_value = self.placeholder
result = torchair_fused_experts_with_all2all(
hidden_states=self.placeholder,
w1=self.placeholder,
w1_scale=self.placeholder,
w2=self.placeholder,
w2_scale=self.placeholder,
topk_weights=self.placeholder,
topk_ids=self.placeholder,
top_k=8,
expert_map=expert_map,
ep_group=ep_group,
log2phy=None,
global_redundant_expert_num=256,
)
self.assertIsNotNone(result)
self.assertEqual(result.dtype, torch.bfloat16)
self.assertEqual(result.shape, (128, 128))
@patch.dict('os.environ', {
'HCCL_INTRA_ROCE_ENABLE': '0',
'HCCL_INTRA_PCIE_ENABLE': '1'
})
@patch(
"vllm_ascend.torchair.quantization.torchair_w8a8_dynamic.get_ascend_device_type"
)
@patch(
'vllm_ascend.torchair.quantization.torchair_w8a8_dynamic.get_mc2_group'
)
@patch('torch_npu.npu_moe_distribute_combine_v2')
@patch('torch_npu.npu_moe_distribute_dispatch_v2')
@patch(
'vllm_ascend.torchair.quantization.torchair_w8a8_dynamic.torchair_apply_mlp_decode'
)
def test_torchair_fused_experts_with_mc2_a2_optimization(
self, mock_mlp_decode, mock_dispatch, mock_combine, mock_get_group,
mock_ascend_soc_version):
"""Test expert_scales is passed in A2 SOC version with mc2 optimization"""
# Setup mocks
mock_ascend_soc_version.return_value = AscendDeviceType._910B
mock_group = MagicMock()
mock_group.rank_in_group = 0
mock_group.world_size = 4
mock_get_group.return_value = mock_group
mock_combine.return_value = self.placeholder
mock_dispatch.return_value = (torch.randn(32, 1024), torch.randn(1),
torch.randint(0, 32, (32, )),
torch.randint(1, 5, (8, )),
torch.randint(1, 5, (4, )), None,
torch.randn(32))
mock_mlp_decode.return_value = self.placeholder
result = torchair_fused_experts_with_mc2(
hidden_states=self.placeholder,
w1=self.placeholder,
w2=self.placeholder,
w1_scale=self.placeholder,
w2_scale=self.placeholder,
topk_weights=self.placeholder,
topk_ids=self.placeholder,
top_k=2,
mc2_mask=self.placeholder)
# Check that expert_scales was passed to dispatch
call_args = mock_dispatch.call_args[1]
self.assertIn('expert_scales', call_args)
self.assertIsInstance(result, torch.Tensor)
self.assertEqual(result.shape, self.placeholder.shape)

View File

@@ -1,95 +0,0 @@
from unittest.mock import MagicMock, patch
import torch
from vllm.attention.backends.abstract import AttentionType
from vllm.distributed.parallel_state import GroupCoordinator
from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.torchair.torchair_attention import \
AscendAttentionTorchairBackendImpl
class TestAscendAttentionTorchairBackendImpl(TestBase):
@patch("torch.zeros")
@patch('vllm.distributed.parallel_state._TP',
new_callable=lambda: MagicMock(spec=GroupCoordinator)) # TODO
@patch("vllm.distributed.get_tensor_model_parallel_world_size",
return_value=2) # TODO
@patch("vllm.config.get_current_vllm_config") # TODO
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config") # TODO
def setUp(self, ascend_config, vllm_config, mock_get_tp_size, mock_tp,
mock_zeros):
mock_tp.world_size = 2 # TODO
ascend_config.torchair_graph_config.enabled = True # TODO
ascend_config.torchair_graph_config.enable_kv_nz = False # TODO
speculative_config = MagicMock()
speculative_config.num_speculative_tokens = 4
vllm_config.speculative_config = speculative_config
num_heads = 32
head_size = 128 # TODO
scale = 0.1 # TODO
num_kv_heads = 4
kv_cache_dtype = "auto"
attn_type = AttentionType.DECODER
mock_zeros.return_value = torch.ones((),
device='cpu',
dtype=torch.int32)
self.impl = AscendAttentionTorchairBackendImpl(
num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype=kv_cache_dtype,
blocksparse_params=None,
logits_soft_cap=None,
attn_type=attn_type,
kv_sharing_target_layer_name=None)
@patch("torch_npu.npu_scatter_nd_update_")
@patch("torch_npu.npu_fused_infer_attention_score")
def test_forward_with_decode_only(self, mock_fused, _):
layer = MagicMock()
layer._k_scale_float = 1.0
layer._v_scale_float = 1.0
seq_len = 1
num_tokens = 100
num_blocks = 256
block_size = 4
query = torch.randn(num_tokens, seq_len,
self.impl.num_heads * self.impl.head_size)
key = torch.randn(num_tokens, seq_len,
self.impl.num_kv_heads * self.impl.head_size)
value = torch.randn(num_tokens, seq_len,
self.impl.num_kv_heads * self.impl.head_size)
kv_cache = (torch.randn(num_blocks, block_size,
self.impl.num_heads * self.impl.head_size),
torch.randn(num_blocks, block_size,
self.impl.num_heads * self.impl.head_size))
output = torch.randn(num_tokens, self.impl.num_heads,
self.impl.head_size)
decode = MagicMock() # TODO
decode.seq_lens_list = [2] * num_tokens
decode.block_table = torch.ones(num_tokens, 8, dtype=torch.int32)
decode.attn_mask = None
metadata = MagicMock()
metadata.attn_state = AscendAttentionState.DecodeOnly
metadata.slot_mapping = torch.arange(num_tokens, dtype=torch.int32)
metadata.decode = decode
mock_fused.return_value = (torch.ones(num_tokens, self.impl.num_heads,
self.impl.head_size),
torch.ones(1))
result = self.impl.forward(layer, query, key, value, kv_cache,
metadata, output)
self.assertEqual(result.shape[0], num_tokens)

View File

@@ -1,887 +0,0 @@
from unittest.mock import MagicMock, patch
import pytest
import torch
from torch import nn
from vllm.distributed.parallel_state import GroupCoordinator
from vllm.model_executor.layers.linear import LinearBase
from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.torchair.torchair_mla import (
AscendMLATorchairBackend, AscendMLATorchairDecodeMetadata,
AscendMLATorchairImpl, AscendMLATorchairMetadata,
AscendMLATorchairMetadataBuilder, AscendMLATorchairPrefillMetadata)
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
class TestAscendMLATorchairBackend(TestBase):
def test_get_name(self):
self.assertEqual(AscendMLATorchairBackend.get_name(),
"ASCEND_MLA_TORCHAIR")
def test_get_builder_cls(self):
self.assertEqual(AscendMLATorchairBackend.get_builder_cls(),
AscendMLATorchairMetadataBuilder)
def test_get_kv_cache_shape(self):
result = AscendMLATorchairBackend.get_kv_cache_shape(2, 4, 8, 128)
self.assertEqual(result, (2, 4, 8, 128))
def test_get_impl_cls(self):
result = AscendMLATorchairBackend.get_impl_cls()
self.assertEqual(result, AscendMLATorchairImpl)
class TestAscendMLATorchairPrefillMetadata(TestBase):
def test_ascend_mla_prefill_metadata_default(self):
attn_mask = torch.tensor([[1, 0], [1, 1]], dtype=torch.bool)
query_lens = [1, 2]
seq_lens = [2, 2]
context_lens = torch.tensor([1, 2])
input_positions = torch.tensor([0, 1, 0, 1])
query_start_loc = torch.tensor([0, 1, 3])
block_table = torch.tensor([[0, 1], [2, 3]])
max_query_len = 2
max_seq_lens = 2
metadata = AscendMLATorchairPrefillMetadata(
attn_mask=attn_mask,
query_lens=query_lens,
seq_lens=seq_lens,
context_lens=context_lens,
input_positions=input_positions,
query_start_loc=query_start_loc,
block_table=block_table,
max_query_len=max_query_len,
max_seq_lens=max_seq_lens)
self.assertIs(metadata.attn_mask, attn_mask)
self.assertEqual(metadata.query_lens, query_lens)
self.assertEqual(metadata.seq_lens, seq_lens)
self.assertIs(metadata.context_lens, context_lens)
self.assertIs(metadata.input_positions, input_positions)
self.assertIs(metadata.query_start_loc, query_start_loc)
self.assertIs(metadata.block_table, block_table)
self.assertEqual(metadata.max_query_len, max_query_len)
self.assertEqual(metadata.max_seq_lens, max_seq_lens)
self.assertIsNone(metadata.chunked_context)
def test_ascend_mla_prefill_metadata_with_chunked_context(self):
cu_seq_lens = torch.tensor([0, 2, 4])
starts = torch.tensor([0, 2])
seq_tot = [2, 2]
max_seq_lens = [2, 2]
workspace = torch.randn(2, 4)
chunk_seq_lens = torch.tensor([2, 2])
chunked_context = AscendMLATorchairPrefillMetadata.TorchairChunkedContextMetadata(
cu_seq_lens=cu_seq_lens,
starts=starts,
seq_tot=seq_tot,
max_seq_lens=max_seq_lens,
workspace=workspace,
chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens)
metadata = AscendMLATorchairPrefillMetadata(
attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool),
query_lens=[1, 2],
seq_lens=[2, 2],
context_lens=torch.tensor([1, 2]),
input_positions=torch.tensor([0, 1, 0, 1]),
query_start_loc=torch.tensor([0, 1, 3]),
block_table=torch.tensor([[0, 1], [2, 3]]),
max_query_len=2,
max_seq_lens=2,
chunked_context=chunked_context)
self.assertIsNotNone(metadata.chunked_context)
self.assertIs(metadata.chunked_context.cu_seq_lens, cu_seq_lens)
self.assertIs(metadata.chunked_context.starts, starts)
self.assertEqual(metadata.chunked_context.seq_tot, seq_tot)
self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens)
self.assertIs(metadata.chunked_context.workspace, workspace)
self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens)
self.assertIs(metadata.chunked_context.chunk_seq_lens_npu,
chunk_seq_lens)
class TestAscendMLATorchairDecodeMetadata(TestBase):
def test_ascend_mla_decode_metadata_default(self):
input_positions = torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]])
block_table = torch.tensor([[0, 3, 2, 1], [0, 2, 1, 3]])
seq_lens = torch.tensor([[2], [3]])
max_seq_lens = 4
seq_lens_list = [2, 3]
attn_mask = None
metadata = AscendMLATorchairDecodeMetadata(input_positions,
block_table, seq_lens,
max_seq_lens, seq_lens_list,
attn_mask)
self.assertIs(metadata.input_positions, input_positions)
self.assertIs(metadata.block_table, block_table)
self.assertIs(metadata.seq_lens, seq_lens)
self.assertEqual(metadata.max_seq_lens, max_seq_lens)
self.assertEqual(metadata.seq_lens_list, seq_lens_list)
self.assertIsNone(attn_mask)
class TestAscendMLATorchairMetadata(TestBase):
def test_ascend_mla_metadata_default(self):
num_actual_tokens = 100
slot_mapping = torch.randn(100, 4, 1024)
query_start_loc = torch.tensor([1, 2, 3, 4])
seq_lens = [30, 50]
block_tables = torch.randint(0, 100, (100, 4))
num_decodes = 4
num_decode_tokens = 8
num_prefills = 8
num_input_tokens = 2
query_lens = None
head_dim = None
attn_mask = None
attn_state = AscendAttentionState.ChunkedPrefill
decode = None
prefill = None
metadata = AscendMLATorchairMetadata(
num_actual_tokens, slot_mapping, query_start_loc, seq_lens,
block_tables, num_decodes, num_decode_tokens, num_prefills,
num_input_tokens, query_lens, head_dim, attn_mask, attn_state,
decode, prefill)
self.assertEqual(metadata.num_actual_tokens, num_actual_tokens)
self.assertIs(metadata.slot_mapping, slot_mapping)
self.assertIs(metadata.query_start_loc, query_start_loc)
self.assertEqual(metadata.seq_lens, seq_lens)
self.assertIs(metadata.block_tables, block_tables)
self.assertEqual(metadata.num_decodes, num_decodes)
self.assertEqual(metadata.num_decode_tokens, num_decode_tokens)
self.assertEqual(metadata.num_prefills, num_prefills)
self.assertEqual(metadata.num_input_tokens, num_input_tokens)
self.assertEqual(metadata.query_lens, query_lens)
self.assertEqual(metadata.head_dim, head_dim)
self.assertEqual(metadata.attn_mask, attn_mask)
self.assertEqual(metadata.attn_state, attn_state)
self.assertEqual(metadata.decode, decode)
self.assertEqual(metadata.prefill, prefill)
class TestAscendMLATorchairMetadataBuilder(TestBase):
def test_ascend_mla_metadata_builder_default(self):
mock_model_config = MagicMock()
mock_model_config.max_model_len = 1024
mock_model_config.get_head_size.return_value = 64
mock_model_config.dtype = torch.float16
mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config
mock_vllm_config.cache_config = MagicMock(block_size=16)
mock_vllm_config.scheduler_config = MagicMock(
max_num_seqs=4, enable_chunked_prefill=False)
mock_vllm_config.speculative_config = None
mock_device = torch.device('cpu')
ascend_config = MagicMock()
ascend_config.torchair_graph_config = MagicMock()
ascend_config.torchair_graph_config.enabled = True
with patch("vllm_ascend.torchair.torchair_mla.get_ascend_config",
return_value=ascend_config):
builder = AscendMLATorchairMetadataBuilder(None, None,
mock_vllm_config,
mock_device)
self.assertEqual(builder.block_size,
mock_vllm_config.cache_config.block_size)
self.assertEqual(
builder.chunked_prefill_enabled,
mock_vllm_config.scheduler_config.enable_chunked_prefill)
self.assertEqual(builder.torchair_graph_enabled, True)
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config")
def test_reorder_batch_with_torchair_graph(self, ascend_config):
mock_model_config = MagicMock()
mock_model_config.max_model_len = 1024
mock_model_config.get_head_size.return_value = 64
mock_model_config.dtype = torch.float16
mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config
mock_vllm_config.cache_config = MagicMock(block_size=16)
mock_vllm_config.scheduler_config = MagicMock(
max_num_seqs=4, enable_chunked_prefill=False)
mock_vllm_config.speculative_config = None
mock_device = torch.device('cpu')
builder = AscendMLATorchairMetadataBuilder(None, None,
mock_vllm_config,
mock_device)
input_batch = MagicMock()
input_batch.req_ids = [0, 1, 2, 3]
scheduler_output = MagicMock()
scheduler_output.num_scheduled_tokens = {0: 2, 1: 1, 2: 3, 3: 1}
scheduler_output.scheduled_spec_decode_tokens = {
0: [1],
1: [],
2: [1, 1],
3: []
}
input_batch.swap_states = MagicMock()
modified = builder.reorder_batch(input_batch, scheduler_output)
self.assertFalse(modified)
input_batch.swap_states.assert_not_called()
def test_reorder_batch_without_torchair_graph(self):
ascend_config = MagicMock()
ascend_config.torchair_graph_config = MagicMock()
ascend_config.torchair_graph_config.enabled = False
mock_model_config = MagicMock()
mock_model_config.max_model_len = 1024
mock_model_config.get_head_size.return_value = 64
mock_model_config.dtype = torch.float16
mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config
mock_vllm_config.cache_config = MagicMock(block_size=16)
mock_vllm_config.scheduler_config = MagicMock(
max_num_seqs=4, enable_chunked_prefill=False)
mock_vllm_config.speculative_config = None
mock_device = torch.device('cpu')
with patch("vllm_ascend.torchair.torchair_mla.get_ascend_config",
return_value=ascend_config):
builder = AscendMLATorchairMetadataBuilder(None, None,
mock_vllm_config,
mock_device)
input_batch = MagicMock()
input_batch.req_ids = [0, 1, 2, 3]
scheduler_output = MagicMock()
scheduler_output.num_scheduled_tokens = {0: 1, 1: 3, 2: 1, 3: 2}
scheduler_output.scheduled_spec_decode_tokens = {
0: [],
1: [1],
2: [],
3: []
}
input_batch.swap_states = MagicMock()
modified = builder.reorder_batch(input_batch, scheduler_output)
self.assertTrue(modified)
input_batch.swap_states.assert_called_once_with(1, 2)
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config")
def test_get_graph_runner_block_tables_normal(self, mock_ascend_config):
ascend_config = MagicMock()
mock_ascend_config.return_value = ascend_config
ascend_config.torchair_graph_config.enabled = False
mock_model_config = MagicMock()
mock_model_config.max_model_len = 1024
mock_model_config.get_head_size.return_value = 64
mock_model_config.dtype = torch.float16
mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config
mock_vllm_config.cache_config = MagicMock(block_size=16)
mock_vllm_config.scheduler_config = MagicMock(
max_num_seqs=4, enable_chunked_prefill=False)
mock_vllm_config.speculative_config = None
mock_device = torch.device('cpu')
builder = AscendMLATorchairMetadataBuilder(None, None,
mock_vllm_config,
mock_device)
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
result = builder._get_graph_runner_block_tables(3, block_tables)
self.assertEqual(result.shape[0], 3)
self.assertEqual(result.shape[1], 64)
self.assertTrue(torch.equal(result[:, :10], block_tables))
@pytest.mark.skip(reason="Skipping this test temporarily.")
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config")
def test_get_graph_runner_block_tables_truncated(self, mock_ascend_config):
ascend_config = MagicMock()
mock_ascend_config.return_value = ascend_config
ascend_config.torchair_graph_config.enabled = False
mock_model_config = MagicMock()
mock_model_config.get_head_size.return_value = 64
mock_model_config.dtype = torch.float16
mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config
mock_vllm_config.cache_config = MagicMock(block_size=16)
mock_vllm_config.scheduler_config = MagicMock(
enable_chunked_prefill=False)
mock_vllm_config.speculative_config = None
mock_device = torch.device('cpu')
builder = AscendMLATorchairMetadataBuilder(None, None,
mock_vllm_config,
mock_device)
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
result = builder._get_graph_runner_block_tables(3, block_tables)
self.assertEqual(result.shape[0], 3)
self.assertEqual(result.shape[1], 4)
self.assertTrue(torch.equal(result, block_tables[:, :4]))
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config")
def test_get_graph_runner_block_tables_from_numpy(self,
mock_ascend_config):
ascend_config = MagicMock()
mock_ascend_config.return_value = ascend_config
ascend_config.torchair_graph_config.enabled = False
mock_model_config = MagicMock()
mock_model_config.max_model_len = 1024
mock_model_config.get_head_size.return_value = 64
mock_model_config.dtype = torch.float16
mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config
mock_vllm_config.cache_config = MagicMock(block_size=16)
mock_vllm_config.scheduler_config = MagicMock(
max_num_seqs=4, enable_chunked_prefill=False)
mock_vllm_config.speculative_config = None
mock_device = torch.device('cpu')
builder = AscendMLATorchairMetadataBuilder(None, None,
mock_vllm_config,
mock_device)
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
result = builder._get_graph_runner_block_tables(3, block_tables)
self.assertEqual(result.shape[0], 3)
self.assertEqual(result.shape[1], 64)
self.assertTrue(torch.equal(result[:, :10], block_tables))
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config")
def test_build_dummy(self, mock_ascend_config):
ascend_config = MagicMock()
mock_ascend_config.return_value = ascend_config
ascend_config.torchair_graph_config.enabled = False
mock_model_config = MagicMock()
mock_model_config.max_model_len = 1024
mock_model_config.get_head_size.return_value = 64
mock_model_config.dtype = torch.float16
mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config
mock_vllm_config.cache_config = MagicMock(block_size=16)
mock_vllm_config.scheduler_config = MagicMock(
max_num_seqs=4, enable_chunked_prefill=False)
mock_vllm_config.speculative_config = None
mock_device = torch.device('cpu')
builder = AscendMLATorchairMetadataBuilder(
None,
None,
mock_vllm_config,
mock_device,
metadata_cls=AscendMLATorchairMetadata)
builder.rope_dim = 64
with patch.object(builder,
"_get_graph_runner_block_tables",
side_effect=lambda x, y: y):
common_attn_metadata = TorchairCommonAttentionMetadata(
num_reqs=3,
num_actual_tokens=3,
decode_token_per_req=1,
actual_seq_lengths_q=[0, 1, 2],
attn_mask=torch.zeros((1, 1), dtype=torch.bool),
spec_attn_mask=torch.zeros((1, 1), dtype=torch.bool),
)
metadata = builder.build_torchair_graph_dummy(common_attn_metadata)
sin_golden = torch.ones(3,
1,
1,
64,
dtype=torch.float16,
device=mock_device)
cos_golden = torch.ones(3,
1,
1,
64,
dtype=torch.float16,
device=mock_device)
self.assertIsInstance(metadata, AscendMLATorchairMetadata)
self.assertEqual(metadata.num_input_tokens, 3)
self.assertEqual(metadata.num_actual_tokens, 3)
self.assertEqual(metadata.num_decodes, 1)
self.assertEqual(metadata.num_decode_tokens, 1)
self.assertEqual(metadata.num_prefills, 0)
self.assertEqual(metadata.attn_state, AscendAttentionState.DecodeOnly)
self.assertIsNone(metadata.prefill)
self.assertIsInstance(metadata.decode, AscendMLATorchairDecodeMetadata)
self.assertEqual(metadata.block_tables.shape[0], 3)
self.assertEqual(metadata.block_tables.shape[1], 64)
self.assertEqual(metadata.seq_lens.shape[0], 3)
self.assertEqual(metadata.slot_mapping.shape[0], 3)
self.assertEqual(metadata.query_start_loc.shape[0], 3)
assert torch.equal(sin_golden, metadata.decode.sin)
assert torch.equal(cos_golden, metadata.decode.cos)
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config")
def test_build_decode(self, mock_ascend_config):
ascend_config = MagicMock()
mock_ascend_config.return_value = ascend_config
ascend_config.torchair_graph_config.enabled = False
mock_model_config = MagicMock()
mock_model_config.max_model_len = 1024
mock_model_config.get_head_size.return_value = 64
mock_model_config.dtype = torch.float16
mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config
mock_vllm_config.cache_config = MagicMock(block_size=16)
mock_vllm_config.scheduler_config = MagicMock(
max_num_seqs=4, enable_chunked_prefill=False)
mock_vllm_config.speculative_config = None
mock_device = torch.device('cpu')
model = MagicMock(spec=nn.Module)
model.model = MagicMock(spec=nn.Module)
builder = AscendMLATorchairMetadataBuilder(
None,
None,
mock_vllm_config,
mock_device,
metadata_cls=AscendMLATorchairMetadata)
builder.rope_dim = 64
builder.sin_cache = torch.tensor([10, 10])
builder.cos_cache = torch.tensor([10, 10])
with patch.object(builder,
"_get_graph_runner_block_tables",
side_effect=lambda x, y: y):
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=torch.tensor([0, 1, 2, 3]),
query_start_loc_cpu=torch.tensor([0, 1, 2, 3]),
seq_lens_cpu=torch.tensor([1, 1, 1]),
num_reqs=3,
num_actual_tokens=3,
max_query_len=1,
decode_token_per_req=torch.tensor([1, 1, 1]),
block_table_tensor=torch.zeros((10, 10)),
slot_mapping=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
positions=torch.tensor([1, 1]),
attn_mask=torch.ones((15, 15)),
spec_attn_mask=None,
attn_state=AscendAttentionState.ChunkedPrefill,
num_computed_tokens_cpu=None,
seq_lens=None)
metadata = builder.build(1, common_attn_metadata, model)
self.assertIsInstance(metadata, AscendMLATorchairMetadata)
self.assertEqual(metadata.num_input_tokens, 0)
self.assertEqual(metadata.num_actual_tokens, 3)
self.assertEqual(metadata.num_decodes, 3)
self.assertEqual(metadata.num_decode_tokens, 3)
self.assertEqual(metadata.num_prefills, 0)
self.assertEqual(metadata.attn_state,
AscendAttentionState.ChunkedPrefill)
self.assertIsNone(metadata.prefill)
self.assertIsInstance(metadata.decode, AscendMLATorchairDecodeMetadata)
self.assertEqual(metadata.block_tables.shape[0], 3)
self.assertEqual(metadata.block_tables.shape[1], 10)
self.assertEqual(metadata.seq_lens.shape[0], 3)
self.assertEqual(metadata.slot_mapping.shape[0], 3)
self.assertEqual(metadata.query_start_loc.shape[0], 4)
class TestAscendMLATorchairImpl(TestBase):
@patch('vllm.distributed.parallel_state._TP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm.distributed.get_tensor_model_parallel_world_size",
return_value=2)
@patch("vllm.config.get_current_vllm_config")
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config")
def setUp(self, ascend_config, vllm_config, mock_get_tp_size, mock_tp):
mock_tp.world_size = 2
ascend_config.torchair_graph_config.enabled = True
ascend_config.torchair_graph_config.enable_kv_nz = False
speculative_config = MagicMock()
speculative_config.num_speculative_tokens = 4
vllm_config.speculative_config = speculative_config
num_heads = 256
head_size = 1024
scale = 0.1
num_kv_heads = 8
kv_cache_dtype = "auto"
kv_a_layernorm = MagicMock()
kv_a_layernorm.weight = torch.randn(96)
kv_a_layernorm.variance_epsilon = 1e-6
kwargs = {
"q_lora_rank": 64,
"kv_lora_rank": 32,
"qk_nope_head_dim": 64,
"qk_rope_head_dim": 32,
"qk_head_dim": 96,
"v_head_dim": 128,
"rotary_emb": MagicMock(),
"q_proj": MagicMock(),
"q_b_proj": MagicMock(),
"kv_b_proj": MagicMock(),
"o_proj": MagicMock(),
"kv_a_proj_with_mqa": MagicMock(),
"kv_a_layernorm": kv_a_layernorm,
}
self.impl = AscendMLATorchairImpl(num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype=kv_cache_dtype,
blocksparse_params=None,
logits_soft_cap=None,
attn_type=None,
kv_sharing_target_layer_name=None,
**kwargs)
def test_init(self):
self.assertEqual(self.impl.num_heads, 256)
self.assertEqual(self.impl.head_size, 1024)
self.assertEqual(self.impl.scale, 0.1)
self.assertEqual(self.impl.num_kv_heads, 8)
self.assertEqual(self.impl.kv_cache_dtype, "auto")
self.assertEqual(self.impl.q_lora_rank, 64)
self.assertEqual(self.impl.kv_lora_rank, 32)
self.assertEqual(self.impl.qk_nope_head_dim, 64)
self.assertEqual(self.impl.qk_rope_head_dim, 32)
self.assertEqual(self.impl.qk_head_dim, 96)
self.assertEqual(self.impl.v_head_dim, 128)
self.assertIsNotNone(self.impl.rotary_emb)
self.assertIsNotNone(self.impl.q_proj)
self.assertIsNotNone(self.impl.kv_b_proj)
self.assertIsNotNone(self.impl.o_proj)
self.assertIsNotNone(self.impl.kv_a_proj_with_mqa)
self.assertIsNotNone(self.impl.kv_a_layernorm)
self.assertEqual(self.impl.num_queries_per_kv, 32)
self.assertEqual(self.impl.tp_size, 2)
self.assertTrue(self.impl.torchair_graph_enabled)
def test_v_up_proj_and_o_proj(self):
batch_size = 4
x = torch.randn(batch_size, self.impl.num_heads,
self.impl.kv_lora_rank)
self.impl.o_proj.return_value = (torch.randn(
batch_size, self.impl.num_heads * self.impl.v_head_dim), )
if not hasattr(self.impl, 'W_UV') or self.impl.W_UV is None:
self.impl.W_UV = torch.randn(self.impl.num_heads,
self.impl.kv_lora_rank,
self.impl.v_head_dim)
result = self.impl._v_up_proj_and_o_proj(x)
self.assertEqual(result.shape[0], batch_size)
self.assertEqual(result.shape[1],
self.impl.num_heads * self.impl.v_head_dim)
def test_q_proj_and_k_up_proj(self):
batch_size = 4
x = torch.randn(batch_size, self.impl.num_heads, self.impl.qk_head_dim)
q_proj_output = torch.randn(batch_size, self.impl.num_heads,
self.impl.qk_head_dim)
self.impl.q_proj.return_value = (q_proj_output, )
if not hasattr(self.impl, 'W_UK_T') or self.impl.W_UK_T is None:
self.impl.W_UK_T = torch.randn(self.impl.num_heads,
self.impl.qk_nope_head_dim,
self.impl.kv_lora_rank)
result = self.impl._q_proj_and_k_up_proj(x)
ql_nope, q_pe = result
self.assertEqual(ql_nope.shape[0], batch_size)
self.assertEqual(ql_nope.shape[1], self.impl.num_heads)
self.assertEqual(ql_nope.shape[2], self.impl.kv_lora_rank)
self.assertEqual(q_pe.shape[0], batch_size)
self.assertEqual(q_pe.shape[1], self.impl.num_heads)
self.assertEqual(q_pe.shape[2], self.impl.qk_rope_head_dim)
def test_process_weights_after_loading(self):
layer = MagicMock(spec=LinearBase)
layer.input_size_per_partition = 10
quant_method = MagicMock()
apply = MagicMock()
quant_method.apply = apply
layer.quant_method = quant_method
shape_0 = self.impl.num_heads * (self.impl.qk_nope_head_dim +
self.impl.v_head_dim)
shape_1 = self.impl.kv_lora_rank
layer.weight = torch.randn(shape_0, shape_1)
self.impl.kv_b_proj = layer
apply.return_value = layer.weight.T
self.impl.process_weights_after_loading(torch.bfloat16)
self.assertEqual(self.impl.W_UK_T.shape[0], self.impl.num_heads)
self.assertEqual(self.impl.W_UK_T.shape[1], self.impl.qk_nope_head_dim)
self.assertEqual(self.impl.W_UK_T.shape[2], self.impl.kv_lora_rank)
self.assertEqual(self.impl.W_UV.shape[0], self.impl.num_heads)
self.assertEqual(self.impl.W_UV.shape[1], self.impl.kv_lora_rank)
self.assertEqual(self.impl.W_UV.shape[2], self.impl.v_head_dim)
def test_compute_prefill_context_none(self):
batch_size = 4
kv_cache = torch.randn(10, 1, 1, 192)
query = torch.randn(batch_size, self.impl.num_heads,
self.impl.qk_head_dim)
metadata = MagicMock()
metadata.prefill = None
prefix_out = torch.randn(2, 16, 128)
prefix_lse = torch.randn(2, 16, 8)
out, lse = self.impl._compute_prefill_context(query, kv_cache, 32,
metadata, prefix_out,
prefix_lse)
self.assertTrue(torch.equal(prefix_out, out))
self.assertTrue(torch.equal(prefix_lse, lse))
@patch("torch_npu.atb.npu_paged_cache_load")
@patch("torch_npu.atb.npu_ring_mla")
def test_compute_prefill_context(self, mock_ring, mock_load):
S, N, D, VD = 2, self.impl.num_heads, self.impl.qk_head_dim, self.impl.v_head_dim
_, AND = self.impl.qk_rope_head_dim, self.impl.qk_nope_head_dim
latent_kv_dim = self.impl.kv_lora_rank
num_blocks, block_size = 100, 20
query = torch.randn(S, N, D)
kv_cache_0 = torch.randn(num_blocks, block_size, N, latent_kv_dim)
kv_cache_1 = torch.randn(num_blocks, block_size, N, D)
kv_cache = [kv_cache_0, kv_cache_1]
prefix_out = torch.randn(S, N, 128)
prefix_lse = torch.randn(S, N)
self.impl.kv_b_proj.return_value = (torch.randn(8, N, VD + AND), )
chunk_ctx = MagicMock()
chunk_ctx.seq_tot = [8]
chunk_ctx.chunk_seq_lens = [torch.tensor([8])]
chunk_ctx.chunk_seq_lens_npu = [torch.tensor([8])]
chunk_ctx.starts = [torch.tensor([0])]
prefill_meta = MagicMock()
prefill_meta.chunked_context = chunk_ctx
prefill_meta.query_lens = [8]
prefill_meta.block_table = torch.randint(0, 100, (S, 4))
meta = MagicMock()
meta.prefill = prefill_meta
out, lse = self.impl._compute_prefill_context(query, kv_cache, 32,
meta, prefix_out,
prefix_lse)
mock_load.assert_called_once()
mock_ring.assert_called_once()
self.assertEqual(out.shape, prefix_out.shape)
self.assertEqual(lse.shape, prefix_lse.shape)
@patch("torch_npu.npu_kv_rmsnorm_rope_cache")
def test_exec_kv(self, mock_kv_cache):
batch_size = 2
hidden = torch.randn(batch_size, 128)
cos = torch.randn(batch_size, 32)
sin = torch.randn(batch_size, 32)
kv_cache = (torch.randn(
4, 8, self.impl.kv_lora_rank + self.impl.qk_rope_head_dim),
torch.randn(
4, 8,
self.impl.kv_lora_rank + self.impl.qk_rope_head_dim))
slots = torch.arange(batch_size, dtype=torch.long)
proj_out = torch.randn(
batch_size, self.impl.num_kv_heads, 1,
self.impl.kv_lora_rank + self.impl.qk_rope_head_dim)
self.impl.kv_a_proj_with_mqa.return_value = (proj_out, )
mock_kv_cache.return_value = (torch.randn(batch_size,
self.impl.num_kv_heads, 1,
self.impl.qk_rope_head_dim),
torch.randn(batch_size,
self.impl.num_kv_heads, 1,
self.impl.kv_lora_rank),
None, None)
k_pe, k_nope, kv = self.impl.exec_kv(hidden, cos, sin, kv_cache, slots)
self.impl.kv_a_proj_with_mqa.assert_called_once_with(hidden)
mock_kv_cache.assert_called_once()
self.assertEqual(k_pe.shape, (batch_size, self.impl.num_kv_heads, 1,
self.impl.qk_rope_head_dim))
self.assertEqual(
k_nope.shape,
(batch_size, self.impl.num_kv_heads, 1, self.impl.kv_lora_rank))
self.assertEqual(kv.shape,
(batch_size, self.impl.num_kv_heads, 1,
self.impl.kv_lora_rank + self.impl.qk_rope_head_dim))
@patch("torch_npu.npu_kv_rmsnorm_rope_cache")
def test_exec_kv_prefill(self, mock_kv):
B, N, S, H = 2, self.impl.num_kv_heads, 1, 128
hidden_states = torch.randn(B, N, S, H)
cos = torch.randn(B, S, 32)
sin = torch.randn(B, S, 32)
kv_cache = (
torch.randn(100, 8,
self.impl.kv_lora_rank + self.impl.qk_rope_head_dim),
torch.randn(100, 8,
self.impl.kv_lora_rank + self.impl.qk_rope_head_dim),
)
slots = torch.arange(B * S, dtype=torch.long)
proj_out = torch.randn(
B, N, S, self.impl.kv_lora_rank + self.impl.qk_rope_head_dim)
self.impl.kv_a_proj_with_mqa.return_value = (proj_out, )
mock_kv.return_value = (None, None,
torch.randn(B, self.impl.num_kv_heads, S,
self.impl.qk_rope_head_dim),
torch.randn(B, self.impl.num_kv_heads, S,
self.impl.kv_lora_rank))
k_pe, k_nope = self.impl.exec_kv_prefill(hidden_states, cos, sin,
kv_cache, slots)
self.impl.kv_a_proj_with_mqa.assert_called_once_with(hidden_states)
mock_kv.assert_called_once()
self.assertEqual(
k_pe.shape,
(B, self.impl.num_kv_heads, S, self.impl.qk_rope_head_dim))
self.assertEqual(
k_nope.shape,
(B, self.impl.num_kv_heads, S, self.impl.kv_lora_rank))
@patch("torch_npu.npu_interleave_rope")
def test_rope_single(self, mock_rope):
B, N, D = 2, 16, 1024
x = torch.randn(B, N, D)
cos = torch.randn(B, N, 1, D)
sin = torch.randn(B, N, 1, D)
mock_rope.return_value = x.view(B, N, 1, D)
result = self.impl.rope_single(x, cos, sin)
self.assertEqual(result.shape[0], B)
self.assertEqual(result.shape[1], N)
self.assertEqual(result.shape[2], D)
mock_rope.assert_called_once()
@patch(
"vllm_ascend.torchair.torchair_mla.AscendMLATorchairImpl._v_up_proj_and_o_proj"
)
@patch("torch_npu._npu_paged_attention_mla")
def test_forward_decode_without_graph(self, mock_page_attention_mla,
mock_up_proj):
self.impl.running_in_graph = False
self.impl.running_chunkprefilll_with_torchair = False
num_tokens = 100
num_blocks = 256
block_size = 4
q_nope = torch.randn(num_tokens, self.impl.num_heads,
self.impl.qk_nope_head_dim)
q_pe = torch.randn(num_tokens, self.impl.num_heads,
self.impl.qk_rope_head_dim)
kv_c_and_k_pe_cache = torch.randn(num_blocks, block_size,
self.impl.num_heads,
self.impl.kv_lora_rank)
metadata = MagicMock()
metadata.decode = MagicMock()
metadata.decode.block_table = MagicMock()
metadata.decode.seq_lens = 10
mock_page_attention_mla.return_value = torch.randn(
num_tokens, self.impl.num_heads, self.impl.kv_lora_rank)
mock_up_proj.return_value = torch.randn(num_tokens,
self.impl.num_heads,
self.impl.v_head_dim)
result = self.impl._forward_decode(q_nope, q_pe, None, None,
kv_c_and_k_pe_cache, metadata)
self.assertEqual(result.shape[0], num_tokens)
self.assertEqual(result.shape[1], self.impl.num_heads)
self.assertEqual(result.shape[2], self.impl.v_head_dim)
mock_up_proj.assert_called_once()
mock_page_attention_mla.assert_called_once()
@patch(
"vllm_ascend.torchair.torchair_mla.AscendMLATorchairImpl._forward_prefill"
)
@patch("torch_npu._npu_reshape_and_cache")
def test_forward_without_graph(self, _, mock_forward_prefill):
self.impl.running_in_graph = False
self.impl.torchair_graph_enabled = False
num_tokens = 100
num_blocks = 256
block_size = 4
rotary_emb_return_value = (torch.randn(num_tokens, 16,
self.impl.kv_lora_rank),
torch.randn(0, 1, self.impl.kv_lora_rank))
self.impl.rotary_emb.side_effect = lambda *args, **kwargs: rotary_emb_return_value
self.impl.o_proj.side_effect = lambda *args, **kwargs: torch.randn(
1, num_blocks, 128)
hidden_states_or_q_c = torch.randn(num_tokens, self.impl.q_lora_rank)
hidden_states_or_kv_c_normed = torch.randn(num_tokens,
self.impl.kv_lora_rank)
k_pe = torch.randn(num_tokens, self.impl.qk_rope_head_dim)
kv_cache = (torch.randn(num_blocks, block_size, self.impl.num_heads,
self.impl.kv_lora_rank),
torch.randn(num_blocks, block_size, self.impl.num_heads,
self.impl.qk_rope_head_dim))
output = torch.randn(num_tokens, self.impl.num_heads,
self.impl.v_head_dim)
metadata = MagicMock()
metadata.num_decodes = 0
metadata.num_prefills = num_tokens
mock_forward_prefill.return_value = torch.randn(
0, self.impl.num_heads * self.impl.v_head_dim)
result = self.impl.forward(None, hidden_states_or_q_c,
hidden_states_or_kv_c_normed, k_pe,
kv_cache, metadata, output, False)
self.assertEqual(result.shape[0], num_tokens)

View File

@@ -1,45 +0,0 @@
from unittest.mock import MagicMock, Mock
import pytest
import torch
from pytest_mock import MockerFixture
from vllm.config import VllmConfig
from tests.ut.base import PytestBase
from vllm_ascend.torchair.torchair_model_runner import NPUTorchairModelRunner
class TestNPUTorchairModelRunner(PytestBase):
@pytest.fixture
def setup_npu_torchair_model_runner(self, mocker: MockerFixture):
mocker.patch.object(NPUTorchairModelRunner, "__init__",
lambda self, *args, **kwargs: None)
runner = NPUTorchairModelRunner(Mock(), Mock())
runner.device = torch.device("cpu")
runner.vllm_config = MagicMock(spec=VllmConfig)
runner.speculative_config = MagicMock(
method="mtp",
num_speculative_tokens=4,
disable_padded_drafter_batch=False)
runner.ascend_config = MagicMock(enable_shared_expert_dp=False,
torchair_graph_config=MagicMock(
use_cached_graph=True,
graph_batch_sizes=[1, 2, 4]))
runner.decode_token_per_req = 2
runner.is_kv_consumer = True
runner.max_num_reqs = 100
runner.model_config = MagicMock(hf_config=MagicMock(index_topk=2))
runner.attn_backend = MagicMock(get_builder_cls=lambda: Mock())
return runner
def test_init(self, mocker: MockerFixture,
setup_npu_torchair_model_runner):
runner = setup_npu_torchair_model_runner
assert isinstance(runner, NPUTorchairModelRunner)

View File

@@ -1,78 +0,0 @@
from unittest.mock import MagicMock, Mock
import pytest
import torch
from pytest_mock import MockerFixture
from vllm.config import CacheConfig, VllmConfig
from tests.ut.base import PytestBase
from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer
class TestTorchairMtpProposer(PytestBase):
@pytest.fixture
def setup_torchair_mtp_proposer(self, mocker: MockerFixture):
vllm_config = MagicMock(spec=VllmConfig)
vllm_config.device_config = MagicMock()
vllm_config.device_config.device = torch.device("cpu")
vllm_config.speculative_config = MagicMock()
vllm_config.speculative_config.draft_model_config = MagicMock()
vllm_config.speculative_config.draft_model_config.dtype = torch.float16
vllm_config.speculative_config.method = "mtp"
vllm_config.speculative_config.num_speculative_tokens = 5
vllm_config.load_config = MagicMock()
cache_config = CacheConfig(block_size=16)
vllm_config.cache_config = cache_config
vllm_config.scheduler_config = MagicMock(max_num_batched_tokens=1024,
max_num_seqs=64)
device = torch.device("cpu")
runner = MagicMock()
runner.pcp_size = 1
runner.dcp_size = 1
runner.pcp_rank = 0
runner.max_num_tokens = 1024
runner.max_num_reqs = 10
runner._use_aclgraph.return_value = True
mocker.patch(
"vllm_ascend.torchair.torchair_mtp_proposer.MtpProposer.__init__",
return_value=None)
mock_set_default_dtype = mocker.patch(
'vllm.utils.torch_utils.set_default_torch_dtype')
mock_set_default_dtype.return_value.__enter__.return_value = None
mock_model_loader = MagicMock()
mocker.patch("vllm.model_executor.model_loader.get_model_loader",
return_value=mock_model_loader)
mock_layers = {
"target_attn_layer_1": Mock(),
"draft_attn_layer_2": Mock()
}
mocker.patch("vllm.config.get_layers_from_vllm_config",
return_value=mock_layers)
mock_set_current = mocker.patch("vllm.config.set_current_vllm_config")
mock_set_current.return_value.__enter__.return_value = None
mock_torchair_deepseek_mtp = MagicMock()
mock_torchair_deepseek_mtp.to.return_value = mock_torchair_deepseek_mtp
mocker.patch(
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMTP",
return_value=mock_torchair_deepseek_mtp)
mocker.patch(
"vllm.model_executor.model_loader.utils.process_weights_after_loading"
)
proposer = TorchairMtpProposer(vllm_config, device, runner)
proposer.vllm_config = vllm_config
proposer.device = device
proposer.runner = runner
proposer.speculative_config = vllm_config.speculative_config
proposer.draft_model_config = vllm_config.speculative_config.draft_model_config
proposer.method = vllm_config.speculative_config.method
return proposer, mock_model_loader, mock_torchair_deepseek_mtp
def test_init(self, setup_torchair_mtp_proposer):
proposer, _, _, = setup_torchair_mtp_proposer
assert isinstance(proposer, TorchairMtpProposer)

View File

@@ -1,340 +0,0 @@
from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.torchair.torchair_sfa import (
AscendSFATorchairBackend, AscendSFATorchairDecodeMetadata,
AscendSFATorchairImpl, AscendSFATorchairMetadata,
AscendSFATorchairMetadataBuilder, AscendSFATorchairPrefillMetadata)
class TestAscendSFATorchairBackend(TestBase):
def test_get_name(self):
self.assertEqual(AscendSFATorchairBackend.get_name(),
"ASCEND_SFA_TORCHAIR")
def test_get_builder_cls(self):
self.assertEqual(AscendSFATorchairBackend.get_builder_cls(),
AscendSFATorchairMetadataBuilder)
def test_get_kv_cache_shape(self):
result = AscendSFATorchairBackend.get_kv_cache_shape(2, 4, 8, 128)
self.assertEqual(result, (2, 4, 8, 128))
def test_get_impl_cls(self):
result = AscendSFATorchairBackend.get_impl_cls()
self.assertEqual(result, AscendSFATorchairImpl)
class TestAscendSFATorchairPrefillMetadata(TestBase):
def test_ascend_sfa_prefill_metadata_default(self):
attn_mask = torch.tensor([[1, 0], [1, 1]], dtype=torch.bool)
query_lens = [1, 2]
seq_lens = [2, 2]
context_lens = torch.tensor([1, 2])
input_positions = torch.tensor([0, 1, 0, 1])
query_start_loc = torch.tensor([0, 1, 3])
block_table = torch.tensor([[0, 1], [2, 3]])
max_query_len = 2
max_seq_lens = 2
metadata = AscendSFATorchairPrefillMetadata(
attn_mask=attn_mask,
query_lens=query_lens,
seq_lens=seq_lens,
context_lens=context_lens,
input_positions=input_positions,
query_start_loc=query_start_loc,
block_table=block_table,
max_query_len=max_query_len,
sin=None,
cos=None,
max_seq_lens=max_seq_lens)
self.assertIs(metadata.attn_mask, attn_mask)
self.assertEqual(metadata.query_lens, query_lens)
self.assertEqual(metadata.seq_lens, seq_lens)
self.assertIs(metadata.context_lens, context_lens)
self.assertIs(metadata.input_positions, input_positions)
self.assertIs(metadata.query_start_loc, query_start_loc)
self.assertIs(metadata.block_table, block_table)
self.assertEqual(metadata.max_query_len, max_query_len)
self.assertEqual(metadata.max_seq_lens, max_seq_lens)
self.assertIsNone(metadata.chunked_context)
def test_ascend_sfa_prefill_metadata_with_chunked_context(self):
cu_seq_lens = torch.tensor([0, 2, 4])
starts = torch.tensor([0, 2])
seq_tot = [2, 2]
max_seq_lens = [2, 2]
workspace = torch.randn(2, 4)
chunk_seq_lens = torch.tensor([2, 2])
chunked_context = AscendSFATorchairPrefillMetadata.TorchairChunkedContextMetadata(
cu_seq_lens=cu_seq_lens,
starts=starts,
seq_tot=seq_tot,
max_seq_lens=max_seq_lens,
workspace=workspace,
chunk_seq_lens=chunk_seq_lens)
metadata = AscendSFATorchairPrefillMetadata(
attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool),
query_lens=[1, 2],
seq_lens=[2, 2],
context_lens=torch.tensor([1, 2]),
input_positions=torch.tensor([0, 1, 0, 1]),
query_start_loc=torch.tensor([0, 1, 3]),
block_table=torch.tensor([[0, 1], [2, 3]]),
max_query_len=2,
max_seq_lens=2,
sin=None,
cos=None,
chunked_context=chunked_context)
self.assertIsNotNone(metadata.chunked_context)
self.assertIs(metadata.chunked_context.cu_seq_lens, cu_seq_lens)
self.assertIs(metadata.chunked_context.starts, starts)
self.assertEqual(metadata.chunked_context.seq_tot, seq_tot)
self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens)
self.assertIs(metadata.chunked_context.workspace, workspace)
self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens)
class TestAscendSFATorchairDecodeMetadata(TestBase):
def test_ascend_sfa_decode_metadata_default(self):
input_positions = torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]])
block_table = torch.tensor([[0, 3, 2, 1], [0, 2, 1, 3]])
seq_lens = torch.tensor([[2], [3]])
max_seq_lens = 4
seq_lens_list = [2, 3]
attn_mask = None
metadata = AscendSFATorchairDecodeMetadata(input_positions,
block_table, seq_lens,
max_seq_lens, seq_lens_list,
None, None, attn_mask)
self.assertIs(metadata.input_positions, input_positions)
self.assertIs(metadata.block_table, block_table)
self.assertIs(metadata.seq_lens, seq_lens)
self.assertEqual(metadata.max_seq_lens, max_seq_lens)
self.assertEqual(metadata.seq_lens_list, seq_lens_list)
self.assertIsNone(attn_mask)
class TestAscendSFATorchairMetadata(TestBase):
def test_ascend_sfa_metadata_default(self):
num_actual_tokens = 100
slot_mapping = torch.randn(100, 4, 1024)
query_start_loc = torch.tensor([1, 2, 3, 4])
seq_lens = [30, 50]
block_tables = torch.randint(0, 100, (100, 4))
num_decodes = 4
num_decode_tokens = 8
num_prefills = 8
num_input_tokens = 2
query_lens = None
head_dim = None
attn_mask = None
attn_state = AscendAttentionState.ChunkedPrefill
decode = None
prefill = None
metadata = AscendSFATorchairMetadata(
num_actual_tokens, slot_mapping, query_start_loc, seq_lens,
block_tables, num_decodes, num_decode_tokens, num_prefills,
num_input_tokens, query_lens, head_dim, attn_mask, attn_state,
decode, prefill)
self.assertEqual(metadata.num_actual_tokens, num_actual_tokens)
self.assertIs(metadata.slot_mapping, slot_mapping)
self.assertIs(metadata.query_start_loc, query_start_loc)
self.assertEqual(metadata.seq_lens, seq_lens)
self.assertIs(metadata.block_tables, block_tables)
self.assertEqual(metadata.num_decodes, num_decodes)
self.assertEqual(metadata.num_decode_tokens, num_decode_tokens)
self.assertEqual(metadata.num_prefills, num_prefills)
self.assertEqual(metadata.num_input_tokens, num_input_tokens)
self.assertEqual(metadata.query_lens, query_lens)
self.assertEqual(metadata.head_dim, head_dim)
self.assertEqual(metadata.attn_mask, attn_mask)
self.assertEqual(metadata.attn_state, attn_state)
self.assertEqual(metadata.decode, decode)
self.assertEqual(metadata.prefill, prefill)
class TestAscendSFATorchairMetadataBuilder(TestBase):
def test_ascend_sfa_metadata_builder_default(self):
mock_model_config = MagicMock()
mock_model_config.max_model_len = 1024
mock_model_config.get_head_size.return_value = 64
mock_model_config.dtype = torch.float16
mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config
mock_vllm_config.cache_config = MagicMock(block_size=16)
mock_vllm_config.scheduler_config = MagicMock(
max_num_seqs=4, enable_chunked_prefill=False)
mock_vllm_config.speculative_config = None
mock_device = torch.device('cpu')
ascend_config = MagicMock()
ascend_config.torchair_graph_config = MagicMock()
ascend_config.torchair_graph_config.enabled = True
with patch("vllm_ascend.torchair.torchair_sfa.get_ascend_config",
return_value=ascend_config):
builder = AscendSFATorchairMetadataBuilder(None, None,
mock_vllm_config,
mock_device)
self.assertEqual(builder.block_size,
mock_vllm_config.cache_config.block_size)
self.assertEqual(
builder.chunked_prefill_enabled,
mock_vllm_config.scheduler_config.enable_chunked_prefill)
self.assertEqual(builder.torchair_graph_enabled, True)
self.assertEqual(builder.max_blocks, (mock_vllm_config.model_config.max_model_len +
mock_vllm_config.cache_config.block_size - 1) \
// mock_vllm_config.cache_config.block_size)
@patch("vllm_ascend.torchair.torchair_sfa.get_ascend_config")
def test_reorder_batch_with_torchair_graph(self, ascend_config):
mock_model_config = MagicMock()
mock_model_config.max_model_len = 1024
mock_model_config.get_head_size.return_value = 64
mock_model_config.dtype = torch.float16
mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config
mock_vllm_config.cache_config = MagicMock(block_size=16)
mock_vllm_config.scheduler_config = MagicMock(
max_num_seqs=4, enable_chunked_prefill=False)
mock_vllm_config.speculative_config = None
mock_device = torch.device('cpu')
ascend_config.torchair_graph_config = MagicMock()
ascend_config.torchair_graph_config.enabled = True
builder = AscendSFATorchairMetadataBuilder(None, None,
mock_vllm_config,
mock_device)
input_batch = MagicMock()
input_batch.req_ids = [0, 1, 2, 3]
scheduler_output = MagicMock()
scheduler_output.num_scheduled_tokens = {0: 2, 1: 1, 2: 3, 3: 1}
scheduler_output.scheduled_spec_decode_tokens = {
0: [1],
1: [],
2: [1, 1],
3: []
}
input_batch.swap_states = MagicMock()
modified = builder.reorder_batch(input_batch, scheduler_output)
self.assertFalse(modified)
input_batch.swap_states.assert_not_called()
@patch("vllm_ascend.torchair.torchair_sfa.get_ascend_config")
def test_get_graph_runner_block_tables_normal(self, mock_ascend_config):
ascend_config = MagicMock()
mock_ascend_config.return_value = ascend_config
ascend_config.torchair_graph_config.enabled = False
mock_model_config = MagicMock()
mock_model_config.max_model_len = 1024
mock_model_config.get_head_size.return_value = 64
mock_model_config.dtype = torch.float16
mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config
mock_vllm_config.cache_config = MagicMock(block_size=16)
mock_vllm_config.scheduler_config = MagicMock(
max_num_seqs=4, enable_chunked_prefill=False)
mock_vllm_config.speculative_config = None
mock_device = torch.device('cpu')
builder = AscendSFATorchairMetadataBuilder(None, None,
mock_vllm_config,
mock_device)
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
result = builder._get_graph_runner_block_tables(3, block_tables)
self.assertEqual(result.shape[0], 3)
self.assertEqual(result.shape[1], 64)
self.assertTrue(torch.equal(result[:, :10], block_tables))
@patch("vllm_ascend.torchair.torchair_sfa.get_ascend_config")
def test_ge_graph_runner_block_tables_truncated(self, mock_ascend_config):
ascend_config = MagicMock()
mock_ascend_config.return_value = ascend_config
ascend_config.torchair_graph_config.enabled = False
mock_model_config = MagicMock()
mock_model_config.max_model_len = 1024
mock_model_config.get_head_size.return_value = 64
mock_model_config.dtype = torch.float16
mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config
mock_vllm_config.cache_config = MagicMock(block_size=16)
mock_vllm_config.scheduler_config = MagicMock(
max_num_seqs=4, enable_chunked_prefill=False)
mock_vllm_config.speculative_config = None
mock_device = torch.device('cpu')
builder = AscendSFATorchairMetadataBuilder(None, None,
mock_vllm_config,
mock_device)
builder.max_blocks = 4
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
result = builder._get_graph_runner_block_tables(3, block_tables)
self.assertEqual(result.shape[0], 3)
self.assertEqual(result.shape[1], 4)
self.assertTrue(torch.equal(result, block_tables[:, :4]))
@patch("vllm_ascend.torchair.torchair_sfa.get_ascend_config")
def test_get_graph_runner_block_tables_from_numpy(self,
mock_ascend_config):
ascend_config = MagicMock()
mock_ascend_config.return_value = ascend_config
ascend_config.torchair_graph_config.enabled = False
mock_model_config = MagicMock()
mock_model_config.max_model_len = 1024
mock_model_config.get_head_size.return_value = 64
mock_model_config.dtype = torch.float16
mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config
mock_vllm_config.cache_config = MagicMock(block_size=16)
mock_vllm_config.scheduler_config = MagicMock(
max_num_seqs=4, enable_chunked_prefill=False)
mock_vllm_config.speculative_config = None
mock_device = torch.device('cpu')
builder = AscendSFATorchairMetadataBuilder(None, None,
mock_vllm_config,
mock_device)
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
result = builder._get_graph_runner_block_tables(3, block_tables)
self.assertEqual(result.shape[0], 3)
self.assertEqual(result.shape[1], 64)
self.assertTrue(torch.equal(result[:, :10], block_tables))

View File

@@ -1,111 +0,0 @@
from unittest.mock import MagicMock, patch
import torch
from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig
from tests.ut.base import TestBase
init_cache_hf_modules_path = "vllm.utils.import_utils.init_cached_hf_modules"
class TestNPUTorchairWorker(TestBase):
def setUp(self):
self.cache_config_mock = MagicMock(spec=CacheConfig)
self.cache_config_mock.cache_type = "auto"
self.model_config_mock = MagicMock(spec=ModelConfig)
self.model_config_mock.dtype = torch.float16
self.model_config_mock.trust_remote_code = False
self.hf_config_mock = MagicMock()
self.hf_config_mock.model_type = "test_model"
if hasattr(self.hf_config_mock, 'index_topk'):
delattr(self.hf_config_mock, 'index_topk')
self.model_config_mock.hf_config = self.hf_config_mock
self.parallel_config_mock = MagicMock(spec=ParallelConfig)
self.vllm_config_mock = MagicMock(spec=VllmConfig)
self.vllm_config_mock.cache_config = self.cache_config_mock
self.vllm_config_mock.model_config = self.model_config_mock
self.vllm_config_mock.parallel_config = self.parallel_config_mock
self.vllm_config_mock.additional_config = None
self.vllm_config_mock.load_config = None
self.vllm_config_mock.scheduler_config = None
self.vllm_config_mock.device_config = None
self.vllm_config_mock.compilation_config = None
self.local_rank = 0
self.rank = 0
self.distributed_init_method = "tcp://localhost:12345"
self.is_driver_worker = False
@patch(
"vllm_ascend.worker.worker_v1.NPUWorker._init_worker_distributed_environment"
)
@patch("vllm_ascend.worker.worker_v1.NPUPlatform")
def test_init_device(self, mock_platform, mock_init_dist_env):
from vllm_ascend.worker.worker_v1 import NPUWorker
mock_platform.mem_get_info.return_value = (1000, 2000)
with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None):
worker = NPUWorker()
worker.local_rank = 1
worker.model_config = MagicMock()
worker.model_config.seed = 42
worker.vllm_config = MagicMock()
worker.parallel_config = MagicMock()
worker.parallel_config.local_world_size = 0
worker.parallel_config.data_parallel_size = 1
result = worker._init_device()
mock_platform.set_device.assert_called_once()
call_args = mock_platform.set_device.call_args[0][0]
self.assertEqual(str(call_args), "npu:1")
mock_platform.empty_cache.assert_called_once()
mock_platform.seed_everything.assert_called_once_with(42)
mock_platform.mem_get_info.assert_called_once()
mock_init_dist_env.assert_called_once()
self.assertEqual(str(result), "npu:1")
self.assertEqual(worker.init_npu_memory, 1000)
@patch(
"vllm_ascend.worker.worker_v1.NPUWorker._init_worker_distributed_environment"
)
@patch("vllm_ascend.worker.worker_v1.NPUPlatform")
def test_init_device_torchair_worker(self, mock_platform,
mock_init_dist_env):
from vllm_ascend.torchair.torchair_worker import NPUTorchairWorker
mock_platform.mem_get_info.return_value = (1000, 2000)
with patch.object(NPUTorchairWorker, "__init__",
lambda x, **kwargs: None):
worker = NPUTorchairWorker()
worker.local_rank = 1
worker.model_config = MagicMock()
worker.model_config.seed = 42
worker.vllm_config = MagicMock()
worker.parallel_config = MagicMock()
worker.parallel_config.local_world_size = 0
worker.parallel_config.data_parallel_size = 1
result = worker._init_device()
mock_platform.set_device.assert_called_once()
call_args = mock_platform.set_device.call_args[0][0]
self.assertEqual(str(call_args), "npu:1")
mock_platform.empty_cache.assert_called_once()
mock_platform.seed_everything.assert_called_once_with(42)
mock_platform.mem_get_info.assert_called_once()
mock_init_dist_env.assert_called_once()
self.assertEqual(str(result), "npu:1")
self.assertEqual(worker.init_npu_memory, 1000)

View File

@@ -1,164 +0,0 @@
import os
from concurrent.futures import ThreadPoolExecutor
from unittest import mock
from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.torchair import utils
class TestTorchairUtils(TestBase):
def test_get_torchair_current_work_dir(self):
cache_dir = utils.TORCHAIR_CACHE_DIR
work_dir = utils._get_torchair_current_work_dir()
self.assertEqual(cache_dir, work_dir)
work_dir = utils._get_torchair_current_work_dir("test")
self.assertEqual(os.path.join(cache_dir, "test"), work_dir)
def test_torchair_cache_dir(self):
utils.write_kv_cache_bytes_to_file(0, 100)
self.assertTrue(utils.check_torchair_cache_exist(),
"Create torchair cache dir failed")
self.assertTrue(utils.check_kv_cache_bytes_cache_exist(),
"Create kv cache bytes cache dir failed")
kv_cache_bytes = utils.read_kv_cache_bytes_from_file(0)
self.assertEqual(100, kv_cache_bytes)
utils.delete_torchair_cache_file()
self.assertFalse(utils.check_torchair_cache_exist(),
"Delete torchair cache dir failed")
self.assertFalse(utils.check_kv_cache_bytes_cache_exist(),
"Delete kv cache bytes cache dir failed")
def test_torchair_cache_dir_multiple_ranks(self):
ranks = [0, 1, 2, 3]
values = [100, 200, 300, 400]
with ThreadPoolExecutor() as executor:
executor.map(utils.write_kv_cache_bytes_to_file, ranks, values)
for rank, expected in zip(ranks, values):
self.assertEqual(expected,
utils.read_kv_cache_bytes_from_file(rank))
utils.delete_torchair_cache_file()
self.assertFalse(utils.check_torchair_cache_exist(),
"Delete torchair cache dir failed")
self.assertFalse(utils.check_kv_cache_bytes_cache_exist(),
"Delete kv cache bytes cache dir failed")
def test_delete_torchair_cache_file_multiple_times(self):
utils.write_kv_cache_bytes_to_file(0, 100)
utils.delete_torchair_cache_file()
for i in range(5):
try:
utils.delete_torchair_cache_file()
except FileNotFoundError:
self.fail(
f"Unexpected FileNotFoundError on delete call #{i+2}")
@patch('vllm.ModelRegistry')
def test_register_torchair_model(self, mock_model_registry):
mock_registry = MagicMock()
mock_model_registry.return_value = mock_registry
utils.register_torchair_model()
self.assertEqual(mock_model_registry.register_model.call_count, 7)
call_args_list = mock_model_registry.register_model.call_args_list
expected_registrations = [
("DeepSeekMTPModel",
"vllm_ascend.torchair.models.torchair_deepseek_mtp:TorchairDeepSeekMTP"
),
("DeepseekV2ForCausalLM",
"vllm_ascend.torchair.models.torchair_deepseek_v2:TorchairDeepseekV2ForCausalLM"
),
("DeepseekV3ForCausalLM",
"vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM"
),
("DeepseekV32ForCausalLM",
"vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM"
),
("Qwen2ForCausalLM",
"vllm_ascend.torchair.models.qwen2:CustomQwen2ForCausalLM"),
("Qwen3MoeForCausalLM",
"vllm_ascend.torchair.models.qwen3_moe:CustomQwen3MoeForCausalLM"
),
("PanguProMoEForCausalLM",
"vllm_ascend.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM"
)
]
for i, (expected_name,
expected_path) in enumerate(expected_registrations):
args, kwargs = call_args_list[i]
self.assertEqual(args[0], expected_name)
self.assertEqual(args[1], expected_path)
@mock.patch('vllm_ascend.torchair.utils.is_enable_nz')
@mock.patch('torch_npu.get_npu_format')
@mock.patch('torch_npu.npu_format_cast')
@mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE',
new=mock.MagicMock)
def test_converting_weight_acl_format_to_nz(self, mock_npu_cast,
mock_get_format, mock_is_nz):
ACL_FORMAT_FRACTAL_NZ = 29
mock_get_format.return_value = 1
mock_npu_cast.return_value = 1
mock_is_nz.return_value = 1
fused_moe = mock.MagicMock()
fused_moe.w13_weight = mock.MagicMock()
fused_moe.w2_weight = mock.MagicMock()
fused_moe.w13_weight.data = torch.randn(128, 256)
fused_moe.w2_weight.data = torch.randn(256, 128)
model = mock.MagicMock()
model.modules.return_value = [fused_moe]
utils.converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ)
self.assertEqual(fused_moe.w13_weight.data, 1)
@mock.patch('torch_npu.get_npu_format')
@mock.patch('torch_npu.npu_format_cast')
@mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE',
new=mock.MagicMock)
def test_converting_weight_acl_format_format_true(self, mock_npu_cast,
mock_get_format):
ACL_FORMAT_FRACTAL_NZ = 29
mock_get_format.return_value = ACL_FORMAT_FRACTAL_NZ
mock_npu_cast.return_value = 1
fused_moe = mock.MagicMock()
fused_moe.w13_weight = mock.MagicMock()
fused_moe.w2_weight = mock.MagicMock()
fused_moe.w13_weight.data = torch.randn(128, 256)
fused_moe.w2_weight.data = torch.randn(256, 128)
model = mock.MagicMock()
model.modules.return_value = [fused_moe]
utils.converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ)
mock_npu_cast.assert_not_called()
@mock.patch('vllm_ascend.torchair.utils.is_enable_nz')
@mock.patch('torch_npu.get_npu_format')
@mock.patch('torch_npu.npu_format_cast')
@mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE',
new=mock.MagicMock)
def test_converting_weight_acl_format_no_nz(self, mock_npu_cast,
mock_get_format, mock_is_nz):
ACL_FORMAT_FRACTAL_NZ = 29
mock_get_format.return_value = 1
mock_npu_cast.return_value = 1
mock_is_nz.return_value = 0
fused_moe = mock.MagicMock()
fused_moe.w13_weight = mock.MagicMock()
fused_moe.w2_weight = mock.MagicMock()
fused_moe.w13_weight.data = torch.randn(128, 256)
fused_moe.w2_weight.data = torch.randn(256, 128)
model = mock.MagicMock()
model.modules.return_value = [fused_moe]
utils.converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ)
mock_npu_cast.assert_not_called()

View File

@@ -18,15 +18,6 @@ from uuid import uuid4
from vllm.logger import logger
TORCHAIR_MODEL_LIST = ["deepseek", "pangu", "kimi_k2", "qwen"]
def _check_torchair_supported(model_type: str):
for supported_model in TORCHAIR_MODEL_LIST:
if supported_model in model_type.lower():
return True
return False
def check_kv_extra_config(vllm_config):
@@ -66,11 +57,6 @@ class AscendConfig:
def __init__(self, vllm_config):
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
torchair_graph_config = additional_config.get("torchair_graph_config",
{})
self.torchair_graph_config = TorchairGraphConfig(
torchair_graph_config, vllm_config, additional_config)
xlite_graph_config = additional_config.get("xlite_graph_config", {})
self.xlite_graph_config = XliteGraphConfig(xlite_graph_config,
@@ -107,8 +93,8 @@ class AscendConfig:
self.chunked_prefill_for_mla = additional_config.get(
"chunked_prefill_for_mla", False)
self.enable_shared_expert_dp = additional_config.get(
"enable_shared_expert_dp", False
) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel
"enable_shared_expert_dp",
False) and vllm_config.parallel_config.enable_expert_parallel
if self.enable_shared_expert_dp:
from vllm_ascend.utils import enable_sp
assert enable_sp(vllm_config=vllm_config,
@@ -215,86 +201,6 @@ class AscendCompilationConfig:
# Add more compilation related configs here as needed
class TorchairGraphConfig:
"""
Configuration Object for torchair_graph_config from additional_config
"""
def __init__(self, torchair_graph_config, vllm_config, additional_config):
self.enabled = torchair_graph_config.get("enabled", False)
self.mode = torchair_graph_config.get("mode", '')
self.use_cached_graph = torchair_graph_config.get(
"use_cached_graph", False)
self.use_cached_kv_cache_bytes = torchair_graph_config.get(
"use_cached_kv_cache_bytes", False)
self.graph_batch_sizes = torchair_graph_config.get(
"graph_batch_sizes", [])
self.graph_batch_sizes_init = torchair_graph_config.get(
"graph_batch_sizes_init", False)
self.enable_multistream_mla = torchair_graph_config.get(
"enable_multistream_mla", False)
self.enable_view_optimize = torchair_graph_config.get(
"enable_view_optimize", True)
self.enable_frozen_parameter = torchair_graph_config.get(
"enable_frozen_parameter", True)
self.enable_kv_nz = torchair_graph_config.get("enable_kv_nz", False)
self.enable_super_kernel = torchair_graph_config.get(
"enable_super_kernel", False)
if not isinstance(self.graph_batch_sizes, list):
raise TypeError("graph_batch_sizes must be list[int]")
if self.graph_batch_sizes_init and len(self.graph_batch_sizes) > 0:
raise ValueError(
"graph_batch_sizes_init is only valid when graph_batch_sizes is empty"
)
if not self.enabled:
if self.mode:
raise RuntimeError(
"mode is valid only when Torchair graph mode is enabled")
if self.use_cached_graph:
raise RuntimeError(
"use_cached_graph is valid only when Torchair graph mode is enabled"
)
if self.use_cached_kv_cache_bytes:
raise RuntimeError(
"use_cached_kv_cache_bytes is valid only when Torchair graph mode is enabled"
)
if self.graph_batch_sizes:
raise RuntimeError(
"graph_batch_sizes is valid only when Torchair graph mode is enabled"
)
if self.graph_batch_sizes_init:
raise RuntimeError(
"graph_batch_sizes_init is valid only when Torchair graph mode is enabled"
)
if self.enable_multistream_mla:
raise RuntimeError(
"enable_multistream_mla is valid only when Torchair graph mode is enabled"
)
if self.enable_kv_nz:
raise RuntimeError(
"enable_kv_nz is valid only when Torchair graph mode is enabled"
)
if self.enable_super_kernel:
raise RuntimeError(
"enable_super_kernel is valid only when Torchair graph mode is enabled"
)
if self.enable_super_kernel:
if vllm_config.parallel_config.tensor_parallel_size != 1:
raise RuntimeError(
"enable_super_kernel is valid only when tensor_parallel_size is 1"
)
if not additional_config.get("multistream_overlap_shared_expert",
False):
raise RuntimeError(
"enable_super_kernel is valid only when multistream_overlap_shared_expert is enabled"
)
if self.use_cached_kv_cache_bytes and not self.use_cached_graph:
raise RuntimeError(
"use_cached_kv_cache_bytes is valid only when Torchair graph mode and use_cached_graph are enabled"
)
class XliteGraphConfig:
"""
Configuration Object for xlite_graph_config from additional_config
@@ -382,39 +288,7 @@ def get_ascend_config():
def check_ascend_config(vllm_config, enforce_eager):
ascend_config = get_ascend_config()
# for eager mode
if enforce_eager:
# torchair_graph cannot be enabled with eager mode.
if ascend_config.torchair_graph_config.enabled:
raise RuntimeError(
"Can't enable graph mode and eager mode at the same time. Please set `enforce_eager=False` if you attempt to enable NPU graph mode."
)
# for graph mode
else:
# torchair_graph case
if ascend_config.torchair_graph_config.enabled:
# torchair_graph is supported for deepseek/pangu/qwen model only.
if vllm_config.model_config:
model_type = vllm_config.model_config.hf_config.model_type
if not _check_torchair_supported(model_type):
raise NotImplementedError(
"Torchair graph mode only works with following model types:"
f"{TORCHAIR_MODEL_LIST}.")
if ascend_config.enable_shared_expert_dp:
logger.warning(
"enable_shared_expert_dp is not supported for torchair graph mode currently, "
"it has been disabled automatically.")
# aclgraph case
else:
if ascend_config.ascend_compilation_config.enable_quantization_fusion:
logger.info(
"Quantization fusion enabled! op fusion on quantization are expected. "
)
if vllm_config.model_config:
model_type = vllm_config.model_config.hf_config.model_type
if "qwen" not in model_type:
logger.warning(
"ACL Graph is currently experimental. Please "
"raise an issue on https://github.com/vllm-project/vllm-ascend/issues"
" if you encourage any Error")
if ascend_config.ascend_compilation_config.enable_quantization_fusion:
logger.info(
"Quantization fusion enabled! op fusion on quantization are expected. "
)

View File

@@ -857,7 +857,6 @@ class AscendMLAImpl(MLAAttentionImpl):
ascend_config = get_ascend_config()
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
self.enable_prefetch = ascend_config.weight_prefetch_config.enabled
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
vllm_config = get_current_vllm_config()
self.ring_mla_mask_size = 512
@@ -1248,7 +1247,7 @@ class AscendMLAImpl(MLAAttentionImpl):
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
kv_no_split = kv_no_split.view(
B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
cache_mode = "PA"
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
kv_no_split,
self.kv_a_layernorm.weight,
@@ -1276,7 +1275,7 @@ class AscendMLAImpl(MLAAttentionImpl):
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
kv_no_split = kv_no_split.view(
B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
cache_mode = "PA"
_, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache(
kv_no_split,
self.kv_a_layernorm.weight,
@@ -1318,18 +1317,11 @@ class AscendMLAImpl(MLAAttentionImpl):
# shape of knope/k_pe for npu graph mode should be:
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
actual_seq_lengths = None
if self.enable_kv_nz:
k_nope = k_nope.view(-1, self.num_kv_heads,
self.kv_lora_rank // 16, block_size, 16)
k_pe = k_pe.view(-1, self.num_kv_heads,
self.qk_rope_head_dim // 16, block_size, 16)
input_layout = "BSND"
else:
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
self.kv_lora_rank)
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
self.qk_rope_head_dim)
input_layout = "BNSD"
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
self.kv_lora_rank)
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
self.qk_rope_head_dim)
input_layout = "BNSD"
if attn_metadata.attn_state in [
AscendAttentionState.SpecDecoding,
@@ -1346,14 +1338,9 @@ class AscendMLAImpl(MLAAttentionImpl):
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
actual_seq_lengths = decode_meta.actual_seq_lengths_q
else:
if self.enable_kv_nz:
q_nope = q_nope.view(num_tokens, 1, self.num_heads,
-1).contiguous()
q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1)
else:
q_nope = q_nope.view(num_tokens, self.num_heads, 1,
-1).contiguous()
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
q_nope = q_nope.view(num_tokens, self.num_heads, 1,
-1).contiguous()
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
sparse_mode = 0
spec_attn_mask = None

View File

@@ -345,7 +345,6 @@ class AscendSFAImpl(MLAAttentionImpl):
ascend_config = get_ascend_config()
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
self.enable_prefetch = ascend_config.weight_prefetch_config.enabled
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO
assert self.indexer is not None, "Indexer is required for DSA."
@@ -534,7 +533,7 @@ class AscendSFAImpl(MLAAttentionImpl):
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
kv_no_split = kv_no_split.view(
B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
cache_mode = "PA"
if self.enable_sfa_cp:
assert slots_cp is not None

View File

@@ -453,7 +453,6 @@ class KVCacheRecvingThread(threading.Thread):
def _cat_kv_cache(self, block_ids: list[list[int]]):
# Get necessary parameters
k_cache = list(self.kv_caches.values())[0][0]
kv_shape = k_cache.shape
dtype = k_cache.dtype
device = k_cache.device
head_dim = self.model_config.hf_config.head_dim
@@ -494,13 +493,6 @@ class KVCacheRecvingThread(threading.Thread):
# Process each layer in the KV cache
for _, (k_cache_layer, v_cache_layer) in self.kv_caches.items():
if len(
k_cache_layer.shape
) == 3: # kv shape in torchair model is [num_block, block_size, num_kv_head*head_dim]
k_cache_layer = k_cache_layer.view(kv_shape[0], kv_shape[1],
num_kv_head, head_dim)
v_cache_layer = v_cache_layer.view(kv_shape[0], kv_shape[1],
num_kv_head, head_dim)
# Load cache data into buffers
torch_npu.atb.npu_paged_cache_load(
k_cache_layer,

View File

@@ -99,8 +99,6 @@ class MoECommMethod(ABC):
w2_scale: Optional[list[torch.Tensor]] = None,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
# For TorchAir graph
is_torchair: bool = False,
# For Cube/Vector parallel
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
@@ -283,8 +281,6 @@ class FusedAlltoAllCommImpl(MoECommMethod):
w2_scale: Optional[torch.Tensor] = None,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
# For TorchAir graph
is_torchair: bool = False,
# For Cube/Vector parallel
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,

View File

@@ -26,10 +26,7 @@ from vllm.platforms import Platform, PlatformEnum
# todo: please remove it when solve cuda hard code in vllm
os.environ["VLLM_DISABLE_SHARED_EXPERTS_STREAM"] = "1"
from vllm_ascend.ascend_config import (check_ascend_config, get_ascend_config,
init_ascend_config)
from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
delete_torchair_cache_file)
from vllm_ascend.ascend_config import check_ascend_config, init_ascend_config
from vllm_ascend.utils import refresh_block_size
# isort: off
@@ -204,25 +201,6 @@ class NPUPlatform(Platform):
compilation_config.mode)
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
# set CUDAGraphMode to None when torchair is enabled, no mather what compilation_config.level is.
if ascend_config.torchair_graph_config.enabled:
logger.info(
"Torchair compilation enabled on NPU. Setting CUDAGraphMode to NONE"
)
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
# Note: We delete the torchair cache folder here to prevent runtime issues caused by dimension
# mismatches or configuration inconsistencies when users reuse cached computation graphs. Though
# this will increase graph compilation duration, it significantly enhances robustness and decreases
# graph launching time during inference.
if check_torchair_cache_exist(
) and not ascend_config.torchair_graph_config.use_cached_kv_cache_bytes:
logger.warning(
"Torchair cache folder is deleted here to prevent runtime issues caused by dimension "
"mismatches or configuration inconsistencies when users reuse cached computation graphs. "
"In order to decrease torchair graph compilation time, users can enable both use_cached_graph "
"and use_cached_kv_cache_bytes in torchair_graph_config.")
delete_torchair_cache_file()
# set cudaprah sizes before extending `compilation_config.splitting_ops`
vllm_config._set_cudagraph_sizes()
# There are cases where default cudagraph_capture_sizes are not friendly
@@ -303,9 +281,7 @@ class NPUPlatform(Platform):
if parallel_config and parallel_config.worker_cls == "auto":
# TODO: this is a tricky way to disable `use_sequence_parallel_moe` in vllm.
parallel_config.all2all_backend = "flashinfer_all2allv"
if ascend_config.torchair_graph_config.enabled:
parallel_config.worker_cls = "vllm_ascend.torchair.torchair_worker.NPUTorchairWorker"
elif ascend_config.xlite_graph_config.enabled:
if ascend_config.xlite_graph_config.enabled:
logger.info(
"Euler Xlite enabled. See: https://gitee.com/openeuler/GVirt/tree/master/xlite"
)
@@ -390,29 +366,14 @@ class NPUPlatform(Platform):
use_sparse=False,
attn_type: str | None = None,
):
ascend_config = get_ascend_config()
if use_mla and ascend_config.enable_shared_expert_dp:
if use_mla and use_sparse:
return "vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend"
use_torchair = ascend_config.torchair_graph_config.enabled
# choose attention backend based on use_mla and use_torchair
# choose attention backend based on use_mla
backend_map = {
(True, False, True):
"vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend",
(True, False, False):
"vllm_ascend.attention.mla_v1.AscendMLABackend",
(False, False, True):
"vllm_ascend.torchair.torchair_attention.AscendAttentionTorchairBackend",
(False, False, False):
(True, False): "vllm_ascend.attention.mla_v1.AscendMLABackend",
(False, False):
"vllm_ascend.attention.attention_v1.AscendAttentionBackend",
(True, True, False):
"vllm_ascend.attention.sfa_v1.AscendSFABackend",
(True, True, True):
"vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend",
(True, True): "vllm_ascend.attention.sfa_v1.AscendSFABackend",
}
return backend_map[(use_mla, use_sparse, use_torchair)]
return backend_map[(use_mla, use_sparse)]
@classmethod
def get_punica_wrapper(cls) -> str:

View File

@@ -111,10 +111,9 @@ class AscendW8A8DynamicFusedMoEMethod:
vllm_config = get_current_vllm_config()
ascend_config = get_ascend_config()
self.use_aclgraph = (
vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE
and not vllm_config.model_config.enforce_eager
and not ascend_config.torchair_graph_config.enabled)
self.use_aclgraph = (vllm_config.compilation_config.mode
== CompilationMode.VLLM_COMPILE
and not vllm_config.model_config.enforce_eager)
self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
self.in_dtype = vllm_config.model_config.dtype

View File

@@ -20,21 +20,14 @@ from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.spec_decode.ngram_proposer import NgramProposer
from vllm_ascend.spec_decode.suffix_proposer import SuffixDecodingProposer
from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer
def get_spec_decode_method(method,
vllm_config,
device,
runner,
is_torchair_graph=False):
def get_spec_decode_method(method, vllm_config, device, runner):
if method == "ngram":
return NgramProposer(vllm_config, device, runner)
elif method in ("eagle", "eagle3"):
return EagleProposer(vllm_config, device, runner)
elif method == "mtp":
if is_torchair_graph:
return TorchairMtpProposer(vllm_config, device, runner)
return MtpProposer(vllm_config, device, runner)
elif method == 'suffix':
return SuffixDecodingProposer(vllm_config, device, runner)

View File

@@ -1,357 +0,0 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
from collections.abc import Iterable
from typing import Any, List, Optional, Union
import torch
import torch.nn.functional as F
import vllm
from torch import nn
from transformers import Qwen2Config
from vllm.attention.backends.abstract import AttentionMetadata, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, tensor_model_parallel_all_gather,
tensor_model_parallel_reduce_scatter)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
from vllm.model_executor.models.qwen2 import Qwen2Attention # noqa: F401
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM # noqa: F401
from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Model
from vllm.model_executor.models.utils import (AutoWeightsLoader,
PPMissingLayer, maybe_prefix)
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import set_default_rope_theta
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
def all_gather_and_maybe_unpad(
hidden_states: torch.Tensor,
pad_size: int,
) -> torch.Tensor:
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
if pad_size > 0:
return hidden_states[:-pad_size, :]
return hidden_states
def maybe_pad_and_reduce_scatter(
hidden_states: torch.Tensor,
pad_size: int,
) -> torch.Tensor:
if pad_size > 0:
hidden_states = F.pad(hidden_states, (0, 0, 0, pad_size))
hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, 0)
return hidden_states
class CustomQwen2Attention(Qwen2Attention):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_parameters: Optional[dict[str, Any]] = None,
max_position: int = 4096 * 32,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
) -> None:
super().__init__(
hidden_size=hidden_size,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
max_position=max_position,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
attn_type=attn_type,
dual_chunk_attention_config=dual_chunk_attention_config,
rope_parameters=rope_parameters)
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.torchair_graph_enabled and attn_metadata is not None and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
q, k = self.rotary_emb(positions,
q,
k,
is_prefill=False,
is_qwen_torchair=True)
forward_kwargs = {}
output_shape = q.shape
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
forward_kwargs['output'] = output
attn_output = self.attn.impl.forward(self.attn,
q,
k,
v,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
**forward_kwargs)
output, _ = self.o_proj(attn_output)
return output
else:
if type(self.rotary_emb) is RotaryEmbedding:
q, k = self.rotary_emb(positions, q, k, is_qwen_torchair=True)
else:
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class CustomQwen2DecoderLayer(nn.Module):
def __init__(
self,
config: Qwen2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
set_default_rope_theta(config, default_theta=1000000)
dual_chunk_attention_config = getattr(config,
"dual_chunk_attention_config",
None)
# By default, Qwen2 uses causal attention as it is a decoder-only model.
# You can override the HF config with `is_causal=False` to enable
# bidirectional attention, which is used in some embedding models
# (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct)
if getattr(config, "is_causal", True):
attn_type = AttentionType.DECODER
else:
attn_type = AttentionType.ENCODER_ONLY
self.self_attn = CustomQwen2Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_parameters=config.rope_parameters,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
attn_type=attn_type,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.mlp = Qwen2MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
# otherwise (seq_len, ).
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
})
class CustomQwen2Model(Qwen2Model):
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
decoder_layer_type: type[nn.Module] = CustomQwen2DecoderLayer):
super().__init__(vllm_config=vllm_config,
prefix=prefix,
decoder_layer_type=decoder_layer_type)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: Optional[List[torch.Tensor]] = None,
attn_metadata: Optional[AttentionMetadata] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
kv_cache = kv_caches[i - self.start_layer] \
if kv_caches is not None else None
hidden_states, residual = layer(positions,
hidden_states,
residual,
kv_cache=kv_cache,
attn_metadata=attn_metadata)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class CustomQwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
# add `CustomQwen2Model` to init self.model
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = CustomQwen2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "lm_head"))
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: Optional[List[torch.Tensor]] = None,
attn_metadata: Optional[AttentionMetadata] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata=None, # type: ignore
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)
vllm.model_executor.models.qwen2.Qwen2ForCausalLM = CustomQwen2ForCausalLM

View File

@@ -1,527 +0,0 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from vllm/model_executor/models/qwen3_moe.py
# This file is a part of the vllm-ascend project.
from typing import Any, List, Optional, Union
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, CompilationMode, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
get_tp_group)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.models.interfaces import (MixtureOfExperts,
SupportsLoRA, SupportsPP)
from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention,
Qwen3MoeDecoderLayer,
Qwen3MoeForCausalLM,
Qwen3MoeMLP, Qwen3MoeModel,
Qwen3MoeSparseMoeBlock)
from vllm.model_executor.models.utils import (
PPMissingLayer, extract_layer_index,
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
from vllm.sequence import IntermediateTensors
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.torchair.ops.sequence_parallel import (MetadataForPadding,
init_metadata_for_sp)
from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
nn.Module.__init__(self)
self.tp_size = get_tensor_model_parallel_world_size()
if self.tp_size > config.num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_experts}.")
self.gate = ReplicatedLinear(
config.hidden_size,
config.num_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.gate",
)
self.experts = TorchairAscendFusedMoE(
num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
)
self.top_k = config.num_experts_per_tok
self.dp_size = get_dp_group().world_size
self.tp_group = get_tp_group().device_group
self.tp_rank = get_tp_group().rank_in_group
self.ep_group = get_ep_group()
self.params_dtype = torch.get_default_dtype()
def forward(
self,
hidden_states,
attn_metadata=None,
_metadata_for_padding: Optional[MetadataForPadding] = None,
):
if attn_metadata is None:
attn_metadata = get_forward_context().attn_metadata
# when profile runs, force experts to load balanced tokens
# to avoid high memory consumption on a single rank.
enable_force_load_balance = get_forward_context().in_profile_run
is_prefill = get_forward_context().with_prefill
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
is_prefill=is_prefill,
top_k=self.top_k,
enable_force_load_balance=enable_force_load_balance,
shared_experts=None,
_metadata_for_padding=_metadata_for_padding,
)
return hidden_states
class CustomQwen3MoeAttention(Qwen3MoeAttention):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_parameters: dict[str, Any],
max_position_embeddings: int = 8192,
head_dim: Optional[int] = None,
rms_norm_eps: float = 1e-06,
qkv_bias: bool = False,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = head_dim or (hidden_size // self.total_num_heads)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj")
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=rope_parameters,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn")
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
@staticmethod
def normalize_qkv(qkv: torch.Tensor, q_size: int, kv_size: int,
head_dim: int, q_norm, k_norm):
q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1)
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // head_dim, head_dim)
q_by_head = q_norm(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // head_dim, head_dim)
k_by_head = k_norm(k_by_head)
k = k_by_head.view(k.shape)
return q, k, v
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = self.normalize_qkv(qkv, self.q_size, self.kv_size,
self.head_dim, self.q_norm, self.k_norm)
if (self.torchair_graph_enabled and attn_metadata is not None and
attn_metadata.attn_state == AscendAttentionState.DecodeOnly):
q, k = self.rotary_emb(positions,
q,
k,
is_prefill=False,
is_qwen_torchair=True)
forward_kwargs = {}
output_shape = q.shape
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
forward_kwargs['output'] = output
attn_output = self.attn.impl.forward(self.attn,
q,
k,
v,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
**forward_kwargs)
output, _ = self.o_proj(attn_output)
return output
else:
q, k = self.rotary_emb(positions, q, k, is_qwen_torchair=True)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
vllm_config: Optional[VllmConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
self.hidden_size = config.hidden_size
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = CustomQwen3MoeAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_parameters=config.rope_parameters,
max_position_embeddings=max_position_embeddings,
rms_norm_eps=config.rms_norm_eps,
qkv_bias=getattr(config, 'attention_bias', False),
head_dim=getattr(config, 'head_dim', None),
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
# `mlp_only_layers` in the config.
layer_idx = extract_layer_index(prefix)
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
config.mlp_only_layers)
self.use_aclgraph = (vllm_config is not None
and vllm_config.compilation_config.mode
== CompilationMode.VLLM_COMPILE
and not vllm_config.model_config.enforce_eager)
if (layer_idx not in mlp_only_layers) and (
config.num_experts > 0 and
(layer_idx + 1) % config.decoder_sparse_step == 0):
if not self.use_aclgraph:
# FIXME: custom sparse moe block doesn't work with aclgraph.
self.mlp = CustomSparseMoeBlock(config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
else:
self.mlp = Qwen3MoeSparseMoeBlock(vllm_config=vllm_config,
prefix=f"{prefix}.mlp")
else:
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.enable_sequence_parallelism = (
vllm_config.compilation_config.pass_config.enable_sp
if vllm_config is not None else False)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None,
_metadata_for_padding: Optional[MetadataForPadding] = None,
) -> torch.Tensor:
# To prevent precision issues during the decoder phase when only prefilling enables SP
if not self.enable_sequence_parallelism:
self.self_attn.o_proj.reduce_results = True
else:
self.self_attn.o_proj.reduce_results = not _metadata_for_padding.not_dummy_and_is_prefill if _metadata_for_padding is not None else True
# Self Attention
if residual is None:
residual = hidden_states
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
residual = _metadata_for_padding.padding_slice(residual)
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
hidden_states = _metadata_for_padding.padding_aligned_reduce_scatter(
hidden_states)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
if not self.use_aclgraph:
hidden_states = self.mlp(
hidden_states, _metadata_for_padding=_metadata_for_padding)
else:
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile
class CustomQwen3MoeModel(Qwen3MoeModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
eplb_config = parallel_config.eplb_config
self.num_redundant_experts = eplb_config.num_redundant_experts
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.config = config
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
prefix=f"{prefix}.embed_tokens")
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: CustomQwen3MoeDecoderLayer(
config=config,
cache_config=cache_config,
quant_config=quant_config,
vllm_config=vllm_config,
prefix=prefix),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: Optional[List[torch.Tensor]] = None,
attn_metadata: Optional[AttentionMetadata] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
_metadata_for_padding: Optional[MetadataForPadding] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
residual,
kv_caches[i -
self.start_layer] if kv_caches is not None else None,
attn_metadata,
_metadata_for_padding=_metadata_for_padding)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
hidden_states)
return hidden_states
class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
SupportsPP.__init__(self)
SupportsLoRA.__init__(self)
MixtureOfExperts.__init__(self)
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = CustomQwen3MoeModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"))
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sp
# Set MoE hyperparameters
self.expert_weights: list[torch.Tensor] = []
self.moe_layers: list[FusedMoE] = []
example_layer = None
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue
assert isinstance(layer, Qwen3MoeDecoderLayer)
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
example_layer = layer.mlp
self.moe_layers.append(layer.mlp.experts)
if example_layer is None:
raise RuntimeError("No Qwen3MoE layer found in the model.layers.")
self.num_moe_layers = len(self.moe_layers)
self.num_expert_groups = 1
self.num_shared_experts = 0
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: Optional[List[torch.Tensor]] = None,
attn_metadata: Optional[AttentionMetadata] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
_metadata_for_padding = init_metadata_for_sp(
input_ids, self.enable_sequence_parallelism)
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
inputs_embeds, _metadata_for_padding)
return hidden_states

View File

@@ -1,221 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Adapted from vllm/model_executor/models/deepseek_mtp.py
# Copyright 2023 The vLLM team.
#
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.models.deepseek_mtp import (
DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer,
SharedHead)
from vllm.model_executor.models.utils import maybe_prefix
from vllm.sequence import IntermediateTensors
from vllm_ascend.torchair.models.torchair_deepseek_v2 import \
TorchairDeepseekV2DecoderLayer
class TorchairDeepSeekShareHead(SharedHead):
def __init__(self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
nn.Module.__init__(self)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "head"))
class TorchairDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer
):
def __init__(
self,
config: PretrainedConfig,
prefix: str,
model_config: ModelConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
nn.Module.__init__(self)
self.tp_size = get_tensor_model_parallel_world_size()
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.eh_proj = nn.Linear(config.hidden_size * 2,
config.hidden_size,
bias=False)
self.shared_head = TorchairDeepSeekShareHead(config=config,
quant_config=quant_config,
prefix=maybe_prefix(
prefix,
"shared_head"))
self.mtp_block = TorchairDeepseekV2DecoderLayer(
config, prefix, model_config, cache_config, quant_config)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_index: int = 0,
) -> torch.Tensor:
assert inputs_embeds is not None
# masking inputs at position 0, as not needed by MTP
inputs_embeds = torch.where((positions == 0).unsqueeze(-1),
torch.zeros_like(inputs_embeds),
inputs_embeds)
inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states)
hidden_states = self.eh_proj(
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
del inputs_embeds, previous_hidden_states
replace_allreduce = hidden_states.shape[0] % self.tp_size == 0
hidden_states, residual = self.mtp_block(
positions=positions,
hidden_states=hidden_states,
residual=None,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
replace_allreduce=replace_allreduce)
hidden_states = residual + hidden_states
return hidden_states
class TorchairDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
config = vllm_config.model_config.hf_config
self.mtp_start_layer_idx = config.num_hidden_layers
self.num_mtp_layers = config.num_nextn_predict_layers
# to map the exact layer index from weights
self.layers = torch.nn.ModuleDict({
str(idx):
TorchairDeepSeekMultiTokenPredictorLayer(
config,
f"{prefix}.layers.{idx}",
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
quant_config=vllm_config.quant_config,
)
for idx in range(self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers)
})
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
# Note: torch._dynamo.exc.Unsupported: builtin: str
self.layers_list = [
self.layers[str(idx)]
for idx in range(self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers)
]
self.logits_processor = LogitsProcessor(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: torch.Tensor,
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
current_step_idx = (spec_step_idx % self.num_mtp_layers)
step_kv_cache = kv_caches[
current_step_idx] if kv_caches is not None else None
return self.layers_list[current_step_idx](
input_ids,
positions,
step_kv_cache,
attn_metadata,
previous_hidden_states,
inputs_embeds,
current_step_idx,
)
def compute_logits(
self,
hidden_states: torch.Tensor,
spec_step_idx: int = 0,
) -> torch.Tensor:
current_step_idx = (spec_step_idx % self.num_mtp_layers)
mtp_layer = self.layers_list[current_step_idx]
logits = self.logits_processor(mtp_layer.shared_head.head,
mtp_layer.shared_head(hidden_states))
return logits
@support_torch_compile
class TorchairDeepSeekMTP(DeepSeekMTP):
# NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized;
# NOTE 2.The description file generated by the current msmodelslim tool does not have
# MTP layer info. Please manually add it and set the value to FLOAT.
packed_modules_mapping = {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
self.config = vllm_config.model_config.hf_config
self.model = TorchairDeepSeekMultiTokenPredictor(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model"))
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: Optional[List[torch.Tensor]] = None,
attn_metadata: Optional[AttentionMetadata] = None,
hidden_states: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, hidden_states, inputs_embeds,
spec_step_idx)
return hidden_states

File diff suppressed because it is too large Load Diff

View File

@@ -1,28 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from vllm_ascend.torchair.models.torchair_deepseek_v2 import \
TorchairDeepseekV2ForCausalLM
class TorchairDeepseekV3ForCausalLM(TorchairDeepseekV2ForCausalLM):
pass

File diff suppressed because it is too large Load Diff

View File

@@ -1,120 +0,0 @@
import torch
from torch.nn import functional as F
from vllm.distributed import (get_tensor_model_parallel_world_size,
get_tp_group, tensor_model_parallel_all_gather,
tensor_model_parallel_reduce_scatter)
from vllm.forward_context import get_forward_context
from vllm_ascend.platform import NPUPlatform
class MetadataForPadding:
def __init__(self,
padding_flag=False,
lengths_sum_padding=0,
lengths_sum_unpadding=0,
pad_size=0,
not_dummy_and_is_prefill=False):
self.padding_flag = padding_flag
self.not_dummy_and_is_prefill = not_dummy_and_is_prefill
self.lengths_sum_padding = lengths_sum_padding
self.lengths_sum_unpadding = lengths_sum_unpadding
self.pad_size = pad_size
self.tp_size = get_tp_group().world_size
self.tp_rank_in_group = get_tp_group().rank_in_group
assert self.lengths_sum_padding % self.tp_size == 0
self.slice_size = self.lengths_sum_padding // self.tp_size
self.mc2_mask = torch.zeros(
self.lengths_sum_padding,
dtype=torch.bool,
device=NPUPlatform.device_type,
)
self.mc2_mask[:lengths_sum_unpadding] = True
def padding_aligned_reduce_scatter(self,
data: torch.Tensor) -> torch.Tensor:
if self.padding_flag:
pad_size = self.pad_size
padded_data = F.pad(data, (0, 0, 0, pad_size))
else:
padded_data = data
padded_data_reduce_scatter = tensor_model_parallel_reduce_scatter(
padded_data, 0)
return padded_data_reduce_scatter
def allgather_unpadding_aligned(self,
padded_data: torch.Tensor) -> torch.Tensor:
padded_data_allgather = tensor_model_parallel_all_gather(
padded_data, 0)
if self.padding_flag:
lengths_sum_unpadding = self.lengths_sum_unpadding
unpadding_data = padded_data_allgather[:lengths_sum_unpadding]
else:
unpadding_data = padded_data_allgather
return unpadding_data
def padding_slice(self, data: torch.Tensor) -> torch.Tensor:
padded_data = F.pad(data, (0, 0, 0, self.pad_size))
start = self.tp_rank_in_group * self.slice_size
end = start + self.slice_size
slice_data = padded_data[start:end]
return slice_data
def padding_aligned_scatter(self, data: torch.Tensor) -> torch.Tensor:
if self.padding_flag:
pad_size = self.pad_size
padded_data = F.pad(data, (0, 0, 0, pad_size))
else:
padded_data = data
# padded_data = data
padded_data = torch.tensor_split(padded_data, self.tp_size, dim=0)
padded_data_reduce_scatter = padded_data[self.tp_rank_in_group]
return padded_data_reduce_scatter
def init_metadata_for_sp(input_ids, enable_sequence_parallelism):
if not enable_sequence_parallelism:
return MetadataForPadding(padding_flag=False,
not_dummy_and_is_prefill=False)
is_perifll = 0
attn_metadata = get_forward_context().attn_metadata
tp_size = get_tensor_model_parallel_world_size()
if attn_metadata is not None:
if hasattr(attn_metadata,
'is_only_prefill') and attn_metadata.is_only_prefill:
is_perifll = 1
if hasattr(attn_metadata,
'num_prefills') and attn_metadata.num_prefills > 0:
is_perifll = 1
if is_perifll:
lengths_sum_unpadding = input_ids.shape[0]
lengths_sum_padding = (
(lengths_sum_unpadding + tp_size - 1) // tp_size) * tp_size
if lengths_sum_unpadding == lengths_sum_padding:
padding_flag = False
else:
padding_flag = True
pad_size = lengths_sum_padding - lengths_sum_unpadding
_metadata_for_padding = MetadataForPadding(
lengths_sum_unpadding=lengths_sum_unpadding,
lengths_sum_padding=lengths_sum_padding,
padding_flag=padding_flag,
pad_size=pad_size,
not_dummy_and_is_prefill=True)
return _metadata_for_padding
return MetadataForPadding(padding_flag=False,
not_dummy_and_is_prefill=False)

View File

@@ -1,245 +0,0 @@
from dataclasses import dataclass
from typing import Callable, Optional
import torch
import torch.distributed as dist
from vllm.distributed.parallel_state import GroupCoordinator
from vllm.model_executor.layers.linear import LinearBase
def dispose_tensor(x: torch.Tensor):
x.set_(torch.empty([], device=x.device, dtype=x.dtype))
@dataclass
class LayerMetadata:
"""Metadata for a layer.
"""
layer: Optional[LinearBase] # The layer object.
post_method: Callable[[
torch.nn.Module
], None] # The `process_weights_after_loading` method from the quant method.
weight: torch.Tensor # The weight tensor.
window_idx: int # The index of the window.
@dataclass
class SharedWindowMetadata:
"""Metadata for a shared window.
"""
weight: torch.Tensor # The weight tensor to be shared by layers.
data_layer_idx: int # The index of the layer this window's weight is equal to.
work: Optional[torch.distributed.Work] # The asynchronous broadcast work.
@dataclass
class SeriesMetadata:
"""Metadata for a weight shared series.
"""
group: GroupCoordinator
start_layer: int
end_layer: int
num_layers: int
prefetch_step: int
dummy_weight: torch.Tensor # Dummy weight to replace the loaded weight matrix. All the layers in the series share the same dummy weight tensor.
layers: list[LayerMetadata]
shared_windows: list[
SharedWindowMetadata] # Shared windows for prefetching. The window size is (`prefetch_step` + 1), as only the weights for the next (`prefetch_step` + 1) layers need to be stored.
window_offset: int # The index of the window for the next coming layer.
def is_source(self, layer_idx) -> bool:
return layer_idx % self.group.world_size == self.group.rank_in_group
def post_process_after_loading(self):
# This method only needs to be called once per series.
if self.shared_windows:
return
for layer_idx in range(self.start_layer, self.end_layer):
layer = self.layers[layer_idx - self.start_layer]
is_source = self.is_source(layer_idx)
# If the weight uses dummy weight, make a copy temporary such that the post method call won't affect other layers which also uses dummy weight.
if not is_source:
layer.weight.set_(torch.empty_like(self.dummy_weight))
# Broadcast to get the true weight.
dist.broadcast(layer.weight,
src=self.group.ranks[layer_idx %
self.group.world_size],
group=self.group.device_group)
assert layer.layer is not None
# Call `process_weights_after_loading` from the quant method.
layer.post_method(layer.layer)
step = layer_idx - self.start_layer
if step < self.prefetch_step:
# Build the windows for the first `prefetch_step` layers. The weights can be used for the first `prefetch_step` layers in `forward()`, so also clone the weights.
self.shared_windows.append(
SharedWindowMetadata(
weight=layer.weight.clone().detach(),
data_layer_idx=layer_idx,
work=None,
))
layer.window_idx = step
# When the layer not intended to be stored in this device, link to the corresponding window's tensor.
if not is_source:
layer.weight.set_(self.shared_windows[-1].weight)
else:
# Build one more window for prefetch. The weight is useless, so just keep the shape.
if step == self.prefetch_step:
self.shared_windows.append(
SharedWindowMetadata(
weight=torch.empty_like(layer.weight),
data_layer_idx=-1,
work=None,
))
# When the layer not intended to be stored in this device, dispose the tensor.
if not is_source:
dispose_tensor(layer.weight)
dispose_tensor(self.dummy_weight)
def reach_layer(self, layer_idx: int):
# The index of the layer to be prefetched.
next_layer_idx = (layer_idx + self.prefetch_step
) % self.num_layers + self.start_layer
next_layer = self.layers[next_layer_idx - self.start_layer]
# The index of the window to store the weight for the coming layer.
next_layer.window_idx = self.window_offset
window = self.shared_windows[next_layer.window_idx]
# When the layer not intended to be stored in this device, link to the corresponding window's tensor.
if not self.is_source(next_layer_idx):
next_layer.weight.set_(window.weight)
# Update `window_offset` by rolling one step.
self.window_offset = (self.window_offset + 1) % (self.prefetch_step +
1)
assert window.data_layer_idx != next_layer_idx
window.data_layer_idx = next_layer_idx
# Start asynchronous broadcast work.
window.work = dist.broadcast(
next_layer.weight,
src=self.group.ranks[next_layer_idx % self.group.world_size],
group=self.group.device_group,
async_op=True)
def wait_weight(self, layer_idx: int):
# Find the asynchronous broadcast work and wait for it.
assert self.shared_windows
window = self.shared_windows[self.layers[layer_idx -
self.start_layer].window_idx]
# Make sure the data in the corresponding shared window is for the current layer.
assert window.data_layer_idx == layer_idx
if window.work is not None:
window.work.wait()
window.work = None
@dataclass
class LayerExternalMetadata:
"""External metadata for a layer.
"""
series: SeriesMetadata
layer_idx: int
_series_dict: dict[str, SeriesMetadata] = {}
_layer_external_dict: dict[int, LayerExternalMetadata] = {}
def _create_forward_wrapper(forward: Callable, series: SeriesMetadata,
layer_idx: int) -> Callable:
def wrapped_forward(*args, **kwargs):
# Wait for the weight.
series.wait_weight(layer_idx)
return forward(*args, **kwargs)
return wrapped_forward
"""
Register linear layers into a shared storage series.
In a parallel group, each device stores a distinct, non-overlapping subset of layers from the series. All layers in a series must have the same structure (are isomorphic). The weight matrix for the i-th layer is stored on device (i % n), where n is the number of devices.
After loading the model, you must call `post_process_after_loading_for_shared_weight_series(layer)` on any layer of this series to complete the initialization.
During execution, each time a new layer is reached, you must call `reach_layer_for_shared_weight_series(layer)` for that layer to prefetch the weights. The argument `prefetch_step` is a non-negative integer k that manages asynchronous weight prefetching. Each call to `reach_layer_for_shared_weight_series(current_layer)` method will trigger an asynchronous prefetch for the weights of the k-th subsequent layer after `current_layer` within the series.
Note: The layers are managed as a circular buffer. The index of the layer to prefetch is determined by the formula:
- total_layers = end_layer - start_layer
- prefetch_layer_idx = (layer_idx + prefetch_step) % total_layers + start_layer
To hold the weights for the current layer and the k prefetched layers, a pool of (k + 1) shared tensor buffers will be created for this series.
Arguments:
series_name: This name identifies which series this layer belongs to.
group: The group coordinator for handling asynchronous communications. It is recommended to create a new group coordinator for each new series.
start_layer: The index of the first layer in the series (inclusive).
end_layer: The index of the last layer in the series (exclusive). Thus, the series includes all layers with indices in the range [start_layer, end_layer).
layer_idx: The index of the current layer.
layer: The linear layer object to register.
prefetch_step: An integer that manages asynchronous weight prefetching. Setting it to 0 or 1 can cover most cases.
"""
def register_layer_to_shared_weight_series(
series_name: str,
group: GroupCoordinator,
start_layer: int,
end_layer: int,
layer_idx: int,
layer: LinearBase,
prefetch_step: int = 1,
):
global _series_dict
if series_name not in _series_dict:
num_layers = end_layer - start_layer
assert num_layers > 0
assert prefetch_step >= 0 and prefetch_step <= num_layers - 2
_series_dict[series_name] = SeriesMetadata(
group=group,
start_layer=start_layer,
end_layer=end_layer,
num_layers=num_layers,
prefetch_step=prefetch_step,
dummy_weight=torch.empty_like(layer.weight),
layers=[
LayerMetadata(
layer=None,
post_method=lambda layer: None,
weight=torch.empty([]),
window_idx=-1,
) for _ in range(num_layers)
],
shared_windows=[],
window_offset=prefetch_step,
)
series = _series_dict[series_name]
assert layer.quant_method is not None
series.layers[layer_idx - start_layer] = LayerMetadata(
layer=layer,
post_method=layer.quant_method.process_weights_after_loading,
weight=layer.weight,
window_idx=-1,
)
# Discard the original `process_weights_after_loading` method such that it won't be called by others.
layer.quant_method.process_weights_after_loading = lambda layer: None
# When the layer not intended to be stored in this device, dispose the tensor and skip weight loading.
if not series.is_source(layer_idx):
dispose_tensor(layer.weight)
layer.weight.weight_loader = lambda *args, **kwargs: None
layer.forward = _create_forward_wrapper(layer.forward, series, layer_idx)
global _layer_external_dict
_layer_external_dict[id(layer)] = LayerExternalMetadata(
series=series,
layer_idx=layer_idx,
)
def post_process_after_loading_for_shared_weight_series(layer: LinearBase):
ext = _layer_external_dict[id(layer)]
ext.series.post_process_after_loading()
def reach_layer_for_shared_weight_series(layer: LinearBase):
ext = _layer_external_dict[id(layer)]
ext.series.reach_layer(ext.layer_idx)

View File

@@ -1,37 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import torch
def torchair_silu_and_mul_forward_oot(self, x: torch.Tensor) -> torch.Tensor:
"""AscendSiluAndMul forward in torchair mode.
The key difference from the original implementation is the removal of operators
from the torch.ops.vllm class, as these operators only function in non-torchair
modes. Adding them back would cause the graph compilation to fail.
"""
import torch_npu
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
if get_ascend_device_type() == AscendDeviceType._310P:
out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16)
else:
out = torch_npu.npu_swiglu(x)
return out

File diff suppressed because it is too large Load Diff

View File

@@ -1,78 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from typing import Optional, Tuple, Union
import torch
from vllm.config import get_current_vllm_config
from vllm.model_executor.layers.layernorm import RMSNorm
_original_re_init = RMSNorm.__init__
def torchair_rmsnorm_init_(
self,
hidden_size: int,
eps: float = 1e-6,
var_hidden_size: Optional[int] = None,
has_weight: bool = True,
dtype: Optional[torch.dtype] = None,
) -> None:
_original_re_init(self, hidden_size, eps, var_hidden_size, has_weight,
dtype)
vllm_config = get_current_vllm_config()
self.bias = None
# quantization with anti_method m4 will generate none-zero norm bias
if vllm_config.quant_config is not None and \
any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()):
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
requires_grad=False)
def torchair_rmsnorm_forward_oot(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""AscendRMSNorm forward in torchair mode.
The key difference from the original implementation is the removal of operators
from the torch.ops.vllm class, as these operators only function in non-torchair
modes. Adding them back would cause the graph compilation to fail.
"""
import torch_npu
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
if residual is not None:
if get_ascend_device_type() == AscendDeviceType._310P:
orig_dtype = residual.dtype
x = x + residual.to(x.dtype)
residual = x.to(orig_dtype)
x, _ = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)
else:
x, _, residual = torch_npu.npu_add_rms_norm(
x, residual, self.weight, self.variance_epsilon)
if self.bias is not None:
x.add_(self.bias)
return x, residual
x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
if self.bias is not None:
x.add_(self.bias)
return x

View File

@@ -1,367 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import math
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
import torch_npu
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import (AscendDeviceType, enable_custom_op,
get_ascend_device_type)
def custom_rotary_embedding_enabled(query, neox_style, head_size):
return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and enable_custom_op(
)
def rope_forward_oot(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
is_neox_style_override: Optional[bool] = None,
is_qwen_torchair: Optional[bool] = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
if get_ascend_config(
).torchair_graph_config.enabled and not is_qwen_torchair:
return self.forward_native(
positions,
query,
key,
offsets,
)
query_shape, key_shape = query.shape, key.shape
if self.cos_sin_cache.device != query.device:
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
if self.cos_sin_cache.dtype != query.dtype:
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
neox_style = self.is_neox_style
if is_neox_style_override is not None:
neox_style = is_neox_style_override
# adopt custom kernel path for rotary_embedding
if custom_rotary_embedding_enabled(
query, neox_style, self.head_size) and get_ascend_device_type(
) != AscendDeviceType._310P:
query, key = torch.ops._C_ascend.rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
neox_style,
)
return query.view(query_shape), key.view(key_shape)
if offsets is not None:
raise NotImplementedError(
"Batched rotary embedding is currently not supported on NPU.")
else:
# TODO: Remove the contiguous in the future.
query = query.contiguous().view(query.shape[0], -1)
key = key.contiguous().view(key.shape[0], -1)
torch_npu._npu_rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
neox_style,
)
return query.view(query_shape), key.view(key_shape)
def native_rope_deepseek_forward(self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None):
if len(key.shape) == 2:
key = key[:, None, :]
# Note: we implement the non neox_style method with shuffle the last dim and neox style
# calculation method which is also more compute friendly to the ascend machine
# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py
neox_style = True
if self.is_neox_style is False:
b, h_q, d = query.shape
query = query.view(b, h_q, d // 2, 2).transpose(3,
2).reshape(b, h_q, d)
b, h_k, d = key.shape
key = key.view(b, h_k, d // 2, 2).transpose(3, 2).reshape(b, h_k, d)
q_pe, k_pe = rope_forward_oot(self, positions, query, key, offsets,
neox_style)
return q_pe, k_pe
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
# Inverse dim formula to find dim based on number of rotations
def yarn_find_correction_dim(num_rotations,
dim,
base=10000,
max_position_embeddings=2048):
# Note: use torch instead of math to solve MTP compilation error.
return (dim * torch.log(
torch.tensor(max_position_embeddings) /
(num_rotations * 2 * torch.pi))) / (2 * torch.log(torch.tensor(base)))
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
# Find dim range bounds based on rotations
def yarn_find_correction_range(low_rot,
high_rot,
dim,
base=10000,
max_position_embeddings=2048):
# Note: use torch instead of math to solve MTP compilation error.
low = torch.floor(
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = torch.ceil(
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
# Note: use torch instead of max/min to solve MTP compilation error.
return torch.clamp(low, min=0), torch.clamp(high, max=dim - 1)
def yarn_linear_ramp_mask(min_value, max_value, dim):
# Note: The if conditional branch is not used here
# to solve MTP compilation error.
max_value += (min_value == max_value).float() * 0.001
linear_func = (torch.arange(dim, dtype=torch.float32) -
min_value) / (max_value - min_value)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos[position_ids]
sin = sin[position_ids]
cos = cos[:, None, None, :]
sin = sin[:, None, None, :]
if len(q.shape) == 3:
q = q[:, :, None, :]
if len(k.shape) == 2:
k = k[:, None, None, :]
elif len(k.shape) == 3:
k = k[:, :, None, :]
b, h_q, s, d = q.shape
q = q.view(b, h_q, s, d // 2, 2).transpose(4, 3).reshape(b, h_q, s, d)
b, h_k, s, d = k.shape
k = k.view(b, h_k, s, d // 2, 2).transpose(4, 3).reshape(b, h_k, s, d)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
q_embed = q_embed.view(b, h_q, d)
k_embed = k_embed.view(b, h_k, d)
return q_embed, k_embed
def _set_cos_sin_cache(self, max_seq_len, device, dtype):
dim = self.rotary_dim
freq_extra = 1.0 / (self.base**(
torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
freq_inter = 1.0 / (self.scaling_factor * self.base**(
torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
low, high = yarn_find_correction_range(
self.beta_fast,
self.beta_slow,
dim,
self.base,
self.max_position_embeddings,
)
inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
device=device, dtype=torch.float32)
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(max_seq_len, device=device, dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale
sin_cached = torch.cat([freqs, freqs], dim=-1).sin() * self.mscale
cos_cached = cos_cached.to(dtype)
sin_cached = sin_cached.to(dtype)
cache = torch.cat([freqs.cos() * self.mscale,
freqs.sin() * self.mscale],
dim=-1).to(dtype)
self.register_buffer("cos_sin_cache", cache, persistent=False)
self.register_buffer("cos_cached", cos_cached, persistent=False)
self.register_buffer("sin_cached", sin_cached, persistent=False)
def __set_cos_sin_cache(self, seq_len, device, dtype):
inv_freq = 1.0 / (self.base**(torch.arange(
0, self.rotary_dim, 2, device=device, dtype=torch.float32) *
(1 / self.rotary_dim)))
self.register_buffer("inv_freq", inv_freq)
t = torch.arange(self.max_position_embeddings,
device=self.inv_freq.device,
dtype=torch.float32)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos", emb.cos().to(dtype=dtype), persistent=False)
self.register_buffer("sin", emb.sin().to(dtype=dtype), persistent=False)
self.embed = F.embedding
_original_re_init = RotaryEmbedding.__init__
def qwen_rope_init_func(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
_original_re_init(self, head_size, rotary_dim, max_position_embeddings,
base, is_neox_style, dtype)
if get_ascend_config().torchair_graph_config.enabled:
__set_cos_sin_cache(self,
seq_len=max_position_embeddings,
device="npu",
dtype=dtype)
def rope_forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
is_neox_style_override: Optional[bool] = None,
max_seq_len: Optional[int] = None,
is_prefill: Optional[bool] = True,
is_qwen_torchair: Optional[bool] = False,
):
if get_ascend_config().torchair_graph_config.enabled \
and is_qwen_torchair and not is_prefill:
if max_seq_len is not None and torch.gt(max_seq_len,
self.max_position_embeddings):
__set_cos_sin_cache(self,
seq_len=max_seq_len,
device=query.device,
dtype=torch.float32)
# bsnd/bnsd
if positions is not None:
cos = self.embed(positions, self.cos)
sin = self.embed(positions, self.sin)
self.cos_embed = cos
self.sin_embed = sin
else:
cos = self.cos_embed
sin = self.sin_embed
query = query.view(*query.shape[:-1], -1, self.head_size).contiguous()
key = key.view(*key.shape[:-1], -1, self.head_size).contiguous()
cos = cos.unsqueeze(-2).unsqueeze(-2)
sin = sin.unsqueeze(-2).unsqueeze(-2)
query = query.unsqueeze(1)
key = key.unsqueeze(1)
q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb(
query, key, cos, sin)
return q_embed.flatten(-2), k_embed.flatten(-2)
else:
return rope_forward_oot(self, positions, query, key, offsets,
is_neox_style_override,
is_qwen_torchair) # type: ignore
def deepseek_rope_init_func(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
mscale: float = 1,
mscale_all_dim: float = 0,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
# Get n-d magnitude scaling corrected for interpolation.
self.mscale = float(
yarn_get_mscale(self.scaling_factor, float(mscale)) /
yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
attn_factor)
super(DeepseekScalingRotaryEmbedding,
self).__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
# NOTE: For ascend friendly computing, reorder sin and cos cache
self.max_seq_len = math.ceil(max_position_embeddings * scaling_factor)
_set_cos_sin_cache(self, self.max_seq_len, dtype=dtype, device="npu")

View File

@@ -1,38 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from vllm.distributed import tensor_model_parallel_all_reduce
def vocab_embedding_forward(self, input_):
if self.tp_size > 1:
# Build the mask.
masked_input, input_mask = self._get_masked_input_and_mask(
input_, self.shard_indices.org_vocab_start_index,
self.shard_indices.org_vocab_end_index,
self.shard_indices.num_org_vocab_padding,
self.shard_indices.added_vocab_start_index,
self.shard_indices.added_vocab_end_index)
else:
masked_input = input_
# Get the embeddings.
output_parallel = self.quant_method.embedding(self, masked_input.long())
# Mask the output embedding.
if self.tp_size > 1:
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
# Reduce across all the model parallel GPUs.
output = tensor_model_parallel_all_reduce(output_parallel)
return output

View File

@@ -1,501 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Any, Callable, Dict, Optional
import numpy as np
import torch
import torch_npu
from vllm.config import get_current_vllm_config
from vllm.distributed import get_ep_group
from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import FusedMoEState
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.torchair.ops.torchair_fused_moe import torchair_select_experts
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import (
torchair_fused_experts_with_all2all, torchair_fused_experts_with_mc2)
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
class TorchairAscendW4A8DynamicLinearMethod:
"""Linear method for Ascend W4A8_DYNAMIC
"""
def __init__(self):
self.transpose_weight = True
vllm_config = get_current_vllm_config()
self.group_size = vllm_config.quant_config.quant_description.get(
"group_size", 256)
quant_version = vllm_config.quant_config.quant_description.get(
"version", "0")
self.new_quant_version = quant_version == "1.0.0"
from vllm.distributed import get_tensor_model_parallel_world_size
self.tp_size = get_tensor_model_parallel_world_size()
def get_weight(self, input_size: int, output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
params_dict = {}
if self.new_quant_version:
pack_factor = 2
actual_output_size = output_size // pack_factor
params_dict["weight"] = torch.empty(actual_output_size,
input_size,
dtype=torch.int8)
params_dict["_packed_dim"] = 0
params_dict["_packed_factor"] = pack_factor
else:
params_dict["weight"] = torch.empty(output_size,
input_size,
dtype=torch.int8)
return params_dict
@staticmethod
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
return {}
@staticmethod
def get_perchannel_param(output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
return {}
def get_pergroup_param(self,
input_size: int,
output_size: int,
params_dtype: torch.dtype,
layer_type: Optional[str] = None) -> Dict[str, Any]:
params_dict = {}
params_dict["weight_scale"] = torch.empty(output_size,
1,
dtype=params_dtype)
params_dict["weight_offset"] = torch.empty(output_size,
1,
dtype=params_dtype)
params_dict["weight_scale_second"] = torch.empty(output_size,
input_size //
self.group_size,
dtype=params_dtype)
params_dict["weight_offset_second"] = torch.empty(output_size,
input_size //
self.group_size,
dtype=params_dtype)
if self.new_quant_version:
scale_bias_dim = 16 if layer_type == "row" else 1
params_dict["scale_bias"] = torch.empty(output_size,
scale_bias_dim,
dtype=torch.float32)
return params_dict
@staticmethod
def process_scale_second(weight: torch.Tensor,
scale: torch.Tensor,
per_group_scale: torch.Tensor,
is_new_quant: bool = False):
k, n = weight.shape
group_num, n_scale = per_group_scale.shape
if is_new_quant:
n = n * 2
bias = None
if not is_new_quant:
weight_high = weight.to(torch.float32).reshape(
group_num, -1, n) * per_group_scale.reshape(group_num, 1, n)
weight_high = weight_high.reshape(k, n)
bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0)
antiquant_scale = (scale * per_group_scale).reshape(group_num, n)
return antiquant_scale.npu(), bias
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = None,
) -> torch.Tensor:
return torch_npu.npu_weight_quant_batchmatmul(
x,
layer.weight,
antiquant_scale=layer.weight_scale_second.to(x.dtype),
antiquant_group_size=self.group_size,
)
def process_weights_after_loading(self, layer: torch.nn.Module):
if self.transpose_weight:
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight_scale.data = layer.weight_scale.data.flatten().to(
torch.float32)
layer.weight_offset.data = layer.weight_offset.data.flatten()
layer.weight_scale_second.data, scale_bias = self.process_scale_second(
layer.weight.data,
layer.weight_scale.data,
layer.weight_scale_second.data.transpose(0, 1).contiguous(),
is_new_quant=self.new_quant_version,
)
if self.new_quant_version:
if hasattr(layer, "scale_bias"):
if layer.scale_bias.data.shape[1] == 1:
layer.scale_bias.data = layer.scale_bias.data.flatten()
else:
layer.scale_bias.data = layer.scale_bias.data.contiguous()
else:
if scale_bias is not None:
param = torch.nn.Parameter(scale_bias, requires_grad=False)
layer.register_parameter("weight_scale_bias", param)
if self.new_quant_version:
assert layer.weight.data.shape[-1] % 4 == 0, \
f"the last dim of weight needs to be divided by 4, got shape {layer.weight.data.shape}"
layer.weight.data = layer.weight.data.view(
torch.int32).contiguous()
else:
layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(
layer.weight.data.to(torch.int32))
class TorchairAscendW4A8DynamicFusedMoEMethod:
"""FusedMoe method for Ascend W4A8_DYNAMIC.
"""
def __init__(self):
self.transpose_weight = True
self.ep_group = get_ep_group()
ascend_config = get_ascend_config()
self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
vllm_config = get_current_vllm_config()
self.group_size = vllm_config.quant_config.quant_description.get(
"group_size", 256)
# NOTE: the weights are quantized from bf16 to int4 through a per-channel quantization process
self.is_per_channel_weight = self.group_size == 0
quant_version = vllm_config.quant_config.quant_description.get(
"version", "0")
# NOTE: new quantize weights: 2 int4 pack into int8
self.new_quant_version = quant_version == "1.0.0"
self.tp_size = 1 if vllm_config.parallel_config.enable_expert_parallel else self.ep_group.world_size
if self.new_quant_version and self.tp_size > 16:
raise ValueError(
"The current weight does not support moe part tp>16.")
try:
device_group = get_mc2_group().device_group
# TODO: Try local_rank = ep_group.rank_in_group
local_rank = torch.distributed.get_rank(group=device_group)
backend = device_group._get_backend(torch.device("npu"))
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
local_rank)
except AttributeError:
self.moe_all_to_all_group_name = ""
def get_weight(self, num_experts: int,
intermediate_size_per_partition: int, hidden_sizes: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
param_dict = {}
if self.new_quant_version:
w13_output_size = intermediate_size_per_partition
w2_output_size = hidden_sizes // 2
else:
w13_output_size = 2 * intermediate_size_per_partition
w2_output_size = hidden_sizes
param_dict["w13_weight"] = torch.empty(num_experts,
w13_output_size,
hidden_sizes,
dtype=torch.int8)
param_dict["w2_weight"] = torch.empty(num_experts,
w2_output_size,
intermediate_size_per_partition,
dtype=torch.int8)
return param_dict
def get_dynamic_quant_param(self, num_experts: int,
intermediate_size_per_partition: int,
hidden_sizes: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
param_dict = {}
param_dict["w13_weight_scale"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=torch.float32)
param_dict["w13_weight_offset"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=torch.float32)
param_dict["w2_weight_scale"] = torch.empty(num_experts,
hidden_sizes,
1,
dtype=torch.float32)
param_dict["w2_weight_offset"] = torch.empty(num_experts,
hidden_sizes,
1,
dtype=torch.float32)
if not self.is_per_channel_weight:
param_dict["w13_weight_scale_second"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_sizes // self.group_size,
dtype=torch.float32)
param_dict["w13_weight_offset_second"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_sizes // self.group_size,
dtype=torch.float32)
param_dict["w2_weight_scale_second"] = torch.empty(
num_experts,
hidden_sizes,
intermediate_size_per_partition // self.group_size,
dtype=torch.float32)
param_dict["w2_weight_offset_second"] = torch.empty(
num_experts,
hidden_sizes,
intermediate_size_per_partition // self.group_size,
dtype=torch.float32)
if self.new_quant_version:
param_dict["w13_scale_bias"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=torch.float32)
param_dict["w2_scale_bias"] = torch.empty(num_experts,
hidden_sizes,
16 // self.tp_size,
dtype=torch.float32)
return param_dict
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: bool = True,
enable_force_load_balance: bool = True,
log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
**kwargs,
) -> torch.Tensor:
assert router_logits.shape[
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
if global_num_experts == 256:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=top_k, # topk currently is 8
bias=e_score_correction_bias,
k_group=topk_group, # fix: 4
group_count=num_expert_group, # fix 8
group_select_mode=
1, # 0: the maximum in the group; 1: topk2.sum(fix)
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
norm_type=1, # 0: softmax; 1: sigmoid(fix)
# out_flag=False, # todo new api; should the third output be output
# y2_flag=False, # old api; should the third output be output
routed_scaling_factor=1,
eps=float(1e-20))
else:
topk_weights, topk_ids = torchair_select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
fused_moe_state = get_forward_context().fused_moe_state
shared_gate_up, shared_dequant_scale = None, None
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
with npu_stream_switch("moe_secondary", 0):
npu_wait_tensor(quantized_x_for_share, router_logits)
share_up_out, _ = shared_experts.gate_up_proj(
(quantized_x_for_share, dynamic_scale_for_share))
shared_gate_up, shared_dequant_scale = share_up_out[
0], share_up_out[1]
# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
topk_weights = topk_weights.to(x.dtype)
if fused_moe_state == FusedMoEState.MC2:
return torchair_fused_experts_with_mc2(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
w1_scale_bias=layer.w13_scale_bias,
w2_scale_bias=layer.w2_scale_bias,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num,
shared_experts=shared_experts,
is_torchair=self.torchair_graph_enabled,
quantized_x_for_share=shared_gate_up,
dynamic_scale_for_share=shared_dequant_scale,
mc2_mask=kwargs.get("mc2_mask", None),
dynamic_eplb=self.dynamic_eplb)
else:
# The current implementation of deepseek moe splits hidden_states
# according to tp_size before they are feed into layers module.
# Therefore, all2all is needed no matter how dp/tp is set so as to
# dispatch/combine tokens.
return torchair_fused_experts_with_all2all(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
w1_scale_bias=layer.w13_scale_bias,
w2_scale_bias=layer.w2_scale_bias,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map,
ep_group=self.ep_group,
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num,
)
def process_scale(self, weight: torch.Tensor, scale, per_group_scale):
scale = scale.transpose(1, 2).contiguous()
if self.is_per_channel_weight:
scale_np = scale.cpu().numpy()
scale_np.dtype = np.uint32
scale_uint64_tensor = torch.from_numpy(scale_np.astype(
np.int64)).npu()
return scale_uint64_tensor, None
per_group_scale = per_group_scale.transpose(1, 2).contiguous()
group_num, k, n = weight.shape
# the weight of the new version is reduced by half by pack n, so it needs to be restored
if self.new_quant_version:
n = n * 2
per_group_scale = per_group_scale.reshape(group_num, -1, n)
group_num, quantgroup_num, n = per_group_scale.shape
bias = None
if not self.new_quant_version:
weight_high = weight.to(torch.float32).reshape([group_num, quantgroup_num, -1, n]) * \
per_group_scale.reshape([group_num, quantgroup_num, 1, n])
weight_high = weight_high.reshape([group_num, k, n])
bias = 8 * (weight_high.to(torch.float32) * scale).sum(axis=1)
scale_fp32 = (scale * per_group_scale).to(torch.float16).to(
torch.float32)
scale_fp32_np = scale_fp32.cpu().numpy()
scale_fp32_np.dtype = np.uint32
sscale_uint64 = np.zeros((group_num, quantgroup_num, n * 2),
dtype=np.uint32)
sscale_uint64[..., ::2] = scale_fp32_np
sscale_uint64_buffer = np.frombuffer(sscale_uint64.tobytes(),
dtype=np.int64).copy()
sscale_uint64_tensor = torch.from_numpy(sscale_uint64_buffer).reshape(
group_num, quantgroup_num, n)
sscale_uint64_tensor = sscale_uint64_tensor.npu()
return sscale_uint64_tensor, bias
def update_bias(self, layer, w13_bias, w2_bias):
if self.new_quant_version:
layer.w13_scale_bias.data = layer.w13_scale_bias.data.transpose(
1, 2).contiguous().sum(axis=1)
layer.w2_scale_bias.data = layer.w2_scale_bias.data.transpose(
1, 2).contiguous().sum(axis=1)
else:
w13_scale_bias = torch.nn.Parameter(w13_bias, requires_grad=False)
layer.register_parameter("w13_scale_bias", w13_scale_bias)
w2_scale_bias = torch.nn.Parameter(w2_bias, requires_grad=False)
layer.register_parameter("w2_scale_bias", w2_scale_bias)
def pack_to_int32(self, weight: torch.Tensor):
if self.new_quant_version:
# pack 4 int8(int4*2) to int32, because in pytorch, we need to use int32 to represent int4
assert weight.shape[
-1] % 4 == 0, "the last dim of weight needs to be divided by 4"
return weight.view(torch.int32).contiguous()
else:
return torch_npu.npu_quantize(weight.to(torch.float32),
torch.tensor([1.]).npu(), None,
torch.quint4x2, -1, False)
def process_weights_after_loading(self, layer):
if self.transpose_weight:
layer.w13_weight.data = layer.w13_weight.data.transpose(
1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose(
1, 2).contiguous()
w13_weight_scale_second = layer.w13_weight_scale_second.data if hasattr(
layer, "w13_weight_scale_second") else None
w2_weight_scale_second = layer.w2_weight_scale_second.data if hasattr(
layer, "w2_weight_scale_second") else None
layer.w13_weight_scale.data, w13_bias = self.process_scale(
layer.w13_weight, layer.w13_weight_scale.data,
w13_weight_scale_second)
layer.w2_weight_scale.data, w2_bias = self.process_scale(
layer.w2_weight, layer.w2_weight_scale.data,
w2_weight_scale_second)
if hasattr(layer, "w13_weight_scale_second"):
# scale_second is no longer used, release this part of the memory
del layer.w13_weight_scale_second
del layer.w2_weight_scale_second
del layer.w13_weight_offset_second
del layer.w2_weight_offset_second
self.update_bias(layer, w13_bias, w2_bias)
layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data)
layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data)

File diff suppressed because it is too large Load Diff

View File

@@ -1,457 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from dataclasses import dataclass
from typing import List, Optional, Tuple, Type
import numpy as np
import torch
import torch.nn as nn
import torch_npu
from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer,
AttentionType)
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
AscendAttentionMetadataBuilder,
AscendAttentionState,
AscendMetadata)
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendDeviceType,
aligned_16, get_ascend_device_type, nd_to_nz_2d)
class AscendAttentionTorchairBackend(AscendAttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_name() -> str:
return "ASCEND_TORCHAIR"
@staticmethod
def get_impl_cls() -> Type["AscendAttentionTorchairBackendImpl"]:
return AscendAttentionTorchairBackendImpl
@staticmethod
def get_builder_cls() -> type["AscendAttentionTorchairMetadataBuilder"]:
return AscendAttentionTorchairMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (2, num_blocks, block_size, num_kv_heads * head_size)
@staticmethod
def get_bsh_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (2, num_blocks, block_size, num_kv_heads * head_size)
@dataclass
class AscendDecodeMetadata:
# Input positions for rotrary embeddings since for MLA the rotary
# position embeddings are applied inside the attention backend
input_positions: torch.Tensor
block_table: torch.Tensor
seq_lens: torch.Tensor
max_seq_lens: int
seq_lens_list: list[int]
attn_mask: Optional[torch.Tensor] = None
@dataclass
class AscendTorchairMetadata(AscendMetadata):
decode: Optional[AscendDecodeMetadata] = None
class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder):
def __init__(
self,
kv_cache_spec,
layer_names,
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.max_num_blocks_per_req = cdiv(
self.model_config.max_model_len,
self.vllm_config.cache_config.block_size)
self.max_blocks = (self.model_config.max_model_len +
self.vllm_config.cache_config.block_size -
1) // self.vllm_config.cache_config.block_size
def _get_graph_runner_block_tables(
self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor:
max_blocks = self.max_blocks
graph_block_tables = torch.zeros((num_seqs, max_blocks),
dtype=block_tables.dtype,
device=block_tables.device)
num_blocks = block_tables.size(1)
if num_blocks <= max_blocks:
graph_block_tables[:num_seqs, :
num_blocks] = block_tables[:num_seqs, :
num_blocks]
else:
graph_block_tables[:num_seqs, :
max_blocks] = block_tables[:num_seqs, :
max_blocks]
return graph_block_tables[:, :max_blocks]
def build_torchair_graph_dummy(
self, common_attn_metadata: TorchairCommonAttentionMetadata
) -> AscendTorchairMetadata:
device = self.device
num_reqs = common_attn_metadata.num_reqs
block_table = torch.zeros((num_reqs, self.max_blocks),
dtype=torch.int32,
device=device)
block_table = self._get_graph_runner_block_tables(
num_reqs, block_table)
seq_lens = torch.ones(num_reqs, dtype=torch.int32, device=device)
input_positions = torch.zeros(num_reqs,
dtype=torch.int32,
device=device).long()
slot_mapping = torch.full((num_reqs, ),
PAD_SLOT_ID,
dtype=torch.int32,
device=device)
query_start_loc = torch.full((num_reqs, ),
-1,
dtype=torch.int32,
device=device)
decode_metadata = AscendDecodeMetadata(input_positions=input_positions,
block_table=block_table,
seq_lens=seq_lens,
seq_lens_list=seq_lens.tolist(),
max_seq_lens=1)
attn_metadata = AscendTorchairMetadata(
num_actual_tokens=common_attn_metadata.num_actual_tokens,
block_tables=block_table,
query_lens=0,
query_start_loc=query_start_loc,
seq_lens=seq_lens,
slot_mapping=slot_mapping,
attn_state=AscendAttentionState.DecodeOnly,
decode=decode_metadata)
return attn_metadata
def build(
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: Optional[nn.Module] = None,
):
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
block_table = common_attn_metadata.block_table_tensor
block_table[:num_reqs, :self.max_num_blocks_per_req] = (
block_table[:num_reqs])
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
attn_mask = common_attn_metadata.attn_mask
attn_state = common_attn_metadata.attn_state
if get_ascend_device_type(
) == AscendDeviceType._310P and attn_state == AscendAttentionState.PrefillNoCache:
mask_nz = nd_to_nz_2d(attn_mask)
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), 29)
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
num_reqs
+ 1]
query_start_loc = query_start_loc_cpu.to(self.device,
non_blocking=True)
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
input_positions = common_attn_metadata.positions[:
num_actual_tokens].long(
)
decode_metadata = None
graph_pad_size = common_attn_metadata.graph_pad_size
use_torchair_graph = graph_pad_size > -1
if common_attn_metadata.attn_state in [
AscendAttentionState.DecodeOnly,
]:
max_seq_lens = seq_lens.max().item()
num_seqs = len(seq_lens)
if use_torchair_graph and common_attn_metadata.attn_state in [
AscendAttentionState.DecodeOnly,
]:
num_reqs_pad_size = 0
num_token_pad_size = 0
if graph_pad_size != 0:
pad_value = 0
num_token_pad_size = graph_pad_size - num_actual_tokens
num_reqs_pad_size = (
graph_pad_size //
common_attn_metadata.decode_token_per_req - num_reqs)
pad_value = 1
padded_seq_lens = seq_lens.tolist() + [pad_value
] * num_reqs_pad_size
seq_lens = torch.from_numpy(
np.array(padded_seq_lens).astype(np.int32))
padding = torch.full((num_token_pad_size, ),
PAD_SLOT_ID,
dtype=slot_mapping.dtype,
device=slot_mapping.device)
slot_mapping = torch.cat([slot_mapping, padding])
block_table_padding = torch.zeros(
(num_reqs_pad_size, ) + block_table.shape[1:],
dtype=block_table.dtype,
device=block_table.device)
block_table = torch.cat([block_table, block_table_padding],
dim=0)
block_table = self._get_graph_runner_block_tables(
num_seqs + num_reqs_pad_size, block_table)
padding_0 = torch.zeros(num_token_pad_size,
dtype=input_positions.dtype,
device=input_positions.device)
input_positions = torch.cat([input_positions, padding_0])
decode_metadata = AscendDecodeMetadata(
input_positions=input_positions,
block_table=block_table,
seq_lens=seq_lens,
seq_lens_list=seq_lens.tolist(),
max_seq_lens=max_seq_lens,
attn_mask=None)
attn_metadata = AscendTorchairMetadata(
decode=decode_metadata,
num_actual_tokens=num_actual_tokens,
block_tables=block_table,
query_start_loc=query_start_loc,
query_lens=query_lens,
seq_lens=seq_lens,
max_query_len=common_attn_metadata.max_query_len,
slot_mapping=slot_mapping,
attn_mask=attn_mask,
attn_state=attn_state)
return attn_metadata
class AscendAttentionTorchairBackendImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
**kwargs,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.hidden_size = self.num_heads * self.head_size
self.kv_cache_dtype = kv_cache_dtype
self.sliding_window = sliding_window
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes,
dtype=torch.float32,
device="npu")
self.alibi_slopes = alibi_slopes
self.attn_type = attn_type
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.key_cache = None
self.value_cache = None
self.scale_tensor = torch.zeros((), device='npu', dtype=torch.int32)
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AscendTorchairMetadata,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with Ascend attention.
Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
kv_cache: shape = [2, num_blocks, block_size,
num_kv_heads, head_size]
key_cache = [num_blocks, block_size,
num_kv_heads, head_size]
value_cache = [num_blocks, block_size,
num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [batch_size * seq_len, num_heads, head_size]
"""
num_tokens = query.shape[0]
use_kv_cache_quant = (kv_cache is not None and len(kv_cache) > 0
and kv_cache[0].numel() > 0
and kv_cache[0].dtype == torch.int8)
if output is None:
output = torch.empty(num_tokens,
self.num_heads,
self.head_size,
dtype=query.dtype,
device=query.device)
if hasattr(layer, 'quant_method') and use_kv_cache_quant:
output = layer.quant_method.apply(layer, query, key, value,
kv_cache, attn_metadata,
self.attn_type, self.scale,
output)
return output.view(num_tokens, self.hidden_size)
if attn_metadata is None:
return output.view(num_tokens, self.hidden_size).fill_(0)
output = output.view(-1, self.num_heads, self.head_size)
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
attn_type = self.attn_type
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"AscendAttentionTorchairBackendImpl")
if kv_cache is not None and kv_cache[0].numel() > 0:
key_cache, value_cache = kv_cache[0], kv_cache[1]
slots = attn_metadata.slot_mapping
block_size = self.scale_tensor + key_cache.shape[1]
slots_indices = slots.reshape(-1, 1)
block_indices = slots_indices // block_size
slots_indices = slots_indices % block_size
indices = torch.cat((block_indices, slots_indices), dim=1)
torch_npu.npu_scatter_nd_update_(key_cache, indices, key)
torch_npu.npu_scatter_nd_update_(value_cache, indices, value)
if attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
self.key_cache = key_cache
self.value_cache = value_cache
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
assert attn_metadata is not None
assert attn_metadata.attn_mask is not None
mask = attn_metadata.attn_mask
# View q k v to BSH.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
if get_ascend_device_type() == AscendDeviceType._310P:
# align q k v output tensors
query = aligned_16(query)
key = aligned_16(key)
value = aligned_16(value)
output = aligned_16(output)
# do reformat in case of broadcasted tensors
mask = mask.repeat(attn_metadata.seq_lens.size(0), 1, 1, 1)
mask = torch_npu.npu_format_cast(mask.contiguous(),
ACL_FORMAT_FRACTAL_NZ)
torch_npu._npu_flash_attention(query=query,
key=key,
value=value,
mask=mask,
seq_len=attn_metadata.seq_lens,
scale_value=self.scale,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
out=output)
output = output[:num_tokens, :, :]
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
assert attn_metadata is not None
assert attn_metadata.attn_mask is not None
compress_mask = attn_metadata.attn_mask
batch_size = attn_metadata.query_lens.shape[0]
block_table = attn_metadata.block_tables[:batch_size, :]
torch_npu._npu_flash_attention_qlens(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
block_table=block_table,
mask=compress_mask,
seq_len=attn_metadata.query_lens,
context_lens=attn_metadata.seq_lens,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
out=output)
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
decode_meta = attn_metadata.decode
assert decode_meta is not None
seq_lens = decode_meta.seq_lens_list
block_table = decode_meta.block_table
block_size = key_cache.shape[1]
query = query.view(num_tokens, 1,
self.num_heads * self.head_size).contiguous()
output, _ = torch_npu.npu_fused_infer_attention_score(
query=query,
key=key_cache,
value=value_cache,
query_rope=None,
key_rope=None,
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout='BSH',
atten_mask=decode_meta.attn_mask,
sparse_mode=0,
scale=self.scale,
antiquant_mode=0,
antiquant_scale=None,
block_table=block_table,
block_size=block_size,
actual_seq_lengths_kv=seq_lens,
)
else:
raise NotImplementedError(
"Torchair graph mode with non-MLA attention backend is still experimental."
"v1 scheduler(chunked prefill) is not supported at this moment."
)
return output.view(num_tokens, self.hidden_size)

File diff suppressed because it is too large Load Diff

View File

@@ -1,574 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
# isort: skip_file
import math
import types
from typing import Any, Optional
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch_npu
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_dp_group
from vllm.forward_context import get_forward_context
from vllm.logger import logger
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.spec_decode import get_spec_decode_method
from vllm_ascend.torchair.utils import (
TORCHAIR_CACHE_DIR, TorchairCommonAttentionMetadata,
check_torchair_cache_exist, converting_weight_acl_format,
register_torchair_model, torchair_ops_patch,
torchair_quant_method_register, write_kv_cache_bytes_to_file)
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
AscendDeviceType, get_ascend_device_type)
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
class NPUTorchairModelRunner(NPUModelRunner):
def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.ascend_config = get_ascend_config()
self.enable_shared_expert_dp = self.ascend_config.enable_shared_expert_dp
super().__init__(vllm_config, device)
if self.speculative_config:
self.actual_seq_lengths_q = list(
range(self.decode_token_per_req, self.max_num_tokens + 1,
self.decode_token_per_req))
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
None, None, vllm_config, device)
self.use_sparse = hasattr(self.model_config.hf_config, "index_topk")
register_torchair_model()
torchair_ops_patch()
torchair_quant_method_register()
if self.enable_shared_expert_dp:
return
self.new_kv_cache_bytes = -1
self.torchair_compiled_model = None # type: ignore
self.torchair_compiled_models = {} # type: ignore
self.use_cached_npu_graph = self.ascend_config.torchair_graph_config.use_cached_graph
self.use_cached_kv_cache_bytes = self.ascend_config.torchair_graph_config.use_cached_kv_cache_bytes
self.torchair_graph_batch_sizes = self.ascend_config.torchair_graph_config.graph_batch_sizes
if self.ascend_config.torchair_graph_config.graph_batch_sizes_init:
self.init_torchair_graph_batch_sizes()
self.update_torchair_graph_batch_sizes()
torch._dynamo.cache_size.config.cache_size_limit += len(
self.torchair_graph_batch_sizes)
torch._dynamo.config.capture_dynamic_output_shape_ops = True
torch._logging.set_logs(
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
self._check_batch_sizes_consistency()
def _set_up_drafter(self):
super()._set_up_drafter()
if self.speculative_config:
# Torchair do not support disable_padded_drafter_batch
# Enforce to disable this feature
self.speculative_config.disable_padded_drafter_batch = True
def _get_drafter(self):
return get_spec_decode_method(self.speculative_config.method,
self.vllm_config,
self.device,
self,
is_torchair_graph=True)
def _may_pad_kv_consumer_num_seq(self):
# pd disaggregation scenario need redundant_batch_sizes to avoid each batch's seq_len exceed 16 tokens
# self.max_num_reqs here is greater than the actual maximum request number
if self.decode_token_per_req > 1 and self.is_kv_consumer:
# applied only when speculative decoding is active
FIA_SEQ_LEN_LIMIT = 16
new_max_num_reqs = self.max_num_reqs + math.ceil(
self.max_num_reqs / FIA_SEQ_LEN_LIMIT) + math.ceil(
(self.max_num_reqs * self.decode_token_per_req) /
(FIA_SEQ_LEN_LIMIT**2))
if self.max_num_reqs < new_max_num_reqs:
logger.warning(
f"max_num_reqs is updated to {new_max_num_reqs}")
self.max_num_reqs = new_max_num_reqs
def _init_mc2_tokens_capacity(self):
# NOTE: To be clear, we need to make sure that during graph capture, the number of
# tokens is less than or equal to mc2_tokens_capacity. According to _set_cudagraph_sizes,
# the max number of tokens in graph is min(max_num_seqs * uniform_decode_query_len, 512).
max_num_tokens = self.max_num_reqs * self.uniform_decode_query_len
tp_size = self.parallel_config.tensor_parallel_size
# Use integer arithmetic for ceiling division.
max_graph_batch_size = self.calculate_new_torchair_graph_batch_size(
max_num_tokens, tp_size)
self.mc2_tokens_capacity = max_graph_batch_size
if get_ascend_device_type(
) == AscendDeviceType._910_93 and self.mc2_tokens_capacity > 512:
logger.error(
f"A3: the max number of tokens must smaller then 512, but now is {self.mc2_tokens_capacity}"
)
if get_ascend_device_type(
) == AscendDeviceType._910B and self.mc2_tokens_capacity > 256:
logger.error(
f"A2: the max number of tokens must smaller then 256, but now is {self.mc2_tokens_capacity}"
)
def _sync_metadata_across_dp(
self, num_tokens: int,
with_prefill: bool) -> tuple[int, Optional[torch.Tensor], bool]:
"""Override from NPUModelRunner to pad num_tokens"""
if self.enable_shared_expert_dp:
# Padding is not required for shared_expert_dp cases in eager mode.
return num_tokens, None, with_prefill
if self.dp_size == 1:
if not with_prefill:
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
num_tokens)
return maybe_padded_num_tokens, None, with_prefill
return num_tokens, None, with_prefill
num_tokens_across_dp = torch.zeros(self.dp_size + 1,
dtype=torch.int32,
device="npu")
num_tokens_across_dp[self.dp_rank] = num_tokens
num_tokens_across_dp[-1] = int(with_prefill)
dist.all_reduce(num_tokens_across_dp,
group=get_dp_group().device_group)
with_prefill = bool(num_tokens_across_dp[-1])
num_tokens_across_dp = num_tokens_across_dp[:-1]
if not with_prefill:
max_num_token = num_tokens_across_dp.max().item()
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
max_num_token)
num_tokens_across_dp = torch.full((self.dp_size, ),
maybe_padded_num_tokens,
dtype=torch.int32,
device="npu")
else:
maybe_padded_num_tokens = num_tokens
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill
def _build_dummy_attn_metadata(
self,
with_prefill: bool,
num_reqs: int,
num_tokens: int,
max_query_len: int,
num_scheduled_tokens: np.ndarray,
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
force_attention: bool = False,
) -> Optional[dict[str, Any]]:
# NOTE: If torchair graph mode and not with_prefill,
# we can't skip_attn, it will cause graph recompile.
if with_prefill or self.enable_shared_expert_dp:
attn_metadata = super()._build_dummy_attn_metadata(
with_prefill, num_reqs, num_tokens, max_query_len,
num_scheduled_tokens, aclgraph_runtime_mode, force_attention)
else:
common_attn_metadata = TorchairCommonAttentionMetadata(
num_reqs=num_reqs,
num_actual_tokens=1,
actual_seq_lengths_q=self.actual_seq_lengths_q,
attn_mask=self.attn_mask,
spec_attn_mask=self.spec_attn_mask,
decode_token_per_req=self.decode_token_per_req,
)
attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy(
common_attn_metadata)
return attn_metadata
def _generate_dummy_run_hidden_states(self, with_prefill,
is_torchair_compile, input_ids,
positions, attn_metadata, num_tokens,
intermediate_tensors, inputs_embeds):
if with_prefill or self.enable_shared_expert_dp:
if get_ascend_device_type() == AscendDeviceType._310P:
converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND)
hidden_states = super()._generate_dummy_run_hidden_states(
with_prefill, is_torchair_compile, input_ids, positions,
attn_metadata, num_tokens, intermediate_tensors, inputs_embeds)
else:
# Only mark static while compiling
if is_torchair_compile:
torch._dynamo.mark_static(input_ids)
torch._dynamo.mark_static(positions)
torch._dynamo.mark_static(attn_metadata.decode.block_table)
torch._dynamo.mark_static(attn_metadata.decode.input_positions)
torch._dynamo.mark_static(get_forward_context().mc2_mask)
if hasattr(attn_metadata.decode, "sin"):
torch._dynamo.mark_static(attn_metadata.decode.sin)
torch._dynamo.mark_static(attn_metadata.decode.cos)
torch._dynamo.mark_static(attn_metadata.slot_mapping)
if self.speculative_config:
torch._dynamo.mark_static(attn_metadata.decode.attn_mask)
for kv in self.kv_caches:
assert isinstance(kv, tuple), "kv_cache must be a tuple"
torch._dynamo.mark_static(kv[0])
torch._dynamo.mark_static(kv[1])
if get_ascend_device_type() == AscendDeviceType._310P:
converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_NZ)
compiled_model = self._get_torchair_lazy_compiled_model(num_tokens)
model_kwargs = {}
model_kwargs["kv_caches"] = self.kv_caches
model_kwargs["attn_metadata"] = attn_metadata
hidden_states = compiled_model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=None,
**model_kwargs,
)
return hidden_states
def _convert_torch_format(self, kv_cache):
if self.enable_shared_expert_dp:
return super()._convert_torch_format(kv_cache)
kv_cache = torch_npu.npu_format_cast(kv_cache, ACL_FORMAT_FRACTAL_ND)
return kv_cache
def _compile_torchair_graph(self, torchair_graph_batch_sizes) -> None:
# Trigger torchair graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
for idx, num_tokens in enumerate(reversed(torchair_graph_batch_sizes)):
for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups):
self._dummy_run(num_tokens, is_torchair_compile=True)
self._dummy_run(num_tokens, is_torchair_compile=True)
logger.info("Batchsize %d is compiled successfully: %d/%d.",
num_tokens, idx + 1, len(torchair_graph_batch_sizes))
def _capture_model(self):
"""Override from NPUModelRunner to use torchair graph capture."""
if self.enable_shared_expert_dp:
return super()._capture_model()
# TODO(NeverRaR): Calling graph_capture(device=self.device) in
# torchair graph capture can cause some issues, so now we just
# temporarily split the codepath for the two different graph patterns.
torchair_graph_batch_sizes = self.torchair_graph_batch_sizes
graph_num = len(torchair_graph_batch_sizes)
if self.use_cached_npu_graph and not check_torchair_cache_exist():
# If caching is enabled but does not exist (either
# use_cached_kv_cache_bytes is disabled or kv_cache_bytes are
# different), we will compile the model twice. The first time is
# used to generate the cache, and the second time is used to load the
# cache to skip the overhead caused by Dynamo guard mechanism.
logger.info(
"Cache compilation for torchair graph is enabled. Now we compile graph to genetate"
" torchair cache, this usually takes %.1f~%.1f mins.",
0.5 * graph_num, 1.5 * graph_num)
self._compile_torchair_graph(torchair_graph_batch_sizes)
NPUPlatform.synchronize()
# Note: We reset dynamo and reload the compiled torchair cached computation graph below
# that was compiled above. This operation reduces graph launch time by 2-4ms and avoids
# runtime errors caused by configuration mismatches in graph mode.
torch._dynamo.reset()
self.torchair_compiled_models.clear()
if self.use_cached_npu_graph:
logger.info(
"Loading torchair graph cache, this usually takes %.1f~%.1f mins.",
0.3 * graph_num, 0.5 * graph_num)
self._compile_torchair_graph(torchair_graph_batch_sizes)
else:
logger.info(
"Capturing torchair graph, this usually takes %.1f~%.1f mins.",
0.5 * graph_num, 1.5 * graph_num)
self._compile_torchair_graph(torchair_graph_batch_sizes)
if self.use_cached_kv_cache_bytes and self.new_kv_cache_bytes > 0:
write_kv_cache_bytes_to_file(torch.distributed.get_rank(),
self.new_kv_cache_bytes)
def _use_aclgraph(self) -> bool:
if self.enable_shared_expert_dp:
return super()._use_aclgraph()
return False
def _check_batch_sizes_consistency(self) -> None:
if not dist.is_initialized():
return
local = torch.tensor(self.torchair_graph_batch_sizes,
device="cpu",
dtype=torch.int32)
gathered_graph_batch_size = local.clone()
dist.all_reduce(gathered_graph_batch_size,
group=get_dp_group().cpu_group)
expected = local * self.dp_size
if not torch.equal(gathered_graph_batch_size, expected):
diff_idxs = (gathered_graph_batch_size != expected).nonzero(
as_tuple=False).flatten().tolist()
raise AssertionError(
f"[Graph BatchSize Mismatch] Found mismatches at indices {diff_idxs}.\n"
f"Local (rank {self.dp_rank}): {local.tolist()}\n"
f"Sum over ranks: {gathered_graph_batch_size.tolist()}\n"
f"Expected if all equal: {[v * self.dp_size for v in local.tolist()]}"
)
def _update_graph_pad_size(self, with_prefill, graph_pad_size):
if with_prefill or self.enable_shared_expert_dp:
super()._update_graph_pad_size(with_prefill, graph_pad_size)
else:
self.graph_pad_size = graph_pad_size
def _update_input_ids_and_positions(self, input_ids, positions,
num_input_tokens, with_prefill,
padded_num_tokens_across_dp):
"""Override from NPUModelRunner to update input_ids and positions"""
input_ids, positions = super()._update_input_ids_and_positions(
input_ids, positions, num_input_tokens, with_prefill,
padded_num_tokens_across_dp)
if with_prefill or self.enable_shared_expert_dp:
return input_ids, positions
else:
input_ids = self.input_ids[:padded_num_tokens_across_dp]
positions = self.positions[:padded_num_tokens_across_dp]
return input_ids, positions
def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
padded_num_tokens_across_dp,
input_ids, positions,
intermediate_tensors,
inputs_embeds):
if attn_metadata is not None and isinstance(attn_metadata, dict):
attn_metadata = attn_metadata['model.layers.0.self_attn.attn']
if self.enable_shared_expert_dp:
return super()._generate_process_reqs_hidden_states(
attn_metadata, with_prefill, padded_num_tokens_across_dp,
input_ids, positions, intermediate_tensors, inputs_embeds)
model_kwargs = {
"kv_caches": self.kv_caches,
"attn_metadata": attn_metadata
}
if not with_prefill:
if get_ascend_device_type() == AscendDeviceType._310P:
converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_NZ)
compiled_model = self._get_torchair_lazy_compiled_model(
padded_num_tokens_across_dp)
hidden_states = compiled_model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**model_kwargs,
)
else:
assert self.model is not None
if get_ascend_device_type() == AscendDeviceType._310P:
converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND)
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**model_kwargs,
)
return hidden_states
def _get_torchair_lazy_compiled_model(self, batch_size: int):
if batch_size < 0 or batch_size > self.torchair_graph_batch_sizes[-1]:
raise ValueError(
f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.torchair_graph_batch_sizes[-1]}"
)
compiled_model = self.torchair_compiled_models.get(
batch_size
) if self.use_cached_npu_graph else self.torchair_compiled_model
if compiled_model:
return compiled_model
import torchair # type: ignore
from torchair import patch_for_hcom # type: ignore
patch_for_hcom()
if get_ascend_device_type() == AscendDeviceType._310P:
# on 300I Duo platform, we need to patch broadcast. however, this patch will be
# overwritten by patch_for_hcom in torchair. so we need to re-patch it here.
from vllm_ascend.patch.platform.patch_distributed import \
communication_adaptation_310p
communication_adaptation_310p()
config = torchair.CompilerConfig()
if self.ascend_config.torchair_graph_config.mode:
config.mode = self.ascend_config.torchair_graph_config.mode
config.experimental_config.frozen_parameter = \
self.ascend_config.torchair_graph_config.enable_frozen_parameter
# enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to
# disable it on 300I Duo platform now.
config.experimental_config.tiling_schedule_optimize = get_ascend_device_type(
) != AscendDeviceType._310P
config.experimental_config.enable_view_optimize = \
self.ascend_config.torchair_graph_config.enable_view_optimize
torch.npu.set_compile_mode(jit_compile=False)
if not self.use_cached_npu_graph:
npu_backend = torchair.get_npu_backend(compiler_config=config)
self.torchair_compiled_model = torch.compile(
self.model,
dynamic=not self.use_sparse,
fullgraph=True,
backend=npu_backend)
return self.torchair_compiled_model
else:
# Generate a new forward proxy code object to prevent the invalidation of
# compilation cache caused by dynamo retracing
forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}"
forward_fn = self.model.forward
code = forward_fn.__code__
# Mark code object with a new proxy name
modified_code = code.replace(co_name=forward_proxy_name, )
modified_func = types.FunctionType(modified_code,
forward_fn.__globals__,
name=forward_proxy_name,
argdefs=forward_fn.__defaults__)
self.model.__dict__[forward_proxy_name] = modified_func.__get__(
self.model, nn.Module)
self.torchair_compiled_models[
batch_size] = torchair.inference.cache_compile(
self.model.__dict__[forward_proxy_name],
dynamic=not self.use_sparse,
fullgraph=True,
cache_dir=TORCHAIR_CACHE_DIR,
config=config,
ge_cache=False)
return self.torchair_compiled_models[batch_size]
def init_torchair_graph_batch_sizes(self):
start_graph_batch_size = 4
tp_size = get_tensor_model_parallel_world_size()
# NOTE: When use all2all | mc2, We need to slice the `num_tokens` dimension into `tp_size` blocks
start_graph_batch_size = max(start_graph_batch_size, tp_size)
while (start_graph_batch_size <= self.max_num_reqs):
self.torchair_graph_batch_sizes.append(start_graph_batch_size)
start_graph_batch_size *= 2
def calculate_new_torchair_graph_batch_size(self, old_graph_batch_size,
tp_size):
cur_graph_batch_size = (old_graph_batch_size + tp_size -
1) // tp_size * tp_size
# MTP > 1: Cal LCMLeast Common Multiple with graph_batch_size and tp_size,
# Both adapter multi-dp and FIA operator
if self.speculative_config is not None and self.speculative_config.num_speculative_tokens > 1:
cur_graph_batch_size = (tp_size * old_graph_batch_size) \
// math.gcd(tp_size, old_graph_batch_size)
return cur_graph_batch_size
def select_torchair_padded_batch_size(self, batch_size: int):
for padded_batch_size in self.torchair_graph_batch_sizes:
if batch_size <= padded_batch_size:
# we treat batch_size as num of requests
return padded_batch_size
raise ValueError(
f"cur batch_size is invalid, torchair_graph_batch_sizes is "
f"{self.torchair_graph_batch_sizes}, but cur batch_size is {batch_size}."
)
def update_torchair_graph_batch_sizes(self):
# return graph_batch_sizes according to the max number of tokens
# first pad according to the number of requests
if self.is_kv_consumer and self.speculative_config and self.speculative_config.method == 'mtp':
# pd disaggregation scenario may incorrectly calculate the batch in mtp scenario, so we force set it to max_num_reqs
self.torchair_graph_batch_sizes = [self.max_num_reqs]
logger.warning(
f"is kv_consumer, torch_graph_batch_sizes sets to [max_num_seqs] {[self.max_num_reqs]}"
)
elif len(self.torchair_graph_batch_sizes) == 0:
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
else:
self.torchair_graph_batch_sizes = sorted(
self.torchair_graph_batch_sizes)
while self.torchair_graph_batch_sizes[-1] > self.max_num_reqs:
self.torchair_graph_batch_sizes.pop()
if len(self.torchair_graph_batch_sizes) == 0:
logger.warning(
"torch_graph_batch_sizes is invalid, reset it to [1, max_num_seqs]"
)
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
if self.torchair_graph_batch_sizes[-1] < self.max_num_reqs:
self.torchair_graph_batch_sizes.append(self.max_num_reqs)
# padded max number tokens = max_num_req * decode_token_per_req
self.torchair_graph_batch_sizes = [
graph_batch_size * self.decode_token_per_req
for graph_batch_size in self.torchair_graph_batch_sizes
]
# NOTE: when enable_expert_parallel on A3, we need to check if `graph_batch_size` is divisible by `tp_size`
# Because we use x_active_mask for dispatch/combine op on A3, which requires that input shape should be same
# on all EP ranks
if get_ascend_device_type(
) == AscendDeviceType._910_93 and self.parallel_config.enable_expert_parallel:
self._align_graph_size_divisible_by_tp_size()
def _align_graph_size_divisible_by_tp_size(self):
tp_size = self.parallel_config.tensor_parallel_size
new_graph_batch_sizes = []
for graph_batch_size in self.torchair_graph_batch_sizes:
cur_graph_batch_size = self.calculate_new_torchair_graph_batch_size(
graph_batch_size, tp_size)
if cur_graph_batch_size not in new_graph_batch_sizes and \
cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens:
new_graph_batch_sizes.append(cur_graph_batch_size)
elif cur_graph_batch_size > self.scheduler_config.max_num_batched_tokens \
and self.decode_token_per_req > 1:
logger.warning(
f"torchair_graph_batch_sizes {cur_graph_batch_size} is bigger than max_num_batched_tokens",
f"{self.scheduler_config.max_num_batched_tokens} will skip this batch size."
)
new_max_num_reqs = math.ceil(
max(new_graph_batch_sizes) / self.decode_token_per_req)
if self.max_num_reqs != new_max_num_reqs:
logger.warning(f"max_num_reqs is updated to {new_max_num_reqs}")
self.max_num_reqs = new_max_num_reqs
if not (self.decode_token_per_req > 1 and self.is_kv_consumer):
# Do not update scheduler_config.max_num_seqs in KV consumer + MTP
# Since FIA need extra space for padding
# Enforce self.max_num_seqs > self.scheduler_config.max_num_seqs in KV consumer + MTP
self.scheduler_config.max_num_seqs = new_max_num_reqs
if new_graph_batch_sizes != self.torchair_graph_batch_sizes:
logger.warning(
f"torchair_graph_batch_sizes are updated to {new_graph_batch_sizes}."
)
self.torchair_graph_batch_sizes = new_graph_batch_sizes
def _build_drafter_prepare_inputs_torchair_param(self):
if self.enable_shared_expert_dp:
return super()._build_drafter_prepare_inputs_torchair_param()
else:
return True

View File

@@ -1,543 +0,0 @@
import types
import torch
import torch.nn as nn
import torchair
from torchair import patch_for_hcom
from vllm.config import (CUDAGraphMode, VllmConfig,
get_layers_from_vllm_config, set_current_vllm_config)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.utils import \
process_weights_after_loading
from vllm.utils.torch_utils import set_default_torch_dtype
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.spec_decode import MtpProposer
from vllm_ascend.torchair.models.torchair_deepseek_mtp import \
TorchairDeepSeekMTP
from vllm_ascend.torchair.utils import (TORCHAIR_CACHE_DIR,
TorchairCommonAttentionMetadata)
from vllm_ascend.utils import ProfileExecuteDuration, lmhead_tp_enable
PADDING_SLOT_ID = -1
class TorchairMtpProposer(MtpProposer):
def __init__(
self,
vllm_config: VllmConfig,
device,
runner,
):
super().__init__(vllm_config, device, runner)
self.torchair_compiled_model = None # type: ignore
self.torchair_compiled_models = {} # type: ignore
def load_model(self, model) -> None:
loader = get_model_loader(self.vllm_config.load_config)
target_attn_layer_names = set(
get_layers_from_vllm_config(self.vllm_config,
AttentionLayerBase).keys())
draft_model_config = \
self.vllm_config.speculative_config.draft_model_config
target_device = self.vllm_config.device_config.device
with set_default_torch_dtype(
draft_model_config.dtype), set_current_vllm_config(
self.vllm_config):
self.model = TorchairDeepSeekMTP(
vllm_config=self.vllm_config).to(target_device)
draft_attn_layer_names = (get_layers_from_vllm_config(
self.vllm_config, AttentionLayerBase).keys() -
target_attn_layer_names)
assert len(draft_attn_layer_names) == 1
self.attn_layer_name = list(draft_attn_layer_names)
self.model.load_weights(
loader.get_all_weights(
self.vllm_config.speculative_config.draft_model_config,
self.model))
process_weights_after_loading(self.model, draft_model_config,
target_device)
@torch.inference_mode()
def dummy_run(self,
num_tokens: int,
with_prefill: bool = False,
skip_attn: bool = False,
num_reqs: int = 0,
num_tokens_across_dp=None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor=None,
dummy_compute_logits=lambda hidden_states: None) -> None:
moe_comm_type = self.runner._select_moe_comm_method(num_tokens)
if not with_prefill:
skip_attn = False
if skip_attn:
attn_metadata = None
else:
common_attn_metadata = TorchairCommonAttentionMetadata(
num_reqs=num_reqs,
num_actual_tokens=1,
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
attn_mask=self.runner.attn_mask,
spec_attn_mask=self.runner.spec_attn_mask,
decode_token_per_req=self.runner.decode_token_per_req,
)
attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy(
common_attn_metadata)
input_ids = self.input_ids[:num_tokens]
positions = self.positions[:num_tokens]
previous_hidden_states = self.hidden_states[:num_tokens]
for _ in range(self.num_speculative_tokens):
with set_ascend_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
with_prefill=with_prefill,
num_tokens_across_dp=num_tokens_across_dp,
reserved_mc2_mask=self.runner.reserved_mc2_mask,
moe_comm_type=moe_comm_type,
in_profile_run=self.runner.in_profile_run,
num_actual_tokens=0):
if not with_prefill:
assert attn_metadata is not None
torch._dynamo.mark_static(input_ids)
torch._dynamo.mark_static(positions)
torch._dynamo.mark_static(previous_hidden_states)
torch._dynamo.mark_static(attn_metadata.decode.block_table)
torch._dynamo.mark_static(
attn_metadata.decode.input_positions)
if hasattr(attn_metadata.decode, "sin"):
torch._dynamo.mark_static(attn_metadata.decode.sin)
torch._dynamo.mark_static(attn_metadata.decode.cos)
torch._dynamo.mark_static(get_forward_context().mc2_mask)
torch._dynamo.mark_static(attn_metadata.slot_mapping)
torch._dynamo.mark_static(attn_metadata.decode.attn_mask)
torchair_compiled_model = self._get_torchair_lazy_compiled_model(
num_tokens)
torchair_compiled_model(
input_ids=input_ids,
positions=positions,
hidden_states=previous_hidden_states,
inputs_embeds=None,
intermediate_tensors=None,
attn_metadata=attn_metadata,
kv_caches=self.runner.kv_caches[-1:],
spec_step_idx=0)
else:
self.model(input_ids=input_ids,
positions=positions,
hidden_states=previous_hidden_states)
dummy_compute_logits(previous_hidden_states)
if with_prefill:
break
def generate_token_ids(self,
valid_sampled_token_ids: list[list[int]],
sampling_metadata: SamplingMetadata = None,
scheduler_output: SchedulerOutput = None,
spec_decode_metadata: SpecDecodeMetadata = None,
positions: torch.Tensor = None,
num_scheduled_tokens: int = 0,
hidden_states: torch.Tensor = None,
attn_metadata=None,
aux_hidden_states: torch.Tensor = None):
if attn_metadata is not None and isinstance(attn_metadata, dict):
attn_metadata = attn_metadata['model.layers.0.self_attn.attn']
next_token_ids: list[int] = []
for i, token_ids in enumerate(valid_sampled_token_ids):
if token_ids:
# Common case.
next_token_id = token_ids[-1]
else:
# Partial prefill (rare case).
# Get the next token id from the request state.
req_id = self.runner.input_batch.req_ids[i]
req_state = self.runner.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
next_token_id = req_state.get_token_id(seq_len)
next_token_ids.append(next_token_id)
next_token_ids = torch.tensor(next_token_ids,
dtype=torch.int32,
device=self.device)
accepted_token_indices = None
if spec_decode_metadata is None:
# input_ids can be None for multimodal models.
target_token_ids = self.runner.input_ids[:num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
target_hidden_states = hidden_states[:num_scheduled_tokens]
target_slot_mapping = attn_metadata.slot_mapping
cu_num_tokens = attn_metadata.query_start_loc
else:
# TODO(woosuk): Refactor this.
num_draft_tokens = spec_decode_metadata.num_draft_tokens
num_rejected_tokens = [
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
for i, n in enumerate(num_draft_tokens)
]
num_rejected_tokens = torch.tensor(
num_rejected_tokens,
dtype=torch.int32,
device=self.device,
)
cu_num_tokens, accepted_token_indices, target_token_ids, \
target_positions, target_hidden_states, target_slot_mapping = self._torchair_prepare_inputs(
attn_metadata.query_start_loc,
num_rejected_tokens,
self.runner.input_ids[:num_scheduled_tokens],
positions[:num_scheduled_tokens],
hidden_states[:num_scheduled_tokens],
attn_metadata.slot_mapping[:num_scheduled_tokens],
)
draft_token_ids = self._propose_torchair(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
target_slot_mapping=target_slot_mapping,
next_token_ids=next_token_ids,
cu_num_tokens=cu_num_tokens,
block_table=attn_metadata.block_tables,
sampling_metadata=sampling_metadata,
token_indices=accepted_token_indices)
spec_token_ids = draft_token_ids.tolist()
return spec_token_ids
def _torchair_prepare_inputs(
self,
# [batch_size + 1]
cu_target_query_lens: torch.Tensor,
# [batch_size]
num_rejected_tokens: torch.Tensor,
token_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
slot_mapping: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor]:
# cu_target_query_lens: [0, a, a + b, a + b + c]
# num_rejected_tokens: [n1, n2, n3]
# num_tokens_per_req: [a - n1, b - n2, c - n3]
# cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
# token_indices: [0, 1, ..., a - n1 - 1,
# a, a + 1, ..., a + b - n2 - 1,
# a + b, a + b + 1, ..., a + b + c - n3 - 1]
# [0, a, a + b, a + b + c] -> [a, b, c]
query_len_per_req = (cu_target_query_lens[1:] -
cu_target_query_lens[:-1])
# [a, b, c] -> [a - n1, b - n2, c - n3]
cu_num_tokens = cu_target_query_lens
relative_index = query_len_per_req - num_rejected_tokens - 1
token_indices = cu_num_tokens[:-1] + relative_index
# the seq len of each bath is padded to 1+num_speculative_tokens, thus input is same as the main model
target_token_ids = token_ids
target_positions = positions
target_hidden_states = hidden_states
target_slot_mapping = slot_mapping
return cu_num_tokens, token_indices, target_token_ids, target_positions, target_hidden_states, target_slot_mapping
def _propose_torchair(
self,
# [num_tokens]
target_token_ids: torch.Tensor,
# [num_tokens]
target_positions: torch.Tensor,
# [num_tokens, hidden_size]
target_hidden_states: torch.Tensor,
# [num_tokens]
target_slot_mapping: torch.Tensor,
# [batch_size]
next_token_ids: torch.Tensor,
# [batch_size + 1] starting with 0
cu_num_tokens: torch.Tensor,
# [batch_size, max_num_blocks_per_req]
block_table: torch.Tensor,
sampling_metadata: SamplingMetadata,
token_indices=None) -> torch.Tensor:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
last_token_indices = cu_num_tokens[1:] - 1
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
if token_indices is not None:
last_token_indices = token_indices
self.input_ids[last_token_indices] = next_token_ids
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
max_query_len = query_lens.max().item()
# FIXME: reorder_batch() needs to be called before build()
# because fields of attn_metadata_builder needs to be updated.
# However, currently reorder_batch() takes input_batch and
# scheduler_output as arguments, we should probably refactor
# the method to use new data structures which are independent
# from input_batch and scheduler_output.
# self.runner.attn_metadata_builder.reorder_batch(
# input_batch=self.runner.input_batch,
# scheduler_output=self.runner.scheduler_output,
# )
if not self.runner.with_prefill:
# Torchair graph mode, padding is same as the main model
num_input_tokens = self.runner.graph_pad_size
elif (self.runner.use_aclgraph
and num_tokens <= self.runner.aclgraph_batch_sizes[-1]):
# Acl graph mode, add padding to the batch size
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
else:
# Eager mode, no padding needed
num_input_tokens = num_tokens
seq_lens = target_positions[last_token_indices] + 1
seq_lens = seq_lens.int()
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=cu_num_tokens[:batch_size + 1],
query_start_loc_cpu=cu_num_tokens[:batch_size + 1].cpu(),
seq_lens_cpu=seq_lens.cpu(),
num_reqs=batch_size,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
block_table_tensor=self.runner.input_batch.block_table[0].
get_device_tensor(),
slot_mapping=target_slot_mapping,
positions=target_positions,
attn_mask=self.runner.attn_mask,
spec_attn_mask=self.runner.spec_attn_mask,
attn_state=self.runner.attn_state,
graph_pad_size=self.runner.graph_pad_size,
decode_token_per_req=self.runner.decode_token_per_req,
num_computed_tokens_cpu=None,
seq_lens=None)
attn_metadata = self.runner.attn_metadata_builder.build(
0, common_attn_metadata, self.runner.get_model())
self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states
# torchair mode can reuse self.runner.num_tokens_across_dp
num_tokens_across_dp = self.runner.num_tokens_across_dp
with_prefill = self.runner.with_prefill
moe_comm_type = self.runner._select_moe_comm_method(num_input_tokens)
for step in range(self.num_speculative_tokens):
with set_ascend_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
with_prefill=with_prefill,
num_tokens_across_dp=num_tokens_across_dp,
reserved_mc2_mask=self.runner.reserved_mc2_mask,
moe_comm_type=moe_comm_type,
in_profile_run=self.runner.in_profile_run,
num_actual_tokens=num_tokens):
with ProfileExecuteDuration().capture_async('mtp_forward'):
model_kwargs = {}
model_kwargs["attn_metadata"] = attn_metadata
model_kwargs["kv_caches"] = self.runner.kv_caches[-1:]
if not self.runner.with_prefill:
torchair_compiled_model = self._get_torchair_lazy_compiled_model(
num_input_tokens)
hidden_states = torchair_compiled_model(
input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens],
hidden_states=self.
hidden_states[:num_input_tokens],
inputs_embeds=None,
intermediate_tensors=None,
spec_step_idx=0,
**model_kwargs)
else:
hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens],
hidden_states=self.hidden_states[:num_input_tokens]
)
num_indices = last_token_indices.shape[0]
if lmhead_tp_enable():
if not self.runner.with_prefill:
max_num_reqs_across_dp = num_input_tokens
else:
max_num_reqs_across_dp = self.vllm_config.scheduler_config.max_num_seqs
last_token_indices = nn.functional.pad(
last_token_indices,
(0, max_num_reqs_across_dp - num_indices))
sample_hidden_states = hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states)
if lmhead_tp_enable() and num_indices < logits.shape[0]:
logits = logits[:num_indices]
draft_token_ids = logits.argmax(dim=-1)
if self.num_speculative_tokens == 1:
# [batch_size, 1]
return draft_token_ids.view(-1, 1)
if step == 0:
draft_token_ids_list = [draft_token_ids]
else:
draft_token_ids_list.append(draft_token_ids)
# prepare next mtp inputs
# mtp>1: prefill skip or decode skip last loop
if with_prefill:
for _ in range(self.num_speculative_tokens - 1):
draft_token_ids_list.append(draft_token_ids)
if step == self.num_speculative_tokens - 1 or with_prefill:
break
attn_metadata_i = attn_metadata
if step == 0:
positions = target_positions[last_token_indices]
hidden_states = hidden_states[last_token_indices]
slot_mapping = attn_metadata_i.slot_mapping[last_token_indices]
attn_metadata_i.slot_mapping.fill_(-1)
attn_metadata_i.query_start_loc = self.arange[:batch_size + 1]
last_token_indices = self.arange[:batch_size]
if attn_metadata_i.num_decode_tokens != 0:
attn_metadata_i.num_decode_tokens = batch_size
if not self.runner.with_prefill:
attn_metadata_i.num_actual_tokens = batch_size
attn_metadata_i.query_lens = [1] * batch_size
input_ids = draft_token_ids_list[-1].int()
positions += 1
# NOTE(woosuk): We should handle the case where the draft model
# generates tokens beyond the max model length. Since it is complex
# to remove such requests from the batch, we keep them in the batch
# but adjust the position ids and slot mappings to avoid the
# out-of-range access during the model execution. The draft tokens
# generated with this adjustment should be ignored.
exceeds_max_model_len = positions >= self.runner.model_config.max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_positions = torch.where(exceeds_max_model_len, 0,
positions)
# Increment the sequence lengths.
attn_metadata_i.seq_lens[:batch_size] += 1
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
exceeds_max_model_len_cpu = exceeds_max_model_len.to(
attn_metadata_i.seq_lens.device, non_blocking=True)
attn_metadata_i.seq_lens[:batch_size].masked_fill_(
exceeds_max_model_len_cpu, 1)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
slot_mapping += 1
slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID)
# copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids
self.positions[:batch_size] = clamped_positions
self.hidden_states[:hidden_states.shape[0]] = hidden_states
attn_metadata_i.slot_mapping[:batch_size] = slot_mapping
if attn_metadata_i.prefill is not None:
attn_metadata_i.prefill.seq_lens = attn_metadata_i.seq_lens
attn_metadata_i.prefill.seq_lens_list = attn_metadata_i.prefill.seq_lens.tolist(
)
attn_metadata_i.prefill.context_lens = attn_metadata_i.seq_lens
attn_metadata_i.prefill.input_positions = self.positions[:
num_input_tokens]
attn_metadata_i.prefill.max_seq_lens += 1
attn_metadata_i.prefill.max_seq_lens = min(
attn_metadata_i.prefill.max_seq_lens,
self.runner.model_config.max_model_len)
if attn_metadata_i.decode is not None:
attn_metadata_i.decode.seq_lens = attn_metadata_i.seq_lens
attn_metadata_i.decode.seq_lens_list = attn_metadata_i.decode.seq_lens.tolist(
)
attn_metadata_i.decode.input_positions = self.positions[:
num_input_tokens]
attn_metadata_i.decode.max_seq_lens += 1
attn_metadata_i.decode.max_seq_lens = min(
attn_metadata_i.decode.max_seq_lens,
self.runner.model_config.max_model_len)
# mtp>1: [batch_size, k]
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
return draft_token_ids
def _get_torchair_lazy_compiled_model(self, batch_size: int):
if batch_size < 0 or batch_size > self.runner.torchair_graph_batch_sizes[
-1]:
raise ValueError(
f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.runner.torchair_graph_batch_sizes[-1]}"
)
compiled_model = self.torchair_compiled_models.get(
batch_size
) if self.runner.use_cached_npu_graph else self.torchair_compiled_model
if compiled_model:
return compiled_model
patch_for_hcom()
config = torchair.CompilerConfig()
config.experimental_config.frozen_parameter = True
config.experimental_config.tiling_schedule_optimize = True
config.experimental_config.enable_view_optimize = \
get_ascend_config().torchair_graph_config.enable_view_optimize
torch.npu.set_compile_mode(jit_compile=False)
if not self.runner.use_cached_npu_graph:
npu_backend = torchair.get_npu_backend(compiler_config=config)
self.torchair_compiled_model = torch.compile(
self.model,
dynamic=not self.use_sparse,
fullgraph=True,
backend=npu_backend)
return self.torchair_compiled_model
else:
# Generate a new forward proxy code object to prevent the invalidation of
# compilation cache caused by dynamo retracing
forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}"
forward_fn = self.model.forward
code = forward_fn.__code__
# Mark code object with a new proxy name
modified_code = code.replace(co_name=forward_proxy_name, )
modified_func = types.FunctionType(modified_code,
forward_fn.__globals__,
name=forward_proxy_name,
argdefs=forward_fn.__defaults__)
self.model.__dict__[forward_proxy_name] = modified_func.__get__(
self.model, nn.Module)
self.torchair_compiled_models[
batch_size] = torchair.inference.cache_compile(
self.model.__dict__[forward_proxy_name],
dynamic=not self.use_sparse,
fullgraph=True,
cache_dir=TORCHAIR_CACHE_DIR,
config=config,
ge_cache=False)
return self.torchair_compiled_models[batch_size]

File diff suppressed because it is too large Load Diff

View File

@@ -1,63 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from vllm.logger import logger
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.torchair.torchair_model_runner import NPUTorchairModelRunner
from vllm_ascend.torchair.utils import (check_kv_cache_bytes_cache_exist,
delete_torchair_cache_file,
read_kv_cache_bytes_from_file)
from vllm_ascend.worker.worker_v1 import NPUWorker
class NPUTorchairWorker(NPUWorker):
"""Torchair worker bases on NPUWorker. Only torchair specified code should be added in this class."""
def determine_available_memory(self) -> int:
"""Override determine_available_memory to use cached torchair kv_cache_bytes."""
available_kv_cache_memory = super().determine_available_memory()
ascend_config = get_ascend_config()
if ascend_config.enable_shared_expert_dp:
return available_kv_cache_memory
if ascend_config.torchair_graph_config.use_cached_kv_cache_bytes:
if check_kv_cache_bytes_cache_exist():
old_kv_cache_bytes = read_kv_cache_bytes_from_file(
torch.distributed.get_rank())
if 0 < old_kv_cache_bytes <= available_kv_cache_memory:
logger.info(
f"Use cached torchair kv_cache_bytes: {old_kv_cache_bytes}"
)
self.model_runner.new_kv_cache_bytes = old_kv_cache_bytes
return old_kv_cache_bytes
else:
logger.info(
"Cached torchair kv_cache_bytes is too big, invalidate old torchair_cache"
)
delete_torchair_cache_file()
bytes_floating_tolerance = 1024 * 1024 * envs_ascend.VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE
available_kv_cache_memory -= bytes_floating_tolerance
logger.info(f"Use new kv_cache_bytes: {available_kv_cache_memory}")
self.model_runner.new_kv_cache_bytes = available_kv_cache_memory
return available_kv_cache_memory
def init_device(self):
"""Override init_device to init torchair model runner"""
device = self._init_device()
# Init ModelRunner here, so that we have access to self.device.
self.model_runner = NPUTorchairModelRunner(self.vllm_config, device)

View File

@@ -1,275 +0,0 @@
import fcntl
import os
import shutil
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
import torch
import torch_npu
from torchair.scope import super_kernel as _super_kernel
try:
# Recent release of torchair has moved these ops to `.scope`.
from torchair.scope import npu_stream_switch as _npu_stream_switch
from torchair.scope import npu_wait_tensor as _npu_wait_tensor
except ImportError:
from torchair.ops import NpuStreamSwitch as _npu_stream_switch
from torchair.ops import npu_wait_tensor as _npu_wait_tensor
import vllm_ascend.envs as envs_ascend
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
KV_CACHE_BYTES_CACHE_PATH_NAME = ".kv_cache_bytes"
KV_CACHE_BYTES_CACHE_FILE_NAME = "kv_cache_bytes"
TORCHAIR_CACHE_PATH_NAME = ".torchair_cache"
TORCHAIR_CACHE_DIR = os.path.join(
os.getenv('TORCHAIR_CACHE_HOME', os.getcwd()), TORCHAIR_CACHE_PATH_NAME)
@dataclass
class TorchairCommonAttentionMetadata:
"""
Per-batch attention metadata, shared across layers and backends.
AttentionMetadataBuilder instances use it to construct per-layer metadata.
For many of the tensors we keep both GPU and CPU versions.
"""
num_reqs: int
"""Number of requests"""
num_actual_tokens: int
"""Total number of tokens in batch"""
decode_token_per_req: int
actual_seq_lengths_q: list[int]
attn_mask: torch.Tensor = None
spec_attn_mask: torch.Tensor = None
graph_pad_size: int = -1
@contextmanager
def _file_lock(file_descriptor, lock_type):
fcntl.flock(file_descriptor, lock_type)
try:
yield
finally:
fcntl.flock(file_descriptor, fcntl.LOCK_UN)
def _get_torchair_current_work_dir(file_name=None):
if file_name is None:
return TORCHAIR_CACHE_DIR
return os.path.join(TORCHAIR_CACHE_DIR, file_name)
def check_torchair_cache_exist():
res = False
torch_air_abs_path = _get_torchair_current_work_dir()
if os.path.exists(torch_air_abs_path):
file_list = os.listdir(torch_air_abs_path)
if len(file_list) != 0:
res = True
return res
def check_kv_cache_bytes_cache_exist():
res = False
kv_cache_bytes_cache_abs_path = _get_torchair_current_work_dir(
KV_CACHE_BYTES_CACHE_PATH_NAME)
if os.path.exists(kv_cache_bytes_cache_abs_path):
file_list = os.listdir(kv_cache_bytes_cache_abs_path)
if len(file_list) != 0:
res = True
return res
def read_kv_cache_bytes_from_file(rank) -> int:
kv_cache_bytes = -1
kv_cache_bytes_cache_abs_path = _get_torchair_current_work_dir(
KV_CACHE_BYTES_CACHE_PATH_NAME)
kv_cache_bytes_file = os.path.join(
kv_cache_bytes_cache_abs_path,
f"{rank}_{KV_CACHE_BYTES_CACHE_FILE_NAME}")
with open(kv_cache_bytes_file, "r", encoding="utf-8") as f:
with _file_lock(f, fcntl.LOCK_SH):
kv_cache_bytes = int(f.readline())
return kv_cache_bytes
def write_kv_cache_bytes_to_file(rank, kv_cache_bytes):
kv_cache_bytes_cache_abs_path = _get_torchair_current_work_dir(
KV_CACHE_BYTES_CACHE_PATH_NAME)
os.makedirs(kv_cache_bytes_cache_abs_path, exist_ok=True)
kv_cache_bytes_file = os.path.join(
kv_cache_bytes_cache_abs_path,
f"{rank}_{KV_CACHE_BYTES_CACHE_FILE_NAME}")
with open(kv_cache_bytes_file, "w", encoding="utf-8") as f:
with _file_lock(f, fcntl.LOCK_EX):
f.write(f"{kv_cache_bytes}")
def delete_torchair_cache_file():
torch_air_abs_path = _get_torchair_current_work_dir()
try:
shutil.rmtree(torch_air_abs_path)
except FileNotFoundError:
pass
def npu_stream_switch(tag: str, priority: int, *, enabled: bool = True):
return _npu_stream_switch(tag, priority) if enabled else nullcontext()
def npu_wait_tensor(self: torch.Tensor,
dependency: torch.Tensor,
*,
enabled: bool = True):
return _npu_wait_tensor(self, dependency) if enabled else self
def converting_weight_acl_format(model, format):
# currently, there are some operations which do not support ACL_FORMAT_FRACTAL_NZ
# in eager mode but support it in torchair graph mode. since ACL_FORMAT_FRACTAL_NZ
# is much more preferred than ACL_FORMAT_FRACTAL_ND on 300I Duo, we add this
# conversion when using torchair graph mode on 300I Duo platform.
# TODO: we will remove this conversion if npu_quant_grouped_matmul_dequant
# accepts weight format of ACL_FORMAT_FRACTAL_NZ in eager mode.
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
for module in model.modules():
if isinstance(module, FusedMoE):
if torch_npu.get_npu_format(module.w13_weight.data) == format:
return
if format == ACL_FORMAT_FRACTAL_NZ \
and not is_enable_nz():
return
module.w13_weight.data = torch_npu.npu_format_cast(
module.w13_weight.data, format)
module.w2_weight.data = torch_npu.npu_format_cast(
module.w2_weight.data, format)
def register_torchair_model():
from vllm import ModelRegistry
ModelRegistry.register_model(
"DeepSeekMTPModel",
"vllm_ascend.torchair.models.torchair_deepseek_mtp:TorchairDeepSeekMTP"
)
ModelRegistry.register_model(
"DeepseekV2ForCausalLM",
"vllm_ascend.torchair.models.torchair_deepseek_v2:TorchairDeepseekV2ForCausalLM"
)
ModelRegistry.register_model(
"DeepseekV3ForCausalLM",
"vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM"
)
ModelRegistry.register_model(
"DeepseekV32ForCausalLM",
"vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM"
)
ModelRegistry.register_model(
"Qwen2ForCausalLM",
"vllm_ascend.torchair.models.qwen2:CustomQwen2ForCausalLM")
ModelRegistry.register_model(
"Qwen3MoeForCausalLM",
"vllm_ascend.torchair.models.qwen3_moe:CustomQwen3MoeForCausalLM")
ModelRegistry.register_model(
"PanguProMoEForCausalLM",
"vllm_ascend.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM"
)
def torchair_quant_method_register():
from vllm_ascend.quantization.utils import ASCEND_QUANTIZATION_METHOD_MAP
from vllm_ascend.torchair.quantization.torchair_w4a8_dynamic import (
TorchairAscendW4A8DynamicFusedMoEMethod,
TorchairAscendW4A8DynamicLinearMethod)
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import (
TorchairAscendW8A8DynamicFusedMoEMethod,
TorchairAscendW8A8DynamicLinearMethod)
ASCEND_QUANTIZATION_METHOD_MAP["W8A8_DYNAMIC"][
"linear"] = TorchairAscendW8A8DynamicLinearMethod
ASCEND_QUANTIZATION_METHOD_MAP["W8A8_DYNAMIC"][
"moe"] = TorchairAscendW8A8DynamicFusedMoEMethod
ASCEND_QUANTIZATION_METHOD_MAP["W4A8_DYNAMIC"][
"linear"] = TorchairAscendW4A8DynamicLinearMethod
ASCEND_QUANTIZATION_METHOD_MAP["W4A8_DYNAMIC"][
"moe"] = TorchairAscendW4A8DynamicFusedMoEMethod
def torchair_ops_patch():
from vllm_ascend.ops.activation import AscendSiluAndMul
from vllm_ascend.ops.layernorm import AscendQuantRMSNorm, AscendRMSNorm
from vllm_ascend.ops.rotary_embedding import (
AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding)
from vllm_ascend.ops.vocab_parallel_embedding import \
AscendVocabParallelEmbedding
from vllm_ascend.torchair.ops import (torchair_activation,
torchair_layernorm)
from vllm_ascend.torchair.ops.torchair_rotary_embedding import (
deepseek_rope_init_func, native_rope_deepseek_forward,
qwen_rope_init_func, rope_forward)
from vllm_ascend.torchair.ops.torchair_vocab_parallel_embedding import \
vocab_embedding_forward
AscendRotaryEmbedding.__init__ = qwen_rope_init_func # type: ignore[method-assign]
AscendRotaryEmbedding.forward_oot = rope_forward # type: ignore[method-assign]
AscendDeepseekScalingRotaryEmbedding.__init__ = deepseek_rope_init_func # type: ignore[method-assign]
AscendDeepseekScalingRotaryEmbedding.forward = native_rope_deepseek_forward # type: ignore[method-assign]
AscendRMSNorm.__init__ = torchair_layernorm.torchair_rmsnorm_init_ # type: ignore[method-assign]
AscendRMSNorm.forward_oot = torchair_layernorm.torchair_rmsnorm_forward_oot # type: ignore[method-assign]
AscendQuantRMSNorm.__init__ = torchair_layernorm.torchair_rmsnorm_init_ # type: ignore[method-assign]
AscendQuantRMSNorm.forward_oot = torchair_layernorm.torchair_rmsnorm_forward_oot # type: ignore[method-assign]
AscendSiluAndMul.forward_oot = torchair_activation.torchair_silu_and_mul_forward_oot # type: ignore[method-assign]
AscendVocabParallelEmbedding.forward = vocab_embedding_forward # type: ignore[method-assign]
def super_kernel(prefix: str, option: str, enabled: bool = True):
return _super_kernel(prefix, option) if enabled else nullcontext()
# TODO(ttanzhiqiang): rm_router_logits
# dp>1 will trigger
# In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors.
def get_rm_router_logits_state(ep_size: int, dp_size: int,
is_deepseek_v3_r1: bool):
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
# only supports deepseek v3/r1
if dp_size > 1:
if (envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
and is_deepseek_v3_r1):
return True
elif ep_size == 1 and is_deepseek_v3_r1:
return True
return False
# TODO(ttanzhiqiang): all_reduce merge
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
# Currently, all_reduce_merge is enabled by default in the AllGather, AllGatherEP and NaiveMulticast scenarios of the deepseek model.
def get_all_reduce_merge_state(ep_size: int, is_deepseek_v3_r1: bool):
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
# only supports deepseek v3/r1
if (envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
and is_deepseek_v3_r1):
return True
elif ep_size == 1 and is_deepseek_v3_r1:
return True
return False

View File

@@ -143,7 +143,6 @@ from vllm_ascend.spec_decode import get_spec_decode_method
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
from vllm_ascend.spec_decode.interface import SpecDcodeType
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
AscendDeviceType, ProfileExecuteDuration,
enable_sp, get_ascend_device_type, is_enable_nz,
@@ -638,7 +637,6 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
# Set up speculative decoding.
self.spec_attn_mask = None
self.drafter: Optional[Union[NgramProposer, EagleProposer, MtpProposer,
TorchairMtpProposer,
SuffixDecodingProposer]] = None
self.actual_seq_lengths_q: list[int] = []
self.decode_token_per_req = 1
@@ -2917,8 +2915,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
return attn_metadata
def _generate_dummy_run_hidden_states(self, with_prefill,
is_torchair_compile, input_ids,
def _generate_dummy_run_hidden_states(self, with_prefill, input_ids,
positions, attn_metadata, num_tokens,
intermediate_tensors, inputs_embeds):
hidden_states = self.model(input_ids=input_ids,
@@ -2960,7 +2957,6 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
self,
num_tokens: int,
with_prefill: bool = False,
is_torchair_compile: bool = False,
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
force_attention: bool = False,
uniform_decode: bool = False,
@@ -3136,9 +3132,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
model_instance=self.model,
weight_prefetch_method=self.weight_prefetch_method):
hidden_states = self._generate_dummy_run_hidden_states(
with_prefill, is_torchair_compile, input_ids, positions,
attn_metadata, num_tokens_padded, intermediate_tensors,
inputs_embeds)
with_prefill, input_ids, positions, attn_metadata,
num_tokens_padded, intermediate_tensors, inputs_embeds)
dummy_compute_logits(hidden_states)
if self.drafter:
@@ -4262,9 +4257,6 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
return list(model.pooler.get_supported_tasks())
def _build_drafter_prepare_inputs_torchair_param(self):
return False
def _update_tokens_for_pcp(self, tokens):
num_reqs = self.input_batch.num_reqs
self.num_pcp_pads = self.num_pcp_pads[:num_reqs]