Drop torchair (#4814)
aclgraph is stable and fast now. Let's drop torchair graph mode now.
TODO: some logic to adapt torchair should be cleaned up as well. We'll
do it in the following PR.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
6
.github/workflows/_e2e_test.yaml
vendored
6
.github/workflows/_e2e_test.yaml
vendored
@@ -107,7 +107,6 @@ jobs:
|
||||
|
||||
# ------------------------------------ v1 spec decode test ------------------------------------ #
|
||||
pytest -sv tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py
|
||||
pytest -sv tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_torchair_correctness.py
|
||||
pytest -sv tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py
|
||||
|
||||
e2e-2-cards:
|
||||
@@ -170,10 +169,6 @@ jobs:
|
||||
if: ${{ inputs.type == 'light' }}
|
||||
run: |
|
||||
pytest -sv tests/e2e/multicard/test_qwen3_moe.py::test_models_distributed_Qwen3_MOE_TP2_WITH_EP
|
||||
pytest -sv tests/e2e/multicard/test_torchair_graph_mode.py::test_e2e_qwen3_moe_with_torchair
|
||||
pytest -sv tests/e2e/multicard/test_torchair_graph_mode.py::test_e2e_deepseekv2lite_with_torchair
|
||||
pytest -sv tests/e2e/multicard/test_torchair_graph_mode.py::test_e2e_deepseekv2lite_with_torchair_v1scheduler
|
||||
pytest -sv tests/e2e/multicard/test_torchair_graph_mode.py::test_e2e_deepseekv2lite_with_nz
|
||||
|
||||
- name: Run vllm-project/vllm-ascend test (full)
|
||||
env:
|
||||
@@ -183,7 +178,6 @@ jobs:
|
||||
run: |
|
||||
pytest -sv tests/e2e/multicard/test_quantization.py
|
||||
pytest -sv tests/e2e/multicard/test_aclgraph_capture_replay.py
|
||||
pytest -sv tests/e2e/multicard/test_torchair_graph_mode.py
|
||||
pytest -sv tests/e2e/multicard/test_full_graph_mode.py
|
||||
pytest -sv tests/e2e/multicard/test_data_parallel.py
|
||||
pytest -sv tests/e2e/multicard/test_expert_parallel.py
|
||||
|
||||
@@ -127,9 +127,6 @@ jobs:
|
||||
- name: multi-node-deepseek-dp
|
||||
config_file_path: DeepSeek-R1-W8A8-A2.yaml
|
||||
size: 2
|
||||
- name: multi-node-deepseek-dp-torchair
|
||||
config_file_path: DeepSeek-R1-W8A8-A2-torchair.yaml
|
||||
size: 2
|
||||
uses: ./.github/workflows/_e2e_nightly_multi_node.yaml
|
||||
with:
|
||||
soc_version: a2
|
||||
|
||||
@@ -134,8 +134,6 @@ jobs:
|
||||
run: |
|
||||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/arm64-linux/devlib
|
||||
pytest -sv --cov --cov-report=xml:unittests-coverage.xml tests/ut \
|
||||
--ignore tests/ut/torchair/models/test_torchair_deepseek_mtp.py \
|
||||
--ignore tests/ut/torchair/models/test_torchair_deepseek_v2.py \
|
||||
--ignore tests/ut/model_loader/netloader/test_netloader_elastic.py \
|
||||
--ignore tests/ut/kv_connector/test_remote_prefill_lifecycle.py \
|
||||
--ignore tests/ut/kv_connector/test_remote_decode_lifecycle.py \
|
||||
|
||||
@@ -12,7 +12,7 @@ repos:
|
||||
- id: codespell
|
||||
args: [
|
||||
--toml, pyproject.toml,
|
||||
'--skip', 'tests/e2e/multicard/test_torchair_graph_mode.py,csrc/**,tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**,.github/**,typos.toml',
|
||||
'--skip', 'csrc/**,tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**,.github/**,typos.toml',
|
||||
'-L', 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,ArchType,AND,ND'
|
||||
]
|
||||
additional_dependencies:
|
||||
|
||||
@@ -251,7 +251,6 @@ This will reproduce the E2E test. See [vllm_ascend_test.yaml](https://github.com
|
||||
- Offline test example: [`tests/e2e/singlecard/test_offline_inference.py`](https://github.com/vllm-project/vllm-ascend/blob/main/tests/e2e/singlecard/test_offline_inference.py)
|
||||
- Online test examples: [`tests/e2e/singlecard/test_prompt_embedding.py`](https://github.com/vllm-project/vllm-ascend/blob/main/tests/e2e/singlecard/test_prompt_embedding.py)
|
||||
- Correctness test example: [`tests/e2e/singlecard/test_aclgraph.py`](https://github.com/vllm-project/vllm-ascend/blob/main/tests/e2e/singlecard/test_aclgraph.py)
|
||||
- Reduced Layer model test example: [test_torchair_graph_mode.py - DeepSeek-V3-Pruning](https://github.com/vllm-project/vllm-ascend/blob/20767a043cccb3764214930d4695e53941de87ec/tests/e2e/multicard/test_torchair_graph_mode.py#L48)
|
||||
|
||||
The CI resource is limited, and you might need to reduce the number of layers of a model. Below is an example of how to generate a reduced layer model:
|
||||
1. Fork the original model repo in modelscope. All the files in the repo except for weights are required.
|
||||
|
||||
@@ -6,7 +6,7 @@ MTP boosts inference performance by parallelizing the prediction of multiple tok
|
||||
## How to Use MTP
|
||||
To enable MTP for DeepSeek-V3 models, add the following parameter when starting the service:
|
||||
|
||||
--speculative_config ' {"method": "deepseek_mtp", "num_speculative_tokens": 1, "disable_padded_drafter_batch": False} '
|
||||
--speculative_config ' {"method": "mtp", "num_speculative_tokens": 1, "disable_padded_drafter_batch": False} '
|
||||
|
||||
- `num_speculative_tokens`: The number of speculative tokens which enable model to predict multiple tokens at once, if provided. It will default to the number in the draft model config if present, otherwise, it is required.
|
||||
- `disable_padded_drafter_batch`: Disable input padding for speculative decoding. If set to True, speculative input batches can contain sequences of different lengths, which may only be supported by certain attention backends. This currently only affects the MTP method of speculation, default is False.
|
||||
@@ -74,21 +74,18 @@ If the bonus token is accepted, the MTP model performs inference for (num_specul
|
||||
|
||||
### Method Validation
|
||||
|
||||
- Currently, the spec_decode scenario only supports methods such as ngram, eagle, eagle3, and deepseek_mtp. If an incorrect parameter is passed for the method, the code will raise an error to alert the user that an incorrect method was provided.
|
||||
- Currently, the spec_decode scenario only supports methods such as ngram, eagle, eagle3, and mtp. If an incorrect parameter is passed for the method, the code will raise an error to alert the user that an incorrect method was provided.
|
||||
|
||||
```
|
||||
def get_spec_decode_method(method,
|
||||
vllm_config,
|
||||
device,
|
||||
runner,
|
||||
is_torchair_graph=False):
|
||||
runner):
|
||||
if method == "ngram":
|
||||
return NgramProposer(vllm_config, device, runner)
|
||||
elif method in ["eagle", "eagle3"]:
|
||||
return EagleProposer(vllm_config, device, runner)
|
||||
elif method == 'deepseek_mtp':
|
||||
if is_torchair_graph:
|
||||
return TorchairMtpProposer(vllm_config, device, runner)
|
||||
elif method == 'mtp':
|
||||
return MtpProposer(vllm_config, device, runner)
|
||||
else:
|
||||
raise ValueError("Unknown speculative decoding method: "
|
||||
|
||||
@@ -128,9 +128,8 @@ vllm serve /weights/DeepSeek-V3.1_w8a8mix_mtp \
|
||||
--trust-remote-code \
|
||||
--no-enable-prefix-caching \
|
||||
--gpu-memory-utilization 0.92 \
|
||||
--speculative-config '{"num_speculative_tokens": 1, "method": "deepseek_mtp"}' \
|
||||
--speculative-config '{"num_speculative_tokens": 1, "method": "mtp"}' \
|
||||
--compilation-config '{"cudagraph_mode": "FULL_DECODE_ONLY"}' \
|
||||
--additional-config '{"torchair_graph_config":{"enabled":false}}'
|
||||
```
|
||||
|
||||
### Multi-node Deployment
|
||||
@@ -190,9 +189,8 @@ vllm serve /weights/DeepSeek-V3.1_w8a8mix_mtp \
|
||||
--trust-remote-code \
|
||||
--no-enable-prefix-caching \
|
||||
--gpu-memory-utilization 0.94 \
|
||||
--speculative-config '{"num_speculative_tokens": 1, "method": "deepseek_mtp"}' \
|
||||
--speculative-config '{"num_speculative_tokens": 1, "method": "mtp"}' \
|
||||
--compilation-config '{"cudagraph_mode": "FULL_DECODE_ONLY"}' \
|
||||
--additional-config '{"torchair_graph_config":{"enabled":false}}'
|
||||
```
|
||||
|
||||
**Node 1**
|
||||
@@ -247,9 +245,8 @@ vllm serve /weights/DeepSeek-V3.1_w8a8mix_mtp \
|
||||
--trust-remote-code \
|
||||
--no-enable-prefix-caching \
|
||||
--gpu-memory-utilization 0.94 \
|
||||
--speculative-config '{"num_speculative_tokens": 1, "method": "deepseek_mtp"}' \
|
||||
--speculative-config '{"num_speculative_tokens": 1, "method": "mtp"}' \
|
||||
--compilation-config '{"cudagraph_mode": "FULL_DECODE_ONLY"}' \
|
||||
--additional-config '{"torchair_graph_config":{"enabled":false}}'
|
||||
```
|
||||
|
||||
### Prefill-Decode Disaggregation
|
||||
@@ -421,7 +418,7 @@ vllm serve /weights/DeepSeek-V3.1_w8a8mix_mtp \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--quantization ascend \
|
||||
--no-enable-prefix-caching \
|
||||
--speculative-config '{"num_speculative_tokens": 1, "method": "deepseek_mtp"}' \
|
||||
--speculative-config '{"num_speculative_tokens": 1, "method": "mtp"}' \
|
||||
--additional-config '{"recompute_scheduler_enable":true,"enable_shared_expert_dp": true}' \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector": "MooncakeConnector",
|
||||
|
||||
@@ -173,8 +173,7 @@ vllm serve vllm-ascend/DeepSeek-V3.2-Exp-W8A8 \
|
||||
--enable-expert-parallel \
|
||||
--trust-remote-code \
|
||||
--no-enable-prefix-caching \
|
||||
--gpu-memory-utilization 0.92 \
|
||||
--additional-config '{"torchair_graph_config":{"enabled":true,"graph_batch_sizes":[16]}}'
|
||||
--gpu-memory-utilization 0.92
|
||||
```
|
||||
|
||||
### Multi-node Deployment
|
||||
@@ -225,8 +224,7 @@ vllm serve /root/.cache/Modelers_Park/DeepSeek-V3.2-Exp \
|
||||
--max-num-batched-tokens 17450 \
|
||||
--trust-remote-code \
|
||||
--no-enable-prefix-caching \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--additional-config '{"torchair_graph_config":{"enabled":true,"graph_batch_sizes":[16]}}'
|
||||
--gpu-memory-utilization 0.9
|
||||
```
|
||||
|
||||
**Node 1**
|
||||
@@ -269,8 +267,7 @@ vllm serve /root/.cache/Modelers_Park/DeepSeek-V3.2-Exp \
|
||||
--enable-expert-parallel \
|
||||
--trust-remote-code \
|
||||
--no-enable-prefix-caching \
|
||||
--gpu-memory-utilization 0.92 \
|
||||
--additional-config '{"torchair_graph_config":{"enabled":true,"graph_batch_sizes":[16]}}'
|
||||
--gpu-memory-utilization 0.92
|
||||
```
|
||||
|
||||
::::
|
||||
@@ -316,8 +313,7 @@ vllm serve vllm-ascend/DeepSeek-V3.2-Exp-W8A8 \
|
||||
--trust-remote-code \
|
||||
--quantization ascend \
|
||||
--no-enable-prefix-caching \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--additional-config '{"torchair_graph_config":{"enabled":true,"graph_batch_sizes":[16]}}'
|
||||
--gpu-memory-utilization 0.9
|
||||
```
|
||||
|
||||
**Node 1**
|
||||
@@ -362,8 +358,7 @@ vllm serve vllm-ascend/DeepSeek-V3.2-Exp-W8A8 \
|
||||
--trust-remote-code \
|
||||
--quantization ascend \
|
||||
--no-enable-prefix-caching \
|
||||
--gpu-memory-utilization 0.92 \
|
||||
--additional-config '{"torchair_graph_config":{"enabled":true,"graph_batch_sizes":[16]}}'
|
||||
--gpu-memory-utilization 0.92
|
||||
```
|
||||
|
||||
::::
|
||||
|
||||
@@ -136,8 +136,7 @@ vllm serve vllm-ascend/DeepSeek-V3.1-W8A8 \
|
||||
--max-num-batched-tokens 8192 \
|
||||
--trust-remote-code \
|
||||
--no-enable-prefix-caching \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--additional-config '{"torchair_graph_config":{"enabled":true}}'
|
||||
--gpu-memory-utilization 0.9
|
||||
```
|
||||
|
||||
**Node 1**
|
||||
@@ -181,8 +180,7 @@ vllm serve vllm-ascend/DeepSeek-V3.1-W8A8 \
|
||||
--enable-expert-parallel \
|
||||
--trust-remote-code \
|
||||
--no-enable-prefix-caching \
|
||||
--gpu-memory-utilization 0.92 \
|
||||
--additional-config '{"torchair_graph_config":{"enabled":true}}'
|
||||
--gpu-memory-utilization 0.92
|
||||
```
|
||||
|
||||
The deployment view looks like:
|
||||
|
||||
@@ -92,8 +92,7 @@ vllm serve /home/cache/weights/Kimi-K2-Instruct-W8A8 \
|
||||
--max-num-batched-tokens 8192 \
|
||||
--trust-remote-code \
|
||||
--no-enable-prefix-caching \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--additional-config '{"torchair_graph_config":{"enabled":true}}'
|
||||
--gpu-memory-utilization 0.9
|
||||
```
|
||||
|
||||
**Node 1**
|
||||
@@ -136,8 +135,7 @@ vllm serve /home/cache/weights/Kimi-K2-Instruct-W8A8 \
|
||||
--enable-expert-parallel \
|
||||
--trust-remote-code \
|
||||
--no-enable-prefix-caching \
|
||||
--gpu-memory-utilization 0.92 \
|
||||
--additional-config '{"torchair_graph_config":{"enabled":true}}'
|
||||
--gpu-memory-utilization 0.92
|
||||
```
|
||||
|
||||
The deployment view looks like:
|
||||
|
||||
@@ -153,12 +153,7 @@ if __name__ == "__main__":
|
||||
enable_expert_parallel=True,
|
||||
distributed_executor_backend="mp",
|
||||
max_model_len=1024,
|
||||
trust_remote_code=True,
|
||||
additional_config={
|
||||
'torchair_graph_config': {
|
||||
'enabled': True,
|
||||
}
|
||||
})
|
||||
trust_remote_code=True)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
for output in outputs:
|
||||
|
||||
@@ -27,7 +27,6 @@ The following table lists additional configuration options available in vLLM Asc
|
||||
| Name | Type | Default | Description |
|
||||
|-------------------------------------|------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `xlite_graph_config` | dict | `{}` | Configuration options for xlite graph mode |
|
||||
| `torchair_graph_config` | dict | `{}` | Configuration options for torchair graph mode |
|
||||
| `weight_prefetch_config` | dict | `{}` | Configuration options for weight prefetch |
|
||||
| `refresh` | bool | `false` | Whether to refresh global Ascend configuration content. This is usually used by rlhf or ut/e2e test case. |
|
||||
| `expert_map_path` | str | `None` | When using expert load balancing for an MoE model, an expert map path needs to be passed in. |
|
||||
@@ -52,21 +51,6 @@ The details of each configuration option are as follows:
|
||||
| `enabled` | bool | `False` | Whether to enable xlite graph mode. Currently only Llama or Qwen dense series models are supported. |
|
||||
| `full_mode` | bool | `False` | Whether to enable xlite for both the prefill and decode stages. By default, xlite is only enabled for the decode stage. |
|
||||
|
||||
**torchair_graph_config**
|
||||
|
||||
| Name | Type | Default | Description |
|
||||
| ---- | ---- | ------- | ----------- |
|
||||
| `enabled` | bool | `False` | Whether to enable torchair graph mode. Currently only DeepSeek series models and PanguProMoE are supported. |
|
||||
| `mode` | str | `None` | When using reduce-overhead mode for torchair, it needs to be set. |
|
||||
| `enable_multistream_mla`| bool | `False` | Whether to put vector operators of MLA to another stream. This option only takes effect on models using MLA (for example, DeepSeek). |
|
||||
| `enable_view_optimize` | bool | `True` | Whether to enable torchair view optimization. |
|
||||
| `enable_frozen_parameter` | bool | `True` | Whether to fix the memory address of weights during inference to reduce the input address refresh time during graph execution. |
|
||||
| `use_cached_graph` | bool | `False` | Whether to use cached graph. |
|
||||
| `graph_batch_sizes` | list[int] | `[]` | The batch size for torchair graph cache. |
|
||||
| `graph_batch_sizes_init` | bool | `False` | Init graph batch size dynamically if `graph_batch_sizes` is empty. |
|
||||
| `enable_kv_nz`| bool | `False` | Whether to enable KV Cache NZ layout. This option only takes effect on models using MLA (for example, DeepSeek). |
|
||||
| `enable_super_kernel` | bool | `False` | Whether to enable super kernel to fuse operators in deepseek moe layers. This option only takes effects on moe models using dynamic w8a8 quantization.|
|
||||
|
||||
**weight_prefetch_config**
|
||||
|
||||
| Name | Type | Default | Description |
|
||||
@@ -80,13 +64,6 @@ An example of additional configuration is as follows:
|
||||
|
||||
```
|
||||
{
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
"use_cached_graph": True,
|
||||
"graph_batch_sizes": [1, 2, 4, 8],
|
||||
"graph_batch_sizes_init": False,
|
||||
"enable_kv_nz": False
|
||||
},
|
||||
"weight_prefetch_config": {
|
||||
"enabled": True,
|
||||
"prefetch_ratio": {
|
||||
|
||||
@@ -10,9 +10,8 @@ This guide provides instructions for using Ascend Graph Mode with vLLM Ascend. P
|
||||
|
||||
From v0.9.1rc1 with V1 Engine, vLLM Ascend will run models in graph mode by default to keep the same behavior with vLLM. If you hit any issues, please feel free to open an issue on GitHub and fallback to the eager mode temporarily by setting `enforce_eager=True` when initializing the model.
|
||||
|
||||
There are three kinds for graph mode supported by vLLM Ascend:
|
||||
There are two kinds for graph mode supported by vLLM Ascend:
|
||||
- **ACLGraph**: This is the default graph mode supported by vLLM Ascend. In v0.9.1rc1, Qwen and Deepseek series models are well tested.
|
||||
- **TorchAirGraph**: This is the GE graph mode. In v0.9.1rc1, only DeepSeek series models are supported.
|
||||
- **XliteGraph**: This is the euler xlite graph mode. In v0.11.0, only Llama and Qwen dense serise models are supported.
|
||||
|
||||
## Using ACLGraph
|
||||
@@ -35,29 +34,6 @@ Online example:
|
||||
vllm serve Qwen/Qwen2-7B-Instruct
|
||||
```
|
||||
|
||||
## Using TorchAirGraph
|
||||
|
||||
If you want to run DeepSeek series models with the graph mode, you should use [TorchAirGraph](https://www.hiascend.com/document/detail/zh/Pytorch/700/modthirdparty/torchairuseguide/torchair_0002.html). In this case, additional configuration is required.
|
||||
|
||||
Offline example:
|
||||
|
||||
```python
|
||||
import os
|
||||
from vllm import LLM
|
||||
|
||||
# TorchAirGraph only works without chunked-prefill now
|
||||
model = LLM(model="path/to/DeepSeek-R1-0528", additional_config={"torchair_graph_config": {"enabled": True}})
|
||||
outputs = model.generate("Hello, how are you?")
|
||||
```
|
||||
|
||||
Online example:
|
||||
|
||||
```shell
|
||||
vllm serve path/to/DeepSeek-R1-0528 --additional-config='{"torchair_graph_config": {"enabled": true}}'
|
||||
```
|
||||
|
||||
You can find more details about additional configuration [here](../configuration/additional_config.md).
|
||||
|
||||
## Using XliteGraph
|
||||
|
||||
If you want to run Llama or Qwen dense series models with xlite graph mode, please install xlite, and set xlite_graph_config.
|
||||
@@ -87,7 +63,7 @@ You can find more details abort xlite [here](https://gitee.com/openeuler/GVirt/b
|
||||
|
||||
## Fallback to the Eager Mode
|
||||
|
||||
If `ACLGraph`, `TorchAirGraph` and `XliteGraph` all fail to run, you should fallback to the eager mode.
|
||||
If `ACLGraph` and `XliteGraph` all fail to run, you should fallback to the eager mode.
|
||||
|
||||
Offline example:
|
||||
|
||||
|
||||
@@ -104,22 +104,3 @@ First, make sure you specify `ascend` as the quantization method. Second, check
|
||||
### 2. How to solve the error "Could not locate the configuration_deepseek.py"?
|
||||
|
||||
Please convert DeepSeek series models using `br_release_MindStudio_8.1.RC2_TR5_20260624` ModelSlim, where the missing configuration_deepseek.py error has been fixed.
|
||||
|
||||
### 3. What should be considered when converting DeepSeek series models with ModelSlim?
|
||||
|
||||
When the MLA portion of the weights used the `W8A8_DYNAMIC` quantization with the torchair graph mode enabled, modify the configuration file in the CANN package to prevent incorrect inference results.
|
||||
|
||||
The operation steps are as follows:
|
||||
|
||||
1. Search in the CANN package directory, for example:
|
||||
find /usr/local/Ascend/ -name fusion_config.json
|
||||
|
||||
2. Add `"AddRmsNormDynamicQuantFusionPass":"off",` and `"MultiAddRmsNormDynamicQuantFusionPass":"off",` to the fusion_config.json you find, the location is as follows:
|
||||
|
||||
```bash
|
||||
{
|
||||
"Switch":{
|
||||
"GraphFusion":{
|
||||
"AddRmsNormDynamicQuantFusionPass":"off",
|
||||
"MultiAddRmsNormDynamicQuantFusionPass":"off",
|
||||
```
|
||||
|
||||
@@ -27,5 +27,4 @@ vllm serve Qwen/Qwen1.5-MoE-A2.7B \
|
||||
--max-num-batched-tokens 4096 \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--trust-remote-code \
|
||||
--enforce-eager \
|
||||
--additional-config '{"torchair_graph_config":{"enabled":false, "use_cached_graph":false}}'
|
||||
--enforce-eager
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
import pytest
|
||||
import vllm # noqa: F401
|
||||
|
||||
import vllm_ascend # noqa: F401
|
||||
from tests.e2e.conftest import VllmRunner
|
||||
|
||||
# Pangu local model path
|
||||
MODELS = [
|
||||
"IntervitensInc/pangu-pro-moe-model",
|
||||
]
|
||||
# set additional config for ascend scheduler and torchair graph
|
||||
ADDITIONAL_CONFIG = [{
|
||||
"additional_config": {
|
||||
"torchair_graph_config": {
|
||||
"enabled": True
|
||||
}
|
||||
}
|
||||
}]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float16"])
|
||||
@pytest.mark.parametrize("max_tokens", [5])
|
||||
@pytest.mark.parametrize("enfore_eager", [True, False])
|
||||
@pytest.mark.parametrize("additional_config", ADDITIONAL_CONFIG)
|
||||
def test_pangu_model(model: str, dtype: str, max_tokens: int,
|
||||
enfore_eager: bool, additional_config: dict) -> None:
|
||||
if enfore_eager:
|
||||
additional_config = {}
|
||||
example_prompts = [
|
||||
"Hello, my name is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
with VllmRunner(model,
|
||||
tensor_parallel_size=4,
|
||||
dtype=dtype,
|
||||
max_model_len=1024,
|
||||
enforce_eager=True,
|
||||
enable_expert_parallel=True,
|
||||
additional_config=additional_config,
|
||||
distributed_executor_backend="mp") as vllm_model:
|
||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
@@ -78,9 +78,6 @@ def test_models_distributed_DeepSeek_multistream_moe():
|
||||
tensor_parallel_size=2,
|
||||
distributed_executor_backend="mp",
|
||||
additional_config={
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
},
|
||||
"enable_multistream_moe": True,
|
||||
"refresh": True,
|
||||
},
|
||||
@@ -144,17 +141,12 @@ def test_models_distributed_DeepSeek_W4A8DYNAMIC(model):
|
||||
"Hello, my name is",
|
||||
]
|
||||
max_tokens = 5
|
||||
with VllmRunner(
|
||||
snapshot_download(model),
|
||||
dtype="auto",
|
||||
tensor_parallel_size=2,
|
||||
quantization="ascend",
|
||||
enforce_eager=True,
|
||||
enable_expert_parallel=True,
|
||||
additional_config={"torchair_graph_config": {
|
||||
"enabled": False,
|
||||
}},
|
||||
) as vllm_model:
|
||||
with VllmRunner(snapshot_download(model),
|
||||
dtype="auto",
|
||||
tensor_parallel_size=2,
|
||||
quantization="ascend",
|
||||
enforce_eager=True,
|
||||
enable_expert_parallel=True) as vllm_model:
|
||||
vllm_model.generate_greedy(prompts, max_tokens)
|
||||
|
||||
|
||||
|
||||
@@ -1,290 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
"""Compare the short outputs of HF and vLLM when using greedy sampling.
|
||||
|
||||
Run `pytest tests/multicard/test_torchair_graph_mode.py`.
|
||||
"""
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.e2e.conftest import VllmRunner
|
||||
|
||||
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
|
||||
|
||||
|
||||
def _deepseek_torchair_test_fixture(
|
||||
additional_config: Dict,
|
||||
*,
|
||||
tensor_parallel_size=2,
|
||||
use_v1_schduler=False,
|
||||
):
|
||||
example_prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
kwargs = {}
|
||||
if not use_v1_schduler:
|
||||
kwargs = {
|
||||
"refresh": True,
|
||||
}
|
||||
additional_config.update(**kwargs)
|
||||
|
||||
with VllmRunner(
|
||||
"vllm-ascend/DeepSeek-V3-Pruning",
|
||||
dtype="half",
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend="mp",
|
||||
additional_config=additional_config,
|
||||
) as vllm_model:
|
||||
# use greedy sampler to make sure the generated results are fix
|
||||
vllm_output = vllm_model.generate_greedy(example_prompts, 5)
|
||||
|
||||
# NOTE: vllm-ascend/DeepSeek-V3-Pruning is a random weight of
|
||||
# DeepSeek-V3 with 2 hidden layers, thus the golden results seems
|
||||
# inaccurate. This will only change if accuracy improves with the
|
||||
# official weights of DeepSeek-V3.
|
||||
golden_results = [
|
||||
'Hello, my name is下载早点向前很有่อง',
|
||||
'The president of the United States isSender)## physiological Albany',
|
||||
'The capital of France is Rocky转角 hospitalizedinterval sparked',
|
||||
'The future of AI is её asegο BIOS一扫',
|
||||
]
|
||||
|
||||
assert len(golden_results) == len(vllm_output)
|
||||
for i in range(len(vllm_output)):
|
||||
assert golden_results[i] == vllm_output[i][1]
|
||||
print(f"Generated text: {vllm_output[i][1]!r}")
|
||||
|
||||
|
||||
def test_e2e_deepseekv3_with_torchair():
|
||||
additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
},
|
||||
}
|
||||
_deepseek_torchair_test_fixture(additional_config)
|
||||
|
||||
|
||||
def test_e2e_deepseekv3_with_torchair_ms_mla():
|
||||
additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
"enable_multistream_mla": True,
|
||||
},
|
||||
}
|
||||
_deepseek_torchair_test_fixture(additional_config)
|
||||
|
||||
|
||||
@pytest.mark.skip("accuracy test failed. Fix me")
|
||||
def test_e2e_deepseekv3_with_torchair_v1scheduler():
|
||||
additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
},
|
||||
}
|
||||
_deepseek_torchair_test_fixture(additional_config, use_v1_schduler=True)
|
||||
|
||||
|
||||
def _pangu_torchair_test_fixture(
|
||||
additional_config: Dict,
|
||||
*,
|
||||
tensor_parallel_size=2,
|
||||
):
|
||||
example_prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
# torchair is only work without chunked-prefill now
|
||||
kwargs = {
|
||||
"refresh": True,
|
||||
}
|
||||
additional_config.update(**kwargs)
|
||||
|
||||
with VllmRunner(
|
||||
"vllm-ascend/pangu-pro-moe-pruing",
|
||||
dtype="half",
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend="mp",
|
||||
additional_config=additional_config,
|
||||
enable_expert_parallel=True,
|
||||
) as vllm_model:
|
||||
# use greedy sampler to make sure the generated results are fix
|
||||
vllm_output = vllm_model.generate_greedy(example_prompts, 5)
|
||||
|
||||
# NOTE: vllm-ascend/pangu-pro-moe-pruing is only part of PanguProMoE
|
||||
# with 2 hidden layers, thus the golden results seems inaccurate.
|
||||
# This will only change if accuracy changes with the official weights
|
||||
# of PanguProMoE.
|
||||
golden_results = [
|
||||
'Hello, my name is Remempondeprecatedmiot忱',
|
||||
'The president of the United States is Remem下的一个 rever ceremoni Segnali',
|
||||
'The capital of France is Rememvoud administrativ Remem投',
|
||||
'The future of AI isotope Segnali Zoeken精细化 supus',
|
||||
]
|
||||
|
||||
assert len(golden_results) == len(vllm_output)
|
||||
for i in range(len(vllm_output)):
|
||||
assert golden_results[i] == vllm_output[i][1]
|
||||
print(f"Generated text: {vllm_output[i][1]!r}")
|
||||
|
||||
|
||||
@pytest.mark.skip("skipping test_e2e_pangu_with_torchair")
|
||||
def test_e2e_pangu_with_torchair():
|
||||
additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
},
|
||||
}
|
||||
_pangu_torchair_test_fixture(additional_config)
|
||||
|
||||
|
||||
def _qwen_torchair_test_fixture(
|
||||
model,
|
||||
tp,
|
||||
enable_expert_parallel,
|
||||
):
|
||||
# The current access control does not support 16 cards,
|
||||
# so the MC2 operator in Qwen's graph mode cannot run.
|
||||
# Once 16-card support is available,
|
||||
# this e2e can be switched to graph mode.
|
||||
example_prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": False,
|
||||
},
|
||||
"refresh": True,
|
||||
}
|
||||
|
||||
with VllmRunner(
|
||||
model,
|
||||
dtype="half",
|
||||
tensor_parallel_size=tp,
|
||||
distributed_executor_backend="mp",
|
||||
enforce_eager=True,
|
||||
additional_config=additional_config,
|
||||
enable_expert_parallel=enable_expert_parallel,
|
||||
) as vllm_model:
|
||||
# use greedy sampler to make sure the generated results are fix
|
||||
vllm_output = vllm_model.generate_greedy(example_prompts, 5)
|
||||
|
||||
# NOTE: vllm-ascend/pangu-pro-moe-pruing is only part of PanguProMoE
|
||||
# with 2 hidden layers, thus the golden results seems inaccurate.
|
||||
# This will only change if accuracy changes with the official weights
|
||||
# of PanguProMoE.
|
||||
golden_results = [
|
||||
'Hello, my name is Remempondeprecatedmiot忱',
|
||||
'The president of the United States is Remem下的一个 rever ceremoni Segnali',
|
||||
'The capital of France is Rememvoud administrativ Remem投',
|
||||
'The future of AI isotope Segnali Zoeken精细化 supus',
|
||||
]
|
||||
|
||||
assert len(golden_results) == len(vllm_output)
|
||||
for i in range(len(vllm_output)):
|
||||
print(f"Generated text: {vllm_output[i][1]!r}")
|
||||
|
||||
|
||||
def test_e2e_qwen2_with_torchair():
|
||||
_qwen_torchair_test_fixture("Qwen/Qwen2.5-0.5B-Instruct", 2, False)
|
||||
|
||||
|
||||
def test_e2e_qwen3_moe_with_torchair():
|
||||
_qwen_torchair_test_fixture("Qwen/Qwen3-30B-A3B", 2, True)
|
||||
|
||||
|
||||
# test deepseek-v2-lite
|
||||
def _deepseek_v2_lite_torchair_test_fixure(
|
||||
additional_config: Dict,
|
||||
*,
|
||||
tensor_parallel_size=2,
|
||||
use_v1_schduler=False,
|
||||
):
|
||||
example_prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
kwargs = {}
|
||||
if not use_v1_schduler:
|
||||
kwargs = {
|
||||
"refresh": True,
|
||||
}
|
||||
additional_config.update(**kwargs)
|
||||
|
||||
with VllmRunner(
|
||||
"deepseek-ai/DeepSeek-V2-Lite",
|
||||
dtype="half",
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend="mp",
|
||||
additional_config=additional_config,
|
||||
) as vllm_model:
|
||||
vllm_output = vllm_model.generate_greedy(example_prompts, 5)
|
||||
|
||||
# NOTE: deepseek-ai/DeepSeek-V2-Lite is a random weight of
|
||||
# DeepSeek-V2-Lite with 2 hidden layers, thus the golden results seems
|
||||
# inaccurate. This will only change if accuracy improves with the
|
||||
# official weights of DeepSeek-V2-Lite.
|
||||
|
||||
for i in range(len(vllm_output)):
|
||||
generated_text = vllm_output[i][1]
|
||||
assert len(
|
||||
generated_text.strip()) > 0, f"The {i}-th output is null, failed"
|
||||
|
||||
|
||||
def test_e2e_deepseekv2lite_with_torchair():
|
||||
additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
},
|
||||
}
|
||||
_deepseek_v2_lite_torchair_test_fixure(additional_config)
|
||||
|
||||
|
||||
def test_e2e_deepseekv2lite_with_torchair_v1scheduler():
|
||||
additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
},
|
||||
}
|
||||
_deepseek_v2_lite_torchair_test_fixure(additional_config,
|
||||
use_v1_schduler=True)
|
||||
|
||||
|
||||
# kv_cache enable e2e test
|
||||
def test_e2e_deepseekv2lite_with_nz():
|
||||
additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
"enable_kv_nz": True,
|
||||
},
|
||||
}
|
||||
_deepseek_v2_lite_torchair_test_fixure(additional_config)
|
||||
@@ -73,7 +73,6 @@ async def test_models(model: str, mode: str) -> None:
|
||||
"VLLM_RPC_TIMEOUT": "3600000",
|
||||
"VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS": "3600000"
|
||||
}
|
||||
additional_config: dict[str, Any] = {}
|
||||
speculative_config = {"num_speculative_tokens": 2, "method": "mtp"}
|
||||
compilation_config = {
|
||||
"cudagraph_capture_sizes": [56],
|
||||
@@ -104,7 +103,6 @@ async def test_models(model: str, mode: str) -> None:
|
||||
["--speculative-config",
|
||||
json.dumps(speculative_config)])
|
||||
server_args.extend(["--gpu-memory-utilization", "0.92"])
|
||||
additional_config["torchair_graph_config"] = {"enabled": True}
|
||||
aisbench_cases = aisbench_gsm8k
|
||||
if mode == "mtp3":
|
||||
env_dict["HCCL_OP_EXPANSION_MODE"] = "AIV"
|
||||
@@ -117,9 +115,7 @@ async def test_models(model: str, mode: str) -> None:
|
||||
server_args.extend(
|
||||
["--compilation-config",
|
||||
json.dumps(compilation_config)])
|
||||
additional_config["torchair_graph_config"] = {"enabled": False}
|
||||
aisbench_cases = aisbench_aime
|
||||
server_args.extend(["--additional-config", json.dumps(additional_config)])
|
||||
request_keyword_args: dict[str, Any] = {
|
||||
**api_keyword_args,
|
||||
}
|
||||
|
||||
@@ -74,13 +74,6 @@ async def test_models(model: str) -> None:
|
||||
"PYTORCH_NPU_ALLOC_CONF": "expandable_segments:True",
|
||||
}
|
||||
additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
"enable_multistream_moe": False,
|
||||
"enable_multistream_mla": True,
|
||||
"graph_batch_size": [16],
|
||||
"use_cached_graph": True
|
||||
},
|
||||
"chunked_prefill_for_mla": True,
|
||||
"enable_weight_nz_layout": True
|
||||
}
|
||||
|
||||
@@ -29,7 +29,6 @@ MODELS = [
|
||||
]
|
||||
|
||||
MODES = [
|
||||
"torchair",
|
||||
"single",
|
||||
"aclgraph",
|
||||
"aclgraph_mlapo",
|
||||
@@ -78,13 +77,6 @@ async def test_models(model: str, mode: str) -> None:
|
||||
}
|
||||
speculative_config = {"num_speculative_tokens": 1, "method": "mtp"}
|
||||
additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
"enable_multistream_moe": False,
|
||||
"enable_multistream_mla": True,
|
||||
"graph_batch_sizes": [16],
|
||||
"use_cached_graph": True
|
||||
},
|
||||
"chunked_prefill_for_mla": True,
|
||||
"enable_weight_nz_layout": True
|
||||
}
|
||||
@@ -99,12 +91,8 @@ async def test_models(model: str, mode: str) -> None:
|
||||
]
|
||||
if mode == "single":
|
||||
server_args.append("--enforce-eager")
|
||||
additional_config["torchair_graph_config"] = {"enabled": False}
|
||||
if mode == "aclgraph":
|
||||
additional_config["torchair_graph_config"] = {"enabled": False}
|
||||
if mode == "aclgraph_mlapo":
|
||||
env_dict["VLLM_ASCEND_ENABLE_MLAPO"] = "1"
|
||||
additional_config["torchair_graph_config"] = {"enabled": False}
|
||||
server_args.extend(["--additional-config", json.dumps(additional_config)])
|
||||
request_keyword_args: dict[str, Any] = {
|
||||
**api_keyword_args,
|
||||
|
||||
@@ -68,9 +68,6 @@ async def test_models(model: str) -> None:
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY"
|
||||
}
|
||||
additional_config: dict[str, Any] = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": True
|
||||
},
|
||||
"enable_shared_expert_dp": False,
|
||||
"multistream_overlap_shared_expert": False,
|
||||
"dynamic_eplb": True,
|
||||
|
||||
@@ -72,27 +72,13 @@ async def test_models(model: str, tp_size: int, dp_size: int,
|
||||
port = get_open_port()
|
||||
env_dict = {"HCCL_BUFFSIZE": "1024", "VLLM_ASCEND_ENABLE_MLAPO": "0"}
|
||||
server_args = [
|
||||
"--no-enable-prefix-caching",
|
||||
"--enable-expert-parallel",
|
||||
"--no-enable-prefix-caching", "--enable-expert-parallel",
|
||||
"--tensor-parallel-size",
|
||||
str(tp_size),
|
||||
"--data-parallel-size",
|
||||
str(dp_size),
|
||||
"--port",
|
||||
str(port),
|
||||
"--max-model-len",
|
||||
"16384",
|
||||
"--max-num-batched-tokens",
|
||||
"16384",
|
||||
"--block-size",
|
||||
"16",
|
||||
"--trust-remote-code",
|
||||
"--quantization",
|
||||
"ascend",
|
||||
"--gpu-memory-utilization",
|
||||
"0.9",
|
||||
"--additional-config",
|
||||
'{"torchair_graph_config":{"enabled":true,"graph_batch_sizes":[16]}}',
|
||||
str(tp_size), "--data-parallel-size",
|
||||
str(dp_size), "--port",
|
||||
str(port), "--max-model-len", "16384", "--max-num-batched-tokens",
|
||||
"16384", "--block-size", "16", "--trust-remote-code", "--quantization",
|
||||
"ascend", "--gpu-memory-utilization", "0.9"
|
||||
]
|
||||
if full_graph:
|
||||
server_args += [
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
test_name: "test DeepSeek-R1-W8A8 torchair on A2"
|
||||
model: "vllm-ascend/DeepSeek-R1-0528-W8A8"
|
||||
num_nodes: 2
|
||||
npu_per_node: 8
|
||||
env_common:
|
||||
VLLM_USE_MODELSCOPE: true
|
||||
HCCL_BUFFSIZE: 1024
|
||||
SERVER_PORT: 8080
|
||||
OMP_PROC_BIND: false
|
||||
OMP_NUM_THREADS: 10
|
||||
|
||||
|
||||
deployment:
|
||||
-
|
||||
server_cmd: >
|
||||
vllm serve vllm-ascend/DeepSeek-R1-0528-W8A8
|
||||
--host 0.0.0.0
|
||||
--port $SERVER_PORT
|
||||
--data-parallel-size 4
|
||||
--data-parallel-size-local 2
|
||||
--data-parallel-address $LOCAL_IP
|
||||
--data-parallel-rpc-port 13399
|
||||
--no-enable-prefix-caching
|
||||
--max-num-seqs 16
|
||||
--tensor-parallel-size 4
|
||||
--max-model-len 36864
|
||||
--max-num-batched-tokens 6000
|
||||
--enable-expert-parallel
|
||||
--trust-remote-code
|
||||
--quantization ascend
|
||||
--gpu-memory-utilization 0.9
|
||||
--speculative-config '{"num_speculative_tokens": 1, "method":"mtp"}'
|
||||
--additional-config '{"torchair_graph_config":{"enabled":true,"enable_multistream_moe":true},"chunked_prefill_for_mla":true,"enable_weight_nz_layout":true}'
|
||||
|
||||
-
|
||||
server_cmd: >
|
||||
vllm serve vllm-ascend/DeepSeek-R1-0528-W8A8
|
||||
--headless
|
||||
--data-parallel-size 4
|
||||
--data-parallel-rpc-port 13399
|
||||
--data-parallel-size-local 2
|
||||
--data-parallel-start-rank 2
|
||||
--data-parallel-address $MASTER_IP
|
||||
--no-enable-prefix-caching
|
||||
--max-num-seqs 16
|
||||
--tensor-parallel-size 4
|
||||
--max-model-len 36864
|
||||
--max-num-batched-tokens 6000
|
||||
--enable-expert-parallel
|
||||
--trust-remote-code
|
||||
--quantization ascend
|
||||
--gpu-memory-utilization 0.9
|
||||
--speculative-config '{"num_speculative_tokens": 1, "method":"mtp"}'
|
||||
--additional-config '{"torchair_graph_config":{"enabled":true,"enable_multistream_moe":true},"chunked_prefill_for_mla":true,"enable_weight_nz_layout":true}'
|
||||
benchmarks:
|
||||
acc:
|
||||
case_type: accuracy
|
||||
dataset_path: vllm-ascend/gsm8k
|
||||
request_conf: vllm_api_general_chat
|
||||
dataset_conf: gsm8k/gsm8k_gen_0_shot_cot_chat_prompt
|
||||
max_out_len: 32768
|
||||
batch_size: 512
|
||||
baseline: 95
|
||||
threshold: 5
|
||||
@@ -58,7 +58,7 @@ deployment:
|
||||
}
|
||||
}'
|
||||
--additional-config
|
||||
'{"torchair_graph_config":{"enabled":false,"enable_multistream_shared_expert":false},"enable_prefill_optimizations":true,"enable_weight_nz_layout":true,"dynamic_eplb":true,"num_iterations_eplb_update":2048,"num_wait_worker_iterations":200}'
|
||||
'{"enable_prefill_optimizations":true,"enable_weight_nz_layout":true,"dynamic_eplb":true,"num_iterations_eplb_update":2048,"num_wait_worker_iterations":200}'
|
||||
|
||||
-
|
||||
server_cmd: >
|
||||
@@ -96,7 +96,7 @@ deployment:
|
||||
}
|
||||
}'
|
||||
--additional-config
|
||||
'{"torchair_graph_config":{"enabled":false,"enable_multistream_shared_expert":false},"enable_prefill_optimizations":true,"enable_weight_nz_layout":true,"dynamic_eplb":true,"num_iterations_eplb_update":2048,"num_wait_worker_iterations":200}'
|
||||
'{"enable_prefill_optimizations":true,"enable_weight_nz_layout":true,"dynamic_eplb":true,"num_iterations_eplb_update":2048,"num_wait_worker_iterations":200}'
|
||||
-
|
||||
server_cmd: >
|
||||
vllm serve vllm-ascend/DeepSeek-R1-0528-W8A8
|
||||
@@ -135,7 +135,7 @@ deployment:
|
||||
}
|
||||
}'
|
||||
--additional-config
|
||||
'{"torchair_graph_config":{"enabled":true,"enable_multistream_mla":true,"graph_batch_sizes":[28],"use_cached_graph":true,"enable_super_kernel":false},"multistream_overlap_shared_expert":true,"dynamic_eplb":true,"num_iterations_eplb_update":2048,"num_wait_worker_iterations":200}'
|
||||
'{"multistream_overlap_shared_expert":true,"dynamic_eplb":true,"num_iterations_eplb_update":2048,"num_wait_worker_iterations":200}'
|
||||
-
|
||||
server_cmd: >
|
||||
vllm serve vllm-ascend/DeepSeek-R1-0528-W8A8
|
||||
@@ -173,7 +173,7 @@ deployment:
|
||||
}
|
||||
}'
|
||||
--additional-config
|
||||
'{"torchair_graph_config":{"enabled":true,"enable_multistream_mla":true,"graph_batch_sizes":[28],"use_cached_graph":true,"enable_super_kernel":false},"multistream_overlap_shared_expert":true,"dynamic_eplb":true,"num_iterations_eplb_update":2048,"num_wait_worker_iterations":200}'
|
||||
'{"multistream_overlap_shared_expert":true,"dynamic_eplb":true,"num_iterations_eplb_update":2048,"num_wait_worker_iterations":200}'
|
||||
benchmarks:
|
||||
perf:
|
||||
case_type: performance
|
||||
|
||||
@@ -57,7 +57,7 @@ deployment:
|
||||
}
|
||||
}'
|
||||
--additional-config
|
||||
'{"torchair_graph_config":{"enabled":false,"enable_multistream_shared_expert":false},"enable_prefill_optimizations":true,"enable_weight_nz_layout":true}'
|
||||
'{"enable_prefill_optimizations":true,"enable_weight_nz_layout":true}'
|
||||
|
||||
-
|
||||
server_cmd: >
|
||||
@@ -95,7 +95,7 @@ deployment:
|
||||
}
|
||||
}'
|
||||
--additional-config
|
||||
'{"torchair_graph_config":{"enabled":false,"enable_multistream_shared_expert":false},"enable_prefill_optimizations":true,"enable_weight_nz_layout":true}'
|
||||
'{"enable_prefill_optimizations":true,"enable_weight_nz_layout":true}'
|
||||
-
|
||||
server_cmd: >
|
||||
vllm serve vllm-ascend/DeepSeek-R1-0528-W8A8
|
||||
@@ -134,7 +134,7 @@ deployment:
|
||||
}
|
||||
}'
|
||||
--additional-config
|
||||
'{"torchair_graph_config":{"enabled":true,"enable_multistream_mla":true,"graph_batch_sizes":[28],"use_cached_graph":true,"enable_super_kernel":false},"multistream_overlap_shared_expert":true}'
|
||||
'{"multistream_overlap_shared_expert":true}'
|
||||
-
|
||||
server_cmd: >
|
||||
vllm serve vllm-ascend/DeepSeek-R1-0528-W8A8
|
||||
@@ -172,7 +172,7 @@ deployment:
|
||||
}
|
||||
}'
|
||||
--additional-config
|
||||
'{"torchair_graph_config":{"enabled":true,"enable_multistream_mla":true,"graph_batch_sizes":[28],"use_cached_graph":true,"enable_super_kernel":false},"multistream_overlap_shared_expert":true}'
|
||||
'{"multistream_overlap_shared_expert":true}'
|
||||
benchmarks:
|
||||
perf:
|
||||
case_type: performance
|
||||
|
||||
@@ -82,7 +82,6 @@ deployment:
|
||||
--trust-remote-code
|
||||
--no-enable-prefix-caching
|
||||
--gpu-memory-utilization 0.9
|
||||
--additional-config '{"torchair_graph_config":{"enabled":true}}'
|
||||
--kv-transfer-config
|
||||
'{"kv_connector": "MooncakeConnector",
|
||||
"kv_role": "kv_consumer",
|
||||
|
||||
@@ -29,7 +29,6 @@ deployment:
|
||||
--trust-remote-code
|
||||
--no-enable-prefix-caching
|
||||
--gpu-memory-utilization 0.9
|
||||
--additional-config '{"torchair_graph_config":{"enabled":true,"graph_batch_sizes":[16]}}'
|
||||
|
||||
-
|
||||
server_cmd: >
|
||||
@@ -49,5 +48,4 @@ deployment:
|
||||
--trust-remote-code
|
||||
--no-enable-prefix-caching
|
||||
--gpu-memory-utilization 0.92
|
||||
--additional-config '{"torchair_graph_config":{"enabled":true,"graph_batch_sizes":[16]}}'
|
||||
benchmarks:
|
||||
|
||||
@@ -1,106 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import CompilationConfig, CUDAGraphMode
|
||||
|
||||
from tests.e2e.conftest import VllmRunner
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sampling_config():
|
||||
return SamplingParams(temperature=0, max_tokens=256, ignore_eos=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_name():
|
||||
return "wemaster/deepseek_mtp_main_random_bf16"
|
||||
|
||||
|
||||
def mtp_torchair_correctness(
|
||||
sampling_config: SamplingParams,
|
||||
model_name: str,
|
||||
graph_mode: CUDAGraphMode = CUDAGraphMode.PIECEWISE,
|
||||
):
|
||||
example_prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
'''
|
||||
Compare the outputs of a original LLM and a speculative LLM
|
||||
should be the same when using mtp speculative decoding.
|
||||
'''
|
||||
with VllmRunner(model_name,
|
||||
tensor_parallel_size=1,
|
||||
gpu_memory_utilization=0.7,
|
||||
max_model_len=256,
|
||||
enforce_eager=False,
|
||||
additional_config={
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
"use_cached_graph": False,
|
||||
"graph_batch_sizes": [1, 2, 4],
|
||||
},
|
||||
"multistream_overlap_shared_expert": "True"
|
||||
}) as ref_llm:
|
||||
ref_outputs = ref_llm.generate(example_prompts, sampling_config)
|
||||
|
||||
graph_mode_str = "PIECEWISE"
|
||||
if graph_mode == CUDAGraphMode.FULL:
|
||||
graph_mode_str = "FULL"
|
||||
|
||||
with VllmRunner(model_name,
|
||||
tensor_parallel_size=1,
|
||||
max_num_seqs=256,
|
||||
gpu_memory_utilization=0.7,
|
||||
distributed_executor_backend="mp",
|
||||
enable_expert_parallel=True,
|
||||
speculative_config={
|
||||
"method": "mtp",
|
||||
"num_speculative_tokens": 1,
|
||||
},
|
||||
enforce_eager=False,
|
||||
max_model_len=2000,
|
||||
compilation_config=CompilationConfig(
|
||||
cudagraph_mode=graph_mode_str),
|
||||
additional_config={
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
"use_cached_graph": False,
|
||||
"graph_batch_sizes": [1, 2, 4],
|
||||
},
|
||||
"multistream_overlap_shared_expert": "True"
|
||||
}) as spec_llm:
|
||||
spec_outputs = spec_llm.generate(example_prompts, sampling_config)
|
||||
|
||||
matches = 0
|
||||
misses = 0
|
||||
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
||||
ref_token_ids = ref_output[0][0]
|
||||
spec_token_ids = spec_output[0][0]
|
||||
if ref_token_ids == spec_token_ids[:len(ref_token_ids)]:
|
||||
matches += 1
|
||||
else:
|
||||
misses += 1
|
||||
print(f"ref_output: {ref_output[1][0]}")
|
||||
print(f"spec_output: {spec_output[1][0]}")
|
||||
|
||||
# Heuristic: expect at least 66% of the prompts to match exactly
|
||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||
assert matches > int(0.66 * len(ref_outputs))
|
||||
|
||||
|
||||
def test_mtp_torchair_correctness_piecewise(
|
||||
sampling_config: SamplingParams,
|
||||
model_name: str,
|
||||
):
|
||||
mtp_torchair_correctness(sampling_config, model_name)
|
||||
|
||||
|
||||
def test_mtp_torchair_correctness_full(
|
||||
sampling_config: SamplingParams,
|
||||
model_name: str,
|
||||
):
|
||||
mtp_torchair_correctness(sampling_config, model_name, CUDAGraphMode.FULL)
|
||||
@@ -118,7 +118,6 @@ def mock_dist_env(mocker: MockerFixture):
|
||||
return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.fused_moe.get_ascend_config',
|
||||
return_value=MagicMock(
|
||||
torchair_graph_config=MagicMock(enabled=False),
|
||||
enable_multistream_moe=False,
|
||||
expert_map_path=None
|
||||
)), \
|
||||
|
||||
@@ -110,11 +110,6 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
|
||||
def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding,
|
||||
mock_custom_enabled,
|
||||
mock_soc_version, mock__c):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
|
||||
# Setup mock for custom kernel path
|
||||
|
||||
mock__c.rotary_embedding.return_value = self.query, self.key
|
||||
vllm_config = VllmConfig()
|
||||
model_config = ModelConfig(MODEL,
|
||||
@@ -139,9 +134,6 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
|
||||
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
|
||||
def test_rope_forward_oot_contiguous(self, mock_npu_rotary,
|
||||
mock_custom_enabled):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
|
||||
# Test contiguous path when custom is disabled
|
||||
non_contig_query = self.query.transpose(0, 1)
|
||||
non_contig_key = self.key.transpose(0, 1)
|
||||
@@ -165,9 +157,6 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
|
||||
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
|
||||
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
|
||||
def test_rope_forward_oot_with_offsets(self):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
|
||||
# Test that NotImplementedError is raised when offsets is provided
|
||||
offsets = torch.tensor([1, 2, 3])
|
||||
with self.assertRaises(NotImplementedError):
|
||||
@@ -190,9 +179,6 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
|
||||
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
|
||||
def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary,
|
||||
mock_custom_enabled):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
|
||||
# Test neox_style override
|
||||
vllm_config = VllmConfig()
|
||||
model_config = ModelConfig(MODEL,
|
||||
@@ -219,9 +205,6 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
|
||||
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
|
||||
def test_rope_forward_oot_rotary_dim_less_than_head_size(
|
||||
self, mock_npu_rotary, mock_custom_enabled):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
|
||||
# test case when rotary_dim < head_size
|
||||
org_rotary_dim = self.layer.rotary_dim
|
||||
self.layer.rotary_dim = self.layer.head_size // 2
|
||||
@@ -415,7 +398,6 @@ class TestAscendMRotaryEmbedding(unittest.TestCase):
|
||||
mrope_section=self.mrope_section)
|
||||
|
||||
self.mock_config = MagicMock()
|
||||
self.mock_config.torchair_graph_config.enabled = False
|
||||
|
||||
def _create_vllm_config(self):
|
||||
vllm_config = VllmConfig()
|
||||
|
||||
@@ -33,7 +33,6 @@ class TestAscendW8A8FusedMoEMethod(TestBase):
|
||||
mock_get_ep_group.return_value = mock_ep_group
|
||||
mock_ascend_config = Mock()
|
||||
|
||||
mock_ascend_config.torchair_graph_config = Mock(enabled=False)
|
||||
mock_ascend_config.enable_chunked_prefill = False
|
||||
mock_get_ascend_config.return_value = mock_ascend_config
|
||||
mock_mc2_group = Mock(device_group=0)
|
||||
|
||||
@@ -13,15 +13,10 @@
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import os
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config import ModelConfig, ParallelConfig, VllmConfig
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ascend_config import (_check_torchair_supported,
|
||||
check_ascend_config,
|
||||
clear_ascend_config, get_ascend_config,
|
||||
from vllm_ascend.ascend_config import (clear_ascend_config, get_ascend_config,
|
||||
init_ascend_config)
|
||||
|
||||
|
||||
@@ -45,17 +40,6 @@ class TestAscendConfig(TestBase):
|
||||
self.assertIsNone(ascend_config.expert_map_path)
|
||||
self.assertFalse(ascend_config.multistream_overlap_shared_expert)
|
||||
|
||||
torchair_graph_config = ascend_config.torchair_graph_config
|
||||
self.assertFalse(torchair_graph_config.enabled)
|
||||
self.assertEqual(torchair_graph_config.mode, '')
|
||||
self.assertFalse(torchair_graph_config.use_cached_graph)
|
||||
self.assertEqual(torchair_graph_config.graph_batch_sizes, [])
|
||||
self.assertFalse(torchair_graph_config.graph_batch_sizes_init)
|
||||
self.assertFalse(torchair_graph_config.enable_multistream_mla)
|
||||
self.assertTrue(torchair_graph_config.enable_view_optimize)
|
||||
self.assertTrue(torchair_graph_config.enable_frozen_parameter)
|
||||
self.assertFalse(torchair_graph_config.enable_kv_nz)
|
||||
|
||||
ascend_compilation_config = ascend_config.ascend_compilation_config
|
||||
self.assertTrue(ascend_compilation_config.enable_quantization_fusion)
|
||||
|
||||
@@ -63,16 +47,6 @@ class TestAscendConfig(TestBase):
|
||||
def test_init_ascend_config_with_additional_config(self):
|
||||
test_vllm_config = VllmConfig()
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
"use_cached_graph": True,
|
||||
"graph_batch_sizes": [1, 2, 4],
|
||||
"graph_batch_sizes_init": False,
|
||||
"enable_multistream_mla": True,
|
||||
"enable_view_optimize": True,
|
||||
"enable_frozen_parameter": True,
|
||||
"enable_kv_nz": True
|
||||
},
|
||||
"ascend_compilation_config": {
|
||||
"enable_quantization_fusion": False,
|
||||
},
|
||||
@@ -84,65 +58,9 @@ class TestAscendConfig(TestBase):
|
||||
self.assertEqual(ascend_config.expert_map_path, "test_expert_map_path")
|
||||
self.assertTrue(ascend_config.multistream_overlap_shared_expert)
|
||||
|
||||
torchair_graph_config = ascend_config.torchair_graph_config
|
||||
self.assertTrue(torchair_graph_config.enabled)
|
||||
self.assertTrue(torchair_graph_config.use_cached_graph)
|
||||
self.assertEqual(torchair_graph_config.graph_batch_sizes, [1, 2, 4])
|
||||
self.assertFalse(torchair_graph_config.graph_batch_sizes_init)
|
||||
self.assertTrue(torchair_graph_config.enable_multistream_mla)
|
||||
self.assertTrue(torchair_graph_config.enable_view_optimize)
|
||||
self.assertTrue(torchair_graph_config.enable_frozen_parameter)
|
||||
self.assertTrue(torchair_graph_config.enable_kv_nz)
|
||||
ascend_compilation_config = ascend_config.ascend_compilation_config
|
||||
self.assertFalse(ascend_compilation_config.enable_quantization_fusion)
|
||||
|
||||
@_clean_up_ascend_config
|
||||
def test_init_ascend_config_with_refresh(self):
|
||||
test_vllm_config = VllmConfig()
|
||||
ascend_config = init_ascend_config(test_vllm_config)
|
||||
self.assertFalse(ascend_config.torchair_graph_config.enabled)
|
||||
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
},
|
||||
}
|
||||
ascend_config = init_ascend_config(test_vllm_config)
|
||||
self.assertFalse(ascend_config.torchair_graph_config.enabled)
|
||||
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
},
|
||||
"refresh": True,
|
||||
}
|
||||
ascend_config = init_ascend_config(test_vllm_config)
|
||||
self.assertTrue(ascend_config.torchair_graph_config.enabled)
|
||||
|
||||
@_clean_up_ascend_config
|
||||
def test_init_ascend_config_with_wrong_input(self):
|
||||
test_vllm_config = VllmConfig()
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
"graph_batch_sizes": "fake_size",
|
||||
},
|
||||
"refresh": True,
|
||||
}
|
||||
with self.assertRaises(TypeError):
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": False,
|
||||
"graph_batch_sizes": [1, 2, 4, 8],
|
||||
"graph_batch_sizes_init": True,
|
||||
},
|
||||
"refresh": True,
|
||||
}
|
||||
with self.assertRaises(ValueError):
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
@_clean_up_ascend_config
|
||||
def test_get_ascend_config(self):
|
||||
test_vllm_config = VllmConfig()
|
||||
@@ -162,203 +80,3 @@ class TestAscendConfig(TestBase):
|
||||
clear_ascend_config()
|
||||
with self.assertRaises(RuntimeError):
|
||||
get_ascend_config()
|
||||
|
||||
@_clean_up_ascend_config
|
||||
def test_check_ascend_config_pass(self):
|
||||
test_vllm_config = VllmConfig()
|
||||
init_ascend_config(test_vllm_config)
|
||||
check_ascend_config(test_vllm_config, False)
|
||||
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
},
|
||||
"refresh": True
|
||||
}
|
||||
init_ascend_config(test_vllm_config)
|
||||
check_ascend_config(test_vllm_config, False)
|
||||
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": False,
|
||||
},
|
||||
"refresh": True
|
||||
}
|
||||
init_ascend_config(test_vllm_config)
|
||||
check_ascend_config(test_vllm_config, False)
|
||||
|
||||
@_clean_up_ascend_config
|
||||
def test_check_ascend_config_wrong_case(self):
|
||||
test_vllm_config = VllmConfig()
|
||||
|
||||
# torchair + eager mode
|
||||
with self.assertRaises(RuntimeError):
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
},
|
||||
"refresh": True
|
||||
}
|
||||
init_ascend_config(test_vllm_config)
|
||||
enforce_eager = True
|
||||
check_ascend_config(test_vllm_config, enforce_eager)
|
||||
# torchair + non deepseek model
|
||||
with self.assertRaises(NotImplementedError):
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
},
|
||||
"refresh": True
|
||||
}
|
||||
model_path = os.path.join(os.path.dirname(__file__), "fake_weight")
|
||||
fake_model_config = ModelConfig(model=model_path)
|
||||
fake_model_config.hf_config = PretrainedConfig()
|
||||
fake_model_config.hf_config.model_type = "llama"
|
||||
test_vllm_config.model_config = fake_model_config
|
||||
init_ascend_config(test_vllm_config)
|
||||
check_ascend_config(test_vllm_config, False)
|
||||
|
||||
def test_check_torchair_supported(self):
|
||||
test_cases = [('deepseek_v3', True), ('PanguProMoE', True),
|
||||
('qwen', True), ('llama', False)]
|
||||
for model_type, expected_output in test_cases:
|
||||
self.assertEqual(_check_torchair_supported(model_type),
|
||||
expected_output)
|
||||
|
||||
@_clean_up_ascend_config
|
||||
def test_ascend_config_load_error(self):
|
||||
test_vllm_config = VllmConfig()
|
||||
# graph_batch_sizes should be list.
|
||||
with self.assertRaises(TypeError):
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"graph_batch_sizes": "fake_size",
|
||||
},
|
||||
"refresh": True
|
||||
}
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
# use_cached_graph should not be enabled without torchair graph mode
|
||||
with self.assertRaises(RuntimeError):
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": False,
|
||||
"use_cached_graph": True,
|
||||
},
|
||||
"refresh": True
|
||||
}
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
# use_cached_kv_cache_bytes should not be enabled without torchair graph mode
|
||||
with self.assertRaises(RuntimeError):
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": False,
|
||||
"use_cached_kv_cache_bytes": True,
|
||||
},
|
||||
"refresh": True
|
||||
}
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
# graph_batch_sizes should not be set without torchair graph mode
|
||||
with self.assertRaises(RuntimeError):
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": False,
|
||||
"graph_batch_sizes": [1, 2, 4],
|
||||
},
|
||||
"refresh": True
|
||||
}
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
# use_cached_kv_cache_bytes is valid only when torchair graph mode and use_cached_graph are enabled
|
||||
with self.assertRaises(RuntimeError):
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
"use_cached_graph": False,
|
||||
"use_cached_kv_cache_bytes": True,
|
||||
},
|
||||
"refresh": True
|
||||
}
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
# graph_batch_sizes_init should not be enabled without torchair graph mode
|
||||
with self.assertRaises(RuntimeError):
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": False,
|
||||
"graph_batch_sizes_init": True,
|
||||
},
|
||||
"refresh": True
|
||||
}
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
# enable_multistream_mla should not be enabled without torchair graph mode
|
||||
with self.assertRaises(RuntimeError):
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": False,
|
||||
"enable_multistream_mla": True,
|
||||
},
|
||||
"refresh": True
|
||||
}
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
# mode should not be configured without torchair graph mode
|
||||
with self.assertRaises(RuntimeError):
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": False,
|
||||
"mode": 'max-autotune',
|
||||
},
|
||||
"refresh": True
|
||||
}
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
# enable_kv_nz should not be enabled without torchair graph mode
|
||||
with self.assertRaises(RuntimeError):
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": False,
|
||||
"enable_kv_nz": True,
|
||||
},
|
||||
"refresh": True
|
||||
}
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
test_vllm_config.additional_config = {
|
||||
"lmhead_tensor_parallel_size": 2,
|
||||
"refresh": True
|
||||
}
|
||||
test_vllm_config.parallel_config = ParallelConfig(
|
||||
data_parallel_size=4, tensor_parallel_size=2)
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
},
|
||||
"oproj_tensor_parallel_size": 2,
|
||||
"refresh": True
|
||||
}
|
||||
test_vllm_config.parallel_config = ParallelConfig(
|
||||
data_parallel_size=4, tensor_parallel_size=2)
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": False,
|
||||
},
|
||||
"oproj_tensor_parallel_size": 2,
|
||||
"refresh": True
|
||||
}
|
||||
test_vllm_config.parallel_config = ParallelConfig(
|
||||
data_parallel_size=4, tensor_parallel_size=1)
|
||||
model_path = os.path.join(os.path.dirname(__file__), "fake_weight")
|
||||
test_vllm_config.model_config = ModelConfig(model=model_path,
|
||||
enforce_eager=True)
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
@@ -31,7 +31,6 @@ class TestNPUPlatform(TestBase):
|
||||
@staticmethod
|
||||
def mock_vllm_ascend_config():
|
||||
mock_ascend_config = MagicMock()
|
||||
mock_ascend_config.torchair_graph_config.enabled = False
|
||||
mock_ascend_config.xlite_graph_config.enabled = False
|
||||
mock_ascend_config.enable_shared_expert_dp = False
|
||||
return mock_ascend_config
|
||||
@@ -403,47 +402,6 @@ class TestNPUPlatform(TestBase):
|
||||
CUDAGraphMode.NONE,
|
||||
)
|
||||
|
||||
@patch('vllm_ascend.utils.get_ascend_device_type',
|
||||
return_value=AscendDeviceType._910_93)
|
||||
@patch("vllm_ascend.utils.update_default_aclgraph_sizes")
|
||||
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||
@patch(
|
||||
"vllm_ascend.core.recompute_schedule_config.RecomputeSchedulerConfig.initialize_from_config"
|
||||
)
|
||||
def test_check_and_update_config_torchair_enabled_compilation(
|
||||
self, mock_init_recompute, mock_init_ascend, mock_check_ascend,
|
||||
mock_update_default, mock_soc_version):
|
||||
mock_update_default.return_value = MagicMock()
|
||||
mock_ascend_config = TestNPUPlatform.mock_vllm_ascend_config()
|
||||
mock_ascend_config.torchair_graph_config.enabled = True
|
||||
mock_init_ascend.return_value = mock_ascend_config
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
vllm_config.model_config.enforce_eager = False
|
||||
vllm_config.parallel_config.decode_context_parallel_size = 1
|
||||
vllm_config.parallel_config.prefill_context_parallel_size = 1
|
||||
vllm_config.parallel_config.tensor_parallel_size = 1
|
||||
mock_init_recompute.return_value = MagicMock()
|
||||
vllm_config.scheduler_config = MagicMock()
|
||||
|
||||
vllm_config.compilation_config.mode = CompilationMode.VLLM_COMPILE
|
||||
|
||||
with self.assertLogs(logger="vllm", level="INFO") as cm:
|
||||
from vllm_ascend import platform
|
||||
|
||||
importlib.reload(platform)
|
||||
self.platform.check_and_update_config(vllm_config)
|
||||
self.assertTrue("Torchair compilation enabled" in cm.output[0])
|
||||
|
||||
self.assertEqual(
|
||||
vllm_config.compilation_config.mode,
|
||||
CompilationMode.NONE,
|
||||
)
|
||||
self.assertEqual(
|
||||
vllm_config.compilation_config.cudagraph_mode,
|
||||
CUDAGraphMode.NONE,
|
||||
)
|
||||
|
||||
@patch('vllm_ascend.utils.get_ascend_device_type',
|
||||
return_value=AscendDeviceType._910_93)
|
||||
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
||||
@@ -503,16 +461,6 @@ class TestNPUPlatform(TestBase):
|
||||
"vllm_ascend.worker.worker_v1.NPUWorker",
|
||||
)
|
||||
|
||||
test_ascend_config = TestNPUPlatform.mock_vllm_ascend_config()
|
||||
test_ascend_config.torchair_graph_config.enabled = True
|
||||
mock_init_ascend.return_value = test_ascend_config
|
||||
vllm_config.parallel_config.worker_cls = "auto"
|
||||
self.platform.check_and_update_config(vllm_config)
|
||||
self.assertEqual(
|
||||
vllm_config.parallel_config.worker_cls,
|
||||
"vllm_ascend.torchair.torchair_worker.NPUTorchairWorker",
|
||||
)
|
||||
|
||||
test_ascend_config = TestNPUPlatform.mock_vllm_ascend_config()
|
||||
test_ascend_config.xlite_graph_config.enabled = True
|
||||
mock_init_ascend.return_value = test_ascend_config
|
||||
@@ -550,14 +498,7 @@ class TestNPUPlatform(TestBase):
|
||||
self.platform.check_and_update_config(vllm_config)
|
||||
self.assertEqual(vllm_config.compilation_config.custom_ops, [])
|
||||
|
||||
@patch('vllm_ascend.platform.get_ascend_config')
|
||||
def test_get_attn_backend_cls_use_v1_and_mla(self, mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
mock_config.enable_shared_expert_dp = False
|
||||
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
def test_get_attn_backend_cls_use_v1_and_mla(self):
|
||||
result = self.platform.get_attn_backend_cls(
|
||||
selected_backend="ascend",
|
||||
head_size=64,
|
||||
@@ -570,56 +511,7 @@ class TestNPUPlatform(TestBase):
|
||||
self.assertEqual(result,
|
||||
"vllm_ascend.attention.mla_v1.AscendMLABackend")
|
||||
|
||||
@patch('vllm_ascend.platform.get_ascend_config')
|
||||
def test_get_attn_backend_cls_use_v1_mla_and_torchair(
|
||||
self, mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = True
|
||||
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
result = self.platform.get_attn_backend_cls(
|
||||
selected_backend="ascend",
|
||||
head_size=64,
|
||||
dtype="float16",
|
||||
kv_cache_dtype="float16",
|
||||
block_size=64,
|
||||
#use_sfa=False,
|
||||
use_mla=True,
|
||||
)
|
||||
self.assertEqual(
|
||||
result,
|
||||
"vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend")
|
||||
|
||||
@patch('vllm_ascend.platform.get_ascend_config')
|
||||
def test_get_attn_backend_cls_use_v1_and_torchair(self,
|
||||
mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = True
|
||||
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
result = self.platform.get_attn_backend_cls(
|
||||
selected_backend="ascend",
|
||||
head_size=64,
|
||||
dtype="float16",
|
||||
kv_cache_dtype="float16",
|
||||
block_size=64,
|
||||
#use_sfa=False,
|
||||
use_mla=False,
|
||||
)
|
||||
self.assertEqual(
|
||||
result,
|
||||
"vllm_ascend.torchair.torchair_attention.AscendAttentionTorchairBackend"
|
||||
)
|
||||
|
||||
@patch('vllm_ascend.platform.get_ascend_config')
|
||||
def test_get_attn_backend_cls_use_v1_only(self, mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
def test_get_attn_backend_cls_use_v1_only(self):
|
||||
result = self.platform.get_attn_backend_cls(
|
||||
selected_backend="ascend",
|
||||
head_size=64,
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.distributed.parallel_state import GroupCoordinator
|
||||
|
||||
from tests.ut.base import PytestBase
|
||||
from vllm_ascend.torchair.models.qwen3_moe import CustomSparseMoeBlock
|
||||
|
||||
|
||||
class TestCustomSparseMoeBlock(PytestBase):
|
||||
|
||||
@pytest.fixture
|
||||
def setup_csmb(self, mocker: MockerFixture):
|
||||
config = PretrainedConfig(num_experts=64,
|
||||
hidden_size=2048,
|
||||
num_experts_per_tok=2,
|
||||
moe_intermediate_size=1408,
|
||||
norm_topk_prob=True)
|
||||
mocker.patch(
|
||||
'vllm_ascend.torchair.models.qwen3_moe.get_tensor_model_parallel_world_size',
|
||||
return_value=10)
|
||||
mocker.patch(
|
||||
'vllm.model_executor.layers.linear.ReplicatedLinear.__init__',
|
||||
return_value=None)
|
||||
mocker.patch(
|
||||
'vllm_ascend.torchair.ops.torchair_fused_moe.TorchairAscendFusedMoE.__init__',
|
||||
return_value=None)
|
||||
|
||||
tp_group = Mock(spec=GroupCoordinator)
|
||||
tp_group.rank_in_group = 0
|
||||
tp_group.world_size = 1
|
||||
tp_group.device_group = Mock()
|
||||
|
||||
dp_group = Mock(spec=GroupCoordinator)
|
||||
dp_group.rank_in_group = 0
|
||||
dp_group.world_size = 1
|
||||
|
||||
ep_group = Mock(spec=GroupCoordinator)
|
||||
ep_group.rank_in_group = 0
|
||||
ep_group.world_size = 1
|
||||
|
||||
mocker.patch('vllm_ascend.torchair.models.qwen3_moe.get_tp_group',
|
||||
return_value=tp_group)
|
||||
mocker.patch('vllm_ascend.torchair.models.qwen3_moe.get_dp_group',
|
||||
return_value=dp_group)
|
||||
mocker.patch('vllm_ascend.torchair.models.qwen3_moe.get_ep_group',
|
||||
return_value=ep_group)
|
||||
ascend_config = mocker.MagicMock()
|
||||
ascend_config.max_num_batched_tokens = 2048
|
||||
ascend_config.max_model_len = 1024
|
||||
mocker.patch("vllm_ascend.utils.get_ascend_config",
|
||||
return_value=ascend_config)
|
||||
|
||||
custom_moe_block = CustomSparseMoeBlock(config, None, "")
|
||||
return custom_moe_block
|
||||
|
||||
def test_init(self, mocker: MockerFixture, setup_csmb):
|
||||
custom_moe_block = setup_csmb
|
||||
assert isinstance(custom_moe_block, CustomSparseMoeBlock)
|
||||
@@ -1,206 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
from pytest_mock import MockerFixture
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
|
||||
from tests.ut.base import PytestBase
|
||||
from vllm_ascend.torchair.models.torchair_deepseek_mtp import (
|
||||
TorchairDeepSeekMTP, TorchairDeepSeekMultiTokenPredictor,
|
||||
TorchairDeepSeekMultiTokenPredictorLayer)
|
||||
|
||||
|
||||
class TestTorchairDeepSeekMultiTokenPredictorLayer(PytestBase):
|
||||
|
||||
@pytest.fixture
|
||||
def setup_mtp_layer(self, mocker: MockerFixture):
|
||||
config = PretrainedConfig(vocab_size=1000,
|
||||
hidden_size=768,
|
||||
rms_norm_eps=1e-5)
|
||||
mocker.patch(
|
||||
'vllm_ascend.torchair.models.torchair_deepseek_mtp.get_tensor_model_parallel_world_size',
|
||||
return_value=1)
|
||||
mocker.patch(
|
||||
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
|
||||
return_value=None)
|
||||
mocker.patch("vllm.model_executor.layers.layernorm.RMSNorm.__init__",
|
||||
return_value=None)
|
||||
mocker.patch(
|
||||
"vllm.model_executor.models.deepseek_mtp.SharedHead.__init__",
|
||||
return_value=None)
|
||||
mocker.patch(
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekShareHead.__init__",
|
||||
return_value=None)
|
||||
mocker_deepseek_v2_decode_layer = mocker.patch(
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_v2.TorchairDeepseekV2DecoderLayer.__init__",
|
||||
return_value=None)
|
||||
mocker.patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
|
||||
return_value=None)
|
||||
ascend_config = mocker.MagicMock()
|
||||
ascend_config.max_num_batched_tokens = 2048
|
||||
ascend_config.max_model_len = 1024
|
||||
mocker.patch("vllm_ascend.utils.get_ascend_config",
|
||||
return_value=ascend_config)
|
||||
|
||||
mtp_layer = TorchairDeepSeekMultiTokenPredictorLayer(config, "", None)
|
||||
mocker_deepseek_v2_decode_layer.assert_called_once()
|
||||
return mtp_layer
|
||||
|
||||
def test_init(self, mocker: MockerFixture, setup_mtp_layer):
|
||||
mtp_layer = setup_mtp_layer
|
||||
assert isinstance(mtp_layer, TorchairDeepSeekMultiTokenPredictorLayer)
|
||||
|
||||
def test_forward(self, mocker: MockerFixture, setup_mtp_layer):
|
||||
mtp_layer = setup_mtp_layer
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
mocker.patch.object(mtp_layer,
|
||||
'eh_proj',
|
||||
return_value=torch.randn(2, 3, 768))
|
||||
mocker.patch("torch.cat", return_value=torch.randn(2, 3, 768))
|
||||
mtp_layer.mtp_block.return_value = (torch.randn(2, 3, 768),
|
||||
torch.randn(2, 3, 768))
|
||||
mtp_layer.enorm.return_value = torch.randn(2, 3, 768)
|
||||
mtp_layer.hnorm.return_value = torch.randn(2, 3, 768)
|
||||
|
||||
input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]])
|
||||
positions = torch.tensor([[0, 1, 2], [0, 1, 2]])
|
||||
kv_cache = torch.randn(2, 3, 768)
|
||||
previous_hidden_states = torch.randn(2, 3, 768)
|
||||
inputs_embeds = torch.tensor([[1.0, 2.0, 3.0]])
|
||||
|
||||
output = mtp_layer(input_ids, positions, kv_cache, None,
|
||||
previous_hidden_states, inputs_embeds, 0)
|
||||
assert output.shape == (3, 768)
|
||||
|
||||
|
||||
class TestTorchairDeepSeekMultiTokenPredictor(PytestBase):
|
||||
|
||||
@pytest.fixture
|
||||
def setup_predictor(self, mocker: MockerFixture):
|
||||
mock_vllm_config = mocker.MagicMock(spec=VllmConfig)
|
||||
mock_model_config = mocker.MagicMock(spec=ModelConfig)
|
||||
mock_hf_config = mocker.MagicMock()
|
||||
mock_hf_config.num_hidden_layers = 12
|
||||
mock_hf_config.num_nextn_predict_layers = 3
|
||||
mock_hf_config.vocab_size = 30000
|
||||
mock_model_config.hf_config = mock_hf_config
|
||||
mock_vllm_config.model_config = mock_model_config
|
||||
mock_vllm_config.cache_config = CacheConfig()
|
||||
mock_vllm_config.quant_config = mocker.MagicMock()
|
||||
mocker.patch(
|
||||
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
|
||||
return_value=None)
|
||||
mocker.patch(
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMultiTokenPredictorLayer.__init__",
|
||||
return_value=None)
|
||||
mocker.patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
|
||||
return_value=None)
|
||||
ascend_config = mocker.MagicMock()
|
||||
ascend_config.max_num_batched_tokens = 2048
|
||||
ascend_config.max_model_len = 1024
|
||||
mocker.patch("vllm_ascend.utils.get_ascend_config",
|
||||
return_value=ascend_config)
|
||||
|
||||
predictor = TorchairDeepSeekMultiTokenPredictor(
|
||||
vllm_config=mock_vllm_config)
|
||||
return predictor
|
||||
|
||||
def test_init(self, mocker: MockerFixture, setup_predictor):
|
||||
predictor = setup_predictor
|
||||
assert predictor.num_mtp_layers == 3
|
||||
assert isinstance(predictor, TorchairDeepSeekMultiTokenPredictor)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'kv_caches, inputs_embeds',
|
||||
[(torch.tensor([[[0.1, 0.2, 0.3]]]), torch.tensor([[0.1, 0.2, 0.3]]))])
|
||||
def test_forward(self, mocker: MockerFixture, setup_predictor, kv_caches,
|
||||
inputs_embeds):
|
||||
predictor = setup_predictor
|
||||
mock_layer = mocker.MagicMock()
|
||||
mock_layer.return_value = torch.tensor([1.0, 2.0, 3.0])
|
||||
predictor.layers_list = [mock_layer]
|
||||
|
||||
# todo: need or not?
|
||||
# predictor.num_mtp_layers = 1
|
||||
input_ids = torch.tensor([[1, 2, 3]])
|
||||
positions = torch.tensor([[0, 1, 2]])
|
||||
mocker.patch(
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMultiTokenPredictorLayer.__call__",
|
||||
return_value=torch.tensor([[1.0, 2.0, 3.0]]))
|
||||
output = predictor.forward(input_ids, positions, kv_caches, None, None,
|
||||
inputs_embeds, 0)
|
||||
mock_layer.assert_called_once()
|
||||
assert torch.allclose(output, torch.tensor([1.0, 2.0, 3.0]))
|
||||
|
||||
def test_compute_logits(self, mocker: MockerFixture, setup_predictor):
|
||||
hidden_states = torch.tensor([[1, 2, 3], [4, 5, 6]])
|
||||
predictor = setup_predictor
|
||||
|
||||
mock_layer = mocker.MagicMock()
|
||||
mock_layer.return_value = torch.tensor([1.0, 2.0, 3.0])
|
||||
predictor.layers_list = [mock_layer]
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
mocker.patch(
|
||||
"vllm.model_executor.layers.logits_processor.LogitsProcessor.__init__",
|
||||
return_value=None)
|
||||
predictor.logits_processor.return_value = torch.tensor([1.0, 2.0, 3.0])
|
||||
|
||||
result_logits = predictor.compute_logits(hidden_states=hidden_states)
|
||||
predictor.logits_processor.assert_called_once()
|
||||
assert torch.allclose(result_logits, torch.tensor([1.0, 2.0, 3.0]))
|
||||
|
||||
|
||||
class TestTorchairDeepSeekMTP(PytestBase):
|
||||
|
||||
@pytest.fixture
|
||||
def setup_mtp(self, mocker: MockerFixture):
|
||||
vllm_config = mocker.MagicMock()
|
||||
vllm_config.model_config.hf_config.num_hidden_layers = 12
|
||||
vllm_config.model_config.hf_config.num_nextn_predict_layers = 3
|
||||
vllm_config.cache_config = mocker.MagicMock()
|
||||
vllm_config.quant_config = mocker.MagicMock()
|
||||
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
mocker.patch(
|
||||
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
|
||||
return_value=None)
|
||||
mocker.patch(
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMultiTokenPredictorLayer.__call__",
|
||||
return_value=None)
|
||||
mocker.patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
|
||||
return_value=None)
|
||||
ascend_config = mocker.MagicMock()
|
||||
ascend_config.max_num_batched_tokens = 2048
|
||||
ascend_config.max_model_len = 1024
|
||||
mocker.patch("vllm_ascend.utils.get_ascend_config",
|
||||
return_value=ascend_config)
|
||||
|
||||
mtp = TorchairDeepSeekMTP(vllm_config=vllm_config)
|
||||
return mtp
|
||||
|
||||
def test_init(self, mocker: MockerFixture, setup_mtp):
|
||||
mtp = setup_mtp
|
||||
assert isinstance(mtp, TorchairDeepSeekMTP)
|
||||
|
||||
def test_forward(self, mocker: MockerFixture, setup_mtp):
|
||||
input_ids = torch.tensor([[1, 2, 3]])
|
||||
positions = torch.tensor([[0, 1, 2]])
|
||||
kv_caches = [torch.tensor([[0.1, 0.2, 0.3]])]
|
||||
previous_hidden_states = torch.tensor([[0.1, 0.2, 0.3]])
|
||||
inputs_embeds = torch.tensor([[0.1, 0.2, 0.3]])
|
||||
spec_step_idx = 0
|
||||
setup_mtp.model.return_value = torch.tensor([[1.0, 2.0, 3.0]])
|
||||
|
||||
output = setup_mtp.forward(input_ids, positions, kv_caches, None,
|
||||
previous_hidden_states, inputs_embeds,
|
||||
spec_step_idx)
|
||||
assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]]))
|
||||
@@ -1,366 +0,0 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed.parallel_state import GroupCoordinator
|
||||
from vllm.transformers_utils.config import patch_rope_parameters
|
||||
|
||||
from vllm_ascend.torchair.models.torchair_deepseek_v2 import (
|
||||
TorchairDeepseekV2DecoderLayer, TorchairDeepseekV2ForCausalLM,
|
||||
TorchairDeepseekV2MergedReplicatedLinear, TorchairDeepseekV2MLAAttention,
|
||||
TorchairDeepseekV2MLP, TorchairDeepseekV2MoE,
|
||||
TorchairDeepseekV2RowParallelLinear,
|
||||
TorchairDeepseekV2RowParallelLinearReplaceAllreduce,
|
||||
TorchairDeepseekV2SiluAndMul)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_config():
|
||||
config = PretrainedConfig(
|
||||
hidden_size=128,
|
||||
num_attention_heads=8,
|
||||
num_hidden_layers=2,
|
||||
intermediate_size=256,
|
||||
hidden_act="silu",
|
||||
rms_norm_eps=1e-6,
|
||||
rope_theta=10000.0,
|
||||
max_position_embeddings=2048,
|
||||
n_routed_experts=4,
|
||||
n_shared_experts=1,
|
||||
moe_intermediate_size=256,
|
||||
num_experts_per_tok=2,
|
||||
routed_scaling_factor=1.0,
|
||||
first_k_dense_replace=0,
|
||||
moe_layer_freq=1,
|
||||
kv_lora_rank=16,
|
||||
qk_nope_head_dim=16,
|
||||
qk_rope_head_dim=16,
|
||||
v_head_dim=32,
|
||||
topk_method="noaux_tc",
|
||||
scoring_func="softmax",
|
||||
norm_topk_prob=True,
|
||||
n_group=1,
|
||||
topk_group=1,
|
||||
vocab_size=10000,
|
||||
)
|
||||
patch_rope_parameters(config)
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vllm_config(base_config):
|
||||
model_config = SimpleNamespace(
|
||||
hf_config=base_config,
|
||||
tensor_parallel_size=1,
|
||||
dtype=torch.float32,
|
||||
use_mla=False,
|
||||
quant_config=None,
|
||||
max_model_len=2048,
|
||||
)
|
||||
|
||||
cache_config = CacheConfig()
|
||||
vllm_config = Mock()
|
||||
vllm_config.model_config = model_config
|
||||
vllm_config.cache_config = cache_config
|
||||
vllm_config.quant_config = None
|
||||
return vllm_config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_distributed():
|
||||
tp_group = Mock(spec=GroupCoordinator)
|
||||
tp_group.rank_in_group = 0
|
||||
tp_group.world_size = 1
|
||||
tp_group.device_group = Mock()
|
||||
|
||||
dp_group = Mock(spec=GroupCoordinator)
|
||||
dp_group.rank_in_group = 0
|
||||
dp_group.world_size = 1
|
||||
|
||||
ep_group = Mock(spec=GroupCoordinator)
|
||||
ep_group.rank_in_group = 0
|
||||
ep_group.world_size = 1
|
||||
|
||||
pp_group = Mock(spec=GroupCoordinator)
|
||||
pp_group.rank_in_group = 0
|
||||
pp_group.world_size = 1
|
||||
|
||||
dcp_group = MagicMock(spec=GroupCoordinator)
|
||||
dcp_group.rank_in_group = 0
|
||||
dcp_group.world_size = 1
|
||||
dcp_group.device_group = MagicMock()
|
||||
|
||||
mlp_tp_group = Mock(spec=GroupCoordinator)
|
||||
mlp_tp_group.rank_in_group = 0
|
||||
mlp_tp_group.world_size = 1
|
||||
mlp_tp_group.all_gather = Mock(return_value=torch.randn(2, 4, 128))
|
||||
|
||||
mock_vllm_config = Mock()
|
||||
mock_vllm_config.scheduler_config = Mock(max_num_seqs=256)
|
||||
mock_vllm_config.model_config = Mock(max_model_len=2048, quant_config=None)
|
||||
|
||||
with patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_tensor_model_parallel_rank", return_value=0), \
|
||||
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_tensor_model_parallel_world_size", return_value=1), \
|
||||
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_tp_group", return_value=tp_group), \
|
||||
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_ep_group", return_value=ep_group), \
|
||||
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_dp_group", return_value=dp_group), \
|
||||
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_pp_group", return_value=pp_group), \
|
||||
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_pp_group",
|
||||
return_value=Mock(is_first_rank=False, is_last_rank=False)), \
|
||||
patch('vllm.distributed.parallel_state.get_dcp_group', return_value=dcp_group), \
|
||||
patch('vllm.distributed.parallel_state._DCP', new_callable=lambda: MagicMock(spec=GroupCoordinator)), \
|
||||
patch("vllm.distributed.get_decode_context_model_parallel_world_size", return_value=1),\
|
||||
patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
|
||||
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,
|
||||
_PP=pp_group), \
|
||||
patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_forward_context():
|
||||
forward_context = Mock(in_profile_run=False, with_prefill=False)
|
||||
with patch(
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_v2.get_forward_context",
|
||||
return_value=forward_context):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_attention_init():
|
||||
try:
|
||||
from vllm_ascend.torchair.models.torchair_deepseek_v2 import \
|
||||
DeepseekV2Attention
|
||||
original_init = DeepseekV2Attention.__init__
|
||||
|
||||
def patched_init(self, *args, **kwargs):
|
||||
kwargs.pop("decoder_layer", None)
|
||||
if 'vllm_config' not in kwargs:
|
||||
mock_vllm_config = Mock()
|
||||
mock_vllm_config.model_config = Mock()
|
||||
mock_vllm_config.model_config.hf_config = Mock()
|
||||
mock_vllm_config.model_config.hf_config.hidden_size = 128
|
||||
mock_vllm_config.model_config.dtype = torch.float32
|
||||
mock_vllm_config.model_config.quant_config = None
|
||||
mock_vllm_config.cache_config = CacheConfig()
|
||||
kwargs['vllm_config'] = mock_vllm_config
|
||||
return original_init(self, *args, **kwargs)
|
||||
|
||||
DeepseekV2Attention.__init__ = patched_init
|
||||
yield
|
||||
DeepseekV2Attention.__init__ = original_init
|
||||
except ImportError:
|
||||
yield
|
||||
|
||||
|
||||
def test_torchair_deepseek_v2_silu_and_mul():
|
||||
torch.set_default_device("cpu")
|
||||
|
||||
silu = TorchairDeepseekV2SiluAndMul()
|
||||
assert silu.weight_scale is None
|
||||
|
||||
x = torch.randn(2, 4)
|
||||
output = silu.forward_oot(x)
|
||||
assert output.shape == (2, 2)
|
||||
|
||||
weight_scale = Mock(return_value=torch.tensor(0.1))
|
||||
silu = TorchairDeepseekV2SiluAndMul(weight_scale=weight_scale)
|
||||
quant_x = torch.randint(-128, 127, (2, 4), dtype=torch.int32)
|
||||
dynamic_scale = torch.randn(2, 1)
|
||||
with patch("torch_npu.npu_dequant_swiglu_quant",
|
||||
return_value=torch.randn(2, 4)):
|
||||
output = silu.forward_oot((quant_x, dynamic_scale))
|
||||
assert output.shape == (2, 4)
|
||||
|
||||
|
||||
def test_torchair_deepseek_v2_merged_replicated_linear(mock_distributed):
|
||||
linear = TorchairDeepseekV2MergedReplicatedLinear(input_size=128,
|
||||
output_sizes=[64, 64],
|
||||
bias=False,
|
||||
quant_config=None)
|
||||
assert linear.output_sizes == [64, 64]
|
||||
|
||||
param = Mock()
|
||||
param.data = torch.zeros(128, 128)
|
||||
param.output_dim = 1
|
||||
param.is_gguf_weight = False
|
||||
param.is_gguf_weight_type = False
|
||||
loaded_weight = torch.randn(128, 64)
|
||||
linear.weight_loader(param, loaded_weight, loaded_shard_id=0)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
linear.weight_loader(param, torch.randn(128, 32), loaded_shard_id=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("cls", [
|
||||
TorchairDeepseekV2RowParallelLinearReplaceAllreduce,
|
||||
TorchairDeepseekV2RowParallelLinear
|
||||
])
|
||||
def test_row_parallel_linear(cls, mock_distributed, mock_forward_context):
|
||||
linear = cls(input_size=128, output_size=64, bias=False, quant_config=None)
|
||||
linear.quant_method = Mock()
|
||||
linear.quant_method.apply.return_value = torch.randn(2, 4, 64)
|
||||
|
||||
input_ = torch.randn(2, 4, 128)
|
||||
with patch(
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_v2.split_tensor_along_last_dim",
|
||||
return_value=[torch.randn(2, 4, 64)]):
|
||||
linear.input_is_parallel = False
|
||||
output = linear(input_, is_prefill=True)
|
||||
assert output[0].shape == (2, 4, 64)
|
||||
|
||||
linear.input_is_parallel = True
|
||||
output = linear(input_, is_prefill=False)
|
||||
assert output[0].shape == (2, 4, 64)
|
||||
|
||||
|
||||
def test_torchair_deepseek_v2_mlp(mock_distributed, base_config):
|
||||
mlp = TorchairDeepseekV2MLP(hidden_size=128,
|
||||
intermediate_size=256,
|
||||
hidden_act="silu",
|
||||
quant_config=None)
|
||||
assert isinstance(mlp.act_fn, TorchairDeepseekV2SiluAndMul)
|
||||
with patch(
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_v2.QuantizationConfig"
|
||||
) as mock_quant_config:
|
||||
mock_quant_config.name = "w8a8dynamic"
|
||||
with pytest.raises(NotImplementedError):
|
||||
TorchairDeepseekV2MLP(hidden_size=128,
|
||||
intermediate_size=256,
|
||||
hidden_act="silu",
|
||||
quant_config=mock_quant_config,
|
||||
force_replicate=False)
|
||||
with pytest.raises(ValueError):
|
||||
TorchairDeepseekV2MLP(hidden_size=128,
|
||||
intermediate_size=256,
|
||||
hidden_act="relu",
|
||||
quant_config=None)
|
||||
|
||||
|
||||
def test_torchair_deepseek_v2_moe(mock_distributed, base_config,
|
||||
mock_forward_context):
|
||||
base_config.n_shared_experts = 1
|
||||
moe = TorchairDeepseekV2MoE(config=base_config,
|
||||
quant_config=None,
|
||||
prefix="mlp")
|
||||
assert moe.top_k == 2
|
||||
|
||||
x = torch.randn(2, 4, 128)
|
||||
attn_metadata = Mock(num_prefills=1)
|
||||
with patch(
|
||||
"vllm_ascend.torchair.ops.torchair_fused_moe.TorchairAscendFusedMoE.__call__",
|
||||
return_value=(torch.randn(2, 4, 128), torch.randn(2, 4, 128))):
|
||||
output = moe(x, attn_metadata)
|
||||
assert output.shape == (2, 4, 128)
|
||||
|
||||
|
||||
@patch("torch_npu.npu_rms_norm")
|
||||
def test_torchair_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
|
||||
base_config):
|
||||
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
|
||||
|
||||
attn = TorchairDeepseekV2MLAAttention(config=base_config,
|
||||
hidden_size=128,
|
||||
num_heads=8,
|
||||
qk_nope_head_dim=16,
|
||||
qk_rope_head_dim=16,
|
||||
v_head_dim=32,
|
||||
q_lora_rank=16,
|
||||
kv_lora_rank=16,
|
||||
cache_config=CacheConfig(),
|
||||
quant_config=None,
|
||||
prefix="layers.0.self_attn")
|
||||
assert attn.debug_layer_idx == 0
|
||||
|
||||
x = torch.randn(2, 4, 128)
|
||||
positions = torch.arange(4).repeat(2, 1)
|
||||
with patch.object(attn.mla_attn,
|
||||
"__call__",
|
||||
return_value=torch.randn(2, 4, 128)):
|
||||
with pytest.raises(AssertionError):
|
||||
attn(positions, x)
|
||||
|
||||
attn = TorchairDeepseekV2MLAAttention(config=base_config,
|
||||
hidden_size=128,
|
||||
num_heads=8,
|
||||
qk_nope_head_dim=16,
|
||||
qk_rope_head_dim=16,
|
||||
v_head_dim=32,
|
||||
q_lora_rank=None,
|
||||
kv_lora_rank=16,
|
||||
prefix="layers.1.self_attn")
|
||||
assert hasattr(attn, "q_proj")
|
||||
|
||||
|
||||
@patch("torch_npu.npu_add_rms_norm")
|
||||
@patch("torch_npu.npu_rms_norm")
|
||||
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
|
||||
def test_torchair_deepseek_v2_decoder_layer(mock_maybe_wait_prefetch_done,
|
||||
mock_rms_norm, mock_add_norm,
|
||||
mock_distributed, base_config,
|
||||
vllm_config, mock_forward_context,
|
||||
patch_attention_init):
|
||||
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
|
||||
mock_add_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128),
|
||||
torch.randn(2, 128))
|
||||
base_config.n_routed_experts = 4
|
||||
layer = TorchairDeepseekV2DecoderLayer(
|
||||
config=base_config,
|
||||
prefix="layers.0",
|
||||
model_config=vllm_config.model_config,
|
||||
cache_config=CacheConfig(),
|
||||
quant_config=None)
|
||||
assert isinstance(layer.mlp, TorchairDeepseekV2MoE)
|
||||
|
||||
x = torch.randn(2, 4, 128)
|
||||
positions = torch.arange(4).repeat(2, 1)
|
||||
|
||||
with patch.object(layer.self_attn, "forward", Mock(return_value=torch.randn(2, 4, 128))), \
|
||||
patch.object(layer.mlp, "forward", Mock(return_value=torch.randn(2, 4, 128))):
|
||||
hidden_states, residual = layer(positions, x, None)
|
||||
assert hidden_states.shape == (2, 4, 128)
|
||||
|
||||
base_config.n_routed_experts = None
|
||||
layer = TorchairDeepseekV2DecoderLayer(
|
||||
config=base_config,
|
||||
prefix="layers.0",
|
||||
model_config=vllm_config.model_config,
|
||||
quant_config=None)
|
||||
assert isinstance(layer.mlp, TorchairDeepseekV2MLP)
|
||||
|
||||
|
||||
def test_torchair_deepseek_v2_for_causal_lm(mock_distributed, vllm_config,
|
||||
patch_attention_init):
|
||||
model = TorchairDeepseekV2ForCausalLM(vllm_config=vllm_config)
|
||||
|
||||
input_ids = torch.randint(0, 10000, (2, 4))
|
||||
positions = torch.arange(4).repeat(2, 1)
|
||||
with patch.object(model.model,
|
||||
"forward",
|
||||
return_value=torch.randn(2, 4, 128)):
|
||||
output = model(input_ids, positions)
|
||||
assert output.shape == (2, 4, 128)
|
||||
|
||||
weights = [("model.embed_tokens.weight", torch.randn(10000, 128))]
|
||||
with patch(
|
||||
"vllm.model_executor.model_loader.weight_utils.default_weight_loader"
|
||||
):
|
||||
loaded = model.load_weights(weights)
|
||||
assert loaded is not None
|
||||
@@ -1,423 +0,0 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from typing import List, TypedDict
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from pytest_mock import MockerFixture
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
|
||||
|
||||
import vllm_ascend
|
||||
from vllm_ascend.ascend_forward_context import get_fused_moe_state
|
||||
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
|
||||
from vllm_ascend.torchair.ops.torchair_fused_moe import (
|
||||
TorchairAscendFusedMoE, TorchairAscendUnquantizedFusedMoEMethod)
|
||||
from vllm_ascend.utils import adapt_patch # noqa E402
|
||||
from vllm_ascend.utils import AscendDeviceType
|
||||
|
||||
adapt_patch(True)
|
||||
|
||||
|
||||
def mock_ep_and_mc2_group(mocker):
|
||||
mock_group = mocker.MagicMock()
|
||||
mock_group.rank_in_group = 0
|
||||
mock_group.rank = 0
|
||||
mock_group.world_size = 4
|
||||
mock_group.device_group = "mock_group_ep"
|
||||
mock_group.all_to_all = MagicMock(return_value=torch.randn(8, 8))
|
||||
return mock_group
|
||||
|
||||
|
||||
def mock_dp_and_tp_group(mocker):
|
||||
mock_group = mocker.MagicMock()
|
||||
mock_group.rank_in_group = 0
|
||||
mock_group.world_size = 2
|
||||
mock_group.device_group = "mock_group"
|
||||
mock_group.all_gather = MagicMock(return_value=torch.randn(10, 32))
|
||||
return mock_group
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dist_env(mocker: MockerFixture):
|
||||
# init dist env patch
|
||||
dp_metadata = MagicMock(num_tokens_across_dp_cpu=[5, 5])
|
||||
|
||||
with patch('torch.npu.is_available', return_value=True), \
|
||||
patch('torch.distributed.get_rank', return_value=0), \
|
||||
patch('torch.distributed.get_world_size', return_value=4), \
|
||||
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('torch.distributed.all_gather', return_value=MagicMock(return_value=torch.randn(10,32))), \
|
||||
patch('torch.distributed.all_to_all_single', return_value=torch.randn(8, 32)), \
|
||||
patch('vllm_ascend.torchair.ops.torchair_fused_moe.tensor_model_parallel_all_reduce',
|
||||
return_value=torch.randn(5, 32)), \
|
||||
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
|
||||
return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_ascend_config',
|
||||
return_value=MagicMock(
|
||||
torchair_graph_config=MagicMock(enabled=False),
|
||||
enable_multistream_moe=False,
|
||||
enable_shared_expert_dp=False,
|
||||
expert_map_path=None,
|
||||
init_redundancy_expert=2,
|
||||
)), \
|
||||
patch('vllm_ascend.torchair.ops.torchair_fused_moe.determine_expert_map',
|
||||
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
|
||||
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context',
|
||||
return_value=MagicMock(
|
||||
max_tokens_across_dp=10,
|
||||
dp_metadata=dp_metadata,
|
||||
)), \
|
||||
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_current_vllm_config',
|
||||
return_value=MagicMock(
|
||||
parallel_config=MagicMock(tensor_parallel_size=2),
|
||||
scheduler_config=MagicMock(max_num_seqs=4),
|
||||
model_config=MagicMock(max_model_len=2048)
|
||||
)):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_moe_env(mocker: MockerFixture):
|
||||
# init moe env patch
|
||||
|
||||
with patch('torch_npu.npu_moe_gating_top_k', return_value=(
|
||||
torch.randn(8, 2),
|
||||
torch.randint(0, 8, (8, 2)),
|
||||
None
|
||||
)), \
|
||||
patch('torch_npu.npu_moe_init_routing', return_value=(
|
||||
torch.randn(8, 2),
|
||||
torch.randint(0, 8, (8, 2)),
|
||||
torch.tensor([0, 1, 2, 4, 6, 2, 7, 1])
|
||||
)), \
|
||||
patch("torch_npu.npu_moe_compute_expert_tokens", return_value=(
|
||||
torch.randn(8, 2)
|
||||
)), \
|
||||
patch("torch_npu.npu_moe_distribute_dispatch", return_value=(
|
||||
torch.randn(16, 2)
|
||||
)), \
|
||||
patch("torch_npu.npu_moe_distribute_combine", return_value=(
|
||||
torch.randn(16, 2)
|
||||
)), \
|
||||
patch("torch_npu.npu_grouped_matmul", return_value=(
|
||||
[torch.randn(16, 2)]
|
||||
)), \
|
||||
patch("torch_npu.npu_swiglu", return_value=(
|
||||
torch.randn(16, 2)
|
||||
)), \
|
||||
patch("torch_npu.npu_moe_gating_top_k_softmax", return_value=(
|
||||
torch.randn(8, 2),
|
||||
torch.randint(0, 8, (8, 2)),
|
||||
torch.tensor([0, 1, 2, 4, 6, 2, 7, 1])
|
||||
)), \
|
||||
patch("torch_npu.npu_moe_finalize_routing", return_value=(
|
||||
torch.randn(16, 2)
|
||||
)):
|
||||
if hasattr(torch_npu, 'npu_moe_distribute_dispatch_v2'):
|
||||
with patch("torch_npu.npu_moe_distribute_dispatch_v2", return_value=(
|
||||
torch.randn(16, 2))), \
|
||||
patch("torch_npu.npu_moe_distribute_combine_v2", return_value=(
|
||||
torch.randn(16, 2))):
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_moe_config():
|
||||
"""default moe config"""
|
||||
return {
|
||||
'num_experts': 8,
|
||||
'top_k': 2,
|
||||
'hidden_size': 512,
|
||||
'intermediate_size': 1024
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def moe_method(mock_dist_env):
|
||||
moe = MagicMock()
|
||||
moe.moe_parallel_config.return_value = MagicMock(ep_size=4)
|
||||
moe.moe_parallel_config.use_ep = False
|
||||
moe.moe_parallel_config.dp_size = 1
|
||||
return TorchairAscendUnquantizedFusedMoEMethod(moe)
|
||||
|
||||
|
||||
class Device(TypedDict):
|
||||
device_id: int
|
||||
device_expert: List[int]
|
||||
|
||||
|
||||
class Layer(TypedDict):
|
||||
layer_id: int
|
||||
device_count: int
|
||||
device_list: List[Device]
|
||||
|
||||
|
||||
class MockData(TypedDict):
|
||||
moe_layer_count: int
|
||||
layer_list: List[Layer]
|
||||
|
||||
|
||||
class MockQuantMethod(nn.Module):
|
||||
|
||||
def __init__(self, shared_experts, num_tokens):
|
||||
super().__init__()
|
||||
if shared_experts:
|
||||
self.apply = MagicMock(return_value=(torch.randn(num_tokens, 32),
|
||||
torch.randn(num_tokens, 10)))
|
||||
else:
|
||||
self.apply = MagicMock(return_value=(torch.randn(num_tokens, 32)))
|
||||
|
||||
|
||||
class MockFusedMoEMethod(FusedMoEMethodBase):
|
||||
moe = MagicMock()
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(self.moe)
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
pass
|
||||
|
||||
def apply(self, hidden_states: torch.Tensor,
|
||||
expert_weights: torch.Tensor) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
def get_fused_moe_quant_config(self, layer: torch.nn.Module):
|
||||
pass
|
||||
|
||||
|
||||
class TestTorchairAscendFusedMoe:
|
||||
|
||||
@pytest.fixture
|
||||
def test_init_no_quant(self, mock_dist_env, default_moe_config):
|
||||
layer = TorchairAscendFusedMoE(**default_moe_config)
|
||||
|
||||
layer.w13_weight = nn.Parameter(
|
||||
torch.randn(default_moe_config['num_experts'],
|
||||
default_moe_config['intermediate_size'] * 2,
|
||||
default_moe_config['hidden_size']))
|
||||
layer.w2_weight = nn.Parameter(
|
||||
torch.randn(default_moe_config['num_experts'],
|
||||
default_moe_config['hidden_size'],
|
||||
default_moe_config['intermediate_size']))
|
||||
|
||||
assert layer.num_experts == default_moe_config['num_experts']
|
||||
assert layer.top_k == default_moe_config['top_k']
|
||||
assert hasattr(layer, 'w13_weight')
|
||||
assert hasattr(layer, 'w2_weight')
|
||||
|
||||
# check group_topk
|
||||
with pytest.raises(AssertionError):
|
||||
error_config = default_moe_config.copy()
|
||||
error_config['use_grouped_topk'] = True
|
||||
layer = TorchairAscendFusedMoE(**error_config)
|
||||
|
||||
# check scoring_func
|
||||
with pytest.raises(ValueError):
|
||||
error_config = default_moe_config.copy()
|
||||
error_config['scoring_func'] = "random"
|
||||
layer = TorchairAscendFusedMoE(**error_config)
|
||||
|
||||
@pytest.fixture
|
||||
def test_init_with_quant(self, mock_dist_env, default_moe_config):
|
||||
mock_quant_config = MagicMock()
|
||||
mock_quant_method = MockFusedMoEMethod()
|
||||
mock_quant_config.get_quant_method.return_value = mock_quant_method
|
||||
mock_quant_config.is_layer_skipped_ascend.return_value = False
|
||||
with patch("vllm_ascend.quantization.quant_config.get_quant_method"):
|
||||
moe = TorchairAscendFusedMoE(**default_moe_config,
|
||||
quant_config=mock_quant_config)
|
||||
assert moe.quant_method is not None
|
||||
assert isinstance(moe.quant_method, AscendFusedMoEMethod)
|
||||
|
||||
@pytest.fixture
|
||||
def test_init_with_mixed_quant(self, mock_dist_env, default_moe_config):
|
||||
mock_quant_config = MagicMock()
|
||||
mock_quant_method = MockFusedMoEMethod()
|
||||
mock_quant_config.get_quant_method.return_value = mock_quant_method
|
||||
mock_quant_config.is_layer_skipped_ascend.return_value = True
|
||||
|
||||
moe = TorchairAscendFusedMoE(**default_moe_config,
|
||||
quant_config=mock_quant_config)
|
||||
|
||||
assert moe.quant_method is not None
|
||||
assert isinstance(moe.quant_method,
|
||||
TorchairAscendUnquantizedFusedMoEMethod)
|
||||
|
||||
@pytest.fixture
|
||||
@pytest.mark.parametrize(
|
||||
"others_param",
|
||||
[[None,
|
||||
MagicMock(return_value=torch.randn(5, 32)), False, 5, None],
|
||||
[2, None, False, 5, None], [None, None, True, 5, None],
|
||||
[None, None, False, 1, None], [None, None, True, 5, 1],
|
||||
[None, None, False, 5, 1]])
|
||||
def test_forward(self, mock_dist_env, default_moe_config, others_param):
|
||||
"""
|
||||
1 test has shared_experts
|
||||
2 test has top_k
|
||||
3 test is_prefill is true
|
||||
4 test single num_tokens(decode)
|
||||
5 test ep_size is 1 and is_prefill is true
|
||||
6 test ep_size is 1 and is_prefill is False
|
||||
"""
|
||||
top_k, shared_experts, is_prefill, num_tokens, ep_size = others_param
|
||||
inputs = torch.randn(num_tokens, 32)
|
||||
router_logits = torch.randn(num_tokens, 8)
|
||||
moe = TorchairAscendFusedMoE(**default_moe_config)
|
||||
|
||||
if ep_size == 1:
|
||||
moe.moe_parallel_config.ep_size = 1
|
||||
|
||||
moe.quant_method = MockQuantMethod(shared_experts, num_tokens)
|
||||
forward_context = MagicMock(mc2_mask=torch.zeros(num_tokens,
|
||||
dtype=torch.bool),
|
||||
padded_num_tokens=num_tokens)
|
||||
with patch(
|
||||
"vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context",
|
||||
return_value=forward_context):
|
||||
output = moe.forward(inputs,
|
||||
router_logits,
|
||||
is_prefill=is_prefill,
|
||||
top_k=top_k,
|
||||
shared_experts=shared_experts)
|
||||
|
||||
moe.quant_method.apply.assert_called_once()
|
||||
|
||||
if shared_experts:
|
||||
assert output[0].shape == (num_tokens, 32)
|
||||
assert output[1].shape == (num_tokens, 10)
|
||||
else:
|
||||
assert output.shape == (num_tokens, 32)
|
||||
|
||||
@pytest.fixture
|
||||
def test_forward_ms_fused_moe_comp(self, mock_dist_env,
|
||||
default_moe_config):
|
||||
inputs = torch.randn(5, 32)
|
||||
router_logits = torch.randn(5, 8)
|
||||
moe = TorchairAscendFusedMoE(**default_moe_config)
|
||||
|
||||
moe.quant_method = MockQuantMethod(None, 5)
|
||||
output = moe._forward_ms_fused_moe_comp(inputs,
|
||||
router_logits,
|
||||
is_prefill=False,
|
||||
real_top_k=1)
|
||||
|
||||
moe.quant_method.apply.assert_called_once()
|
||||
|
||||
assert output.shape == (5, 32)
|
||||
|
||||
|
||||
class TestTorchairAscendUnquantizedFusedMoEMethod:
|
||||
|
||||
def test_process_weights_after_loading(self, moe_method, mock_dist_env):
|
||||
layer = MagicMock()
|
||||
layer.w13_weight.data = torch.randn(16, 32)
|
||||
layer.w2_weight.data = torch.randn(16, 32)
|
||||
|
||||
moe_method.process_weights_after_loading(layer)
|
||||
|
||||
assert isinstance(layer.w13_weight, torch.nn.Parameter)
|
||||
assert isinstance(layer.w2_weight, torch.nn.Parameter)
|
||||
assert not layer.w13_weight.requires_grad
|
||||
assert not layer.w2_weight.requires_grad
|
||||
|
||||
@pytest.mark.parametrize("others_param",
|
||||
[[256, 4], [128, 1], [128, 1], [128, 4]])
|
||||
def test_apply_without_expert_map(self, moe_method, mock_dist_env,
|
||||
mock_moe_env, others_param):
|
||||
"""
|
||||
1 test is_deepseek_v3_r1=true and use fused_experts_with_all2all
|
||||
2 test use_select_experts and fused_experts
|
||||
3 test use select_gating_topk_softmax_experts and fused_experts
|
||||
4 test use select_experts and fused_experts_with_all2all_buffer
|
||||
"""
|
||||
global_num_experts, ep_size = others_param
|
||||
is_prefill = False
|
||||
global_redundant_expert_num = vllm_ascend.torchair.ops.torchair_fused_moe.get_ascend_config(
|
||||
).init_redundancy_expert
|
||||
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
|
||||
forward_context = MagicMock(fused_moe_state=get_fused_moe_state(
|
||||
ep_size, is_prefill, is_deepseek_v3_r1))
|
||||
with patch(
|
||||
"vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context",
|
||||
return_value=forward_context):
|
||||
moe_method.ep_size = ep_size
|
||||
x = torch.randn(8, 2, 2)
|
||||
router_logits = torch.randn(8, 8)
|
||||
layer = MagicMock()
|
||||
layer.w13_weight = torch.randn(8, 16, 1)
|
||||
layer.w2_weight = torch.randn(16, 8, 1)
|
||||
result = moe_method.apply(layer=layer,
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
top_k=2,
|
||||
renormalize=True,
|
||||
global_num_experts=global_num_experts,
|
||||
is_prefill=is_prefill)
|
||||
|
||||
if ep_size == 1:
|
||||
assert result.shape == (16, 2)
|
||||
else:
|
||||
assert result.shape == x.shape
|
||||
|
||||
@pytest.mark.parametrize("others_param", [16, 1, 4])
|
||||
def test_apply_with_expert_map(self, moe_method, mock_dist_env,
|
||||
mock_moe_env, others_param):
|
||||
"""
|
||||
1 test use_select_experts and use fused_expters_with_mc2
|
||||
2 test use_select_experts and fused_experts_with_all2all_buffer
|
||||
3 test use_select_experts and fused_experts_with_all2all
|
||||
4 test use_select_experts and fused_experts
|
||||
"""
|
||||
ep_size = others_param
|
||||
is_prefill = False
|
||||
forward_context = MagicMock(
|
||||
fused_moe_state=get_fused_moe_state(ep_size, is_prefill, True))
|
||||
with patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context", return_value=forward_context), \
|
||||
patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_ascend_device_type", return_value=AscendDeviceType._910_93):
|
||||
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
|
||||
moe_method.ep_size = ep_size
|
||||
x = torch.randn(8, 2, 2)
|
||||
if ep_size == 1:
|
||||
x = x.view(-1, 2)
|
||||
router_logits = torch.randn(8, 8)
|
||||
layer = MagicMock()
|
||||
layer.w13_weight = torch.randn(8, 16, 1)
|
||||
layer.w2_weight = torch.randn(16, 8, 1)
|
||||
result = moe_method.apply(layer=layer,
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
top_k=2,
|
||||
renormalize=True,
|
||||
global_num_experts=128,
|
||||
expert_map=expert_map,
|
||||
is_prefill=is_prefill)
|
||||
|
||||
if ep_size == 16 or ep_size == 1:
|
||||
assert result.shape == (16, 2)
|
||||
else:
|
||||
assert result.shape == x.shape
|
||||
@@ -1,333 +0,0 @@
|
||||
import math
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.torchair.ops.torchair_rotary_embedding import (
|
||||
_set_cos_sin_cache, custom_rotary_embedding_enabled,
|
||||
native_rope_deepseek_forward, rope_forward_oot, rotate_half,
|
||||
yarn_find_correction_dim, yarn_get_mscale)
|
||||
from vllm_ascend.utils import AscendDeviceType
|
||||
|
||||
|
||||
class TestCustomRotaryEmbeddingEnabled(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
# Common setup for tests
|
||||
self.positions = torch.tensor([1, 2, 3])
|
||||
self.query = torch.randn(3, 4, dtype=torch.float16)
|
||||
self.key = torch.randn(3, 4, dtype=torch.float16)
|
||||
self.head_size = 32
|
||||
self.cos_sin_cache = torch.randn(3, 4)
|
||||
|
||||
# Mock self object for rope_forward_oot
|
||||
self.mock_self = MagicMock()
|
||||
self.mock_self.head_size = self.head_size
|
||||
self.mock_self.cos_sin_cache = self.cos_sin_cache
|
||||
self.mock_self.is_neox_style = True
|
||||
self.mock_self.forward_native.return_value = (self.query, self.key)
|
||||
|
||||
def test_custom_rotary_embedding_enabled(self):
|
||||
# Test when all conditions are True
|
||||
with patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op',
|
||||
return_value=True):
|
||||
result = custom_rotary_embedding_enabled(self.query, True,
|
||||
self.head_size)
|
||||
self.assertTrue(result)
|
||||
|
||||
# Test when dtype is not float16
|
||||
with patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op',
|
||||
return_value=True):
|
||||
query = self.query.to(torch.float32)
|
||||
result = custom_rotary_embedding_enabled(query, True,
|
||||
self.head_size)
|
||||
self.assertFalse(result)
|
||||
|
||||
# Test when neox_style is False
|
||||
with patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op',
|
||||
return_value=True):
|
||||
result = custom_rotary_embedding_enabled(self.query, False,
|
||||
self.head_size)
|
||||
self.assertFalse(result)
|
||||
|
||||
# Test when head_size is not divisible by 32
|
||||
with patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op',
|
||||
return_value=True):
|
||||
result = custom_rotary_embedding_enabled(self.query, True,
|
||||
self.head_size + 1)
|
||||
self.assertFalse(result)
|
||||
|
||||
# Test when custom op is disabled
|
||||
with patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op',
|
||||
return_value=False):
|
||||
result = custom_rotary_embedding_enabled(self.query, True,
|
||||
self.head_size)
|
||||
self.assertFalse(result)
|
||||
|
||||
|
||||
class TestRopeForwardOot(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
# Common setup for tests
|
||||
self.positions = torch.tensor([1, 2, 3])
|
||||
self.query = torch.randn(3, 4, dtype=torch.float16)
|
||||
self.key = torch.randn(3, 4, dtype=torch.float16)
|
||||
self.head_size = 32
|
||||
self.cos_sin_cache = torch.randn(3, 4)
|
||||
|
||||
# Mock self object for rope_forward_oot
|
||||
self.mock_self = MagicMock()
|
||||
self.mock_self.head_size = self.head_size
|
||||
self.mock_self.cos_sin_cache = self.cos_sin_cache
|
||||
self.mock_self.is_neox_style = True
|
||||
self.mock_self.forward_native.return_value = (self.query, self.key)
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
|
||||
def test_rope_forward_oot_torchair_enabled_base(self,
|
||||
mock_get_ascend_config):
|
||||
# Setup mock for torchair enabled
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = True
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
result_q, result_k = rope_forward_oot(self.mock_self, self.positions,
|
||||
self.query, self.key)
|
||||
|
||||
self.mock_self.forward_native.assert_called_once_with(
|
||||
self.positions, self.query, self.key, None)
|
||||
self.assertTrue(torch.equal(result_q, self.query))
|
||||
self.assertTrue(torch.equal(result_k, self.key))
|
||||
|
||||
@patch('torch.ops._C_ascend')
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
|
||||
@patch('vllm_ascend.utils.get_ascend_device_type',
|
||||
return_value=AscendDeviceType._910_93)
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.custom_rotary_embedding_enabled',
|
||||
return_value=True)
|
||||
@patch('torch.ops._npu_rotary_embedding')
|
||||
def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding,
|
||||
mock_custom_enabled,
|
||||
mock_soc_version,
|
||||
mock_get_ascend_config, mock__c):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
# Setup mock for custom kernel path
|
||||
|
||||
mock__c.rotary_embedding.return_value = self.query, self.key
|
||||
|
||||
result_q, result_k = rope_forward_oot(self.mock_self, self.positions,
|
||||
self.query, self.key)
|
||||
|
||||
self.assertEqual(result_q.shape, self.query.shape)
|
||||
self.assertEqual(result_k.shape, self.key.shape)
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.custom_rotary_embedding_enabled',
|
||||
return_value=False)
|
||||
@patch('torch_npu._npu_rotary_embedding')
|
||||
def test_rope_forward_oot_contiguous(self, mock_npu_rotary,
|
||||
mock_custom_enabled,
|
||||
mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
# Test contiguous path when custom is disabled
|
||||
non_contig_query = self.query.transpose(0, 1)
|
||||
non_contig_key = self.key.transpose(0, 1)
|
||||
|
||||
result_q, result_k = rope_forward_oot(self.mock_self, self.positions,
|
||||
non_contig_query, non_contig_key)
|
||||
|
||||
mock_npu_rotary.assert_called_once()
|
||||
self.assertEqual(result_q.shape, non_contig_query.shape)
|
||||
self.assertEqual(result_k.shape, non_contig_key.shape)
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
|
||||
def test_rope_forward_oot_with_offsets(self, mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
# Test that NotImplementedError is raised when offsets is provided
|
||||
offsets = torch.tensor([1, 2, 3])
|
||||
with self.assertRaises(NotImplementedError):
|
||||
rope_forward_oot(self.mock_self, self.positions, self.query,
|
||||
self.key, offsets)
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.custom_rotary_embedding_enabled',
|
||||
return_value=False)
|
||||
@patch('torch_npu._npu_rotary_embedding')
|
||||
def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary,
|
||||
mock_custom_enabled,
|
||||
mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
# Test neox_style override
|
||||
result_q, result_k = rope_forward_oot(self.mock_self,
|
||||
self.positions,
|
||||
self.query,
|
||||
self.key,
|
||||
is_neox_style_override=False)
|
||||
|
||||
# Check that neox_style=False was passed to the NPU function
|
||||
args, kwargs = mock_npu_rotary.call_args
|
||||
self.assertFalse(args[-1])
|
||||
|
||||
|
||||
class MockRopeModule:
|
||||
|
||||
def __init__(self, max_seq_len=2048, is_neox_style=True):
|
||||
self.max_seq_len = max_seq_len
|
||||
self.is_neox_style = is_neox_style
|
||||
self.cos_cached = None
|
||||
self.sin_cached = None
|
||||
self.rotary_dim = 1
|
||||
self.base = 1
|
||||
self.beta_fast = 32
|
||||
self.beta_slow = 1
|
||||
self.max_position_embeddings = 4096
|
||||
self.mscale = 1.0
|
||||
self.scaling_factor = 40
|
||||
|
||||
def register_buffer(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestSetSinCosCache(TestBase):
|
||||
|
||||
def test_set_cos_sin_cache(self):
|
||||
module = MockRopeModule()
|
||||
|
||||
with patch.object(module, "register_buffer") as mock_register_buffer:
|
||||
_set_cos_sin_cache(module,
|
||||
1024,
|
||||
device="cpu",
|
||||
dtype=torch.bfloat16)
|
||||
|
||||
mock_register_buffer.assert_called()
|
||||
|
||||
|
||||
class TestNativeRopeDeepseekForward(TestBase):
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot')
|
||||
def test_native_rope_deepseek_forward_base(self, mock_rope_forward_oot):
|
||||
module = MockRopeModule()
|
||||
positions = torch.tensor([1, 2, 3])
|
||||
query = torch.randn(1, 8, 128)
|
||||
key = torch.randn(1, 8, 128)
|
||||
|
||||
mock_rope_forward_oot.return_value = (query, key)
|
||||
|
||||
q_pe, k_pe = native_rope_deepseek_forward(module, positions, query,
|
||||
key)
|
||||
|
||||
assert q_pe.shape == query.shape
|
||||
assert k_pe.shape == key.shape
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot')
|
||||
def test_native_rope_deepseek_forward_key_reshaping(
|
||||
self, mock_rope_forward_oot):
|
||||
module = MockRopeModule()
|
||||
positions = torch.tensor([1, 2, 3])
|
||||
query = torch.randn(1, 8, 128)
|
||||
key = torch.randn(1, 128)
|
||||
|
||||
mock_rope_forward_oot.return_value = (query, key)
|
||||
|
||||
q_pe, k_pe = native_rope_deepseek_forward(module, positions, query,
|
||||
key)
|
||||
|
||||
assert q_pe.shape == query.shape
|
||||
assert k_pe.shape == (1, 128)
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot')
|
||||
def test_native_rope_deepseek_forward_non_neox_style(
|
||||
self, mock_rope_forward_oot):
|
||||
module = MockRopeModule(is_neox_style=False)
|
||||
positions = torch.tensor([1, 2, 3])
|
||||
query = torch.randn(1, 8, 128)
|
||||
key = torch.randn(1, 8, 128)
|
||||
|
||||
mock_rope_forward_oot.return_value = (query, key)
|
||||
|
||||
q_pe, k_pe = native_rope_deepseek_forward(module, positions, query,
|
||||
key)
|
||||
|
||||
assert q_pe.shape == query.shape
|
||||
assert k_pe.shape == key.shape
|
||||
|
||||
|
||||
class TestRotateHalf(TestBase):
|
||||
|
||||
def test_rotate_half_even_dim(self):
|
||||
# Test with even dimension
|
||||
x = torch.tensor([1.0, 2.0, 3.0, 4.0])
|
||||
expected = torch.tensor([-3.0, -4.0, 1.0, 2.0])
|
||||
result = rotate_half(x)
|
||||
self.assertTrue(torch.allclose(result, expected))
|
||||
|
||||
|
||||
class TestYarnFindCorrectionDim(TestBase):
|
||||
|
||||
def test_basic_case(self):
|
||||
# Test with standard values
|
||||
num_rotations = 100
|
||||
dim = 512
|
||||
base = 10000
|
||||
max_position_embeddings = 2048
|
||||
|
||||
result = yarn_find_correction_dim(num_rotations, dim, base,
|
||||
max_position_embeddings)
|
||||
|
||||
# Calculate expected value manually
|
||||
expected = (dim * torch.log(
|
||||
torch.tensor(max_position_embeddings) /
|
||||
(num_rotations * 2 * torch.pi))) / (2 *
|
||||
torch.log(torch.tensor(base)))
|
||||
|
||||
self.assertTrue(torch.allclose(result, expected))
|
||||
|
||||
|
||||
class TestYarnGetMscale(TestBase):
|
||||
|
||||
def test_scale_less_than_or_equal_1(self):
|
||||
self.assertEqual(yarn_get_mscale(scale=0.5), 1.0)
|
||||
self.assertEqual(yarn_get_mscale(scale=1.0), 1.0)
|
||||
self.assertEqual(yarn_get_mscale(scale=0.999), 1.0)
|
||||
|
||||
def test_scale_greater_than_1(self):
|
||||
test_cases = [(2.0, 1.0, 1.0 + 0.1 * math.log(2.0)),
|
||||
(10.0, 1.0, 1.0 + 0.1 * math.log(10.0)),
|
||||
(5.0, 2.0, 1.0 + 0.2 * math.log(5.0)),
|
||||
(math.e, 1.0, 1.0 + 0.1)]
|
||||
|
||||
for scale, mscale, expected in test_cases:
|
||||
result = yarn_get_mscale(scale, mscale)
|
||||
self.assertAlmostEqual(
|
||||
result,
|
||||
expected,
|
||||
places=6,
|
||||
msg=f"Failed for scale={scale}, mscale={mscale}")
|
||||
@@ -1,296 +0,0 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.torchair.quantization.torchair_w4a8_dynamic import (
|
||||
TorchairAscendW4A8DynamicFusedMoEMethod,
|
||||
TorchairAscendW4A8DynamicLinearMethod)
|
||||
|
||||
|
||||
class TestAscendW4A8DynamicLinearMethod(TestBase):
|
||||
|
||||
@patch('vllm.distributed.get_tensor_model_parallel_world_size')
|
||||
@patch(
|
||||
'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_current_vllm_config'
|
||||
)
|
||||
def setUp(self, mock_get_current_vllm_config, mock_get_tp_world_size):
|
||||
mock_get_tp_world_size.return_value = 1
|
||||
mock_vllm_config = Mock()
|
||||
mock_vllm_config.quant_config = Mock(
|
||||
quant_description={"group_size": 256})
|
||||
mock_get_current_vllm_config.return_value = mock_vllm_config
|
||||
self.method = TorchairAscendW4A8DynamicLinearMethod()
|
||||
self.method.group_size = 8
|
||||
|
||||
def test_get_weight(self):
|
||||
weight = self.method.get_weight(8, 32, torch.bfloat16)
|
||||
self.assertEqual(weight["weight"].dtype, torch.int8)
|
||||
self.assertEqual(weight["weight"].shape, (32, 8))
|
||||
# new quant version weight
|
||||
self.method.new_quant_version = True
|
||||
weight = self.method.get_weight(8, 32, torch.bfloat16)
|
||||
self.assertEqual(weight["weight"].dtype, torch.int8)
|
||||
self.assertEqual(weight["weight"].shape, (16, 8))
|
||||
self.assertEqual(weight["_packed_dim"], 0)
|
||||
self.assertEqual(weight["_packed_factor"], 2)
|
||||
|
||||
def test_get_pergroup_param(self):
|
||||
params = self.method.get_pergroup_param(8, 32, torch.bfloat16)
|
||||
self.assertEqual(params["weight_scale"].dtype, torch.bfloat16)
|
||||
self.assertEqual(params["weight_scale"].shape, (32, 1))
|
||||
self.assertEqual(params["weight_offset"].dtype, torch.bfloat16)
|
||||
self.assertEqual(params["weight_offset"].shape, (32, 1))
|
||||
self.assertEqual(params["weight_scale_second"].dtype, torch.bfloat16)
|
||||
self.assertEqual(params["weight_scale_second"].shape, (32, 1))
|
||||
self.assertEqual(params["weight_offset_second"].dtype, torch.bfloat16)
|
||||
self.assertEqual(params["weight_offset_second"].shape, (32, 1))
|
||||
# new quant version weight
|
||||
self.method.new_quant_version = True
|
||||
params = self.method.get_pergroup_param(8,
|
||||
32,
|
||||
torch.bfloat16,
|
||||
layer_type="column")
|
||||
self.assertEqual(params["scale_bias"].dtype, torch.float32)
|
||||
self.assertEqual(params["scale_bias"].shape, (32, 1))
|
||||
params = self.method.get_pergroup_param(8,
|
||||
32,
|
||||
torch.bfloat16,
|
||||
layer_type="row")
|
||||
self.assertEqual(params["scale_bias"].dtype, torch.float32)
|
||||
self.assertEqual(params["scale_bias"].shape, (32, 16))
|
||||
|
||||
@patch('torch_npu.npu_convert_weight_to_int4pack')
|
||||
@patch('torch.Tensor.npu')
|
||||
def test_process_weights_after_loading(self, mock_npu,
|
||||
mock_npu_convert_weight):
|
||||
mock_npu.side_effect = lambda: torch.zeros(
|
||||
(1, 32), dtype=torch.float32)
|
||||
mock_npu_convert_weight.return_value = torch.zeros((32, 4),
|
||||
dtype=torch.int32)
|
||||
# old quant version weight
|
||||
layer = torch.nn.Module()
|
||||
layer.weight = torch.nn.Parameter(torch.zeros((32, 8),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.weight_scale = torch.nn.Parameter(torch.ones(
|
||||
(32, 1), dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.weight_offset = torch.nn.Parameter(torch.empty_like(
|
||||
layer.weight_scale.data),
|
||||
requires_grad=False)
|
||||
layer.weight_scale_second = torch.nn.Parameter(torch.ones(
|
||||
(32, 1), dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.weight_offset_second = torch.nn.Parameter(torch.empty_like(
|
||||
layer.weight_scale_second.data),
|
||||
requires_grad=False)
|
||||
self.method.process_weights_after_loading(layer)
|
||||
self.assertTrue(hasattr(layer, "weight_scale_bias"))
|
||||
self.assertEqual(layer.weight_scale_bias.data.shape, (32, ))
|
||||
self.assertEqual(layer.weight_scale_bias.data.dtype, torch.float32)
|
||||
# new quant version weight
|
||||
self.method.new_quant_version = True
|
||||
new_layer = torch.nn.Module()
|
||||
new_layer.weight = torch.nn.Parameter(torch.zeros((16, 8),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
new_layer.weight_scale = torch.nn.Parameter(torch.ones(
|
||||
(32, 1), dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
new_layer.weight_offset = torch.nn.Parameter(torch.empty_like(
|
||||
new_layer.weight_scale.data),
|
||||
requires_grad=False)
|
||||
new_layer.weight_scale_second = torch.nn.Parameter(torch.ones(
|
||||
(32, 1), dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
new_layer.weight_offset_second = torch.nn.Parameter(
|
||||
torch.empty_like(new_layer.weight_scale_second.data),
|
||||
requires_grad=False)
|
||||
new_layer.scale_bias = torch.nn.Parameter(torch.zeros(
|
||||
(32, 1), dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
self.method.process_weights_after_loading(new_layer)
|
||||
self.assertEqual(new_layer.scale_bias.data.shape, (32, ))
|
||||
self.assertTrue(hasattr(new_layer, "weight_scale_second"))
|
||||
self.assertEqual(new_layer.weight_scale_second.data.shape, (1, 32))
|
||||
|
||||
|
||||
class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
|
||||
experts = 8
|
||||
input_size = 16
|
||||
output_size = 56
|
||||
group_size = 2
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_current_vllm_config'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_ep_group')
|
||||
@patch(
|
||||
'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_ascend_config'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_mc2_group'
|
||||
)
|
||||
@patch('torch.distributed.get_rank', return_value=0)
|
||||
def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ascend_config,
|
||||
mock_get_ep_group, get_current_vllm_config):
|
||||
mock_ascend_config = Mock()
|
||||
mock_ascend_config.torchair_graph_config = Mock(enabled=False)
|
||||
mock_get_ascend_config.return_value = mock_ascend_config
|
||||
mock_vllm_config = Mock()
|
||||
mock_vllm_config.quant_config = Mock(quant_description={
|
||||
"group_size": self.group_size,
|
||||
"version": "0.0.0"
|
||||
})
|
||||
mock_vllm_config.parallel_config = Mock(enable_expert_parallel=True)
|
||||
get_current_vllm_config.return_value = mock_vllm_config
|
||||
self.quant_method = TorchairAscendW4A8DynamicFusedMoEMethod()
|
||||
|
||||
def test_get_weight(self):
|
||||
# old quant version w4a8 weight
|
||||
param_dict = self.quant_method.get_weight(self.experts,
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight"].dtype, torch.int8)
|
||||
self.assertEqual(param_dict["w13_weight"].shape,
|
||||
(self.experts, 2 * self.input_size, self.output_size))
|
||||
# new quant version weight
|
||||
self.quant_method.new_quant_version = True
|
||||
param_dict = self.quant_method.get_weight(self.experts,
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight"].dtype, torch.int8)
|
||||
self.assertEqual(param_dict["w13_weight"].shape,
|
||||
(self.experts, self.input_size, self.output_size))
|
||||
|
||||
def test_get_dynamic_quant_param(self):
|
||||
# old quant version weight
|
||||
param_dict = self.quant_method.get_dynamic_quant_param(
|
||||
self.experts, self.input_size, self.output_size, torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.float32)
|
||||
self.assertEqual(param_dict["w13_weight_scale"].shape,
|
||||
(self.experts, 2 * self.input_size, 1))
|
||||
self.assertEqual(param_dict["w13_weight_scale_second"].dtype,
|
||||
torch.float32)
|
||||
self.assertEqual(param_dict["w13_weight_scale_second"].shape,
|
||||
(self.experts, 2 * self.input_size,
|
||||
self.output_size // self.group_size))
|
||||
self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.float32)
|
||||
self.assertEqual(param_dict["w2_weight_scale"].shape,
|
||||
(self.experts, self.output_size, 1))
|
||||
self.assertEqual(param_dict["w2_weight_scale_second"].dtype,
|
||||
torch.float32)
|
||||
self.assertEqual(param_dict["w2_weight_scale_second"].shape,
|
||||
(self.experts, self.output_size,
|
||||
self.input_size // self.group_size))
|
||||
# new quant version weight
|
||||
self.quant_method.new_quant_version = True
|
||||
param_dict = self.quant_method.get_dynamic_quant_param(
|
||||
self.experts, self.input_size, self.output_size, torch.bfloat16)
|
||||
self.assertEqual(param_dict["w2_scale_bias"].dtype, torch.float32)
|
||||
self.assertEqual(
|
||||
param_dict["w2_scale_bias"].shape,
|
||||
(self.experts, self.output_size, 16 // self.quant_method.tp_size))
|
||||
# per-channel weight
|
||||
self.quant_method.is_per_channel_weight = True
|
||||
param_dict = self.quant_method.get_dynamic_quant_param(
|
||||
self.experts, self.input_size, self.output_size, torch.bfloat16)
|
||||
pergroup_param = [
|
||||
"w13_weight_scale_second", "w13_weight_offset_second",
|
||||
"w2_weight_scale_second", "w2_weight_offset_second"
|
||||
]
|
||||
is_contains = any(key in param_dict for key in pergroup_param)
|
||||
self.assertFalse(is_contains)
|
||||
|
||||
def build_layer(self,
|
||||
is_new_quant_version=True,
|
||||
is_per_channel_weight=False):
|
||||
layer = torch.nn.Module()
|
||||
if is_new_quant_version:
|
||||
layer.w13_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, self.input_size, self.output_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, self.output_size // 2, self.input_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
w13_scale_bias = torch.zeros(
|
||||
(self.experts, 2 * self.input_size, 1), dtype=torch.float32)
|
||||
layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias,
|
||||
requires_grad=False)
|
||||
w2_scale_bias = torch.zeros((self.experts, self.output_size,
|
||||
16 // self.quant_method.tp_size),
|
||||
dtype=torch.float32)
|
||||
layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias,
|
||||
requires_grad=False)
|
||||
else:
|
||||
layer.w13_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, 2 * self.input_size, self.output_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, self.output_size, self.input_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, 2 * self.input_size, 1), dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, self.output_size, 1), dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
if not is_per_channel_weight:
|
||||
layer.w13_weight_scale_second = torch.nn.Parameter(
|
||||
torch.ones((self.experts, 2 * self.input_size,
|
||||
self.output_size // self.group_size),
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.w13_weight_offset_second = torch.nn.Parameter(
|
||||
torch.empty_like(layer.w13_weight_scale_second.data),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale_second = torch.nn.Parameter(
|
||||
torch.ones((self.experts, self.output_size,
|
||||
self.input_size // self.group_size),
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_offset_second = torch.nn.Parameter(
|
||||
torch.empty_like(layer.w2_weight_scale_second.data),
|
||||
requires_grad=False)
|
||||
return layer
|
||||
|
||||
@patch('torch_npu.npu_quantize')
|
||||
@patch('torch.Tensor.npu')
|
||||
def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize):
|
||||
mock_npu.return_value = torch.Tensor()
|
||||
mock_npu_quantize.return_value = torch.Tensor()
|
||||
# old quant version weight
|
||||
layer = self.build_layer(is_new_quant_version=False)
|
||||
self.quant_method.process_weights_after_loading(layer)
|
||||
self.assertTrue(hasattr(layer, "w13_scale_bias"))
|
||||
self.assertEqual(layer.w13_scale_bias.data.shape,
|
||||
(self.experts, 2 * self.input_size))
|
||||
self.assertEqual(layer.w13_scale_bias.data.dtype, torch.float32)
|
||||
self.assertTrue(hasattr(layer, "w2_scale_bias"))
|
||||
self.assertEqual(layer.w2_scale_bias.data.shape,
|
||||
(self.experts, self.output_size))
|
||||
self.assertEqual(layer.w2_scale_bias.data.dtype, torch.float32)
|
||||
# new quant version weight
|
||||
self.quant_method.new_quant_version = True
|
||||
new_layer = self.build_layer(is_new_quant_version=True)
|
||||
self.quant_method.process_weights_after_loading(new_layer)
|
||||
self.assertEqual(new_layer.w13_scale_bias.data.shape,
|
||||
(self.experts, 2 * self.input_size))
|
||||
self.assertEqual(new_layer.w2_scale_bias.data.shape,
|
||||
(self.experts, self.output_size))
|
||||
self.assertFalse(hasattr(new_layer, "w13_weight_scale_second"))
|
||||
# per-channel weight
|
||||
self.quant_method.is_per_channel_weight = True
|
||||
per_channel_layer = self.build_layer(is_new_quant_version=True,
|
||||
is_per_channel_weight=True)
|
||||
self.quant_method.process_weights_after_loading(per_channel_layer)
|
||||
self.assertEqual(new_layer.w13_scale_bias.data.shape,
|
||||
(self.experts, 2 * self.input_size))
|
||||
@@ -1,129 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import (
|
||||
torchair_fused_experts_with_all2all, torchair_fused_experts_with_mc2)
|
||||
from vllm_ascend.utils import AscendDeviceType
|
||||
|
||||
|
||||
class TestAscendW8A8FusedMoEMethod(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.hidden_size = 128
|
||||
self.num_tokens = 128
|
||||
self.placeholder = torch.randn(self.num_tokens,
|
||||
self.hidden_size,
|
||||
dtype=torch.bfloat16)
|
||||
|
||||
@patch("torch.distributed.all_to_all_single")
|
||||
@patch("torch_npu.npu_moe_re_routing")
|
||||
@patch("torch_npu.npu_grouped_matmul")
|
||||
@patch("torch_npu.npu_swiglu")
|
||||
@patch("torch_npu.npu_dynamic_quant")
|
||||
@patch("torch_npu.npu_moe_finalize_routing")
|
||||
@patch("torch_npu.npu_moe_init_routing_quant")
|
||||
def test_torchair_fused_experts_with_all2all(
|
||||
self, mock_npu_moe_init_routing_quant, mock_moe_finalize_routing,
|
||||
mock_dynamic_quant, mock_swiglu, mock_grouped_matmul,
|
||||
mock_moe_re_routing, mock_all_to_all_single):
|
||||
|
||||
expert_map = MagicMock()
|
||||
ep_group = MagicMock()
|
||||
placeholder_int8 = torch.randint(0,
|
||||
100,
|
||||
(self.num_tokens, self.hidden_size),
|
||||
dtype=torch.int8)
|
||||
placeholder_ones = torch.ones(self.num_tokens, dtype=torch.int32)
|
||||
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
|
||||
input)
|
||||
mock_npu_moe_init_routing_quant.return_value = (
|
||||
placeholder_int8, placeholder_ones, placeholder_ones,
|
||||
torch.bincount(placeholder_ones, minlength=len(expert_map)),
|
||||
torch.randn(self.num_tokens))
|
||||
mock_moe_re_routing.return_value = (placeholder_int8, self.placeholder,
|
||||
torch.randint(0,
|
||||
100,
|
||||
(self.num_tokens, ),
|
||||
dtype=torch.int32),
|
||||
self.placeholder)
|
||||
mock_grouped_matmul.return_value = self.placeholder
|
||||
mock_swiglu.return_value = self.placeholder
|
||||
mock_dynamic_quant.return_value = (
|
||||
placeholder_int8,
|
||||
torch.randn(self.num_tokens),
|
||||
)
|
||||
mock_moe_finalize_routing.return_value = self.placeholder
|
||||
|
||||
result = torchair_fused_experts_with_all2all(
|
||||
hidden_states=self.placeholder,
|
||||
w1=self.placeholder,
|
||||
w1_scale=self.placeholder,
|
||||
w2=self.placeholder,
|
||||
w2_scale=self.placeholder,
|
||||
topk_weights=self.placeholder,
|
||||
topk_ids=self.placeholder,
|
||||
top_k=8,
|
||||
expert_map=expert_map,
|
||||
ep_group=ep_group,
|
||||
log2phy=None,
|
||||
global_redundant_expert_num=256,
|
||||
)
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
self.assertEqual(result.shape, (128, 128))
|
||||
|
||||
@patch.dict('os.environ', {
|
||||
'HCCL_INTRA_ROCE_ENABLE': '0',
|
||||
'HCCL_INTRA_PCIE_ENABLE': '1'
|
||||
})
|
||||
@patch(
|
||||
"vllm_ascend.torchair.quantization.torchair_w8a8_dynamic.get_ascend_device_type"
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.torchair.quantization.torchair_w8a8_dynamic.get_mc2_group'
|
||||
)
|
||||
@patch('torch_npu.npu_moe_distribute_combine_v2')
|
||||
@patch('torch_npu.npu_moe_distribute_dispatch_v2')
|
||||
@patch(
|
||||
'vllm_ascend.torchair.quantization.torchair_w8a8_dynamic.torchair_apply_mlp_decode'
|
||||
)
|
||||
def test_torchair_fused_experts_with_mc2_a2_optimization(
|
||||
self, mock_mlp_decode, mock_dispatch, mock_combine, mock_get_group,
|
||||
mock_ascend_soc_version):
|
||||
"""Test expert_scales is passed in A2 SOC version with mc2 optimization"""
|
||||
# Setup mocks
|
||||
mock_ascend_soc_version.return_value = AscendDeviceType._910B
|
||||
|
||||
mock_group = MagicMock()
|
||||
mock_group.rank_in_group = 0
|
||||
mock_group.world_size = 4
|
||||
mock_get_group.return_value = mock_group
|
||||
|
||||
mock_combine.return_value = self.placeholder
|
||||
|
||||
mock_dispatch.return_value = (torch.randn(32, 1024), torch.randn(1),
|
||||
torch.randint(0, 32, (32, )),
|
||||
torch.randint(1, 5, (8, )),
|
||||
torch.randint(1, 5, (4, )), None,
|
||||
torch.randn(32))
|
||||
mock_mlp_decode.return_value = self.placeholder
|
||||
|
||||
result = torchair_fused_experts_with_mc2(
|
||||
hidden_states=self.placeholder,
|
||||
w1=self.placeholder,
|
||||
w2=self.placeholder,
|
||||
w1_scale=self.placeholder,
|
||||
w2_scale=self.placeholder,
|
||||
topk_weights=self.placeholder,
|
||||
topk_ids=self.placeholder,
|
||||
top_k=2,
|
||||
mc2_mask=self.placeholder)
|
||||
|
||||
# Check that expert_scales was passed to dispatch
|
||||
call_args = mock_dispatch.call_args[1]
|
||||
self.assertIn('expert_scales', call_args)
|
||||
|
||||
self.assertIsInstance(result, torch.Tensor)
|
||||
self.assertEqual(result.shape, self.placeholder.shape)
|
||||
@@ -1,95 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.distributed.parallel_state import GroupCoordinator
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.torchair.torchair_attention import \
|
||||
AscendAttentionTorchairBackendImpl
|
||||
|
||||
|
||||
class TestAscendAttentionTorchairBackendImpl(TestBase):
|
||||
|
||||
@patch("torch.zeros")
|
||||
@patch('vllm.distributed.parallel_state._TP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator)) # TODO
|
||||
@patch("vllm.distributed.get_tensor_model_parallel_world_size",
|
||||
return_value=2) # TODO
|
||||
@patch("vllm.config.get_current_vllm_config") # TODO
|
||||
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config") # TODO
|
||||
def setUp(self, ascend_config, vllm_config, mock_get_tp_size, mock_tp,
|
||||
mock_zeros):
|
||||
mock_tp.world_size = 2 # TODO
|
||||
ascend_config.torchair_graph_config.enabled = True # TODO
|
||||
ascend_config.torchair_graph_config.enable_kv_nz = False # TODO
|
||||
speculative_config = MagicMock()
|
||||
speculative_config.num_speculative_tokens = 4
|
||||
vllm_config.speculative_config = speculative_config
|
||||
|
||||
num_heads = 32
|
||||
head_size = 128 # TODO
|
||||
scale = 0.1 # TODO
|
||||
num_kv_heads = 4
|
||||
kv_cache_dtype = "auto"
|
||||
attn_type = AttentionType.DECODER
|
||||
mock_zeros.return_value = torch.ones((),
|
||||
device='cpu',
|
||||
dtype=torch.int32)
|
||||
|
||||
self.impl = AscendAttentionTorchairBackendImpl(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=num_kv_heads,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
blocksparse_params=None,
|
||||
logits_soft_cap=None,
|
||||
attn_type=attn_type,
|
||||
kv_sharing_target_layer_name=None)
|
||||
|
||||
@patch("torch_npu.npu_scatter_nd_update_")
|
||||
@patch("torch_npu.npu_fused_infer_attention_score")
|
||||
def test_forward_with_decode_only(self, mock_fused, _):
|
||||
layer = MagicMock()
|
||||
layer._k_scale_float = 1.0
|
||||
layer._v_scale_float = 1.0
|
||||
|
||||
seq_len = 1
|
||||
num_tokens = 100
|
||||
num_blocks = 256
|
||||
block_size = 4
|
||||
|
||||
query = torch.randn(num_tokens, seq_len,
|
||||
self.impl.num_heads * self.impl.head_size)
|
||||
key = torch.randn(num_tokens, seq_len,
|
||||
self.impl.num_kv_heads * self.impl.head_size)
|
||||
value = torch.randn(num_tokens, seq_len,
|
||||
self.impl.num_kv_heads * self.impl.head_size)
|
||||
kv_cache = (torch.randn(num_blocks, block_size,
|
||||
self.impl.num_heads * self.impl.head_size),
|
||||
torch.randn(num_blocks, block_size,
|
||||
self.impl.num_heads * self.impl.head_size))
|
||||
output = torch.randn(num_tokens, self.impl.num_heads,
|
||||
self.impl.head_size)
|
||||
|
||||
decode = MagicMock() # TODO
|
||||
decode.seq_lens_list = [2] * num_tokens
|
||||
decode.block_table = torch.ones(num_tokens, 8, dtype=torch.int32)
|
||||
decode.attn_mask = None
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.attn_state = AscendAttentionState.DecodeOnly
|
||||
metadata.slot_mapping = torch.arange(num_tokens, dtype=torch.int32)
|
||||
metadata.decode = decode
|
||||
|
||||
mock_fused.return_value = (torch.ones(num_tokens, self.impl.num_heads,
|
||||
self.impl.head_size),
|
||||
torch.ones(1))
|
||||
|
||||
result = self.impl.forward(layer, query, key, value, kv_cache,
|
||||
metadata, output)
|
||||
self.assertEqual(result.shape[0], num_tokens)
|
||||
@@ -1,887 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
from vllm.distributed.parallel_state import GroupCoordinator
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.torchair.torchair_mla import (
|
||||
AscendMLATorchairBackend, AscendMLATorchairDecodeMetadata,
|
||||
AscendMLATorchairImpl, AscendMLATorchairMetadata,
|
||||
AscendMLATorchairMetadataBuilder, AscendMLATorchairPrefillMetadata)
|
||||
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
|
||||
|
||||
|
||||
class TestAscendMLATorchairBackend(TestBase):
|
||||
|
||||
def test_get_name(self):
|
||||
self.assertEqual(AscendMLATorchairBackend.get_name(),
|
||||
"ASCEND_MLA_TORCHAIR")
|
||||
|
||||
def test_get_builder_cls(self):
|
||||
self.assertEqual(AscendMLATorchairBackend.get_builder_cls(),
|
||||
AscendMLATorchairMetadataBuilder)
|
||||
|
||||
def test_get_kv_cache_shape(self):
|
||||
result = AscendMLATorchairBackend.get_kv_cache_shape(2, 4, 8, 128)
|
||||
self.assertEqual(result, (2, 4, 8, 128))
|
||||
|
||||
def test_get_impl_cls(self):
|
||||
result = AscendMLATorchairBackend.get_impl_cls()
|
||||
self.assertEqual(result, AscendMLATorchairImpl)
|
||||
|
||||
|
||||
class TestAscendMLATorchairPrefillMetadata(TestBase):
|
||||
|
||||
def test_ascend_mla_prefill_metadata_default(self):
|
||||
attn_mask = torch.tensor([[1, 0], [1, 1]], dtype=torch.bool)
|
||||
query_lens = [1, 2]
|
||||
seq_lens = [2, 2]
|
||||
context_lens = torch.tensor([1, 2])
|
||||
input_positions = torch.tensor([0, 1, 0, 1])
|
||||
query_start_loc = torch.tensor([0, 1, 3])
|
||||
block_table = torch.tensor([[0, 1], [2, 3]])
|
||||
max_query_len = 2
|
||||
max_seq_lens = 2
|
||||
|
||||
metadata = AscendMLATorchairPrefillMetadata(
|
||||
attn_mask=attn_mask,
|
||||
query_lens=query_lens,
|
||||
seq_lens=seq_lens,
|
||||
context_lens=context_lens,
|
||||
input_positions=input_positions,
|
||||
query_start_loc=query_start_loc,
|
||||
block_table=block_table,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_lens=max_seq_lens)
|
||||
self.assertIs(metadata.attn_mask, attn_mask)
|
||||
self.assertEqual(metadata.query_lens, query_lens)
|
||||
self.assertEqual(metadata.seq_lens, seq_lens)
|
||||
self.assertIs(metadata.context_lens, context_lens)
|
||||
self.assertIs(metadata.input_positions, input_positions)
|
||||
self.assertIs(metadata.query_start_loc, query_start_loc)
|
||||
self.assertIs(metadata.block_table, block_table)
|
||||
self.assertEqual(metadata.max_query_len, max_query_len)
|
||||
self.assertEqual(metadata.max_seq_lens, max_seq_lens)
|
||||
self.assertIsNone(metadata.chunked_context)
|
||||
|
||||
def test_ascend_mla_prefill_metadata_with_chunked_context(self):
|
||||
cu_seq_lens = torch.tensor([0, 2, 4])
|
||||
starts = torch.tensor([0, 2])
|
||||
seq_tot = [2, 2]
|
||||
max_seq_lens = [2, 2]
|
||||
workspace = torch.randn(2, 4)
|
||||
chunk_seq_lens = torch.tensor([2, 2])
|
||||
|
||||
chunked_context = AscendMLATorchairPrefillMetadata.TorchairChunkedContextMetadata(
|
||||
cu_seq_lens=cu_seq_lens,
|
||||
starts=starts,
|
||||
seq_tot=seq_tot,
|
||||
max_seq_lens=max_seq_lens,
|
||||
workspace=workspace,
|
||||
chunk_seq_lens=chunk_seq_lens,
|
||||
chunk_seq_lens_npu=chunk_seq_lens)
|
||||
|
||||
metadata = AscendMLATorchairPrefillMetadata(
|
||||
attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool),
|
||||
query_lens=[1, 2],
|
||||
seq_lens=[2, 2],
|
||||
context_lens=torch.tensor([1, 2]),
|
||||
input_positions=torch.tensor([0, 1, 0, 1]),
|
||||
query_start_loc=torch.tensor([0, 1, 3]),
|
||||
block_table=torch.tensor([[0, 1], [2, 3]]),
|
||||
max_query_len=2,
|
||||
max_seq_lens=2,
|
||||
chunked_context=chunked_context)
|
||||
|
||||
self.assertIsNotNone(metadata.chunked_context)
|
||||
self.assertIs(metadata.chunked_context.cu_seq_lens, cu_seq_lens)
|
||||
self.assertIs(metadata.chunked_context.starts, starts)
|
||||
self.assertEqual(metadata.chunked_context.seq_tot, seq_tot)
|
||||
self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens)
|
||||
self.assertIs(metadata.chunked_context.workspace, workspace)
|
||||
self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens)
|
||||
self.assertIs(metadata.chunked_context.chunk_seq_lens_npu,
|
||||
chunk_seq_lens)
|
||||
|
||||
|
||||
class TestAscendMLATorchairDecodeMetadata(TestBase):
|
||||
|
||||
def test_ascend_mla_decode_metadata_default(self):
|
||||
input_positions = torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]])
|
||||
block_table = torch.tensor([[0, 3, 2, 1], [0, 2, 1, 3]])
|
||||
seq_lens = torch.tensor([[2], [3]])
|
||||
max_seq_lens = 4
|
||||
seq_lens_list = [2, 3]
|
||||
attn_mask = None
|
||||
|
||||
metadata = AscendMLATorchairDecodeMetadata(input_positions,
|
||||
block_table, seq_lens,
|
||||
max_seq_lens, seq_lens_list,
|
||||
attn_mask)
|
||||
|
||||
self.assertIs(metadata.input_positions, input_positions)
|
||||
self.assertIs(metadata.block_table, block_table)
|
||||
self.assertIs(metadata.seq_lens, seq_lens)
|
||||
self.assertEqual(metadata.max_seq_lens, max_seq_lens)
|
||||
self.assertEqual(metadata.seq_lens_list, seq_lens_list)
|
||||
self.assertIsNone(attn_mask)
|
||||
|
||||
|
||||
class TestAscendMLATorchairMetadata(TestBase):
|
||||
|
||||
def test_ascend_mla_metadata_default(self):
|
||||
num_actual_tokens = 100
|
||||
slot_mapping = torch.randn(100, 4, 1024)
|
||||
query_start_loc = torch.tensor([1, 2, 3, 4])
|
||||
seq_lens = [30, 50]
|
||||
block_tables = torch.randint(0, 100, (100, 4))
|
||||
|
||||
num_decodes = 4
|
||||
num_decode_tokens = 8
|
||||
num_prefills = 8
|
||||
|
||||
num_input_tokens = 2
|
||||
|
||||
query_lens = None
|
||||
head_dim = None
|
||||
attn_mask = None
|
||||
attn_state = AscendAttentionState.ChunkedPrefill
|
||||
|
||||
decode = None
|
||||
prefill = None
|
||||
|
||||
metadata = AscendMLATorchairMetadata(
|
||||
num_actual_tokens, slot_mapping, query_start_loc, seq_lens,
|
||||
block_tables, num_decodes, num_decode_tokens, num_prefills,
|
||||
num_input_tokens, query_lens, head_dim, attn_mask, attn_state,
|
||||
decode, prefill)
|
||||
|
||||
self.assertEqual(metadata.num_actual_tokens, num_actual_tokens)
|
||||
self.assertIs(metadata.slot_mapping, slot_mapping)
|
||||
self.assertIs(metadata.query_start_loc, query_start_loc)
|
||||
self.assertEqual(metadata.seq_lens, seq_lens)
|
||||
self.assertIs(metadata.block_tables, block_tables)
|
||||
self.assertEqual(metadata.num_decodes, num_decodes)
|
||||
self.assertEqual(metadata.num_decode_tokens, num_decode_tokens)
|
||||
self.assertEqual(metadata.num_prefills, num_prefills)
|
||||
self.assertEqual(metadata.num_input_tokens, num_input_tokens)
|
||||
self.assertEqual(metadata.query_lens, query_lens)
|
||||
self.assertEqual(metadata.head_dim, head_dim)
|
||||
self.assertEqual(metadata.attn_mask, attn_mask)
|
||||
self.assertEqual(metadata.attn_state, attn_state)
|
||||
self.assertEqual(metadata.decode, decode)
|
||||
self.assertEqual(metadata.prefill, prefill)
|
||||
|
||||
|
||||
class TestAscendMLATorchairMetadataBuilder(TestBase):
|
||||
|
||||
def test_ascend_mla_metadata_builder_default(self):
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.max_model_len = 1024
|
||||
mock_model_config.get_head_size.return_value = 64
|
||||
mock_model_config.dtype = torch.float16
|
||||
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config = mock_model_config
|
||||
mock_vllm_config.cache_config = MagicMock(block_size=16)
|
||||
mock_vllm_config.scheduler_config = MagicMock(
|
||||
max_num_seqs=4, enable_chunked_prefill=False)
|
||||
mock_vllm_config.speculative_config = None
|
||||
|
||||
mock_device = torch.device('cpu')
|
||||
ascend_config = MagicMock()
|
||||
ascend_config.torchair_graph_config = MagicMock()
|
||||
ascend_config.torchair_graph_config.enabled = True
|
||||
with patch("vllm_ascend.torchair.torchair_mla.get_ascend_config",
|
||||
return_value=ascend_config):
|
||||
builder = AscendMLATorchairMetadataBuilder(None, None,
|
||||
mock_vllm_config,
|
||||
mock_device)
|
||||
|
||||
self.assertEqual(builder.block_size,
|
||||
mock_vllm_config.cache_config.block_size)
|
||||
self.assertEqual(
|
||||
builder.chunked_prefill_enabled,
|
||||
mock_vllm_config.scheduler_config.enable_chunked_prefill)
|
||||
self.assertEqual(builder.torchair_graph_enabled, True)
|
||||
|
||||
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config")
|
||||
def test_reorder_batch_with_torchair_graph(self, ascend_config):
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.max_model_len = 1024
|
||||
mock_model_config.get_head_size.return_value = 64
|
||||
mock_model_config.dtype = torch.float16
|
||||
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config = mock_model_config
|
||||
mock_vllm_config.cache_config = MagicMock(block_size=16)
|
||||
mock_vllm_config.scheduler_config = MagicMock(
|
||||
max_num_seqs=4, enable_chunked_prefill=False)
|
||||
mock_vllm_config.speculative_config = None
|
||||
|
||||
mock_device = torch.device('cpu')
|
||||
|
||||
builder = AscendMLATorchairMetadataBuilder(None, None,
|
||||
mock_vllm_config,
|
||||
mock_device)
|
||||
|
||||
input_batch = MagicMock()
|
||||
input_batch.req_ids = [0, 1, 2, 3]
|
||||
|
||||
scheduler_output = MagicMock()
|
||||
scheduler_output.num_scheduled_tokens = {0: 2, 1: 1, 2: 3, 3: 1}
|
||||
scheduler_output.scheduled_spec_decode_tokens = {
|
||||
0: [1],
|
||||
1: [],
|
||||
2: [1, 1],
|
||||
3: []
|
||||
}
|
||||
|
||||
input_batch.swap_states = MagicMock()
|
||||
|
||||
modified = builder.reorder_batch(input_batch, scheduler_output)
|
||||
|
||||
self.assertFalse(modified)
|
||||
input_batch.swap_states.assert_not_called()
|
||||
|
||||
def test_reorder_batch_without_torchair_graph(self):
|
||||
ascend_config = MagicMock()
|
||||
ascend_config.torchair_graph_config = MagicMock()
|
||||
ascend_config.torchair_graph_config.enabled = False
|
||||
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.max_model_len = 1024
|
||||
mock_model_config.get_head_size.return_value = 64
|
||||
mock_model_config.dtype = torch.float16
|
||||
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config = mock_model_config
|
||||
mock_vllm_config.cache_config = MagicMock(block_size=16)
|
||||
mock_vllm_config.scheduler_config = MagicMock(
|
||||
max_num_seqs=4, enable_chunked_prefill=False)
|
||||
mock_vllm_config.speculative_config = None
|
||||
|
||||
mock_device = torch.device('cpu')
|
||||
|
||||
with patch("vllm_ascend.torchair.torchair_mla.get_ascend_config",
|
||||
return_value=ascend_config):
|
||||
builder = AscendMLATorchairMetadataBuilder(None, None,
|
||||
mock_vllm_config,
|
||||
mock_device)
|
||||
|
||||
input_batch = MagicMock()
|
||||
input_batch.req_ids = [0, 1, 2, 3]
|
||||
|
||||
scheduler_output = MagicMock()
|
||||
scheduler_output.num_scheduled_tokens = {0: 1, 1: 3, 2: 1, 3: 2}
|
||||
scheduler_output.scheduled_spec_decode_tokens = {
|
||||
0: [],
|
||||
1: [1],
|
||||
2: [],
|
||||
3: []
|
||||
}
|
||||
|
||||
input_batch.swap_states = MagicMock()
|
||||
|
||||
modified = builder.reorder_batch(input_batch, scheduler_output)
|
||||
|
||||
self.assertTrue(modified)
|
||||
input_batch.swap_states.assert_called_once_with(1, 2)
|
||||
|
||||
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config")
|
||||
def test_get_graph_runner_block_tables_normal(self, mock_ascend_config):
|
||||
ascend_config = MagicMock()
|
||||
mock_ascend_config.return_value = ascend_config
|
||||
ascend_config.torchair_graph_config.enabled = False
|
||||
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.max_model_len = 1024
|
||||
mock_model_config.get_head_size.return_value = 64
|
||||
mock_model_config.dtype = torch.float16
|
||||
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config = mock_model_config
|
||||
mock_vllm_config.cache_config = MagicMock(block_size=16)
|
||||
mock_vllm_config.scheduler_config = MagicMock(
|
||||
max_num_seqs=4, enable_chunked_prefill=False)
|
||||
mock_vllm_config.speculative_config = None
|
||||
|
||||
mock_device = torch.device('cpu')
|
||||
|
||||
builder = AscendMLATorchairMetadataBuilder(None, None,
|
||||
mock_vllm_config,
|
||||
mock_device)
|
||||
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
|
||||
|
||||
result = builder._get_graph_runner_block_tables(3, block_tables)
|
||||
self.assertEqual(result.shape[0], 3)
|
||||
self.assertEqual(result.shape[1], 64)
|
||||
self.assertTrue(torch.equal(result[:, :10], block_tables))
|
||||
|
||||
@pytest.mark.skip(reason="Skipping this test temporarily.")
|
||||
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config")
|
||||
def test_get_graph_runner_block_tables_truncated(self, mock_ascend_config):
|
||||
ascend_config = MagicMock()
|
||||
mock_ascend_config.return_value = ascend_config
|
||||
ascend_config.torchair_graph_config.enabled = False
|
||||
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.get_head_size.return_value = 64
|
||||
mock_model_config.dtype = torch.float16
|
||||
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config = mock_model_config
|
||||
mock_vllm_config.cache_config = MagicMock(block_size=16)
|
||||
mock_vllm_config.scheduler_config = MagicMock(
|
||||
enable_chunked_prefill=False)
|
||||
mock_vllm_config.speculative_config = None
|
||||
|
||||
mock_device = torch.device('cpu')
|
||||
|
||||
builder = AscendMLATorchairMetadataBuilder(None, None,
|
||||
mock_vllm_config,
|
||||
mock_device)
|
||||
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
|
||||
|
||||
result = builder._get_graph_runner_block_tables(3, block_tables)
|
||||
self.assertEqual(result.shape[0], 3)
|
||||
self.assertEqual(result.shape[1], 4)
|
||||
self.assertTrue(torch.equal(result, block_tables[:, :4]))
|
||||
|
||||
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config")
|
||||
def test_get_graph_runner_block_tables_from_numpy(self,
|
||||
mock_ascend_config):
|
||||
ascend_config = MagicMock()
|
||||
mock_ascend_config.return_value = ascend_config
|
||||
ascend_config.torchair_graph_config.enabled = False
|
||||
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.max_model_len = 1024
|
||||
mock_model_config.get_head_size.return_value = 64
|
||||
mock_model_config.dtype = torch.float16
|
||||
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config = mock_model_config
|
||||
mock_vllm_config.cache_config = MagicMock(block_size=16)
|
||||
mock_vllm_config.scheduler_config = MagicMock(
|
||||
max_num_seqs=4, enable_chunked_prefill=False)
|
||||
mock_vllm_config.speculative_config = None
|
||||
|
||||
mock_device = torch.device('cpu')
|
||||
|
||||
builder = AscendMLATorchairMetadataBuilder(None, None,
|
||||
mock_vllm_config,
|
||||
mock_device)
|
||||
|
||||
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
|
||||
|
||||
result = builder._get_graph_runner_block_tables(3, block_tables)
|
||||
|
||||
self.assertEqual(result.shape[0], 3)
|
||||
self.assertEqual(result.shape[1], 64)
|
||||
self.assertTrue(torch.equal(result[:, :10], block_tables))
|
||||
|
||||
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config")
|
||||
def test_build_dummy(self, mock_ascend_config):
|
||||
ascend_config = MagicMock()
|
||||
mock_ascend_config.return_value = ascend_config
|
||||
ascend_config.torchair_graph_config.enabled = False
|
||||
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.max_model_len = 1024
|
||||
mock_model_config.get_head_size.return_value = 64
|
||||
mock_model_config.dtype = torch.float16
|
||||
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config = mock_model_config
|
||||
mock_vllm_config.cache_config = MagicMock(block_size=16)
|
||||
mock_vllm_config.scheduler_config = MagicMock(
|
||||
max_num_seqs=4, enable_chunked_prefill=False)
|
||||
mock_vllm_config.speculative_config = None
|
||||
|
||||
mock_device = torch.device('cpu')
|
||||
|
||||
builder = AscendMLATorchairMetadataBuilder(
|
||||
None,
|
||||
None,
|
||||
mock_vllm_config,
|
||||
mock_device,
|
||||
metadata_cls=AscendMLATorchairMetadata)
|
||||
builder.rope_dim = 64
|
||||
|
||||
with patch.object(builder,
|
||||
"_get_graph_runner_block_tables",
|
||||
side_effect=lambda x, y: y):
|
||||
common_attn_metadata = TorchairCommonAttentionMetadata(
|
||||
num_reqs=3,
|
||||
num_actual_tokens=3,
|
||||
decode_token_per_req=1,
|
||||
actual_seq_lengths_q=[0, 1, 2],
|
||||
attn_mask=torch.zeros((1, 1), dtype=torch.bool),
|
||||
spec_attn_mask=torch.zeros((1, 1), dtype=torch.bool),
|
||||
)
|
||||
metadata = builder.build_torchair_graph_dummy(common_attn_metadata)
|
||||
|
||||
sin_golden = torch.ones(3,
|
||||
1,
|
||||
1,
|
||||
64,
|
||||
dtype=torch.float16,
|
||||
device=mock_device)
|
||||
cos_golden = torch.ones(3,
|
||||
1,
|
||||
1,
|
||||
64,
|
||||
dtype=torch.float16,
|
||||
device=mock_device)
|
||||
|
||||
self.assertIsInstance(metadata, AscendMLATorchairMetadata)
|
||||
self.assertEqual(metadata.num_input_tokens, 3)
|
||||
self.assertEqual(metadata.num_actual_tokens, 3)
|
||||
self.assertEqual(metadata.num_decodes, 1)
|
||||
self.assertEqual(metadata.num_decode_tokens, 1)
|
||||
self.assertEqual(metadata.num_prefills, 0)
|
||||
self.assertEqual(metadata.attn_state, AscendAttentionState.DecodeOnly)
|
||||
self.assertIsNone(metadata.prefill)
|
||||
self.assertIsInstance(metadata.decode, AscendMLATorchairDecodeMetadata)
|
||||
self.assertEqual(metadata.block_tables.shape[0], 3)
|
||||
self.assertEqual(metadata.block_tables.shape[1], 64)
|
||||
self.assertEqual(metadata.seq_lens.shape[0], 3)
|
||||
self.assertEqual(metadata.slot_mapping.shape[0], 3)
|
||||
self.assertEqual(metadata.query_start_loc.shape[0], 3)
|
||||
assert torch.equal(sin_golden, metadata.decode.sin)
|
||||
assert torch.equal(cos_golden, metadata.decode.cos)
|
||||
|
||||
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config")
|
||||
def test_build_decode(self, mock_ascend_config):
|
||||
ascend_config = MagicMock()
|
||||
mock_ascend_config.return_value = ascend_config
|
||||
ascend_config.torchair_graph_config.enabled = False
|
||||
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.max_model_len = 1024
|
||||
mock_model_config.get_head_size.return_value = 64
|
||||
mock_model_config.dtype = torch.float16
|
||||
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config = mock_model_config
|
||||
mock_vllm_config.cache_config = MagicMock(block_size=16)
|
||||
mock_vllm_config.scheduler_config = MagicMock(
|
||||
max_num_seqs=4, enable_chunked_prefill=False)
|
||||
mock_vllm_config.speculative_config = None
|
||||
|
||||
mock_device = torch.device('cpu')
|
||||
|
||||
model = MagicMock(spec=nn.Module)
|
||||
model.model = MagicMock(spec=nn.Module)
|
||||
|
||||
builder = AscendMLATorchairMetadataBuilder(
|
||||
None,
|
||||
None,
|
||||
mock_vllm_config,
|
||||
mock_device,
|
||||
metadata_cls=AscendMLATorchairMetadata)
|
||||
builder.rope_dim = 64
|
||||
|
||||
builder.sin_cache = torch.tensor([10, 10])
|
||||
builder.cos_cache = torch.tensor([10, 10])
|
||||
|
||||
with patch.object(builder,
|
||||
"_get_graph_runner_block_tables",
|
||||
side_effect=lambda x, y: y):
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=torch.tensor([0, 1, 2, 3]),
|
||||
query_start_loc_cpu=torch.tensor([0, 1, 2, 3]),
|
||||
seq_lens_cpu=torch.tensor([1, 1, 1]),
|
||||
num_reqs=3,
|
||||
num_actual_tokens=3,
|
||||
max_query_len=1,
|
||||
decode_token_per_req=torch.tensor([1, 1, 1]),
|
||||
block_table_tensor=torch.zeros((10, 10)),
|
||||
slot_mapping=torch.tensor(range(20)),
|
||||
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
|
||||
positions=torch.tensor([1, 1]),
|
||||
attn_mask=torch.ones((15, 15)),
|
||||
spec_attn_mask=None,
|
||||
attn_state=AscendAttentionState.ChunkedPrefill,
|
||||
num_computed_tokens_cpu=None,
|
||||
seq_lens=None)
|
||||
|
||||
metadata = builder.build(1, common_attn_metadata, model)
|
||||
|
||||
self.assertIsInstance(metadata, AscendMLATorchairMetadata)
|
||||
self.assertEqual(metadata.num_input_tokens, 0)
|
||||
self.assertEqual(metadata.num_actual_tokens, 3)
|
||||
self.assertEqual(metadata.num_decodes, 3)
|
||||
self.assertEqual(metadata.num_decode_tokens, 3)
|
||||
self.assertEqual(metadata.num_prefills, 0)
|
||||
self.assertEqual(metadata.attn_state,
|
||||
AscendAttentionState.ChunkedPrefill)
|
||||
self.assertIsNone(metadata.prefill)
|
||||
self.assertIsInstance(metadata.decode, AscendMLATorchairDecodeMetadata)
|
||||
self.assertEqual(metadata.block_tables.shape[0], 3)
|
||||
self.assertEqual(metadata.block_tables.shape[1], 10)
|
||||
self.assertEqual(metadata.seq_lens.shape[0], 3)
|
||||
self.assertEqual(metadata.slot_mapping.shape[0], 3)
|
||||
self.assertEqual(metadata.query_start_loc.shape[0], 4)
|
||||
|
||||
|
||||
class TestAscendMLATorchairImpl(TestBase):
|
||||
|
||||
@patch('vllm.distributed.parallel_state._TP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("vllm.distributed.get_tensor_model_parallel_world_size",
|
||||
return_value=2)
|
||||
@patch("vllm.config.get_current_vllm_config")
|
||||
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config")
|
||||
def setUp(self, ascend_config, vllm_config, mock_get_tp_size, mock_tp):
|
||||
mock_tp.world_size = 2
|
||||
ascend_config.torchair_graph_config.enabled = True
|
||||
ascend_config.torchair_graph_config.enable_kv_nz = False
|
||||
speculative_config = MagicMock()
|
||||
speculative_config.num_speculative_tokens = 4
|
||||
vllm_config.speculative_config = speculative_config
|
||||
|
||||
num_heads = 256
|
||||
head_size = 1024
|
||||
scale = 0.1
|
||||
num_kv_heads = 8
|
||||
kv_cache_dtype = "auto"
|
||||
|
||||
kv_a_layernorm = MagicMock()
|
||||
kv_a_layernorm.weight = torch.randn(96)
|
||||
kv_a_layernorm.variance_epsilon = 1e-6
|
||||
kwargs = {
|
||||
"q_lora_rank": 64,
|
||||
"kv_lora_rank": 32,
|
||||
"qk_nope_head_dim": 64,
|
||||
"qk_rope_head_dim": 32,
|
||||
"qk_head_dim": 96,
|
||||
"v_head_dim": 128,
|
||||
"rotary_emb": MagicMock(),
|
||||
"q_proj": MagicMock(),
|
||||
"q_b_proj": MagicMock(),
|
||||
"kv_b_proj": MagicMock(),
|
||||
"o_proj": MagicMock(),
|
||||
"kv_a_proj_with_mqa": MagicMock(),
|
||||
"kv_a_layernorm": kv_a_layernorm,
|
||||
}
|
||||
|
||||
self.impl = AscendMLATorchairImpl(num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=num_kv_heads,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
blocksparse_params=None,
|
||||
logits_soft_cap=None,
|
||||
attn_type=None,
|
||||
kv_sharing_target_layer_name=None,
|
||||
**kwargs)
|
||||
|
||||
def test_init(self):
|
||||
self.assertEqual(self.impl.num_heads, 256)
|
||||
self.assertEqual(self.impl.head_size, 1024)
|
||||
self.assertEqual(self.impl.scale, 0.1)
|
||||
self.assertEqual(self.impl.num_kv_heads, 8)
|
||||
self.assertEqual(self.impl.kv_cache_dtype, "auto")
|
||||
self.assertEqual(self.impl.q_lora_rank, 64)
|
||||
self.assertEqual(self.impl.kv_lora_rank, 32)
|
||||
self.assertEqual(self.impl.qk_nope_head_dim, 64)
|
||||
self.assertEqual(self.impl.qk_rope_head_dim, 32)
|
||||
self.assertEqual(self.impl.qk_head_dim, 96)
|
||||
self.assertEqual(self.impl.v_head_dim, 128)
|
||||
self.assertIsNotNone(self.impl.rotary_emb)
|
||||
self.assertIsNotNone(self.impl.q_proj)
|
||||
self.assertIsNotNone(self.impl.kv_b_proj)
|
||||
self.assertIsNotNone(self.impl.o_proj)
|
||||
self.assertIsNotNone(self.impl.kv_a_proj_with_mqa)
|
||||
self.assertIsNotNone(self.impl.kv_a_layernorm)
|
||||
self.assertEqual(self.impl.num_queries_per_kv, 32)
|
||||
self.assertEqual(self.impl.tp_size, 2)
|
||||
self.assertTrue(self.impl.torchair_graph_enabled)
|
||||
|
||||
def test_v_up_proj_and_o_proj(self):
|
||||
batch_size = 4
|
||||
x = torch.randn(batch_size, self.impl.num_heads,
|
||||
self.impl.kv_lora_rank)
|
||||
|
||||
self.impl.o_proj.return_value = (torch.randn(
|
||||
batch_size, self.impl.num_heads * self.impl.v_head_dim), )
|
||||
if not hasattr(self.impl, 'W_UV') or self.impl.W_UV is None:
|
||||
self.impl.W_UV = torch.randn(self.impl.num_heads,
|
||||
self.impl.kv_lora_rank,
|
||||
self.impl.v_head_dim)
|
||||
result = self.impl._v_up_proj_and_o_proj(x)
|
||||
|
||||
self.assertEqual(result.shape[0], batch_size)
|
||||
self.assertEqual(result.shape[1],
|
||||
self.impl.num_heads * self.impl.v_head_dim)
|
||||
|
||||
def test_q_proj_and_k_up_proj(self):
|
||||
batch_size = 4
|
||||
x = torch.randn(batch_size, self.impl.num_heads, self.impl.qk_head_dim)
|
||||
q_proj_output = torch.randn(batch_size, self.impl.num_heads,
|
||||
self.impl.qk_head_dim)
|
||||
self.impl.q_proj.return_value = (q_proj_output, )
|
||||
if not hasattr(self.impl, 'W_UK_T') or self.impl.W_UK_T is None:
|
||||
self.impl.W_UK_T = torch.randn(self.impl.num_heads,
|
||||
self.impl.qk_nope_head_dim,
|
||||
self.impl.kv_lora_rank)
|
||||
result = self.impl._q_proj_and_k_up_proj(x)
|
||||
ql_nope, q_pe = result
|
||||
self.assertEqual(ql_nope.shape[0], batch_size)
|
||||
self.assertEqual(ql_nope.shape[1], self.impl.num_heads)
|
||||
self.assertEqual(ql_nope.shape[2], self.impl.kv_lora_rank)
|
||||
self.assertEqual(q_pe.shape[0], batch_size)
|
||||
self.assertEqual(q_pe.shape[1], self.impl.num_heads)
|
||||
self.assertEqual(q_pe.shape[2], self.impl.qk_rope_head_dim)
|
||||
|
||||
def test_process_weights_after_loading(self):
|
||||
layer = MagicMock(spec=LinearBase)
|
||||
layer.input_size_per_partition = 10
|
||||
quant_method = MagicMock()
|
||||
apply = MagicMock()
|
||||
quant_method.apply = apply
|
||||
layer.quant_method = quant_method
|
||||
shape_0 = self.impl.num_heads * (self.impl.qk_nope_head_dim +
|
||||
self.impl.v_head_dim)
|
||||
shape_1 = self.impl.kv_lora_rank
|
||||
layer.weight = torch.randn(shape_0, shape_1)
|
||||
self.impl.kv_b_proj = layer
|
||||
apply.return_value = layer.weight.T
|
||||
self.impl.process_weights_after_loading(torch.bfloat16)
|
||||
|
||||
self.assertEqual(self.impl.W_UK_T.shape[0], self.impl.num_heads)
|
||||
self.assertEqual(self.impl.W_UK_T.shape[1], self.impl.qk_nope_head_dim)
|
||||
self.assertEqual(self.impl.W_UK_T.shape[2], self.impl.kv_lora_rank)
|
||||
|
||||
self.assertEqual(self.impl.W_UV.shape[0], self.impl.num_heads)
|
||||
self.assertEqual(self.impl.W_UV.shape[1], self.impl.kv_lora_rank)
|
||||
self.assertEqual(self.impl.W_UV.shape[2], self.impl.v_head_dim)
|
||||
|
||||
def test_compute_prefill_context_none(self):
|
||||
batch_size = 4
|
||||
kv_cache = torch.randn(10, 1, 1, 192)
|
||||
query = torch.randn(batch_size, self.impl.num_heads,
|
||||
self.impl.qk_head_dim)
|
||||
metadata = MagicMock()
|
||||
metadata.prefill = None
|
||||
prefix_out = torch.randn(2, 16, 128)
|
||||
prefix_lse = torch.randn(2, 16, 8)
|
||||
out, lse = self.impl._compute_prefill_context(query, kv_cache, 32,
|
||||
metadata, prefix_out,
|
||||
prefix_lse)
|
||||
|
||||
self.assertTrue(torch.equal(prefix_out, out))
|
||||
self.assertTrue(torch.equal(prefix_lse, lse))
|
||||
|
||||
@patch("torch_npu.atb.npu_paged_cache_load")
|
||||
@patch("torch_npu.atb.npu_ring_mla")
|
||||
def test_compute_prefill_context(self, mock_ring, mock_load):
|
||||
S, N, D, VD = 2, self.impl.num_heads, self.impl.qk_head_dim, self.impl.v_head_dim
|
||||
_, AND = self.impl.qk_rope_head_dim, self.impl.qk_nope_head_dim
|
||||
latent_kv_dim = self.impl.kv_lora_rank
|
||||
num_blocks, block_size = 100, 20
|
||||
query = torch.randn(S, N, D)
|
||||
kv_cache_0 = torch.randn(num_blocks, block_size, N, latent_kv_dim)
|
||||
kv_cache_1 = torch.randn(num_blocks, block_size, N, D)
|
||||
kv_cache = [kv_cache_0, kv_cache_1]
|
||||
prefix_out = torch.randn(S, N, 128)
|
||||
prefix_lse = torch.randn(S, N)
|
||||
|
||||
self.impl.kv_b_proj.return_value = (torch.randn(8, N, VD + AND), )
|
||||
|
||||
chunk_ctx = MagicMock()
|
||||
chunk_ctx.seq_tot = [8]
|
||||
chunk_ctx.chunk_seq_lens = [torch.tensor([8])]
|
||||
chunk_ctx.chunk_seq_lens_npu = [torch.tensor([8])]
|
||||
chunk_ctx.starts = [torch.tensor([0])]
|
||||
|
||||
prefill_meta = MagicMock()
|
||||
prefill_meta.chunked_context = chunk_ctx
|
||||
prefill_meta.query_lens = [8]
|
||||
prefill_meta.block_table = torch.randint(0, 100, (S, 4))
|
||||
|
||||
meta = MagicMock()
|
||||
meta.prefill = prefill_meta
|
||||
|
||||
out, lse = self.impl._compute_prefill_context(query, kv_cache, 32,
|
||||
meta, prefix_out,
|
||||
prefix_lse)
|
||||
|
||||
mock_load.assert_called_once()
|
||||
mock_ring.assert_called_once()
|
||||
|
||||
self.assertEqual(out.shape, prefix_out.shape)
|
||||
self.assertEqual(lse.shape, prefix_lse.shape)
|
||||
|
||||
@patch("torch_npu.npu_kv_rmsnorm_rope_cache")
|
||||
def test_exec_kv(self, mock_kv_cache):
|
||||
batch_size = 2
|
||||
hidden = torch.randn(batch_size, 128)
|
||||
cos = torch.randn(batch_size, 32)
|
||||
sin = torch.randn(batch_size, 32)
|
||||
kv_cache = (torch.randn(
|
||||
4, 8, self.impl.kv_lora_rank + self.impl.qk_rope_head_dim),
|
||||
torch.randn(
|
||||
4, 8,
|
||||
self.impl.kv_lora_rank + self.impl.qk_rope_head_dim))
|
||||
slots = torch.arange(batch_size, dtype=torch.long)
|
||||
|
||||
proj_out = torch.randn(
|
||||
batch_size, self.impl.num_kv_heads, 1,
|
||||
self.impl.kv_lora_rank + self.impl.qk_rope_head_dim)
|
||||
self.impl.kv_a_proj_with_mqa.return_value = (proj_out, )
|
||||
|
||||
mock_kv_cache.return_value = (torch.randn(batch_size,
|
||||
self.impl.num_kv_heads, 1,
|
||||
self.impl.qk_rope_head_dim),
|
||||
torch.randn(batch_size,
|
||||
self.impl.num_kv_heads, 1,
|
||||
self.impl.kv_lora_rank),
|
||||
None, None)
|
||||
|
||||
k_pe, k_nope, kv = self.impl.exec_kv(hidden, cos, sin, kv_cache, slots)
|
||||
|
||||
self.impl.kv_a_proj_with_mqa.assert_called_once_with(hidden)
|
||||
mock_kv_cache.assert_called_once()
|
||||
self.assertEqual(k_pe.shape, (batch_size, self.impl.num_kv_heads, 1,
|
||||
self.impl.qk_rope_head_dim))
|
||||
self.assertEqual(
|
||||
k_nope.shape,
|
||||
(batch_size, self.impl.num_kv_heads, 1, self.impl.kv_lora_rank))
|
||||
self.assertEqual(kv.shape,
|
||||
(batch_size, self.impl.num_kv_heads, 1,
|
||||
self.impl.kv_lora_rank + self.impl.qk_rope_head_dim))
|
||||
|
||||
@patch("torch_npu.npu_kv_rmsnorm_rope_cache")
|
||||
def test_exec_kv_prefill(self, mock_kv):
|
||||
B, N, S, H = 2, self.impl.num_kv_heads, 1, 128
|
||||
hidden_states = torch.randn(B, N, S, H)
|
||||
cos = torch.randn(B, S, 32)
|
||||
sin = torch.randn(B, S, 32)
|
||||
kv_cache = (
|
||||
torch.randn(100, 8,
|
||||
self.impl.kv_lora_rank + self.impl.qk_rope_head_dim),
|
||||
torch.randn(100, 8,
|
||||
self.impl.kv_lora_rank + self.impl.qk_rope_head_dim),
|
||||
)
|
||||
|
||||
slots = torch.arange(B * S, dtype=torch.long)
|
||||
|
||||
proj_out = torch.randn(
|
||||
B, N, S, self.impl.kv_lora_rank + self.impl.qk_rope_head_dim)
|
||||
self.impl.kv_a_proj_with_mqa.return_value = (proj_out, )
|
||||
|
||||
mock_kv.return_value = (None, None,
|
||||
torch.randn(B, self.impl.num_kv_heads, S,
|
||||
self.impl.qk_rope_head_dim),
|
||||
torch.randn(B, self.impl.num_kv_heads, S,
|
||||
self.impl.kv_lora_rank))
|
||||
|
||||
k_pe, k_nope = self.impl.exec_kv_prefill(hidden_states, cos, sin,
|
||||
kv_cache, slots)
|
||||
|
||||
self.impl.kv_a_proj_with_mqa.assert_called_once_with(hidden_states)
|
||||
mock_kv.assert_called_once()
|
||||
|
||||
self.assertEqual(
|
||||
k_pe.shape,
|
||||
(B, self.impl.num_kv_heads, S, self.impl.qk_rope_head_dim))
|
||||
self.assertEqual(
|
||||
k_nope.shape,
|
||||
(B, self.impl.num_kv_heads, S, self.impl.kv_lora_rank))
|
||||
|
||||
@patch("torch_npu.npu_interleave_rope")
|
||||
def test_rope_single(self, mock_rope):
|
||||
B, N, D = 2, 16, 1024
|
||||
x = torch.randn(B, N, D)
|
||||
cos = torch.randn(B, N, 1, D)
|
||||
sin = torch.randn(B, N, 1, D)
|
||||
mock_rope.return_value = x.view(B, N, 1, D)
|
||||
result = self.impl.rope_single(x, cos, sin)
|
||||
self.assertEqual(result.shape[0], B)
|
||||
self.assertEqual(result.shape[1], N)
|
||||
self.assertEqual(result.shape[2], D)
|
||||
mock_rope.assert_called_once()
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.torchair.torchair_mla.AscendMLATorchairImpl._v_up_proj_and_o_proj"
|
||||
)
|
||||
@patch("torch_npu._npu_paged_attention_mla")
|
||||
def test_forward_decode_without_graph(self, mock_page_attention_mla,
|
||||
mock_up_proj):
|
||||
self.impl.running_in_graph = False
|
||||
self.impl.running_chunkprefilll_with_torchair = False
|
||||
num_tokens = 100
|
||||
num_blocks = 256
|
||||
block_size = 4
|
||||
q_nope = torch.randn(num_tokens, self.impl.num_heads,
|
||||
self.impl.qk_nope_head_dim)
|
||||
q_pe = torch.randn(num_tokens, self.impl.num_heads,
|
||||
self.impl.qk_rope_head_dim)
|
||||
kv_c_and_k_pe_cache = torch.randn(num_blocks, block_size,
|
||||
self.impl.num_heads,
|
||||
self.impl.kv_lora_rank)
|
||||
metadata = MagicMock()
|
||||
metadata.decode = MagicMock()
|
||||
metadata.decode.block_table = MagicMock()
|
||||
metadata.decode.seq_lens = 10
|
||||
mock_page_attention_mla.return_value = torch.randn(
|
||||
num_tokens, self.impl.num_heads, self.impl.kv_lora_rank)
|
||||
mock_up_proj.return_value = torch.randn(num_tokens,
|
||||
self.impl.num_heads,
|
||||
self.impl.v_head_dim)
|
||||
result = self.impl._forward_decode(q_nope, q_pe, None, None,
|
||||
kv_c_and_k_pe_cache, metadata)
|
||||
self.assertEqual(result.shape[0], num_tokens)
|
||||
self.assertEqual(result.shape[1], self.impl.num_heads)
|
||||
self.assertEqual(result.shape[2], self.impl.v_head_dim)
|
||||
mock_up_proj.assert_called_once()
|
||||
mock_page_attention_mla.assert_called_once()
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.torchair.torchair_mla.AscendMLATorchairImpl._forward_prefill"
|
||||
)
|
||||
@patch("torch_npu._npu_reshape_and_cache")
|
||||
def test_forward_without_graph(self, _, mock_forward_prefill):
|
||||
self.impl.running_in_graph = False
|
||||
self.impl.torchair_graph_enabled = False
|
||||
|
||||
num_tokens = 100
|
||||
num_blocks = 256
|
||||
block_size = 4
|
||||
rotary_emb_return_value = (torch.randn(num_tokens, 16,
|
||||
self.impl.kv_lora_rank),
|
||||
torch.randn(0, 1, self.impl.kv_lora_rank))
|
||||
self.impl.rotary_emb.side_effect = lambda *args, **kwargs: rotary_emb_return_value
|
||||
self.impl.o_proj.side_effect = lambda *args, **kwargs: torch.randn(
|
||||
1, num_blocks, 128)
|
||||
|
||||
hidden_states_or_q_c = torch.randn(num_tokens, self.impl.q_lora_rank)
|
||||
hidden_states_or_kv_c_normed = torch.randn(num_tokens,
|
||||
self.impl.kv_lora_rank)
|
||||
k_pe = torch.randn(num_tokens, self.impl.qk_rope_head_dim)
|
||||
kv_cache = (torch.randn(num_blocks, block_size, self.impl.num_heads,
|
||||
self.impl.kv_lora_rank),
|
||||
torch.randn(num_blocks, block_size, self.impl.num_heads,
|
||||
self.impl.qk_rope_head_dim))
|
||||
output = torch.randn(num_tokens, self.impl.num_heads,
|
||||
self.impl.v_head_dim)
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.num_decodes = 0
|
||||
metadata.num_prefills = num_tokens
|
||||
mock_forward_prefill.return_value = torch.randn(
|
||||
0, self.impl.num_heads * self.impl.v_head_dim)
|
||||
result = self.impl.forward(None, hidden_states_or_q_c,
|
||||
hidden_states_or_kv_c_normed, k_pe,
|
||||
kv_cache, metadata, output, False)
|
||||
self.assertEqual(result.shape[0], num_tokens)
|
||||
@@ -1,45 +0,0 @@
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from pytest_mock import MockerFixture
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
from tests.ut.base import PytestBase
|
||||
from vllm_ascend.torchair.torchair_model_runner import NPUTorchairModelRunner
|
||||
|
||||
|
||||
class TestNPUTorchairModelRunner(PytestBase):
|
||||
|
||||
@pytest.fixture
|
||||
def setup_npu_torchair_model_runner(self, mocker: MockerFixture):
|
||||
mocker.patch.object(NPUTorchairModelRunner, "__init__",
|
||||
lambda self, *args, **kwargs: None)
|
||||
runner = NPUTorchairModelRunner(Mock(), Mock())
|
||||
|
||||
runner.device = torch.device("cpu")
|
||||
runner.vllm_config = MagicMock(spec=VllmConfig)
|
||||
|
||||
runner.speculative_config = MagicMock(
|
||||
method="mtp",
|
||||
num_speculative_tokens=4,
|
||||
disable_padded_drafter_batch=False)
|
||||
|
||||
runner.ascend_config = MagicMock(enable_shared_expert_dp=False,
|
||||
torchair_graph_config=MagicMock(
|
||||
use_cached_graph=True,
|
||||
graph_batch_sizes=[1, 2, 4]))
|
||||
|
||||
runner.decode_token_per_req = 2
|
||||
runner.is_kv_consumer = True
|
||||
runner.max_num_reqs = 100
|
||||
|
||||
runner.model_config = MagicMock(hf_config=MagicMock(index_topk=2))
|
||||
runner.attn_backend = MagicMock(get_builder_cls=lambda: Mock())
|
||||
|
||||
return runner
|
||||
|
||||
def test_init(self, mocker: MockerFixture,
|
||||
setup_npu_torchair_model_runner):
|
||||
runner = setup_npu_torchair_model_runner
|
||||
assert isinstance(runner, NPUTorchairModelRunner)
|
||||
@@ -1,78 +0,0 @@
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from pytest_mock import MockerFixture
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
|
||||
from tests.ut.base import PytestBase
|
||||
from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer
|
||||
|
||||
|
||||
class TestTorchairMtpProposer(PytestBase):
|
||||
|
||||
@pytest.fixture
|
||||
def setup_torchair_mtp_proposer(self, mocker: MockerFixture):
|
||||
vllm_config = MagicMock(spec=VllmConfig)
|
||||
vllm_config.device_config = MagicMock()
|
||||
vllm_config.device_config.device = torch.device("cpu")
|
||||
vllm_config.speculative_config = MagicMock()
|
||||
vllm_config.speculative_config.draft_model_config = MagicMock()
|
||||
vllm_config.speculative_config.draft_model_config.dtype = torch.float16
|
||||
vllm_config.speculative_config.method = "mtp"
|
||||
vllm_config.speculative_config.num_speculative_tokens = 5
|
||||
vllm_config.load_config = MagicMock()
|
||||
cache_config = CacheConfig(block_size=16)
|
||||
vllm_config.cache_config = cache_config
|
||||
vllm_config.scheduler_config = MagicMock(max_num_batched_tokens=1024,
|
||||
max_num_seqs=64)
|
||||
|
||||
device = torch.device("cpu")
|
||||
runner = MagicMock()
|
||||
runner.pcp_size = 1
|
||||
runner.dcp_size = 1
|
||||
runner.pcp_rank = 0
|
||||
runner.max_num_tokens = 1024
|
||||
runner.max_num_reqs = 10
|
||||
runner._use_aclgraph.return_value = True
|
||||
|
||||
mocker.patch(
|
||||
"vllm_ascend.torchair.torchair_mtp_proposer.MtpProposer.__init__",
|
||||
return_value=None)
|
||||
mock_set_default_dtype = mocker.patch(
|
||||
'vllm.utils.torch_utils.set_default_torch_dtype')
|
||||
mock_set_default_dtype.return_value.__enter__.return_value = None
|
||||
|
||||
mock_model_loader = MagicMock()
|
||||
mocker.patch("vllm.model_executor.model_loader.get_model_loader",
|
||||
return_value=mock_model_loader)
|
||||
mock_layers = {
|
||||
"target_attn_layer_1": Mock(),
|
||||
"draft_attn_layer_2": Mock()
|
||||
}
|
||||
mocker.patch("vllm.config.get_layers_from_vllm_config",
|
||||
return_value=mock_layers)
|
||||
mock_set_current = mocker.patch("vllm.config.set_current_vllm_config")
|
||||
mock_set_current.return_value.__enter__.return_value = None
|
||||
mock_torchair_deepseek_mtp = MagicMock()
|
||||
mock_torchair_deepseek_mtp.to.return_value = mock_torchair_deepseek_mtp
|
||||
mocker.patch(
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMTP",
|
||||
return_value=mock_torchair_deepseek_mtp)
|
||||
mocker.patch(
|
||||
"vllm.model_executor.model_loader.utils.process_weights_after_loading"
|
||||
)
|
||||
|
||||
proposer = TorchairMtpProposer(vllm_config, device, runner)
|
||||
proposer.vllm_config = vllm_config
|
||||
proposer.device = device
|
||||
proposer.runner = runner
|
||||
proposer.speculative_config = vllm_config.speculative_config
|
||||
proposer.draft_model_config = vllm_config.speculative_config.draft_model_config
|
||||
proposer.method = vllm_config.speculative_config.method
|
||||
|
||||
return proposer, mock_model_loader, mock_torchair_deepseek_mtp
|
||||
|
||||
def test_init(self, setup_torchair_mtp_proposer):
|
||||
proposer, _, _, = setup_torchair_mtp_proposer
|
||||
assert isinstance(proposer, TorchairMtpProposer)
|
||||
@@ -1,340 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.torchair.torchair_sfa import (
|
||||
AscendSFATorchairBackend, AscendSFATorchairDecodeMetadata,
|
||||
AscendSFATorchairImpl, AscendSFATorchairMetadata,
|
||||
AscendSFATorchairMetadataBuilder, AscendSFATorchairPrefillMetadata)
|
||||
|
||||
|
||||
class TestAscendSFATorchairBackend(TestBase):
|
||||
|
||||
def test_get_name(self):
|
||||
self.assertEqual(AscendSFATorchairBackend.get_name(),
|
||||
"ASCEND_SFA_TORCHAIR")
|
||||
|
||||
def test_get_builder_cls(self):
|
||||
self.assertEqual(AscendSFATorchairBackend.get_builder_cls(),
|
||||
AscendSFATorchairMetadataBuilder)
|
||||
|
||||
def test_get_kv_cache_shape(self):
|
||||
result = AscendSFATorchairBackend.get_kv_cache_shape(2, 4, 8, 128)
|
||||
self.assertEqual(result, (2, 4, 8, 128))
|
||||
|
||||
def test_get_impl_cls(self):
|
||||
result = AscendSFATorchairBackend.get_impl_cls()
|
||||
self.assertEqual(result, AscendSFATorchairImpl)
|
||||
|
||||
|
||||
class TestAscendSFATorchairPrefillMetadata(TestBase):
|
||||
|
||||
def test_ascend_sfa_prefill_metadata_default(self):
|
||||
attn_mask = torch.tensor([[1, 0], [1, 1]], dtype=torch.bool)
|
||||
query_lens = [1, 2]
|
||||
seq_lens = [2, 2]
|
||||
context_lens = torch.tensor([1, 2])
|
||||
input_positions = torch.tensor([0, 1, 0, 1])
|
||||
query_start_loc = torch.tensor([0, 1, 3])
|
||||
block_table = torch.tensor([[0, 1], [2, 3]])
|
||||
max_query_len = 2
|
||||
max_seq_lens = 2
|
||||
|
||||
metadata = AscendSFATorchairPrefillMetadata(
|
||||
attn_mask=attn_mask,
|
||||
query_lens=query_lens,
|
||||
seq_lens=seq_lens,
|
||||
context_lens=context_lens,
|
||||
input_positions=input_positions,
|
||||
query_start_loc=query_start_loc,
|
||||
block_table=block_table,
|
||||
max_query_len=max_query_len,
|
||||
sin=None,
|
||||
cos=None,
|
||||
max_seq_lens=max_seq_lens)
|
||||
self.assertIs(metadata.attn_mask, attn_mask)
|
||||
self.assertEqual(metadata.query_lens, query_lens)
|
||||
self.assertEqual(metadata.seq_lens, seq_lens)
|
||||
self.assertIs(metadata.context_lens, context_lens)
|
||||
self.assertIs(metadata.input_positions, input_positions)
|
||||
self.assertIs(metadata.query_start_loc, query_start_loc)
|
||||
self.assertIs(metadata.block_table, block_table)
|
||||
self.assertEqual(metadata.max_query_len, max_query_len)
|
||||
self.assertEqual(metadata.max_seq_lens, max_seq_lens)
|
||||
self.assertIsNone(metadata.chunked_context)
|
||||
|
||||
def test_ascend_sfa_prefill_metadata_with_chunked_context(self):
|
||||
cu_seq_lens = torch.tensor([0, 2, 4])
|
||||
starts = torch.tensor([0, 2])
|
||||
seq_tot = [2, 2]
|
||||
max_seq_lens = [2, 2]
|
||||
workspace = torch.randn(2, 4)
|
||||
chunk_seq_lens = torch.tensor([2, 2])
|
||||
|
||||
chunked_context = AscendSFATorchairPrefillMetadata.TorchairChunkedContextMetadata(
|
||||
cu_seq_lens=cu_seq_lens,
|
||||
starts=starts,
|
||||
seq_tot=seq_tot,
|
||||
max_seq_lens=max_seq_lens,
|
||||
workspace=workspace,
|
||||
chunk_seq_lens=chunk_seq_lens)
|
||||
|
||||
metadata = AscendSFATorchairPrefillMetadata(
|
||||
attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool),
|
||||
query_lens=[1, 2],
|
||||
seq_lens=[2, 2],
|
||||
context_lens=torch.tensor([1, 2]),
|
||||
input_positions=torch.tensor([0, 1, 0, 1]),
|
||||
query_start_loc=torch.tensor([0, 1, 3]),
|
||||
block_table=torch.tensor([[0, 1], [2, 3]]),
|
||||
max_query_len=2,
|
||||
max_seq_lens=2,
|
||||
sin=None,
|
||||
cos=None,
|
||||
chunked_context=chunked_context)
|
||||
|
||||
self.assertIsNotNone(metadata.chunked_context)
|
||||
self.assertIs(metadata.chunked_context.cu_seq_lens, cu_seq_lens)
|
||||
self.assertIs(metadata.chunked_context.starts, starts)
|
||||
self.assertEqual(metadata.chunked_context.seq_tot, seq_tot)
|
||||
self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens)
|
||||
self.assertIs(metadata.chunked_context.workspace, workspace)
|
||||
self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens)
|
||||
|
||||
|
||||
class TestAscendSFATorchairDecodeMetadata(TestBase):
|
||||
|
||||
def test_ascend_sfa_decode_metadata_default(self):
|
||||
input_positions = torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]])
|
||||
block_table = torch.tensor([[0, 3, 2, 1], [0, 2, 1, 3]])
|
||||
seq_lens = torch.tensor([[2], [3]])
|
||||
max_seq_lens = 4
|
||||
seq_lens_list = [2, 3]
|
||||
attn_mask = None
|
||||
|
||||
metadata = AscendSFATorchairDecodeMetadata(input_positions,
|
||||
block_table, seq_lens,
|
||||
max_seq_lens, seq_lens_list,
|
||||
None, None, attn_mask)
|
||||
|
||||
self.assertIs(metadata.input_positions, input_positions)
|
||||
self.assertIs(metadata.block_table, block_table)
|
||||
self.assertIs(metadata.seq_lens, seq_lens)
|
||||
self.assertEqual(metadata.max_seq_lens, max_seq_lens)
|
||||
self.assertEqual(metadata.seq_lens_list, seq_lens_list)
|
||||
self.assertIsNone(attn_mask)
|
||||
|
||||
|
||||
class TestAscendSFATorchairMetadata(TestBase):
|
||||
|
||||
def test_ascend_sfa_metadata_default(self):
|
||||
num_actual_tokens = 100
|
||||
slot_mapping = torch.randn(100, 4, 1024)
|
||||
query_start_loc = torch.tensor([1, 2, 3, 4])
|
||||
seq_lens = [30, 50]
|
||||
block_tables = torch.randint(0, 100, (100, 4))
|
||||
|
||||
num_decodes = 4
|
||||
num_decode_tokens = 8
|
||||
num_prefills = 8
|
||||
|
||||
num_input_tokens = 2
|
||||
|
||||
query_lens = None
|
||||
head_dim = None
|
||||
attn_mask = None
|
||||
attn_state = AscendAttentionState.ChunkedPrefill
|
||||
|
||||
decode = None
|
||||
prefill = None
|
||||
|
||||
metadata = AscendSFATorchairMetadata(
|
||||
num_actual_tokens, slot_mapping, query_start_loc, seq_lens,
|
||||
block_tables, num_decodes, num_decode_tokens, num_prefills,
|
||||
num_input_tokens, query_lens, head_dim, attn_mask, attn_state,
|
||||
decode, prefill)
|
||||
|
||||
self.assertEqual(metadata.num_actual_tokens, num_actual_tokens)
|
||||
self.assertIs(metadata.slot_mapping, slot_mapping)
|
||||
self.assertIs(metadata.query_start_loc, query_start_loc)
|
||||
self.assertEqual(metadata.seq_lens, seq_lens)
|
||||
self.assertIs(metadata.block_tables, block_tables)
|
||||
self.assertEqual(metadata.num_decodes, num_decodes)
|
||||
self.assertEqual(metadata.num_decode_tokens, num_decode_tokens)
|
||||
self.assertEqual(metadata.num_prefills, num_prefills)
|
||||
self.assertEqual(metadata.num_input_tokens, num_input_tokens)
|
||||
self.assertEqual(metadata.query_lens, query_lens)
|
||||
self.assertEqual(metadata.head_dim, head_dim)
|
||||
self.assertEqual(metadata.attn_mask, attn_mask)
|
||||
self.assertEqual(metadata.attn_state, attn_state)
|
||||
self.assertEqual(metadata.decode, decode)
|
||||
self.assertEqual(metadata.prefill, prefill)
|
||||
|
||||
|
||||
class TestAscendSFATorchairMetadataBuilder(TestBase):
|
||||
|
||||
def test_ascend_sfa_metadata_builder_default(self):
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.max_model_len = 1024
|
||||
mock_model_config.get_head_size.return_value = 64
|
||||
mock_model_config.dtype = torch.float16
|
||||
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config = mock_model_config
|
||||
mock_vllm_config.cache_config = MagicMock(block_size=16)
|
||||
mock_vllm_config.scheduler_config = MagicMock(
|
||||
max_num_seqs=4, enable_chunked_prefill=False)
|
||||
mock_vllm_config.speculative_config = None
|
||||
|
||||
mock_device = torch.device('cpu')
|
||||
ascend_config = MagicMock()
|
||||
ascend_config.torchair_graph_config = MagicMock()
|
||||
ascend_config.torchair_graph_config.enabled = True
|
||||
with patch("vllm_ascend.torchair.torchair_sfa.get_ascend_config",
|
||||
return_value=ascend_config):
|
||||
builder = AscendSFATorchairMetadataBuilder(None, None,
|
||||
mock_vllm_config,
|
||||
mock_device)
|
||||
|
||||
self.assertEqual(builder.block_size,
|
||||
mock_vllm_config.cache_config.block_size)
|
||||
self.assertEqual(
|
||||
builder.chunked_prefill_enabled,
|
||||
mock_vllm_config.scheduler_config.enable_chunked_prefill)
|
||||
self.assertEqual(builder.torchair_graph_enabled, True)
|
||||
self.assertEqual(builder.max_blocks, (mock_vllm_config.model_config.max_model_len +
|
||||
mock_vllm_config.cache_config.block_size - 1) \
|
||||
// mock_vllm_config.cache_config.block_size)
|
||||
|
||||
@patch("vllm_ascend.torchair.torchair_sfa.get_ascend_config")
|
||||
def test_reorder_batch_with_torchair_graph(self, ascend_config):
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.max_model_len = 1024
|
||||
mock_model_config.get_head_size.return_value = 64
|
||||
mock_model_config.dtype = torch.float16
|
||||
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config = mock_model_config
|
||||
mock_vllm_config.cache_config = MagicMock(block_size=16)
|
||||
mock_vllm_config.scheduler_config = MagicMock(
|
||||
max_num_seqs=4, enable_chunked_prefill=False)
|
||||
mock_vllm_config.speculative_config = None
|
||||
|
||||
mock_device = torch.device('cpu')
|
||||
ascend_config.torchair_graph_config = MagicMock()
|
||||
ascend_config.torchair_graph_config.enabled = True
|
||||
|
||||
builder = AscendSFATorchairMetadataBuilder(None, None,
|
||||
mock_vllm_config,
|
||||
mock_device)
|
||||
|
||||
input_batch = MagicMock()
|
||||
input_batch.req_ids = [0, 1, 2, 3]
|
||||
|
||||
scheduler_output = MagicMock()
|
||||
scheduler_output.num_scheduled_tokens = {0: 2, 1: 1, 2: 3, 3: 1}
|
||||
scheduler_output.scheduled_spec_decode_tokens = {
|
||||
0: [1],
|
||||
1: [],
|
||||
2: [1, 1],
|
||||
3: []
|
||||
}
|
||||
|
||||
input_batch.swap_states = MagicMock()
|
||||
|
||||
modified = builder.reorder_batch(input_batch, scheduler_output)
|
||||
|
||||
self.assertFalse(modified)
|
||||
input_batch.swap_states.assert_not_called()
|
||||
|
||||
@patch("vllm_ascend.torchair.torchair_sfa.get_ascend_config")
|
||||
def test_get_graph_runner_block_tables_normal(self, mock_ascend_config):
|
||||
ascend_config = MagicMock()
|
||||
mock_ascend_config.return_value = ascend_config
|
||||
ascend_config.torchair_graph_config.enabled = False
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.max_model_len = 1024
|
||||
mock_model_config.get_head_size.return_value = 64
|
||||
mock_model_config.dtype = torch.float16
|
||||
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config = mock_model_config
|
||||
mock_vllm_config.cache_config = MagicMock(block_size=16)
|
||||
mock_vllm_config.scheduler_config = MagicMock(
|
||||
max_num_seqs=4, enable_chunked_prefill=False)
|
||||
mock_vllm_config.speculative_config = None
|
||||
mock_device = torch.device('cpu')
|
||||
|
||||
builder = AscendSFATorchairMetadataBuilder(None, None,
|
||||
mock_vllm_config,
|
||||
mock_device)
|
||||
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
|
||||
|
||||
result = builder._get_graph_runner_block_tables(3, block_tables)
|
||||
self.assertEqual(result.shape[0], 3)
|
||||
self.assertEqual(result.shape[1], 64)
|
||||
self.assertTrue(torch.equal(result[:, :10], block_tables))
|
||||
|
||||
@patch("vllm_ascend.torchair.torchair_sfa.get_ascend_config")
|
||||
def test_ge_graph_runner_block_tables_truncated(self, mock_ascend_config):
|
||||
ascend_config = MagicMock()
|
||||
mock_ascend_config.return_value = ascend_config
|
||||
ascend_config.torchair_graph_config.enabled = False
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.max_model_len = 1024
|
||||
mock_model_config.get_head_size.return_value = 64
|
||||
mock_model_config.dtype = torch.float16
|
||||
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config = mock_model_config
|
||||
mock_vllm_config.cache_config = MagicMock(block_size=16)
|
||||
mock_vllm_config.scheduler_config = MagicMock(
|
||||
max_num_seqs=4, enable_chunked_prefill=False)
|
||||
mock_vllm_config.speculative_config = None
|
||||
|
||||
mock_device = torch.device('cpu')
|
||||
|
||||
builder = AscendSFATorchairMetadataBuilder(None, None,
|
||||
mock_vllm_config,
|
||||
mock_device)
|
||||
|
||||
builder.max_blocks = 4
|
||||
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
|
||||
|
||||
result = builder._get_graph_runner_block_tables(3, block_tables)
|
||||
self.assertEqual(result.shape[0], 3)
|
||||
self.assertEqual(result.shape[1], 4)
|
||||
self.assertTrue(torch.equal(result, block_tables[:, :4]))
|
||||
|
||||
@patch("vllm_ascend.torchair.torchair_sfa.get_ascend_config")
|
||||
def test_get_graph_runner_block_tables_from_numpy(self,
|
||||
mock_ascend_config):
|
||||
ascend_config = MagicMock()
|
||||
mock_ascend_config.return_value = ascend_config
|
||||
ascend_config.torchair_graph_config.enabled = False
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.max_model_len = 1024
|
||||
mock_model_config.get_head_size.return_value = 64
|
||||
mock_model_config.dtype = torch.float16
|
||||
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config = mock_model_config
|
||||
mock_vllm_config.cache_config = MagicMock(block_size=16)
|
||||
mock_vllm_config.scheduler_config = MagicMock(
|
||||
max_num_seqs=4, enable_chunked_prefill=False)
|
||||
mock_vllm_config.speculative_config = None
|
||||
|
||||
mock_device = torch.device('cpu')
|
||||
builder = AscendSFATorchairMetadataBuilder(None, None,
|
||||
mock_vllm_config,
|
||||
mock_device)
|
||||
|
||||
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
|
||||
|
||||
result = builder._get_graph_runner_block_tables(3, block_tables)
|
||||
|
||||
self.assertEqual(result.shape[0], 3)
|
||||
self.assertEqual(result.shape[1], 64)
|
||||
self.assertTrue(torch.equal(result[:, :10], block_tables))
|
||||
@@ -1,111 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
|
||||
init_cache_hf_modules_path = "vllm.utils.import_utils.init_cached_hf_modules"
|
||||
|
||||
|
||||
class TestNPUTorchairWorker(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.cache_config_mock = MagicMock(spec=CacheConfig)
|
||||
self.cache_config_mock.cache_type = "auto"
|
||||
|
||||
self.model_config_mock = MagicMock(spec=ModelConfig)
|
||||
self.model_config_mock.dtype = torch.float16
|
||||
self.model_config_mock.trust_remote_code = False
|
||||
|
||||
self.hf_config_mock = MagicMock()
|
||||
self.hf_config_mock.model_type = "test_model"
|
||||
if hasattr(self.hf_config_mock, 'index_topk'):
|
||||
delattr(self.hf_config_mock, 'index_topk')
|
||||
|
||||
self.model_config_mock.hf_config = self.hf_config_mock
|
||||
|
||||
self.parallel_config_mock = MagicMock(spec=ParallelConfig)
|
||||
|
||||
self.vllm_config_mock = MagicMock(spec=VllmConfig)
|
||||
self.vllm_config_mock.cache_config = self.cache_config_mock
|
||||
self.vllm_config_mock.model_config = self.model_config_mock
|
||||
self.vllm_config_mock.parallel_config = self.parallel_config_mock
|
||||
self.vllm_config_mock.additional_config = None
|
||||
self.vllm_config_mock.load_config = None
|
||||
self.vllm_config_mock.scheduler_config = None
|
||||
self.vllm_config_mock.device_config = None
|
||||
self.vllm_config_mock.compilation_config = None
|
||||
|
||||
self.local_rank = 0
|
||||
self.rank = 0
|
||||
self.distributed_init_method = "tcp://localhost:12345"
|
||||
self.is_driver_worker = False
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.worker.worker_v1.NPUWorker._init_worker_distributed_environment"
|
||||
)
|
||||
@patch("vllm_ascend.worker.worker_v1.NPUPlatform")
|
||||
def test_init_device(self, mock_platform, mock_init_dist_env):
|
||||
from vllm_ascend.worker.worker_v1 import NPUWorker
|
||||
|
||||
mock_platform.mem_get_info.return_value = (1000, 2000)
|
||||
|
||||
with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None):
|
||||
worker = NPUWorker()
|
||||
worker.local_rank = 1
|
||||
worker.model_config = MagicMock()
|
||||
worker.model_config.seed = 42
|
||||
worker.vllm_config = MagicMock()
|
||||
worker.parallel_config = MagicMock()
|
||||
worker.parallel_config.local_world_size = 0
|
||||
worker.parallel_config.data_parallel_size = 1
|
||||
|
||||
result = worker._init_device()
|
||||
|
||||
mock_platform.set_device.assert_called_once()
|
||||
call_args = mock_platform.set_device.call_args[0][0]
|
||||
self.assertEqual(str(call_args), "npu:1")
|
||||
|
||||
mock_platform.empty_cache.assert_called_once()
|
||||
mock_platform.seed_everything.assert_called_once_with(42)
|
||||
mock_platform.mem_get_info.assert_called_once()
|
||||
mock_init_dist_env.assert_called_once()
|
||||
|
||||
self.assertEqual(str(result), "npu:1")
|
||||
self.assertEqual(worker.init_npu_memory, 1000)
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.worker.worker_v1.NPUWorker._init_worker_distributed_environment"
|
||||
)
|
||||
@patch("vllm_ascend.worker.worker_v1.NPUPlatform")
|
||||
def test_init_device_torchair_worker(self, mock_platform,
|
||||
mock_init_dist_env):
|
||||
from vllm_ascend.torchair.torchair_worker import NPUTorchairWorker
|
||||
|
||||
mock_platform.mem_get_info.return_value = (1000, 2000)
|
||||
|
||||
with patch.object(NPUTorchairWorker, "__init__",
|
||||
lambda x, **kwargs: None):
|
||||
worker = NPUTorchairWorker()
|
||||
worker.local_rank = 1
|
||||
worker.model_config = MagicMock()
|
||||
worker.model_config.seed = 42
|
||||
worker.vllm_config = MagicMock()
|
||||
worker.parallel_config = MagicMock()
|
||||
worker.parallel_config.local_world_size = 0
|
||||
worker.parallel_config.data_parallel_size = 1
|
||||
|
||||
result = worker._init_device()
|
||||
|
||||
mock_platform.set_device.assert_called_once()
|
||||
call_args = mock_platform.set_device.call_args[0][0]
|
||||
self.assertEqual(str(call_args), "npu:1")
|
||||
|
||||
mock_platform.empty_cache.assert_called_once()
|
||||
mock_platform.seed_everything.assert_called_once_with(42)
|
||||
mock_platform.mem_get_info.assert_called_once()
|
||||
mock_init_dist_env.assert_called_once()
|
||||
|
||||
self.assertEqual(str(result), "npu:1")
|
||||
self.assertEqual(worker.init_npu_memory, 1000)
|
||||
@@ -1,164 +0,0 @@
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.torchair import utils
|
||||
|
||||
|
||||
class TestTorchairUtils(TestBase):
|
||||
|
||||
def test_get_torchair_current_work_dir(self):
|
||||
cache_dir = utils.TORCHAIR_CACHE_DIR
|
||||
work_dir = utils._get_torchair_current_work_dir()
|
||||
self.assertEqual(cache_dir, work_dir)
|
||||
work_dir = utils._get_torchair_current_work_dir("test")
|
||||
self.assertEqual(os.path.join(cache_dir, "test"), work_dir)
|
||||
|
||||
def test_torchair_cache_dir(self):
|
||||
utils.write_kv_cache_bytes_to_file(0, 100)
|
||||
self.assertTrue(utils.check_torchair_cache_exist(),
|
||||
"Create torchair cache dir failed")
|
||||
self.assertTrue(utils.check_kv_cache_bytes_cache_exist(),
|
||||
"Create kv cache bytes cache dir failed")
|
||||
kv_cache_bytes = utils.read_kv_cache_bytes_from_file(0)
|
||||
self.assertEqual(100, kv_cache_bytes)
|
||||
utils.delete_torchair_cache_file()
|
||||
self.assertFalse(utils.check_torchair_cache_exist(),
|
||||
"Delete torchair cache dir failed")
|
||||
self.assertFalse(utils.check_kv_cache_bytes_cache_exist(),
|
||||
"Delete kv cache bytes cache dir failed")
|
||||
|
||||
def test_torchair_cache_dir_multiple_ranks(self):
|
||||
ranks = [0, 1, 2, 3]
|
||||
values = [100, 200, 300, 400]
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
executor.map(utils.write_kv_cache_bytes_to_file, ranks, values)
|
||||
for rank, expected in zip(ranks, values):
|
||||
self.assertEqual(expected,
|
||||
utils.read_kv_cache_bytes_from_file(rank))
|
||||
utils.delete_torchair_cache_file()
|
||||
|
||||
self.assertFalse(utils.check_torchair_cache_exist(),
|
||||
"Delete torchair cache dir failed")
|
||||
self.assertFalse(utils.check_kv_cache_bytes_cache_exist(),
|
||||
"Delete kv cache bytes cache dir failed")
|
||||
|
||||
def test_delete_torchair_cache_file_multiple_times(self):
|
||||
utils.write_kv_cache_bytes_to_file(0, 100)
|
||||
utils.delete_torchair_cache_file()
|
||||
for i in range(5):
|
||||
try:
|
||||
utils.delete_torchair_cache_file()
|
||||
except FileNotFoundError:
|
||||
self.fail(
|
||||
f"Unexpected FileNotFoundError on delete call #{i+2}")
|
||||
|
||||
@patch('vllm.ModelRegistry')
|
||||
def test_register_torchair_model(self, mock_model_registry):
|
||||
mock_registry = MagicMock()
|
||||
mock_model_registry.return_value = mock_registry
|
||||
utils.register_torchair_model()
|
||||
|
||||
self.assertEqual(mock_model_registry.register_model.call_count, 7)
|
||||
call_args_list = mock_model_registry.register_model.call_args_list
|
||||
|
||||
expected_registrations = [
|
||||
("DeepSeekMTPModel",
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_mtp:TorchairDeepSeekMTP"
|
||||
),
|
||||
("DeepseekV2ForCausalLM",
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_v2:TorchairDeepseekV2ForCausalLM"
|
||||
),
|
||||
("DeepseekV3ForCausalLM",
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM"
|
||||
),
|
||||
("DeepseekV32ForCausalLM",
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM"
|
||||
),
|
||||
("Qwen2ForCausalLM",
|
||||
"vllm_ascend.torchair.models.qwen2:CustomQwen2ForCausalLM"),
|
||||
("Qwen3MoeForCausalLM",
|
||||
"vllm_ascend.torchair.models.qwen3_moe:CustomQwen3MoeForCausalLM"
|
||||
),
|
||||
("PanguProMoEForCausalLM",
|
||||
"vllm_ascend.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM"
|
||||
)
|
||||
]
|
||||
|
||||
for i, (expected_name,
|
||||
expected_path) in enumerate(expected_registrations):
|
||||
args, kwargs = call_args_list[i]
|
||||
self.assertEqual(args[0], expected_name)
|
||||
self.assertEqual(args[1], expected_path)
|
||||
|
||||
@mock.patch('vllm_ascend.torchair.utils.is_enable_nz')
|
||||
@mock.patch('torch_npu.get_npu_format')
|
||||
@mock.patch('torch_npu.npu_format_cast')
|
||||
@mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE',
|
||||
new=mock.MagicMock)
|
||||
def test_converting_weight_acl_format_to_nz(self, mock_npu_cast,
|
||||
mock_get_format, mock_is_nz):
|
||||
ACL_FORMAT_FRACTAL_NZ = 29
|
||||
mock_get_format.return_value = 1
|
||||
mock_npu_cast.return_value = 1
|
||||
mock_is_nz.return_value = 1
|
||||
|
||||
fused_moe = mock.MagicMock()
|
||||
fused_moe.w13_weight = mock.MagicMock()
|
||||
fused_moe.w2_weight = mock.MagicMock()
|
||||
fused_moe.w13_weight.data = torch.randn(128, 256)
|
||||
fused_moe.w2_weight.data = torch.randn(256, 128)
|
||||
model = mock.MagicMock()
|
||||
model.modules.return_value = [fused_moe]
|
||||
|
||||
utils.converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ)
|
||||
self.assertEqual(fused_moe.w13_weight.data, 1)
|
||||
|
||||
@mock.patch('torch_npu.get_npu_format')
|
||||
@mock.patch('torch_npu.npu_format_cast')
|
||||
@mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE',
|
||||
new=mock.MagicMock)
|
||||
def test_converting_weight_acl_format_format_true(self, mock_npu_cast,
|
||||
mock_get_format):
|
||||
ACL_FORMAT_FRACTAL_NZ = 29
|
||||
mock_get_format.return_value = ACL_FORMAT_FRACTAL_NZ
|
||||
mock_npu_cast.return_value = 1
|
||||
|
||||
fused_moe = mock.MagicMock()
|
||||
fused_moe.w13_weight = mock.MagicMock()
|
||||
fused_moe.w2_weight = mock.MagicMock()
|
||||
fused_moe.w13_weight.data = torch.randn(128, 256)
|
||||
fused_moe.w2_weight.data = torch.randn(256, 128)
|
||||
model = mock.MagicMock()
|
||||
model.modules.return_value = [fused_moe]
|
||||
|
||||
utils.converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ)
|
||||
mock_npu_cast.assert_not_called()
|
||||
|
||||
@mock.patch('vllm_ascend.torchair.utils.is_enable_nz')
|
||||
@mock.patch('torch_npu.get_npu_format')
|
||||
@mock.patch('torch_npu.npu_format_cast')
|
||||
@mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE',
|
||||
new=mock.MagicMock)
|
||||
def test_converting_weight_acl_format_no_nz(self, mock_npu_cast,
|
||||
mock_get_format, mock_is_nz):
|
||||
ACL_FORMAT_FRACTAL_NZ = 29
|
||||
mock_get_format.return_value = 1
|
||||
mock_npu_cast.return_value = 1
|
||||
mock_is_nz.return_value = 0
|
||||
|
||||
fused_moe = mock.MagicMock()
|
||||
fused_moe.w13_weight = mock.MagicMock()
|
||||
fused_moe.w2_weight = mock.MagicMock()
|
||||
fused_moe.w13_weight.data = torch.randn(128, 256)
|
||||
fused_moe.w2_weight.data = torch.randn(256, 128)
|
||||
model = mock.MagicMock()
|
||||
model.modules.return_value = [fused_moe]
|
||||
|
||||
utils.converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ)
|
||||
mock_npu_cast.assert_not_called()
|
||||
@@ -18,15 +18,6 @@ from uuid import uuid4
|
||||
|
||||
from vllm.logger import logger
|
||||
|
||||
TORCHAIR_MODEL_LIST = ["deepseek", "pangu", "kimi_k2", "qwen"]
|
||||
|
||||
|
||||
def _check_torchair_supported(model_type: str):
|
||||
for supported_model in TORCHAIR_MODEL_LIST:
|
||||
if supported_model in model_type.lower():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def check_kv_extra_config(vllm_config):
|
||||
|
||||
@@ -66,11 +57,6 @@ class AscendConfig:
|
||||
|
||||
def __init__(self, vllm_config):
|
||||
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
|
||||
torchair_graph_config = additional_config.get("torchair_graph_config",
|
||||
{})
|
||||
|
||||
self.torchair_graph_config = TorchairGraphConfig(
|
||||
torchair_graph_config, vllm_config, additional_config)
|
||||
|
||||
xlite_graph_config = additional_config.get("xlite_graph_config", {})
|
||||
self.xlite_graph_config = XliteGraphConfig(xlite_graph_config,
|
||||
@@ -107,8 +93,8 @@ class AscendConfig:
|
||||
self.chunked_prefill_for_mla = additional_config.get(
|
||||
"chunked_prefill_for_mla", False)
|
||||
self.enable_shared_expert_dp = additional_config.get(
|
||||
"enable_shared_expert_dp", False
|
||||
) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel
|
||||
"enable_shared_expert_dp",
|
||||
False) and vllm_config.parallel_config.enable_expert_parallel
|
||||
if self.enable_shared_expert_dp:
|
||||
from vllm_ascend.utils import enable_sp
|
||||
assert enable_sp(vllm_config=vllm_config,
|
||||
@@ -215,86 +201,6 @@ class AscendCompilationConfig:
|
||||
# Add more compilation related configs here as needed
|
||||
|
||||
|
||||
class TorchairGraphConfig:
|
||||
"""
|
||||
Configuration Object for torchair_graph_config from additional_config
|
||||
"""
|
||||
|
||||
def __init__(self, torchair_graph_config, vllm_config, additional_config):
|
||||
self.enabled = torchair_graph_config.get("enabled", False)
|
||||
self.mode = torchair_graph_config.get("mode", '')
|
||||
self.use_cached_graph = torchair_graph_config.get(
|
||||
"use_cached_graph", False)
|
||||
self.use_cached_kv_cache_bytes = torchair_graph_config.get(
|
||||
"use_cached_kv_cache_bytes", False)
|
||||
self.graph_batch_sizes = torchair_graph_config.get(
|
||||
"graph_batch_sizes", [])
|
||||
self.graph_batch_sizes_init = torchair_graph_config.get(
|
||||
"graph_batch_sizes_init", False)
|
||||
self.enable_multistream_mla = torchair_graph_config.get(
|
||||
"enable_multistream_mla", False)
|
||||
self.enable_view_optimize = torchair_graph_config.get(
|
||||
"enable_view_optimize", True)
|
||||
self.enable_frozen_parameter = torchair_graph_config.get(
|
||||
"enable_frozen_parameter", True)
|
||||
self.enable_kv_nz = torchair_graph_config.get("enable_kv_nz", False)
|
||||
self.enable_super_kernel = torchair_graph_config.get(
|
||||
"enable_super_kernel", False)
|
||||
|
||||
if not isinstance(self.graph_batch_sizes, list):
|
||||
raise TypeError("graph_batch_sizes must be list[int]")
|
||||
if self.graph_batch_sizes_init and len(self.graph_batch_sizes) > 0:
|
||||
raise ValueError(
|
||||
"graph_batch_sizes_init is only valid when graph_batch_sizes is empty"
|
||||
)
|
||||
if not self.enabled:
|
||||
if self.mode:
|
||||
raise RuntimeError(
|
||||
"mode is valid only when Torchair graph mode is enabled")
|
||||
if self.use_cached_graph:
|
||||
raise RuntimeError(
|
||||
"use_cached_graph is valid only when Torchair graph mode is enabled"
|
||||
)
|
||||
if self.use_cached_kv_cache_bytes:
|
||||
raise RuntimeError(
|
||||
"use_cached_kv_cache_bytes is valid only when Torchair graph mode is enabled"
|
||||
)
|
||||
if self.graph_batch_sizes:
|
||||
raise RuntimeError(
|
||||
"graph_batch_sizes is valid only when Torchair graph mode is enabled"
|
||||
)
|
||||
if self.graph_batch_sizes_init:
|
||||
raise RuntimeError(
|
||||
"graph_batch_sizes_init is valid only when Torchair graph mode is enabled"
|
||||
)
|
||||
if self.enable_multistream_mla:
|
||||
raise RuntimeError(
|
||||
"enable_multistream_mla is valid only when Torchair graph mode is enabled"
|
||||
)
|
||||
if self.enable_kv_nz:
|
||||
raise RuntimeError(
|
||||
"enable_kv_nz is valid only when Torchair graph mode is enabled"
|
||||
)
|
||||
if self.enable_super_kernel:
|
||||
raise RuntimeError(
|
||||
"enable_super_kernel is valid only when Torchair graph mode is enabled"
|
||||
)
|
||||
if self.enable_super_kernel:
|
||||
if vllm_config.parallel_config.tensor_parallel_size != 1:
|
||||
raise RuntimeError(
|
||||
"enable_super_kernel is valid only when tensor_parallel_size is 1"
|
||||
)
|
||||
if not additional_config.get("multistream_overlap_shared_expert",
|
||||
False):
|
||||
raise RuntimeError(
|
||||
"enable_super_kernel is valid only when multistream_overlap_shared_expert is enabled"
|
||||
)
|
||||
if self.use_cached_kv_cache_bytes and not self.use_cached_graph:
|
||||
raise RuntimeError(
|
||||
"use_cached_kv_cache_bytes is valid only when Torchair graph mode and use_cached_graph are enabled"
|
||||
)
|
||||
|
||||
|
||||
class XliteGraphConfig:
|
||||
"""
|
||||
Configuration Object for xlite_graph_config from additional_config
|
||||
@@ -382,39 +288,7 @@ def get_ascend_config():
|
||||
def check_ascend_config(vllm_config, enforce_eager):
|
||||
ascend_config = get_ascend_config()
|
||||
|
||||
# for eager mode
|
||||
if enforce_eager:
|
||||
# torchair_graph cannot be enabled with eager mode.
|
||||
if ascend_config.torchair_graph_config.enabled:
|
||||
raise RuntimeError(
|
||||
"Can't enable graph mode and eager mode at the same time. Please set `enforce_eager=False` if you attempt to enable NPU graph mode."
|
||||
)
|
||||
# for graph mode
|
||||
else:
|
||||
# torchair_graph case
|
||||
if ascend_config.torchair_graph_config.enabled:
|
||||
# torchair_graph is supported for deepseek/pangu/qwen model only.
|
||||
if vllm_config.model_config:
|
||||
model_type = vllm_config.model_config.hf_config.model_type
|
||||
if not _check_torchair_supported(model_type):
|
||||
raise NotImplementedError(
|
||||
"Torchair graph mode only works with following model types:"
|
||||
f"{TORCHAIR_MODEL_LIST}.")
|
||||
if ascend_config.enable_shared_expert_dp:
|
||||
logger.warning(
|
||||
"enable_shared_expert_dp is not supported for torchair graph mode currently, "
|
||||
"it has been disabled automatically.")
|
||||
# aclgraph case
|
||||
else:
|
||||
if ascend_config.ascend_compilation_config.enable_quantization_fusion:
|
||||
logger.info(
|
||||
"Quantization fusion enabled! op fusion on quantization are expected. "
|
||||
)
|
||||
|
||||
if vllm_config.model_config:
|
||||
model_type = vllm_config.model_config.hf_config.model_type
|
||||
if "qwen" not in model_type:
|
||||
logger.warning(
|
||||
"ACL Graph is currently experimental. Please "
|
||||
"raise an issue on https://github.com/vllm-project/vllm-ascend/issues"
|
||||
" if you encourage any Error")
|
||||
if ascend_config.ascend_compilation_config.enable_quantization_fusion:
|
||||
logger.info(
|
||||
"Quantization fusion enabled! op fusion on quantization are expected. "
|
||||
)
|
||||
|
||||
@@ -857,7 +857,6 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
ascend_config = get_ascend_config()
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
self.enable_prefetch = ascend_config.weight_prefetch_config.enabled
|
||||
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.ring_mla_mask_size = 512
|
||||
@@ -1248,7 +1247,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
||||
kv_no_split = kv_no_split.view(
|
||||
B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
||||
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
|
||||
cache_mode = "PA"
|
||||
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
|
||||
kv_no_split,
|
||||
self.kv_a_layernorm.weight,
|
||||
@@ -1276,7 +1275,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
||||
kv_no_split = kv_no_split.view(
|
||||
B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
||||
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
|
||||
cache_mode = "PA"
|
||||
_, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache(
|
||||
kv_no_split,
|
||||
self.kv_a_layernorm.weight,
|
||||
@@ -1318,18 +1317,11 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
# shape of knope/k_pe for npu graph mode should be:
|
||||
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
|
||||
actual_seq_lengths = None
|
||||
if self.enable_kv_nz:
|
||||
k_nope = k_nope.view(-1, self.num_kv_heads,
|
||||
self.kv_lora_rank // 16, block_size, 16)
|
||||
k_pe = k_pe.view(-1, self.num_kv_heads,
|
||||
self.qk_rope_head_dim // 16, block_size, 16)
|
||||
input_layout = "BSND"
|
||||
else:
|
||||
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
|
||||
self.kv_lora_rank)
|
||||
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
|
||||
self.qk_rope_head_dim)
|
||||
input_layout = "BNSD"
|
||||
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
|
||||
self.kv_lora_rank)
|
||||
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
|
||||
self.qk_rope_head_dim)
|
||||
input_layout = "BNSD"
|
||||
|
||||
if attn_metadata.attn_state in [
|
||||
AscendAttentionState.SpecDecoding,
|
||||
@@ -1346,14 +1338,9 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
|
||||
actual_seq_lengths = decode_meta.actual_seq_lengths_q
|
||||
else:
|
||||
if self.enable_kv_nz:
|
||||
q_nope = q_nope.view(num_tokens, 1, self.num_heads,
|
||||
-1).contiguous()
|
||||
q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1)
|
||||
else:
|
||||
q_nope = q_nope.view(num_tokens, self.num_heads, 1,
|
||||
-1).contiguous()
|
||||
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
|
||||
q_nope = q_nope.view(num_tokens, self.num_heads, 1,
|
||||
-1).contiguous()
|
||||
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
|
||||
sparse_mode = 0
|
||||
spec_attn_mask = None
|
||||
|
||||
|
||||
@@ -345,7 +345,6 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
ascend_config = get_ascend_config()
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
self.enable_prefetch = ascend_config.weight_prefetch_config.enabled
|
||||
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
|
||||
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO
|
||||
|
||||
assert self.indexer is not None, "Indexer is required for DSA."
|
||||
@@ -534,7 +533,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
||||
kv_no_split = kv_no_split.view(
|
||||
B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
||||
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
|
||||
cache_mode = "PA"
|
||||
|
||||
if self.enable_sfa_cp:
|
||||
assert slots_cp is not None
|
||||
|
||||
@@ -453,7 +453,6 @@ class KVCacheRecvingThread(threading.Thread):
|
||||
def _cat_kv_cache(self, block_ids: list[list[int]]):
|
||||
# Get necessary parameters
|
||||
k_cache = list(self.kv_caches.values())[0][0]
|
||||
kv_shape = k_cache.shape
|
||||
dtype = k_cache.dtype
|
||||
device = k_cache.device
|
||||
head_dim = self.model_config.hf_config.head_dim
|
||||
@@ -494,13 +493,6 @@ class KVCacheRecvingThread(threading.Thread):
|
||||
|
||||
# Process each layer in the KV cache
|
||||
for _, (k_cache_layer, v_cache_layer) in self.kv_caches.items():
|
||||
if len(
|
||||
k_cache_layer.shape
|
||||
) == 3: # kv shape in torchair model is [num_block, block_size, num_kv_head*head_dim]
|
||||
k_cache_layer = k_cache_layer.view(kv_shape[0], kv_shape[1],
|
||||
num_kv_head, head_dim)
|
||||
v_cache_layer = v_cache_layer.view(kv_shape[0], kv_shape[1],
|
||||
num_kv_head, head_dim)
|
||||
# Load cache data into buffers
|
||||
torch_npu.atb.npu_paged_cache_load(
|
||||
k_cache_layer,
|
||||
|
||||
@@ -99,8 +99,6 @@ class MoECommMethod(ABC):
|
||||
w2_scale: Optional[list[torch.Tensor]] = None,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
# For TorchAir graph
|
||||
is_torchair: bool = False,
|
||||
# For Cube/Vector parallel
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
@@ -283,8 +281,6 @@ class FusedAlltoAllCommImpl(MoECommMethod):
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
# For TorchAir graph
|
||||
is_torchair: bool = False,
|
||||
# For Cube/Vector parallel
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
|
||||
@@ -26,10 +26,7 @@ from vllm.platforms import Platform, PlatformEnum
|
||||
# todo: please remove it when solve cuda hard code in vllm
|
||||
os.environ["VLLM_DISABLE_SHARED_EXPERTS_STREAM"] = "1"
|
||||
|
||||
from vllm_ascend.ascend_config import (check_ascend_config, get_ascend_config,
|
||||
init_ascend_config)
|
||||
from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
|
||||
delete_torchair_cache_file)
|
||||
from vllm_ascend.ascend_config import check_ascend_config, init_ascend_config
|
||||
from vllm_ascend.utils import refresh_block_size
|
||||
|
||||
# isort: off
|
||||
@@ -204,25 +201,6 @@ class NPUPlatform(Platform):
|
||||
compilation_config.mode)
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
# set CUDAGraphMode to None when torchair is enabled, no mather what compilation_config.level is.
|
||||
if ascend_config.torchair_graph_config.enabled:
|
||||
logger.info(
|
||||
"Torchair compilation enabled on NPU. Setting CUDAGraphMode to NONE"
|
||||
)
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
# Note: We delete the torchair cache folder here to prevent runtime issues caused by dimension
|
||||
# mismatches or configuration inconsistencies when users reuse cached computation graphs. Though
|
||||
# this will increase graph compilation duration, it significantly enhances robustness and decreases
|
||||
# graph launching time during inference.
|
||||
if check_torchair_cache_exist(
|
||||
) and not ascend_config.torchair_graph_config.use_cached_kv_cache_bytes:
|
||||
logger.warning(
|
||||
"Torchair cache folder is deleted here to prevent runtime issues caused by dimension "
|
||||
"mismatches or configuration inconsistencies when users reuse cached computation graphs. "
|
||||
"In order to decrease torchair graph compilation time, users can enable both use_cached_graph "
|
||||
"and use_cached_kv_cache_bytes in torchair_graph_config.")
|
||||
delete_torchair_cache_file()
|
||||
|
||||
# set cudaprah sizes before extending `compilation_config.splitting_ops`
|
||||
vllm_config._set_cudagraph_sizes()
|
||||
# There are cases where default cudagraph_capture_sizes are not friendly
|
||||
@@ -303,9 +281,7 @@ class NPUPlatform(Platform):
|
||||
if parallel_config and parallel_config.worker_cls == "auto":
|
||||
# TODO: this is a tricky way to disable `use_sequence_parallel_moe` in vllm.
|
||||
parallel_config.all2all_backend = "flashinfer_all2allv"
|
||||
if ascend_config.torchair_graph_config.enabled:
|
||||
parallel_config.worker_cls = "vllm_ascend.torchair.torchair_worker.NPUTorchairWorker"
|
||||
elif ascend_config.xlite_graph_config.enabled:
|
||||
if ascend_config.xlite_graph_config.enabled:
|
||||
logger.info(
|
||||
"Euler Xlite enabled. See: https://gitee.com/openeuler/GVirt/tree/master/xlite"
|
||||
)
|
||||
@@ -390,29 +366,14 @@ class NPUPlatform(Platform):
|
||||
use_sparse=False,
|
||||
attn_type: str | None = None,
|
||||
):
|
||||
ascend_config = get_ascend_config()
|
||||
|
||||
if use_mla and ascend_config.enable_shared_expert_dp:
|
||||
if use_mla and use_sparse:
|
||||
return "vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend"
|
||||
|
||||
use_torchair = ascend_config.torchair_graph_config.enabled
|
||||
# choose attention backend based on use_mla and use_torchair
|
||||
# choose attention backend based on use_mla
|
||||
backend_map = {
|
||||
(True, False, True):
|
||||
"vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend",
|
||||
(True, False, False):
|
||||
"vllm_ascend.attention.mla_v1.AscendMLABackend",
|
||||
(False, False, True):
|
||||
"vllm_ascend.torchair.torchair_attention.AscendAttentionTorchairBackend",
|
||||
(False, False, False):
|
||||
(True, False): "vllm_ascend.attention.mla_v1.AscendMLABackend",
|
||||
(False, False):
|
||||
"vllm_ascend.attention.attention_v1.AscendAttentionBackend",
|
||||
(True, True, False):
|
||||
"vllm_ascend.attention.sfa_v1.AscendSFABackend",
|
||||
(True, True, True):
|
||||
"vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend",
|
||||
(True, True): "vllm_ascend.attention.sfa_v1.AscendSFABackend",
|
||||
}
|
||||
return backend_map[(use_mla, use_sparse, use_torchair)]
|
||||
return backend_map[(use_mla, use_sparse)]
|
||||
|
||||
@classmethod
|
||||
def get_punica_wrapper(cls) -> str:
|
||||
|
||||
@@ -111,10 +111,9 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
ascend_config = get_ascend_config()
|
||||
self.use_aclgraph = (
|
||||
vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE
|
||||
and not vllm_config.model_config.enforce_eager
|
||||
and not ascend_config.torchair_graph_config.enabled)
|
||||
self.use_aclgraph = (vllm_config.compilation_config.mode
|
||||
== CompilationMode.VLLM_COMPILE
|
||||
and not vllm_config.model_config.enforce_eager)
|
||||
|
||||
self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
|
||||
self.in_dtype = vllm_config.model_config.dtype
|
||||
|
||||
@@ -20,21 +20,14 @@ from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
||||
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
|
||||
from vllm_ascend.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm_ascend.spec_decode.suffix_proposer import SuffixDecodingProposer
|
||||
from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer
|
||||
|
||||
|
||||
def get_spec_decode_method(method,
|
||||
vllm_config,
|
||||
device,
|
||||
runner,
|
||||
is_torchair_graph=False):
|
||||
def get_spec_decode_method(method, vllm_config, device, runner):
|
||||
if method == "ngram":
|
||||
return NgramProposer(vllm_config, device, runner)
|
||||
elif method in ("eagle", "eagle3"):
|
||||
return EagleProposer(vllm_config, device, runner)
|
||||
elif method == "mtp":
|
||||
if is_torchair_graph:
|
||||
return TorchairMtpProposer(vllm_config, device, runner)
|
||||
return MtpProposer(vllm_config, device, runner)
|
||||
elif method == 'suffix':
|
||||
return SuffixDecodingProposer(vllm_config, device, runner)
|
||||
|
||||
@@ -1,357 +0,0 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import vllm
|
||||
from torch import nn
|
||||
from transformers import Qwen2Config
|
||||
from vllm.attention.backends.abstract import AttentionMetadata, AttentionType
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group, tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_reduce_scatter)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
|
||||
from vllm.model_executor.models.qwen2 import Qwen2Attention # noqa: F401
|
||||
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM # noqa: F401
|
||||
from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Model
|
||||
from vllm.model_executor.models.utils import (AutoWeightsLoader,
|
||||
PPMissingLayer, maybe_prefix)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.config import set_default_rope_theta
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
|
||||
|
||||
def all_gather_and_maybe_unpad(
|
||||
hidden_states: torch.Tensor,
|
||||
pad_size: int,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
|
||||
if pad_size > 0:
|
||||
return hidden_states[:-pad_size, :]
|
||||
return hidden_states
|
||||
|
||||
|
||||
def maybe_pad_and_reduce_scatter(
|
||||
hidden_states: torch.Tensor,
|
||||
pad_size: int,
|
||||
) -> torch.Tensor:
|
||||
if pad_size > 0:
|
||||
hidden_states = F.pad(hidden_states, (0, 0, 0, pad_size))
|
||||
hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, 0)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomQwen2Attention(Qwen2Attention):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_parameters: Optional[dict[str, Any]] = None,
|
||||
max_position: int = 4096 * 32,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
hidden_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
max_position=max_position,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
attn_type=attn_type,
|
||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||
rope_parameters=rope_parameters)
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
if self.torchair_graph_enabled and attn_metadata is not None and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
q, k = self.rotary_emb(positions,
|
||||
q,
|
||||
k,
|
||||
is_prefill=False,
|
||||
is_qwen_torchair=True)
|
||||
forward_kwargs = {}
|
||||
output_shape = q.shape
|
||||
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
|
||||
forward_kwargs['output'] = output
|
||||
|
||||
attn_output = self.attn.impl.forward(self.attn,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
**forward_kwargs)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
else:
|
||||
if type(self.rotary_emb) is RotaryEmbedding:
|
||||
q, k = self.rotary_emb(positions, q, k, is_qwen_torchair=True)
|
||||
else:
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class CustomQwen2DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen2Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
set_default_rope_theta(config, default_theta=1000000)
|
||||
|
||||
dual_chunk_attention_config = getattr(config,
|
||||
"dual_chunk_attention_config",
|
||||
None)
|
||||
|
||||
# By default, Qwen2 uses causal attention as it is a decoder-only model.
|
||||
# You can override the HF config with `is_causal=False` to enable
|
||||
# bidirectional attention, which is used in some embedding models
|
||||
# (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct)
|
||||
if getattr(config, "is_causal", True):
|
||||
attn_type = AttentionType.DECODER
|
||||
else:
|
||||
attn_type = AttentionType.ENCODER_ONLY
|
||||
|
||||
self.self_attn = CustomQwen2Attention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
max_position=config.max_position_embeddings,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_parameters=config.rope_parameters,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
attn_type=attn_type,
|
||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||
)
|
||||
self.mlp = Qwen2MLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
hidden_states = self.self_attn(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile(
|
||||
dynamic_arg_dims={
|
||||
"input_ids": 0,
|
||||
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
|
||||
# otherwise (seq_len, ).
|
||||
"positions": -1,
|
||||
"intermediate_tensors": 0,
|
||||
"inputs_embeds": 0,
|
||||
})
|
||||
class CustomQwen2Model(Qwen2Model):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
decoder_layer_type: type[nn.Module] = CustomQwen2DecoderLayer):
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
decoder_layer_type=decoder_layer_type)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
kv_cache = kv_caches[i - self.start_layer] \
|
||||
if kv_caches is not None else None
|
||||
hidden_states, residual = layer(positions,
|
||||
hidden_states,
|
||||
residual,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomQwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
# add `CustomQwen2Model` to init self.model
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.model = CustomQwen2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "lm_head"))
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata=None, # type: ignore
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."]
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
|
||||
vllm.model_executor.models.qwen2.Qwen2ForCausalLM = CustomQwen2ForCausalLM
|
||||
@@ -1,527 +0,0 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2024 The Qwen team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Adapted from vllm/model_executor/models/qwen3_moe.py
|
||||
# This file is a part of the vllm-ascend project.
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, CompilationMode, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
|
||||
get_tp_group)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.models.interfaces import (MixtureOfExperts,
|
||||
SupportsLoRA, SupportsPP)
|
||||
from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention,
|
||||
Qwen3MoeDecoderLayer,
|
||||
Qwen3MoeForCausalLM,
|
||||
Qwen3MoeMLP, Qwen3MoeModel,
|
||||
Qwen3MoeSparseMoeBlock)
|
||||
from vllm.model_executor.models.utils import (
|
||||
PPMissingLayer, extract_layer_index,
|
||||
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.torchair.ops.sequence_parallel import (MetadataForPadding,
|
||||
init_metadata_for_sp)
|
||||
from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE
|
||||
|
||||
|
||||
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
nn.Module.__init__(self)
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
if self.tp_size > config.num_experts:
|
||||
raise ValueError(
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
f"the number of experts {config.num_experts}.")
|
||||
|
||||
self.gate = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
config.num_experts,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
|
||||
self.experts = TorchairAscendFusedMoE(
|
||||
num_experts=config.num_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=False,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts",
|
||||
)
|
||||
|
||||
self.top_k = config.num_experts_per_tok
|
||||
|
||||
self.dp_size = get_dp_group().world_size
|
||||
|
||||
self.tp_group = get_tp_group().device_group
|
||||
self.tp_rank = get_tp_group().rank_in_group
|
||||
self.ep_group = get_ep_group()
|
||||
|
||||
self.params_dtype = torch.get_default_dtype()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attn_metadata=None,
|
||||
_metadata_for_padding: Optional[MetadataForPadding] = None,
|
||||
):
|
||||
if attn_metadata is None:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
# when profile runs, force experts to load balanced tokens
|
||||
# to avoid high memory consumption on a single rank.
|
||||
enable_force_load_balance = get_forward_context().in_profile_run
|
||||
is_prefill = get_forward_context().with_prefill
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
is_prefill=is_prefill,
|
||||
top_k=self.top_k,
|
||||
enable_force_load_balance=enable_force_load_balance,
|
||||
shared_experts=None,
|
||||
_metadata_for_padding=_metadata_for_padding,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomQwen3MoeAttention(Qwen3MoeAttention):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_parameters: dict[str, Any],
|
||||
max_position_embeddings: int = 8192,
|
||||
head_dim: Optional[int] = None,
|
||||
rms_norm_eps: float = 1e-06,
|
||||
qkv_bias: bool = False,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = head_dim or (hidden_size // self.total_num_heads)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=qkv_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj")
|
||||
|
||||
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj")
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
rope_parameters=rope_parameters,
|
||||
)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
|
||||
@staticmethod
|
||||
def normalize_qkv(qkv: torch.Tensor, q_size: int, kv_size: int,
|
||||
head_dim: int, q_norm, k_norm):
|
||||
q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1)
|
||||
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // head_dim, head_dim)
|
||||
q_by_head = q_norm(q_by_head)
|
||||
q = q_by_head.view(q.shape)
|
||||
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // head_dim, head_dim)
|
||||
k_by_head = k_norm(k_by_head)
|
||||
k = k_by_head.view(k.shape)
|
||||
|
||||
return q, k, v
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = self.normalize_qkv(qkv, self.q_size, self.kv_size,
|
||||
self.head_dim, self.q_norm, self.k_norm)
|
||||
|
||||
if (self.torchair_graph_enabled and attn_metadata is not None and
|
||||
attn_metadata.attn_state == AscendAttentionState.DecodeOnly):
|
||||
q, k = self.rotary_emb(positions,
|
||||
q,
|
||||
k,
|
||||
is_prefill=False,
|
||||
is_qwen_torchair=True)
|
||||
forward_kwargs = {}
|
||||
output_shape = q.shape
|
||||
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
|
||||
forward_kwargs['output'] = output
|
||||
|
||||
attn_output = self.attn.impl.forward(self.attn,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
**forward_kwargs)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
else:
|
||||
q, k = self.rotary_emb(positions, q, k, is_qwen_torchair=True)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
vllm_config: Optional[VllmConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
|
||||
nn.Module.__init__(self)
|
||||
self.hidden_size = config.hidden_size
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
self.self_attn = CustomQwen3MoeAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_parameters=config.rope_parameters,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
qkv_bias=getattr(config, 'attention_bias', False),
|
||||
head_dim=getattr(config, 'head_dim', None),
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
|
||||
# `mlp_only_layers` in the config.
|
||||
layer_idx = extract_layer_index(prefix)
|
||||
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
|
||||
config.mlp_only_layers)
|
||||
self.use_aclgraph = (vllm_config is not None
|
||||
and vllm_config.compilation_config.mode
|
||||
== CompilationMode.VLLM_COMPILE
|
||||
and not vllm_config.model_config.enforce_eager)
|
||||
if (layer_idx not in mlp_only_layers) and (
|
||||
config.num_experts > 0 and
|
||||
(layer_idx + 1) % config.decoder_sparse_step == 0):
|
||||
if not self.use_aclgraph:
|
||||
# FIXME: custom sparse moe block doesn't work with aclgraph.
|
||||
self.mlp = CustomSparseMoeBlock(config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
else:
|
||||
self.mlp = Qwen3MoeSparseMoeBlock(vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
else:
|
||||
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
self.enable_sequence_parallelism = (
|
||||
vllm_config.compilation_config.pass_config.enable_sp
|
||||
if vllm_config is not None else False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
_metadata_for_padding: Optional[MetadataForPadding] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# To prevent precision issues during the decoder phase when only prefilling enables SP
|
||||
if not self.enable_sequence_parallelism:
|
||||
self.self_attn.o_proj.reduce_results = True
|
||||
else:
|
||||
self.self_attn.o_proj.reduce_results = not _metadata_for_padding.not_dummy_and_is_prefill if _metadata_for_padding is not None else True
|
||||
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
|
||||
residual = _metadata_for_padding.padding_slice(residual)
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
|
||||
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
|
||||
hidden_states)
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
|
||||
hidden_states = _metadata_for_padding.padding_aligned_reduce_scatter(
|
||||
hidden_states)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
if not self.use_aclgraph:
|
||||
hidden_states = self.mlp(
|
||||
hidden_states, _metadata_for_padding=_metadata_for_padding)
|
||||
else:
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class CustomQwen3MoeModel(Qwen3MoeModel):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
eplb_config = parallel_config.eplb_config
|
||||
self.num_redundant_experts = eplb_config.num_redundant_experts
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.config = config
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
prefix=f"{prefix}.embed_tokens")
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: CustomQwen3MoeDecoderLayer(
|
||||
config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
vllm_config=vllm_config,
|
||||
prefix=prefix),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
_metadata_for_padding: Optional[MetadataForPadding] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
residual,
|
||||
kv_caches[i -
|
||||
self.start_layer] if kv_caches is not None else None,
|
||||
attn_metadata,
|
||||
_metadata_for_padding=_metadata_for_padding)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
|
||||
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
|
||||
hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
SupportsPP.__init__(self)
|
||||
SupportsLoRA.__init__(self)
|
||||
MixtureOfExperts.__init__(self)
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = CustomQwen3MoeModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"))
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sp
|
||||
# Set MoE hyperparameters
|
||||
self.expert_weights: list[torch.Tensor] = []
|
||||
|
||||
self.moe_layers: list[FusedMoE] = []
|
||||
example_layer = None
|
||||
for layer in self.model.layers:
|
||||
if isinstance(layer, PPMissingLayer):
|
||||
continue
|
||||
|
||||
assert isinstance(layer, Qwen3MoeDecoderLayer)
|
||||
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
|
||||
example_layer = layer.mlp
|
||||
self.moe_layers.append(layer.mlp.experts)
|
||||
|
||||
if example_layer is None:
|
||||
raise RuntimeError("No Qwen3MoE layer found in the model.layers.")
|
||||
|
||||
self.num_moe_layers = len(self.moe_layers)
|
||||
self.num_expert_groups = 1
|
||||
self.num_shared_experts = 0
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
_metadata_for_padding = init_metadata_for_sp(
|
||||
input_ids, self.enable_sequence_parallelism)
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds, _metadata_for_padding)
|
||||
return hidden_states
|
||||
@@ -1,221 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Adapted from vllm/model_executor/models/deepseek_mtp.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.models.deepseek_mtp import (
|
||||
DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer,
|
||||
SharedHead)
|
||||
from vllm.model_executor.models.utils import maybe_prefix
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_ascend.torchair.models.torchair_deepseek_v2 import \
|
||||
TorchairDeepseekV2DecoderLayer
|
||||
|
||||
|
||||
class TorchairDeepSeekShareHead(SharedHead):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "head"))
|
||||
|
||||
|
||||
class TorchairDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer
|
||||
):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
prefix: str,
|
||||
model_config: ModelConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.eh_proj = nn.Linear(config.hidden_size * 2,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
self.shared_head = TorchairDeepSeekShareHead(config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix,
|
||||
"shared_head"))
|
||||
self.mtp_block = TorchairDeepseekV2DecoderLayer(
|
||||
config, prefix, model_config, cache_config, quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_index: int = 0,
|
||||
) -> torch.Tensor:
|
||||
assert inputs_embeds is not None
|
||||
# masking inputs at position 0, as not needed by MTP
|
||||
inputs_embeds = torch.where((positions == 0).unsqueeze(-1),
|
||||
torch.zeros_like(inputs_embeds),
|
||||
inputs_embeds)
|
||||
inputs_embeds = self.enorm(inputs_embeds)
|
||||
previous_hidden_states = self.hnorm(previous_hidden_states)
|
||||
|
||||
hidden_states = self.eh_proj(
|
||||
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
|
||||
|
||||
del inputs_embeds, previous_hidden_states
|
||||
replace_allreduce = hidden_states.shape[0] % self.tp_size == 0
|
||||
|
||||
hidden_states, residual = self.mtp_block(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=None,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
replace_allreduce=replace_allreduce)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TorchairDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.mtp_start_layer_idx = config.num_hidden_layers
|
||||
self.num_mtp_layers = config.num_nextn_predict_layers
|
||||
# to map the exact layer index from weights
|
||||
self.layers = torch.nn.ModuleDict({
|
||||
str(idx):
|
||||
TorchairDeepSeekMultiTokenPredictorLayer(
|
||||
config,
|
||||
f"{prefix}.layers.{idx}",
|
||||
model_config=vllm_config.model_config,
|
||||
cache_config=vllm_config.cache_config,
|
||||
quant_config=vllm_config.quant_config,
|
||||
)
|
||||
for idx in range(self.mtp_start_layer_idx,
|
||||
self.mtp_start_layer_idx + self.num_mtp_layers)
|
||||
})
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
|
||||
# Note: torch._dynamo.exc.Unsupported: builtin: str
|
||||
self.layers_list = [
|
||||
self.layers[str(idx)]
|
||||
for idx in range(self.mtp_start_layer_idx,
|
||||
self.mtp_start_layer_idx + self.num_mtp_layers)
|
||||
]
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
||||
step_kv_cache = kv_caches[
|
||||
current_step_idx] if kv_caches is not None else None
|
||||
return self.layers_list[current_step_idx](
|
||||
input_ids,
|
||||
positions,
|
||||
step_kv_cache,
|
||||
attn_metadata,
|
||||
previous_hidden_states,
|
||||
inputs_embeds,
|
||||
current_step_idx,
|
||||
)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
||||
mtp_layer = self.layers_list[current_step_idx]
|
||||
logits = self.logits_processor(mtp_layer.shared_head.head,
|
||||
mtp_layer.shared_head(hidden_states))
|
||||
return logits
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class TorchairDeepSeekMTP(DeepSeekMTP):
|
||||
# NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized;
|
||||
# NOTE 2.The description file generated by the current msmodelslim tool does not have
|
||||
# MTP layer info. Please manually add it and set the value to FLOAT.
|
||||
packed_modules_mapping = {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
self.model = TorchairDeepSeekMultiTokenPredictor(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
hidden_states: Optional[torch.Tensor] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, hidden_states, inputs_embeds,
|
||||
spec_step_idx)
|
||||
return hidden_states
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,28 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from vllm_ascend.torchair.models.torchair_deepseek_v2 import \
|
||||
TorchairDeepseekV2ForCausalLM
|
||||
|
||||
|
||||
class TorchairDeepseekV3ForCausalLM(TorchairDeepseekV2ForCausalLM):
|
||||
pass
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,120 +0,0 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||
get_tp_group, tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_reduce_scatter)
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
|
||||
|
||||
class MetadataForPadding:
|
||||
|
||||
def __init__(self,
|
||||
padding_flag=False,
|
||||
lengths_sum_padding=0,
|
||||
lengths_sum_unpadding=0,
|
||||
pad_size=0,
|
||||
not_dummy_and_is_prefill=False):
|
||||
self.padding_flag = padding_flag
|
||||
self.not_dummy_and_is_prefill = not_dummy_and_is_prefill
|
||||
|
||||
self.lengths_sum_padding = lengths_sum_padding
|
||||
self.lengths_sum_unpadding = lengths_sum_unpadding
|
||||
self.pad_size = pad_size
|
||||
|
||||
self.tp_size = get_tp_group().world_size
|
||||
self.tp_rank_in_group = get_tp_group().rank_in_group
|
||||
|
||||
assert self.lengths_sum_padding % self.tp_size == 0
|
||||
self.slice_size = self.lengths_sum_padding // self.tp_size
|
||||
|
||||
self.mc2_mask = torch.zeros(
|
||||
self.lengths_sum_padding,
|
||||
dtype=torch.bool,
|
||||
device=NPUPlatform.device_type,
|
||||
)
|
||||
self.mc2_mask[:lengths_sum_unpadding] = True
|
||||
|
||||
def padding_aligned_reduce_scatter(self,
|
||||
data: torch.Tensor) -> torch.Tensor:
|
||||
if self.padding_flag:
|
||||
pad_size = self.pad_size
|
||||
padded_data = F.pad(data, (0, 0, 0, pad_size))
|
||||
else:
|
||||
padded_data = data
|
||||
padded_data_reduce_scatter = tensor_model_parallel_reduce_scatter(
|
||||
padded_data, 0)
|
||||
|
||||
return padded_data_reduce_scatter
|
||||
|
||||
def allgather_unpadding_aligned(self,
|
||||
padded_data: torch.Tensor) -> torch.Tensor:
|
||||
padded_data_allgather = tensor_model_parallel_all_gather(
|
||||
padded_data, 0)
|
||||
if self.padding_flag:
|
||||
lengths_sum_unpadding = self.lengths_sum_unpadding
|
||||
unpadding_data = padded_data_allgather[:lengths_sum_unpadding]
|
||||
else:
|
||||
unpadding_data = padded_data_allgather
|
||||
return unpadding_data
|
||||
|
||||
def padding_slice(self, data: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
padded_data = F.pad(data, (0, 0, 0, self.pad_size))
|
||||
start = self.tp_rank_in_group * self.slice_size
|
||||
end = start + self.slice_size
|
||||
slice_data = padded_data[start:end]
|
||||
|
||||
return slice_data
|
||||
|
||||
def padding_aligned_scatter(self, data: torch.Tensor) -> torch.Tensor:
|
||||
if self.padding_flag:
|
||||
pad_size = self.pad_size
|
||||
padded_data = F.pad(data, (0, 0, 0, pad_size))
|
||||
else:
|
||||
padded_data = data
|
||||
# padded_data = data
|
||||
padded_data = torch.tensor_split(padded_data, self.tp_size, dim=0)
|
||||
|
||||
padded_data_reduce_scatter = padded_data[self.tp_rank_in_group]
|
||||
|
||||
return padded_data_reduce_scatter
|
||||
|
||||
|
||||
def init_metadata_for_sp(input_ids, enable_sequence_parallelism):
|
||||
if not enable_sequence_parallelism:
|
||||
return MetadataForPadding(padding_flag=False,
|
||||
not_dummy_and_is_prefill=False)
|
||||
|
||||
is_perifll = 0
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if attn_metadata is not None:
|
||||
if hasattr(attn_metadata,
|
||||
'is_only_prefill') and attn_metadata.is_only_prefill:
|
||||
is_perifll = 1
|
||||
if hasattr(attn_metadata,
|
||||
'num_prefills') and attn_metadata.num_prefills > 0:
|
||||
is_perifll = 1
|
||||
|
||||
if is_perifll:
|
||||
lengths_sum_unpadding = input_ids.shape[0]
|
||||
lengths_sum_padding = (
|
||||
(lengths_sum_unpadding + tp_size - 1) // tp_size) * tp_size
|
||||
if lengths_sum_unpadding == lengths_sum_padding:
|
||||
padding_flag = False
|
||||
else:
|
||||
padding_flag = True
|
||||
pad_size = lengths_sum_padding - lengths_sum_unpadding
|
||||
_metadata_for_padding = MetadataForPadding(
|
||||
lengths_sum_unpadding=lengths_sum_unpadding,
|
||||
lengths_sum_padding=lengths_sum_padding,
|
||||
padding_flag=padding_flag,
|
||||
pad_size=pad_size,
|
||||
not_dummy_and_is_prefill=True)
|
||||
|
||||
return _metadata_for_padding
|
||||
|
||||
return MetadataForPadding(padding_flag=False,
|
||||
not_dummy_and_is_prefill=False)
|
||||
@@ -1,245 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from vllm.distributed.parallel_state import GroupCoordinator
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
|
||||
|
||||
def dispose_tensor(x: torch.Tensor):
|
||||
x.set_(torch.empty([], device=x.device, dtype=x.dtype))
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayerMetadata:
|
||||
"""Metadata for a layer.
|
||||
"""
|
||||
layer: Optional[LinearBase] # The layer object.
|
||||
post_method: Callable[[
|
||||
torch.nn.Module
|
||||
], None] # The `process_weights_after_loading` method from the quant method.
|
||||
weight: torch.Tensor # The weight tensor.
|
||||
window_idx: int # The index of the window.
|
||||
|
||||
|
||||
@dataclass
|
||||
class SharedWindowMetadata:
|
||||
"""Metadata for a shared window.
|
||||
"""
|
||||
weight: torch.Tensor # The weight tensor to be shared by layers.
|
||||
data_layer_idx: int # The index of the layer this window's weight is equal to.
|
||||
work: Optional[torch.distributed.Work] # The asynchronous broadcast work.
|
||||
|
||||
|
||||
@dataclass
|
||||
class SeriesMetadata:
|
||||
"""Metadata for a weight shared series.
|
||||
"""
|
||||
group: GroupCoordinator
|
||||
start_layer: int
|
||||
end_layer: int
|
||||
num_layers: int
|
||||
prefetch_step: int
|
||||
dummy_weight: torch.Tensor # Dummy weight to replace the loaded weight matrix. All the layers in the series share the same dummy weight tensor.
|
||||
layers: list[LayerMetadata]
|
||||
shared_windows: list[
|
||||
SharedWindowMetadata] # Shared windows for prefetching. The window size is (`prefetch_step` + 1), as only the weights for the next (`prefetch_step` + 1) layers need to be stored.
|
||||
window_offset: int # The index of the window for the next coming layer.
|
||||
|
||||
def is_source(self, layer_idx) -> bool:
|
||||
return layer_idx % self.group.world_size == self.group.rank_in_group
|
||||
|
||||
def post_process_after_loading(self):
|
||||
# This method only needs to be called once per series.
|
||||
if self.shared_windows:
|
||||
return
|
||||
for layer_idx in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[layer_idx - self.start_layer]
|
||||
is_source = self.is_source(layer_idx)
|
||||
# If the weight uses dummy weight, make a copy temporary such that the post method call won't affect other layers which also uses dummy weight.
|
||||
if not is_source:
|
||||
layer.weight.set_(torch.empty_like(self.dummy_weight))
|
||||
# Broadcast to get the true weight.
|
||||
dist.broadcast(layer.weight,
|
||||
src=self.group.ranks[layer_idx %
|
||||
self.group.world_size],
|
||||
group=self.group.device_group)
|
||||
assert layer.layer is not None
|
||||
# Call `process_weights_after_loading` from the quant method.
|
||||
layer.post_method(layer.layer)
|
||||
step = layer_idx - self.start_layer
|
||||
if step < self.prefetch_step:
|
||||
# Build the windows for the first `prefetch_step` layers. The weights can be used for the first `prefetch_step` layers in `forward()`, so also clone the weights.
|
||||
self.shared_windows.append(
|
||||
SharedWindowMetadata(
|
||||
weight=layer.weight.clone().detach(),
|
||||
data_layer_idx=layer_idx,
|
||||
work=None,
|
||||
))
|
||||
layer.window_idx = step
|
||||
# When the layer not intended to be stored in this device, link to the corresponding window's tensor.
|
||||
if not is_source:
|
||||
layer.weight.set_(self.shared_windows[-1].weight)
|
||||
else:
|
||||
# Build one more window for prefetch. The weight is useless, so just keep the shape.
|
||||
if step == self.prefetch_step:
|
||||
self.shared_windows.append(
|
||||
SharedWindowMetadata(
|
||||
weight=torch.empty_like(layer.weight),
|
||||
data_layer_idx=-1,
|
||||
work=None,
|
||||
))
|
||||
# When the layer not intended to be stored in this device, dispose the tensor.
|
||||
if not is_source:
|
||||
dispose_tensor(layer.weight)
|
||||
|
||||
dispose_tensor(self.dummy_weight)
|
||||
|
||||
def reach_layer(self, layer_idx: int):
|
||||
# The index of the layer to be prefetched.
|
||||
next_layer_idx = (layer_idx + self.prefetch_step
|
||||
) % self.num_layers + self.start_layer
|
||||
next_layer = self.layers[next_layer_idx - self.start_layer]
|
||||
# The index of the window to store the weight for the coming layer.
|
||||
next_layer.window_idx = self.window_offset
|
||||
window = self.shared_windows[next_layer.window_idx]
|
||||
# When the layer not intended to be stored in this device, link to the corresponding window's tensor.
|
||||
if not self.is_source(next_layer_idx):
|
||||
next_layer.weight.set_(window.weight)
|
||||
# Update `window_offset` by rolling one step.
|
||||
self.window_offset = (self.window_offset + 1) % (self.prefetch_step +
|
||||
1)
|
||||
assert window.data_layer_idx != next_layer_idx
|
||||
window.data_layer_idx = next_layer_idx
|
||||
# Start asynchronous broadcast work.
|
||||
window.work = dist.broadcast(
|
||||
next_layer.weight,
|
||||
src=self.group.ranks[next_layer_idx % self.group.world_size],
|
||||
group=self.group.device_group,
|
||||
async_op=True)
|
||||
|
||||
def wait_weight(self, layer_idx: int):
|
||||
# Find the asynchronous broadcast work and wait for it.
|
||||
assert self.shared_windows
|
||||
window = self.shared_windows[self.layers[layer_idx -
|
||||
self.start_layer].window_idx]
|
||||
# Make sure the data in the corresponding shared window is for the current layer.
|
||||
assert window.data_layer_idx == layer_idx
|
||||
if window.work is not None:
|
||||
window.work.wait()
|
||||
window.work = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayerExternalMetadata:
|
||||
"""External metadata for a layer.
|
||||
"""
|
||||
series: SeriesMetadata
|
||||
layer_idx: int
|
||||
|
||||
|
||||
_series_dict: dict[str, SeriesMetadata] = {}
|
||||
|
||||
_layer_external_dict: dict[int, LayerExternalMetadata] = {}
|
||||
|
||||
|
||||
def _create_forward_wrapper(forward: Callable, series: SeriesMetadata,
|
||||
layer_idx: int) -> Callable:
|
||||
|
||||
def wrapped_forward(*args, **kwargs):
|
||||
# Wait for the weight.
|
||||
series.wait_weight(layer_idx)
|
||||
return forward(*args, **kwargs)
|
||||
|
||||
return wrapped_forward
|
||||
|
||||
|
||||
"""
|
||||
Register linear layers into a shared storage series.
|
||||
|
||||
In a parallel group, each device stores a distinct, non-overlapping subset of layers from the series. All layers in a series must have the same structure (are isomorphic). The weight matrix for the i-th layer is stored on device (i % n), where n is the number of devices.
|
||||
|
||||
After loading the model, you must call `post_process_after_loading_for_shared_weight_series(layer)` on any layer of this series to complete the initialization.
|
||||
|
||||
During execution, each time a new layer is reached, you must call `reach_layer_for_shared_weight_series(layer)` for that layer to prefetch the weights. The argument `prefetch_step` is a non-negative integer k that manages asynchronous weight prefetching. Each call to `reach_layer_for_shared_weight_series(current_layer)` method will trigger an asynchronous prefetch for the weights of the k-th subsequent layer after `current_layer` within the series.
|
||||
|
||||
Note: The layers are managed as a circular buffer. The index of the layer to prefetch is determined by the formula:
|
||||
- total_layers = end_layer - start_layer
|
||||
- prefetch_layer_idx = (layer_idx + prefetch_step) % total_layers + start_layer
|
||||
|
||||
To hold the weights for the current layer and the k prefetched layers, a pool of (k + 1) shared tensor buffers will be created for this series.
|
||||
|
||||
Arguments:
|
||||
series_name: This name identifies which series this layer belongs to.
|
||||
group: The group coordinator for handling asynchronous communications. It is recommended to create a new group coordinator for each new series.
|
||||
start_layer: The index of the first layer in the series (inclusive).
|
||||
end_layer: The index of the last layer in the series (exclusive). Thus, the series includes all layers with indices in the range [start_layer, end_layer).
|
||||
layer_idx: The index of the current layer.
|
||||
layer: The linear layer object to register.
|
||||
prefetch_step: An integer that manages asynchronous weight prefetching. Setting it to 0 or 1 can cover most cases.
|
||||
"""
|
||||
|
||||
|
||||
def register_layer_to_shared_weight_series(
|
||||
series_name: str,
|
||||
group: GroupCoordinator,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
layer_idx: int,
|
||||
layer: LinearBase,
|
||||
prefetch_step: int = 1,
|
||||
):
|
||||
global _series_dict
|
||||
if series_name not in _series_dict:
|
||||
num_layers = end_layer - start_layer
|
||||
assert num_layers > 0
|
||||
assert prefetch_step >= 0 and prefetch_step <= num_layers - 2
|
||||
_series_dict[series_name] = SeriesMetadata(
|
||||
group=group,
|
||||
start_layer=start_layer,
|
||||
end_layer=end_layer,
|
||||
num_layers=num_layers,
|
||||
prefetch_step=prefetch_step,
|
||||
dummy_weight=torch.empty_like(layer.weight),
|
||||
layers=[
|
||||
LayerMetadata(
|
||||
layer=None,
|
||||
post_method=lambda layer: None,
|
||||
weight=torch.empty([]),
|
||||
window_idx=-1,
|
||||
) for _ in range(num_layers)
|
||||
],
|
||||
shared_windows=[],
|
||||
window_offset=prefetch_step,
|
||||
)
|
||||
series = _series_dict[series_name]
|
||||
assert layer.quant_method is not None
|
||||
series.layers[layer_idx - start_layer] = LayerMetadata(
|
||||
layer=layer,
|
||||
post_method=layer.quant_method.process_weights_after_loading,
|
||||
weight=layer.weight,
|
||||
window_idx=-1,
|
||||
)
|
||||
# Discard the original `process_weights_after_loading` method such that it won't be called by others.
|
||||
layer.quant_method.process_weights_after_loading = lambda layer: None
|
||||
# When the layer not intended to be stored in this device, dispose the tensor and skip weight loading.
|
||||
if not series.is_source(layer_idx):
|
||||
dispose_tensor(layer.weight)
|
||||
layer.weight.weight_loader = lambda *args, **kwargs: None
|
||||
layer.forward = _create_forward_wrapper(layer.forward, series, layer_idx)
|
||||
global _layer_external_dict
|
||||
_layer_external_dict[id(layer)] = LayerExternalMetadata(
|
||||
series=series,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
|
||||
|
||||
def post_process_after_loading_for_shared_weight_series(layer: LinearBase):
|
||||
ext = _layer_external_dict[id(layer)]
|
||||
ext.series.post_process_after_loading()
|
||||
|
||||
|
||||
def reach_layer_for_shared_weight_series(layer: LinearBase):
|
||||
ext = _layer_external_dict[id(layer)]
|
||||
ext.series.reach_layer(ext.layer_idx)
|
||||
@@ -1,37 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def torchair_silu_and_mul_forward_oot(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""AscendSiluAndMul forward in torchair mode.
|
||||
|
||||
The key difference from the original implementation is the removal of operators
|
||||
from the torch.ops.vllm class, as these operators only function in non-torchair
|
||||
modes. Adding them back would cause the graph compilation to fail.
|
||||
"""
|
||||
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
|
||||
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16)
|
||||
else:
|
||||
out = torch_npu.npu_swiglu(x)
|
||||
return out
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,78 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
|
||||
_original_re_init = RMSNorm.__init__
|
||||
|
||||
|
||||
def torchair_rmsnorm_init_(
|
||||
self,
|
||||
hidden_size: int,
|
||||
eps: float = 1e-6,
|
||||
var_hidden_size: Optional[int] = None,
|
||||
has_weight: bool = True,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
_original_re_init(self, hidden_size, eps, var_hidden_size, has_weight,
|
||||
dtype)
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.bias = None
|
||||
# quantization with anti_method m4 will generate none-zero norm bias
|
||||
if vllm_config.quant_config is not None and \
|
||||
any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()):
|
||||
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
|
||||
requires_grad=False)
|
||||
|
||||
|
||||
def torchair_rmsnorm_forward_oot(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""AscendRMSNorm forward in torchair mode.
|
||||
|
||||
The key difference from the original implementation is the removal of operators
|
||||
from the torch.ops.vllm class, as these operators only function in non-torchair
|
||||
modes. Adding them back would cause the graph compilation to fail.
|
||||
"""
|
||||
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
|
||||
if residual is not None:
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
orig_dtype = residual.dtype
|
||||
x = x + residual.to(x.dtype)
|
||||
residual = x.to(orig_dtype)
|
||||
x, _ = torch_npu.npu_rms_norm(x, self.weight,
|
||||
self.variance_epsilon)
|
||||
else:
|
||||
x, _, residual = torch_npu.npu_add_rms_norm(
|
||||
x, residual, self.weight, self.variance_epsilon)
|
||||
if self.bias is not None:
|
||||
x.add_(self.bias)
|
||||
return x, residual
|
||||
|
||||
x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
|
||||
if self.bias is not None:
|
||||
x.add_(self.bias)
|
||||
return x
|
||||
@@ -1,367 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.utils import (AscendDeviceType, enable_custom_op,
|
||||
get_ascend_device_type)
|
||||
|
||||
|
||||
def custom_rotary_embedding_enabled(query, neox_style, head_size):
|
||||
return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and enable_custom_op(
|
||||
)
|
||||
|
||||
|
||||
def rope_forward_oot(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
is_neox_style_override: Optional[bool] = None,
|
||||
is_qwen_torchair: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if get_ascend_config(
|
||||
).torchair_graph_config.enabled and not is_qwen_torchair:
|
||||
return self.forward_native(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
offsets,
|
||||
)
|
||||
|
||||
query_shape, key_shape = query.shape, key.shape
|
||||
if self.cos_sin_cache.device != query.device:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
|
||||
if self.cos_sin_cache.dtype != query.dtype:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
|
||||
neox_style = self.is_neox_style
|
||||
if is_neox_style_override is not None:
|
||||
neox_style = is_neox_style_override
|
||||
# adopt custom kernel path for rotary_embedding
|
||||
if custom_rotary_embedding_enabled(
|
||||
query, neox_style, self.head_size) and get_ascend_device_type(
|
||||
) != AscendDeviceType._310P:
|
||||
query, key = torch.ops._C_ascend.rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
neox_style,
|
||||
)
|
||||
return query.view(query_shape), key.view(key_shape)
|
||||
if offsets is not None:
|
||||
raise NotImplementedError(
|
||||
"Batched rotary embedding is currently not supported on NPU.")
|
||||
else:
|
||||
# TODO: Remove the contiguous in the future.
|
||||
query = query.contiguous().view(query.shape[0], -1)
|
||||
key = key.contiguous().view(key.shape[0], -1)
|
||||
torch_npu._npu_rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
neox_style,
|
||||
)
|
||||
return query.view(query_shape), key.view(key_shape)
|
||||
|
||||
|
||||
def native_rope_deepseek_forward(self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None):
|
||||
if len(key.shape) == 2:
|
||||
key = key[:, None, :]
|
||||
# Note: we implement the non neox_style method with shuffle the last dim and neox style
|
||||
# calculation method which is also more compute friendly to the ascend machine
|
||||
# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py
|
||||
neox_style = True
|
||||
if self.is_neox_style is False:
|
||||
b, h_q, d = query.shape
|
||||
query = query.view(b, h_q, d // 2, 2).transpose(3,
|
||||
2).reshape(b, h_q, d)
|
||||
b, h_k, d = key.shape
|
||||
key = key.view(b, h_k, d // 2, 2).transpose(3, 2).reshape(b, h_k, d)
|
||||
q_pe, k_pe = rope_forward_oot(self, positions, query, key, offsets,
|
||||
neox_style)
|
||||
return q_pe, k_pe
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., :x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2:]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# Inverse dim formula to find dim based on number of rotations
|
||||
def yarn_find_correction_dim(num_rotations,
|
||||
dim,
|
||||
base=10000,
|
||||
max_position_embeddings=2048):
|
||||
# Note: use torch instead of math to solve MTP compilation error.
|
||||
return (dim * torch.log(
|
||||
torch.tensor(max_position_embeddings) /
|
||||
(num_rotations * 2 * torch.pi))) / (2 * torch.log(torch.tensor(base)))
|
||||
|
||||
|
||||
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
|
||||
# Find dim range bounds based on rotations
|
||||
def yarn_find_correction_range(low_rot,
|
||||
high_rot,
|
||||
dim,
|
||||
base=10000,
|
||||
max_position_embeddings=2048):
|
||||
# Note: use torch instead of math to solve MTP compilation error.
|
||||
low = torch.floor(
|
||||
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
|
||||
high = torch.ceil(
|
||||
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
|
||||
# Note: use torch instead of max/min to solve MTP compilation error.
|
||||
return torch.clamp(low, min=0), torch.clamp(high, max=dim - 1)
|
||||
|
||||
|
||||
def yarn_linear_ramp_mask(min_value, max_value, dim):
|
||||
# Note: The if conditional branch is not used here
|
||||
# to solve MTP compilation error.
|
||||
max_value += (min_value == max_value).float() * 0.001
|
||||
linear_func = (torch.arange(dim, dtype=torch.float32) -
|
||||
min_value) / (max_value - min_value)
|
||||
ramp_func = torch.clamp(linear_func, 0, 1)
|
||||
return ramp_func
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`):
|
||||
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
||||
used to pass offsetted position ids when working with a KV-cache.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos[position_ids]
|
||||
sin = sin[position_ids]
|
||||
cos = cos[:, None, None, :]
|
||||
sin = sin[:, None, None, :]
|
||||
|
||||
if len(q.shape) == 3:
|
||||
q = q[:, :, None, :]
|
||||
if len(k.shape) == 2:
|
||||
k = k[:, None, None, :]
|
||||
elif len(k.shape) == 3:
|
||||
k = k[:, :, None, :]
|
||||
|
||||
b, h_q, s, d = q.shape
|
||||
q = q.view(b, h_q, s, d // 2, 2).transpose(4, 3).reshape(b, h_q, s, d)
|
||||
|
||||
b, h_k, s, d = k.shape
|
||||
k = k.view(b, h_k, s, d // 2, 2).transpose(4, 3).reshape(b, h_k, s, d)
|
||||
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
|
||||
q_embed = q_embed.view(b, h_q, d)
|
||||
k_embed = k_embed.view(b, h_k, d)
|
||||
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def _set_cos_sin_cache(self, max_seq_len, device, dtype):
|
||||
dim = self.rotary_dim
|
||||
|
||||
freq_extra = 1.0 / (self.base**(
|
||||
torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
|
||||
freq_inter = 1.0 / (self.scaling_factor * self.base**(
|
||||
torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
|
||||
|
||||
low, high = yarn_find_correction_range(
|
||||
self.beta_fast,
|
||||
self.beta_slow,
|
||||
dim,
|
||||
self.base,
|
||||
self.max_position_embeddings,
|
||||
)
|
||||
inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
|
||||
device=device, dtype=torch.float32)
|
||||
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
t = torch.arange(max_seq_len, device=device, dtype=torch.float32)
|
||||
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale
|
||||
sin_cached = torch.cat([freqs, freqs], dim=-1).sin() * self.mscale
|
||||
cos_cached = cos_cached.to(dtype)
|
||||
sin_cached = sin_cached.to(dtype)
|
||||
cache = torch.cat([freqs.cos() * self.mscale,
|
||||
freqs.sin() * self.mscale],
|
||||
dim=-1).to(dtype)
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
self.register_buffer("cos_cached", cos_cached, persistent=False)
|
||||
self.register_buffer("sin_cached", sin_cached, persistent=False)
|
||||
|
||||
|
||||
def __set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
inv_freq = 1.0 / (self.base**(torch.arange(
|
||||
0, self.rotary_dim, 2, device=device, dtype=torch.float32) *
|
||||
(1 / self.rotary_dim)))
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
|
||||
t = torch.arange(self.max_position_embeddings,
|
||||
device=self.inv_freq.device,
|
||||
dtype=torch.float32)
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos", emb.cos().to(dtype=dtype), persistent=False)
|
||||
self.register_buffer("sin", emb.sin().to(dtype=dtype), persistent=False)
|
||||
self.embed = F.embedding
|
||||
|
||||
|
||||
_original_re_init = RotaryEmbedding.__init__
|
||||
|
||||
|
||||
def qwen_rope_init_func(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: float,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
_original_re_init(self, head_size, rotary_dim, max_position_embeddings,
|
||||
base, is_neox_style, dtype)
|
||||
if get_ascend_config().torchair_graph_config.enabled:
|
||||
__set_cos_sin_cache(self,
|
||||
seq_len=max_position_embeddings,
|
||||
device="npu",
|
||||
dtype=dtype)
|
||||
|
||||
|
||||
def rope_forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
is_neox_style_override: Optional[bool] = None,
|
||||
max_seq_len: Optional[int] = None,
|
||||
is_prefill: Optional[bool] = True,
|
||||
is_qwen_torchair: Optional[bool] = False,
|
||||
):
|
||||
if get_ascend_config().torchair_graph_config.enabled \
|
||||
and is_qwen_torchair and not is_prefill:
|
||||
if max_seq_len is not None and torch.gt(max_seq_len,
|
||||
self.max_position_embeddings):
|
||||
__set_cos_sin_cache(self,
|
||||
seq_len=max_seq_len,
|
||||
device=query.device,
|
||||
dtype=torch.float32)
|
||||
|
||||
# bsnd/bnsd
|
||||
if positions is not None:
|
||||
cos = self.embed(positions, self.cos)
|
||||
sin = self.embed(positions, self.sin)
|
||||
self.cos_embed = cos
|
||||
self.sin_embed = sin
|
||||
else:
|
||||
cos = self.cos_embed
|
||||
sin = self.sin_embed
|
||||
|
||||
query = query.view(*query.shape[:-1], -1, self.head_size).contiguous()
|
||||
key = key.view(*key.shape[:-1], -1, self.head_size).contiguous()
|
||||
|
||||
cos = cos.unsqueeze(-2).unsqueeze(-2)
|
||||
sin = sin.unsqueeze(-2).unsqueeze(-2)
|
||||
|
||||
query = query.unsqueeze(1)
|
||||
key = key.unsqueeze(1)
|
||||
|
||||
q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb(
|
||||
query, key, cos, sin)
|
||||
return q_embed.flatten(-2), k_embed.flatten(-2)
|
||||
else:
|
||||
return rope_forward_oot(self, positions, query, key, offsets,
|
||||
is_neox_style_override,
|
||||
is_qwen_torchair) # type: ignore
|
||||
|
||||
|
||||
def deepseek_rope_init_func(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
scaling_factor: float,
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
extrapolation_factor: float = 1,
|
||||
attn_factor: float = 1,
|
||||
beta_fast: int = 32,
|
||||
beta_slow: int = 1,
|
||||
mscale: float = 1,
|
||||
mscale_all_dim: float = 0,
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
self.extrapolation_factor = extrapolation_factor
|
||||
self.attn_factor = attn_factor
|
||||
self.beta_fast = beta_fast
|
||||
self.beta_slow = beta_slow
|
||||
# Get n-d magnitude scaling corrected for interpolation.
|
||||
self.mscale = float(
|
||||
yarn_get_mscale(self.scaling_factor, float(mscale)) /
|
||||
yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
|
||||
attn_factor)
|
||||
super(DeepseekScalingRotaryEmbedding,
|
||||
self).__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||
is_neox_style, dtype)
|
||||
|
||||
# NOTE: For ascend friendly computing, reorder sin and cos cache
|
||||
self.max_seq_len = math.ceil(max_position_embeddings * scaling_factor)
|
||||
_set_cos_sin_cache(self, self.max_seq_len, dtype=dtype, device="npu")
|
||||
@@ -1,38 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
|
||||
|
||||
def vocab_embedding_forward(self, input_):
|
||||
if self.tp_size > 1:
|
||||
# Build the mask.
|
||||
masked_input, input_mask = self._get_masked_input_and_mask(
|
||||
input_, self.shard_indices.org_vocab_start_index,
|
||||
self.shard_indices.org_vocab_end_index,
|
||||
self.shard_indices.num_org_vocab_padding,
|
||||
self.shard_indices.added_vocab_start_index,
|
||||
self.shard_indices.added_vocab_end_index)
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = self.quant_method.embedding(self, masked_input.long())
|
||||
# Mask the output embedding.
|
||||
if self.tp_size > 1:
|
||||
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
return output
|
||||
@@ -1,501 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import FusedMoEState
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.torchair.ops.torchair_fused_moe import torchair_select_experts
|
||||
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import (
|
||||
torchair_fused_experts_with_all2all, torchair_fused_experts_with_mc2)
|
||||
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
|
||||
|
||||
|
||||
class TorchairAscendW4A8DynamicLinearMethod:
|
||||
"""Linear method for Ascend W4A8_DYNAMIC
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.transpose_weight = True
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.group_size = vllm_config.quant_config.quant_description.get(
|
||||
"group_size", 256)
|
||||
quant_version = vllm_config.quant_config.quant_description.get(
|
||||
"version", "0")
|
||||
self.new_quant_version = quant_version == "1.0.0"
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
def get_weight(self, input_size: int, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
params_dict = {}
|
||||
|
||||
if self.new_quant_version:
|
||||
pack_factor = 2
|
||||
actual_output_size = output_size // pack_factor
|
||||
params_dict["weight"] = torch.empty(actual_output_size,
|
||||
input_size,
|
||||
dtype=torch.int8)
|
||||
params_dict["_packed_dim"] = 0
|
||||
params_dict["_packed_factor"] = pack_factor
|
||||
else:
|
||||
params_dict["weight"] = torch.empty(output_size,
|
||||
input_size,
|
||||
dtype=torch.int8)
|
||||
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def get_perchannel_param(output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
def get_pergroup_param(self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
layer_type: Optional[str] = None) -> Dict[str, Any]:
|
||||
params_dict = {}
|
||||
params_dict["weight_scale"] = torch.empty(output_size,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
params_dict["weight_offset"] = torch.empty(output_size,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
params_dict["weight_scale_second"] = torch.empty(output_size,
|
||||
input_size //
|
||||
self.group_size,
|
||||
dtype=params_dtype)
|
||||
params_dict["weight_offset_second"] = torch.empty(output_size,
|
||||
input_size //
|
||||
self.group_size,
|
||||
dtype=params_dtype)
|
||||
|
||||
if self.new_quant_version:
|
||||
scale_bias_dim = 16 if layer_type == "row" else 1
|
||||
params_dict["scale_bias"] = torch.empty(output_size,
|
||||
scale_bias_dim,
|
||||
dtype=torch.float32)
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def process_scale_second(weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
per_group_scale: torch.Tensor,
|
||||
is_new_quant: bool = False):
|
||||
k, n = weight.shape
|
||||
group_num, n_scale = per_group_scale.shape
|
||||
|
||||
if is_new_quant:
|
||||
n = n * 2
|
||||
|
||||
bias = None
|
||||
if not is_new_quant:
|
||||
weight_high = weight.to(torch.float32).reshape(
|
||||
group_num, -1, n) * per_group_scale.reshape(group_num, 1, n)
|
||||
weight_high = weight_high.reshape(k, n)
|
||||
bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0)
|
||||
|
||||
antiquant_scale = (scale * per_group_scale).reshape(group_num, n)
|
||||
return antiquant_scale.npu(), bias
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
tp_rank: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch_npu.npu_weight_quant_batchmatmul(
|
||||
x,
|
||||
layer.weight,
|
||||
antiquant_scale=layer.weight_scale_second.to(x.dtype),
|
||||
antiquant_group_size=self.group_size,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
if self.transpose_weight:
|
||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||
layer.weight_scale.data = layer.weight_scale.data.flatten().to(
|
||||
torch.float32)
|
||||
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
||||
layer.weight_scale_second.data, scale_bias = self.process_scale_second(
|
||||
layer.weight.data,
|
||||
layer.weight_scale.data,
|
||||
layer.weight_scale_second.data.transpose(0, 1).contiguous(),
|
||||
is_new_quant=self.new_quant_version,
|
||||
)
|
||||
|
||||
if self.new_quant_version:
|
||||
if hasattr(layer, "scale_bias"):
|
||||
if layer.scale_bias.data.shape[1] == 1:
|
||||
layer.scale_bias.data = layer.scale_bias.data.flatten()
|
||||
else:
|
||||
layer.scale_bias.data = layer.scale_bias.data.contiguous()
|
||||
else:
|
||||
if scale_bias is not None:
|
||||
param = torch.nn.Parameter(scale_bias, requires_grad=False)
|
||||
layer.register_parameter("weight_scale_bias", param)
|
||||
|
||||
if self.new_quant_version:
|
||||
assert layer.weight.data.shape[-1] % 4 == 0, \
|
||||
f"the last dim of weight needs to be divided by 4, got shape {layer.weight.data.shape}"
|
||||
layer.weight.data = layer.weight.data.view(
|
||||
torch.int32).contiguous()
|
||||
else:
|
||||
layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(
|
||||
layer.weight.data.to(torch.int32))
|
||||
|
||||
|
||||
class TorchairAscendW4A8DynamicFusedMoEMethod:
|
||||
"""FusedMoe method for Ascend W4A8_DYNAMIC.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.transpose_weight = True
|
||||
|
||||
self.ep_group = get_ep_group()
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.group_size = vllm_config.quant_config.quant_description.get(
|
||||
"group_size", 256)
|
||||
# NOTE: the weights are quantized from bf16 to int4 through a per-channel quantization process
|
||||
self.is_per_channel_weight = self.group_size == 0
|
||||
quant_version = vllm_config.quant_config.quant_description.get(
|
||||
"version", "0")
|
||||
# NOTE: new quantize weights: 2 int4 pack into int8
|
||||
self.new_quant_version = quant_version == "1.0.0"
|
||||
self.tp_size = 1 if vllm_config.parallel_config.enable_expert_parallel else self.ep_group.world_size
|
||||
if self.new_quant_version and self.tp_size > 16:
|
||||
raise ValueError(
|
||||
"The current weight does not support moe part tp>16.")
|
||||
|
||||
try:
|
||||
device_group = get_mc2_group().device_group
|
||||
# TODO: Try local_rank = ep_group.rank_in_group
|
||||
local_rank = torch.distributed.get_rank(group=device_group)
|
||||
backend = device_group._get_backend(torch.device("npu"))
|
||||
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
|
||||
local_rank)
|
||||
except AttributeError:
|
||||
self.moe_all_to_all_group_name = ""
|
||||
|
||||
def get_weight(self, num_experts: int,
|
||||
intermediate_size_per_partition: int, hidden_sizes: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
param_dict = {}
|
||||
if self.new_quant_version:
|
||||
w13_output_size = intermediate_size_per_partition
|
||||
w2_output_size = hidden_sizes // 2
|
||||
else:
|
||||
w13_output_size = 2 * intermediate_size_per_partition
|
||||
w2_output_size = hidden_sizes
|
||||
|
||||
param_dict["w13_weight"] = torch.empty(num_experts,
|
||||
w13_output_size,
|
||||
hidden_sizes,
|
||||
dtype=torch.int8)
|
||||
param_dict["w2_weight"] = torch.empty(num_experts,
|
||||
w2_output_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.int8)
|
||||
return param_dict
|
||||
|
||||
def get_dynamic_quant_param(self, num_experts: int,
|
||||
intermediate_size_per_partition: int,
|
||||
hidden_sizes: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
param_dict = {}
|
||||
param_dict["w13_weight_scale"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
|
||||
param_dict["w13_weight_offset"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
|
||||
param_dict["w2_weight_scale"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
param_dict["w2_weight_offset"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
|
||||
if not self.is_per_channel_weight:
|
||||
param_dict["w13_weight_scale_second"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_sizes // self.group_size,
|
||||
dtype=torch.float32)
|
||||
param_dict["w13_weight_offset_second"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_sizes // self.group_size,
|
||||
dtype=torch.float32)
|
||||
|
||||
param_dict["w2_weight_scale_second"] = torch.empty(
|
||||
num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition // self.group_size,
|
||||
dtype=torch.float32)
|
||||
param_dict["w2_weight_offset_second"] = torch.empty(
|
||||
num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition // self.group_size,
|
||||
dtype=torch.float32)
|
||||
|
||||
if self.new_quant_version:
|
||||
param_dict["w13_scale_bias"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
param_dict["w2_scale_bias"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
16 // self.tp_size,
|
||||
dtype=torch.float32)
|
||||
|
||||
return param_dict
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
enable_force_load_balance: bool = True,
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
|
||||
|
||||
if global_num_experts == 256:
|
||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||
router_logits,
|
||||
k=top_k, # topk currently is 8
|
||||
bias=e_score_correction_bias,
|
||||
k_group=topk_group, # fix: 4
|
||||
group_count=num_expert_group, # fix 8
|
||||
group_select_mode=
|
||||
1, # 0: the maximum in the group; 1: topk2.sum(fix)
|
||||
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
|
||||
norm_type=1, # 0: softmax; 1: sigmoid(fix)
|
||||
# out_flag=False, # todo new api; should the third output be output
|
||||
# y2_flag=False, # old api; should the third output be output
|
||||
routed_scaling_factor=1,
|
||||
eps=float(1e-20))
|
||||
else:
|
||||
topk_weights, topk_ids = torchair_select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
fused_moe_state = get_forward_context().fused_moe_state
|
||||
shared_gate_up, shared_dequant_scale = None, None
|
||||
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
|
||||
with npu_stream_switch("moe_secondary", 0):
|
||||
npu_wait_tensor(quantized_x_for_share, router_logits)
|
||||
share_up_out, _ = shared_experts.gate_up_proj(
|
||||
(quantized_x_for_share, dynamic_scale_for_share))
|
||||
shared_gate_up, shared_dequant_scale = share_up_out[
|
||||
0], share_up_out[1]
|
||||
|
||||
# this is a naive implementation for experts load balance so as
|
||||
# to avoid accumulating too much tokens on a single rank.
|
||||
# currently it is only activated when doing profile runs.
|
||||
if enable_force_load_balance:
|
||||
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
||||
|
||||
topk_weights = topk_weights.to(x.dtype)
|
||||
if fused_moe_state == FusedMoEState.MC2:
|
||||
return torchair_fused_experts_with_mc2(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w1_scale_bias=layer.w13_scale_bias,
|
||||
w2_scale_bias=layer.w2_scale_bias,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map,
|
||||
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
|
||||
log2phy=log2phy,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
shared_experts=shared_experts,
|
||||
is_torchair=self.torchair_graph_enabled,
|
||||
quantized_x_for_share=shared_gate_up,
|
||||
dynamic_scale_for_share=shared_dequant_scale,
|
||||
mc2_mask=kwargs.get("mc2_mask", None),
|
||||
dynamic_eplb=self.dynamic_eplb)
|
||||
else:
|
||||
# The current implementation of deepseek moe splits hidden_states
|
||||
# according to tp_size before they are feed into layers module.
|
||||
# Therefore, all2all is needed no matter how dp/tp is set so as to
|
||||
# dispatch/combine tokens.
|
||||
return torchair_fused_experts_with_all2all(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w1_scale_bias=layer.w13_scale_bias,
|
||||
w2_scale_bias=layer.w2_scale_bias,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map,
|
||||
ep_group=self.ep_group,
|
||||
log2phy=log2phy,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
)
|
||||
|
||||
def process_scale(self, weight: torch.Tensor, scale, per_group_scale):
|
||||
scale = scale.transpose(1, 2).contiguous()
|
||||
if self.is_per_channel_weight:
|
||||
scale_np = scale.cpu().numpy()
|
||||
scale_np.dtype = np.uint32
|
||||
scale_uint64_tensor = torch.from_numpy(scale_np.astype(
|
||||
np.int64)).npu()
|
||||
return scale_uint64_tensor, None
|
||||
per_group_scale = per_group_scale.transpose(1, 2).contiguous()
|
||||
group_num, k, n = weight.shape
|
||||
# the weight of the new version is reduced by half by pack n, so it needs to be restored
|
||||
if self.new_quant_version:
|
||||
n = n * 2
|
||||
per_group_scale = per_group_scale.reshape(group_num, -1, n)
|
||||
group_num, quantgroup_num, n = per_group_scale.shape
|
||||
bias = None
|
||||
if not self.new_quant_version:
|
||||
weight_high = weight.to(torch.float32).reshape([group_num, quantgroup_num, -1, n]) * \
|
||||
per_group_scale.reshape([group_num, quantgroup_num, 1, n])
|
||||
weight_high = weight_high.reshape([group_num, k, n])
|
||||
bias = 8 * (weight_high.to(torch.float32) * scale).sum(axis=1)
|
||||
scale_fp32 = (scale * per_group_scale).to(torch.float16).to(
|
||||
torch.float32)
|
||||
scale_fp32_np = scale_fp32.cpu().numpy()
|
||||
scale_fp32_np.dtype = np.uint32
|
||||
sscale_uint64 = np.zeros((group_num, quantgroup_num, n * 2),
|
||||
dtype=np.uint32)
|
||||
|
||||
sscale_uint64[..., ::2] = scale_fp32_np
|
||||
|
||||
sscale_uint64_buffer = np.frombuffer(sscale_uint64.tobytes(),
|
||||
dtype=np.int64).copy()
|
||||
sscale_uint64_tensor = torch.from_numpy(sscale_uint64_buffer).reshape(
|
||||
group_num, quantgroup_num, n)
|
||||
sscale_uint64_tensor = sscale_uint64_tensor.npu()
|
||||
return sscale_uint64_tensor, bias
|
||||
|
||||
def update_bias(self, layer, w13_bias, w2_bias):
|
||||
if self.new_quant_version:
|
||||
layer.w13_scale_bias.data = layer.w13_scale_bias.data.transpose(
|
||||
1, 2).contiguous().sum(axis=1)
|
||||
layer.w2_scale_bias.data = layer.w2_scale_bias.data.transpose(
|
||||
1, 2).contiguous().sum(axis=1)
|
||||
else:
|
||||
w13_scale_bias = torch.nn.Parameter(w13_bias, requires_grad=False)
|
||||
layer.register_parameter("w13_scale_bias", w13_scale_bias)
|
||||
w2_scale_bias = torch.nn.Parameter(w2_bias, requires_grad=False)
|
||||
layer.register_parameter("w2_scale_bias", w2_scale_bias)
|
||||
|
||||
def pack_to_int32(self, weight: torch.Tensor):
|
||||
if self.new_quant_version:
|
||||
# pack 4 int8(int4*2) to int32, because in pytorch, we need to use int32 to represent int4
|
||||
assert weight.shape[
|
||||
-1] % 4 == 0, "the last dim of weight needs to be divided by 4"
|
||||
return weight.view(torch.int32).contiguous()
|
||||
else:
|
||||
return torch_npu.npu_quantize(weight.to(torch.float32),
|
||||
torch.tensor([1.]).npu(), None,
|
||||
torch.quint4x2, -1, False)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if self.transpose_weight:
|
||||
layer.w13_weight.data = layer.w13_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
w13_weight_scale_second = layer.w13_weight_scale_second.data if hasattr(
|
||||
layer, "w13_weight_scale_second") else None
|
||||
w2_weight_scale_second = layer.w2_weight_scale_second.data if hasattr(
|
||||
layer, "w2_weight_scale_second") else None
|
||||
layer.w13_weight_scale.data, w13_bias = self.process_scale(
|
||||
layer.w13_weight, layer.w13_weight_scale.data,
|
||||
w13_weight_scale_second)
|
||||
layer.w2_weight_scale.data, w2_bias = self.process_scale(
|
||||
layer.w2_weight, layer.w2_weight_scale.data,
|
||||
w2_weight_scale_second)
|
||||
if hasattr(layer, "w13_weight_scale_second"):
|
||||
# scale_second is no longer used, release this part of the memory
|
||||
del layer.w13_weight_scale_second
|
||||
del layer.w2_weight_scale_second
|
||||
del layer.w13_weight_offset_second
|
||||
del layer.w2_weight_offset_second
|
||||
|
||||
self.update_bias(layer, w13_bias, w2_bias)
|
||||
|
||||
layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data)
|
||||
layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,457 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer,
|
||||
AttentionType)
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils.math_utils import cdiv
|
||||
|
||||
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
|
||||
AscendAttentionMetadataBuilder,
|
||||
AscendAttentionState,
|
||||
AscendMetadata)
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendDeviceType,
|
||||
aligned_16, get_ascend_device_type, nd_to_nz_2d)
|
||||
|
||||
|
||||
class AscendAttentionTorchairBackend(AscendAttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ASCEND_TORCHAIR"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["AscendAttentionTorchairBackendImpl"]:
|
||||
return AscendAttentionTorchairBackendImpl
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["AscendAttentionTorchairMetadataBuilder"]:
|
||||
return AscendAttentionTorchairMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (2, num_blocks, block_size, num_kv_heads * head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_bsh_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (2, num_blocks, block_size, num_kv_heads * head_size)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendDecodeMetadata:
|
||||
# Input positions for rotrary embeddings since for MLA the rotary
|
||||
# position embeddings are applied inside the attention backend
|
||||
input_positions: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
max_seq_lens: int
|
||||
seq_lens_list: list[int]
|
||||
attn_mask: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendTorchairMetadata(AscendMetadata):
|
||||
|
||||
decode: Optional[AscendDecodeMetadata] = None
|
||||
|
||||
|
||||
class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec,
|
||||
layer_names,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
self.max_num_blocks_per_req = cdiv(
|
||||
self.model_config.max_model_len,
|
||||
self.vllm_config.cache_config.block_size)
|
||||
self.max_blocks = (self.model_config.max_model_len +
|
||||
self.vllm_config.cache_config.block_size -
|
||||
1) // self.vllm_config.cache_config.block_size
|
||||
|
||||
def _get_graph_runner_block_tables(
|
||||
self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor:
|
||||
max_blocks = self.max_blocks
|
||||
|
||||
graph_block_tables = torch.zeros((num_seqs, max_blocks),
|
||||
dtype=block_tables.dtype,
|
||||
device=block_tables.device)
|
||||
|
||||
num_blocks = block_tables.size(1)
|
||||
if num_blocks <= max_blocks:
|
||||
graph_block_tables[:num_seqs, :
|
||||
num_blocks] = block_tables[:num_seqs, :
|
||||
num_blocks]
|
||||
else:
|
||||
graph_block_tables[:num_seqs, :
|
||||
max_blocks] = block_tables[:num_seqs, :
|
||||
max_blocks]
|
||||
|
||||
return graph_block_tables[:, :max_blocks]
|
||||
|
||||
def build_torchair_graph_dummy(
|
||||
self, common_attn_metadata: TorchairCommonAttentionMetadata
|
||||
) -> AscendTorchairMetadata:
|
||||
device = self.device
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
block_table = torch.zeros((num_reqs, self.max_blocks),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
block_table = self._get_graph_runner_block_tables(
|
||||
num_reqs, block_table)
|
||||
seq_lens = torch.ones(num_reqs, dtype=torch.int32, device=device)
|
||||
input_positions = torch.zeros(num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=device).long()
|
||||
slot_mapping = torch.full((num_reqs, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
query_start_loc = torch.full((num_reqs, ),
|
||||
-1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
decode_metadata = AscendDecodeMetadata(input_positions=input_positions,
|
||||
block_table=block_table,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_list=seq_lens.tolist(),
|
||||
max_seq_lens=1)
|
||||
|
||||
attn_metadata = AscendTorchairMetadata(
|
||||
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||
block_tables=block_table,
|
||||
query_lens=0,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_lens=seq_lens,
|
||||
slot_mapping=slot_mapping,
|
||||
attn_state=AscendAttentionState.DecodeOnly,
|
||||
decode=decode_metadata)
|
||||
return attn_metadata
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
model: Optional[nn.Module] = None,
|
||||
):
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
|
||||
block_table = common_attn_metadata.block_table_tensor
|
||||
block_table[:num_reqs, :self.max_num_blocks_per_req] = (
|
||||
block_table[:num_reqs])
|
||||
|
||||
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
||||
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
|
||||
attn_mask = common_attn_metadata.attn_mask
|
||||
|
||||
attn_state = common_attn_metadata.attn_state
|
||||
if get_ascend_device_type(
|
||||
) == AscendDeviceType._310P and attn_state == AscendAttentionState.PrefillNoCache:
|
||||
mask_nz = nd_to_nz_2d(attn_mask)
|
||||
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), 29)
|
||||
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
|
||||
num_reqs
|
||||
+ 1]
|
||||
query_start_loc = query_start_loc_cpu.to(self.device,
|
||||
non_blocking=True)
|
||||
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
input_positions = common_attn_metadata.positions[:
|
||||
num_actual_tokens].long(
|
||||
)
|
||||
|
||||
decode_metadata = None
|
||||
graph_pad_size = common_attn_metadata.graph_pad_size
|
||||
use_torchair_graph = graph_pad_size > -1
|
||||
if common_attn_metadata.attn_state in [
|
||||
AscendAttentionState.DecodeOnly,
|
||||
]:
|
||||
max_seq_lens = seq_lens.max().item()
|
||||
num_seqs = len(seq_lens)
|
||||
if use_torchair_graph and common_attn_metadata.attn_state in [
|
||||
AscendAttentionState.DecodeOnly,
|
||||
]:
|
||||
num_reqs_pad_size = 0
|
||||
num_token_pad_size = 0
|
||||
if graph_pad_size != 0:
|
||||
pad_value = 0
|
||||
num_token_pad_size = graph_pad_size - num_actual_tokens
|
||||
num_reqs_pad_size = (
|
||||
graph_pad_size //
|
||||
common_attn_metadata.decode_token_per_req - num_reqs)
|
||||
pad_value = 1
|
||||
padded_seq_lens = seq_lens.tolist() + [pad_value
|
||||
] * num_reqs_pad_size
|
||||
|
||||
seq_lens = torch.from_numpy(
|
||||
np.array(padded_seq_lens).astype(np.int32))
|
||||
padding = torch.full((num_token_pad_size, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=slot_mapping.dtype,
|
||||
device=slot_mapping.device)
|
||||
slot_mapping = torch.cat([slot_mapping, padding])
|
||||
block_table_padding = torch.zeros(
|
||||
(num_reqs_pad_size, ) + block_table.shape[1:],
|
||||
dtype=block_table.dtype,
|
||||
device=block_table.device)
|
||||
block_table = torch.cat([block_table, block_table_padding],
|
||||
dim=0)
|
||||
block_table = self._get_graph_runner_block_tables(
|
||||
num_seqs + num_reqs_pad_size, block_table)
|
||||
padding_0 = torch.zeros(num_token_pad_size,
|
||||
dtype=input_positions.dtype,
|
||||
device=input_positions.device)
|
||||
input_positions = torch.cat([input_positions, padding_0])
|
||||
|
||||
decode_metadata = AscendDecodeMetadata(
|
||||
input_positions=input_positions,
|
||||
block_table=block_table,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_list=seq_lens.tolist(),
|
||||
max_seq_lens=max_seq_lens,
|
||||
attn_mask=None)
|
||||
|
||||
attn_metadata = AscendTorchairMetadata(
|
||||
decode=decode_metadata,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
block_tables=block_table,
|
||||
query_start_loc=query_start_loc,
|
||||
query_lens=query_lens,
|
||||
seq_lens=seq_lens,
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
slot_mapping=slot_mapping,
|
||||
attn_mask=attn_mask,
|
||||
attn_state=attn_state)
|
||||
return attn_metadata
|
||||
|
||||
|
||||
class AscendAttentionTorchairBackendImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
self.hidden_size = self.num_heads * self.head_size
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.sliding_window = sliding_window
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes,
|
||||
dtype=torch.float32,
|
||||
device="npu")
|
||||
self.alibi_slopes = alibi_slopes
|
||||
self.attn_type = attn_type
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.key_cache = None
|
||||
self.value_cache = None
|
||||
self.scale_tensor = torch.zeros((), device='npu', dtype=torch.int32)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AscendTorchairMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with Ascend attention.
|
||||
Args:
|
||||
query: shape = [batch_size, seq_len, num_heads * head_size]
|
||||
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
kv_cache: shape = [2, num_blocks, block_size,
|
||||
num_kv_heads, head_size]
|
||||
key_cache = [num_blocks, block_size,
|
||||
num_kv_heads, head_size]
|
||||
value_cache = [num_blocks, block_size,
|
||||
num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [batch_size * seq_len, num_heads, head_size]
|
||||
"""
|
||||
num_tokens = query.shape[0]
|
||||
use_kv_cache_quant = (kv_cache is not None and len(kv_cache) > 0
|
||||
and kv_cache[0].numel() > 0
|
||||
and kv_cache[0].dtype == torch.int8)
|
||||
if output is None:
|
||||
output = torch.empty(num_tokens,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
|
||||
if hasattr(layer, 'quant_method') and use_kv_cache_quant:
|
||||
output = layer.quant_method.apply(layer, query, key, value,
|
||||
kv_cache, attn_metadata,
|
||||
self.attn_type, self.scale,
|
||||
output)
|
||||
return output.view(num_tokens, self.hidden_size)
|
||||
|
||||
if attn_metadata is None:
|
||||
return output.view(num_tokens, self.hidden_size).fill_(0)
|
||||
|
||||
output = output.view(-1, self.num_heads, self.head_size)
|
||||
|
||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||
attn_type = self.attn_type
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"AscendAttentionTorchairBackendImpl")
|
||||
|
||||
if kv_cache is not None and kv_cache[0].numel() > 0:
|
||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||
slots = attn_metadata.slot_mapping
|
||||
|
||||
block_size = self.scale_tensor + key_cache.shape[1]
|
||||
slots_indices = slots.reshape(-1, 1)
|
||||
block_indices = slots_indices // block_size
|
||||
slots_indices = slots_indices % block_size
|
||||
indices = torch.cat((block_indices, slots_indices), dim=1)
|
||||
torch_npu.npu_scatter_nd_update_(key_cache, indices, key)
|
||||
torch_npu.npu_scatter_nd_update_(value_cache, indices, value)
|
||||
if attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
|
||||
self.key_cache = key_cache
|
||||
self.value_cache = value_cache
|
||||
|
||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
assert attn_metadata is not None
|
||||
assert attn_metadata.attn_mask is not None
|
||||
mask = attn_metadata.attn_mask
|
||||
|
||||
# View q k v to BSH.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
# align q k v output tensors
|
||||
query = aligned_16(query)
|
||||
key = aligned_16(key)
|
||||
value = aligned_16(value)
|
||||
output = aligned_16(output)
|
||||
|
||||
# do reformat in case of broadcasted tensors
|
||||
mask = mask.repeat(attn_metadata.seq_lens.size(0), 1, 1, 1)
|
||||
mask = torch_npu.npu_format_cast(mask.contiguous(),
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
torch_npu._npu_flash_attention(query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
mask=mask,
|
||||
seq_len=attn_metadata.seq_lens,
|
||||
scale_value=self.scale,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
out=output)
|
||||
output = output[:num_tokens, :, :]
|
||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
|
||||
assert attn_metadata is not None
|
||||
assert attn_metadata.attn_mask is not None
|
||||
compress_mask = attn_metadata.attn_mask
|
||||
batch_size = attn_metadata.query_lens.shape[0]
|
||||
block_table = attn_metadata.block_tables[:batch_size, :]
|
||||
torch_npu._npu_flash_attention_qlens(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
block_table=block_table,
|
||||
mask=compress_mask,
|
||||
seq_len=attn_metadata.query_lens,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
out=output)
|
||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
decode_meta = attn_metadata.decode
|
||||
assert decode_meta is not None
|
||||
seq_lens = decode_meta.seq_lens_list
|
||||
block_table = decode_meta.block_table
|
||||
block_size = key_cache.shape[1]
|
||||
query = query.view(num_tokens, 1,
|
||||
self.num_heads * self.head_size).contiguous()
|
||||
output, _ = torch_npu.npu_fused_infer_attention_score(
|
||||
query=query,
|
||||
key=key_cache,
|
||||
value=value_cache,
|
||||
query_rope=None,
|
||||
key_rope=None,
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
input_layout='BSH',
|
||||
atten_mask=decode_meta.attn_mask,
|
||||
sparse_mode=0,
|
||||
scale=self.scale,
|
||||
antiquant_mode=0,
|
||||
antiquant_scale=None,
|
||||
block_table=block_table,
|
||||
block_size=block_size,
|
||||
actual_seq_lengths_kv=seq_lens,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Torchair graph mode with non-MLA attention backend is still experimental."
|
||||
"v1 scheduler(chunked prefill) is not supported at this moment."
|
||||
)
|
||||
|
||||
return output.view(num_tokens, self.hidden_size)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,574 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
|
||||
# isort: skip_file
|
||||
|
||||
import math
|
||||
import types
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import logger
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.spec_decode import get_spec_decode_method
|
||||
from vllm_ascend.torchair.utils import (
|
||||
TORCHAIR_CACHE_DIR, TorchairCommonAttentionMetadata,
|
||||
check_torchair_cache_exist, converting_weight_acl_format,
|
||||
register_torchair_model, torchair_ops_patch,
|
||||
torchair_quant_method_register, write_kv_cache_bytes_to_file)
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
AscendDeviceType, get_ascend_device_type)
|
||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
|
||||
|
||||
class NPUTorchairModelRunner(NPUModelRunner):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
||||
self.ascend_config = get_ascend_config()
|
||||
self.enable_shared_expert_dp = self.ascend_config.enable_shared_expert_dp
|
||||
super().__init__(vllm_config, device)
|
||||
if self.speculative_config:
|
||||
self.actual_seq_lengths_q = list(
|
||||
range(self.decode_token_per_req, self.max_num_tokens + 1,
|
||||
self.decode_token_per_req))
|
||||
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
|
||||
None, None, vllm_config, device)
|
||||
self.use_sparse = hasattr(self.model_config.hf_config, "index_topk")
|
||||
|
||||
register_torchair_model()
|
||||
torchair_ops_patch()
|
||||
torchair_quant_method_register()
|
||||
if self.enable_shared_expert_dp:
|
||||
return
|
||||
self.new_kv_cache_bytes = -1
|
||||
self.torchair_compiled_model = None # type: ignore
|
||||
self.torchair_compiled_models = {} # type: ignore
|
||||
self.use_cached_npu_graph = self.ascend_config.torchair_graph_config.use_cached_graph
|
||||
self.use_cached_kv_cache_bytes = self.ascend_config.torchair_graph_config.use_cached_kv_cache_bytes
|
||||
self.torchair_graph_batch_sizes = self.ascend_config.torchair_graph_config.graph_batch_sizes
|
||||
if self.ascend_config.torchair_graph_config.graph_batch_sizes_init:
|
||||
self.init_torchair_graph_batch_sizes()
|
||||
|
||||
self.update_torchair_graph_batch_sizes()
|
||||
|
||||
torch._dynamo.cache_size.config.cache_size_limit += len(
|
||||
self.torchair_graph_batch_sizes)
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
torch._logging.set_logs(
|
||||
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
|
||||
|
||||
self._check_batch_sizes_consistency()
|
||||
|
||||
def _set_up_drafter(self):
|
||||
super()._set_up_drafter()
|
||||
if self.speculative_config:
|
||||
# Torchair do not support disable_padded_drafter_batch
|
||||
# Enforce to disable this feature
|
||||
self.speculative_config.disable_padded_drafter_batch = True
|
||||
|
||||
def _get_drafter(self):
|
||||
return get_spec_decode_method(self.speculative_config.method,
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
self,
|
||||
is_torchair_graph=True)
|
||||
|
||||
def _may_pad_kv_consumer_num_seq(self):
|
||||
# pd disaggregation scenario need redundant_batch_sizes to avoid each batch's seq_len exceed 16 tokens
|
||||
# self.max_num_reqs here is greater than the actual maximum request number
|
||||
if self.decode_token_per_req > 1 and self.is_kv_consumer:
|
||||
# applied only when speculative decoding is active
|
||||
FIA_SEQ_LEN_LIMIT = 16
|
||||
new_max_num_reqs = self.max_num_reqs + math.ceil(
|
||||
self.max_num_reqs / FIA_SEQ_LEN_LIMIT) + math.ceil(
|
||||
(self.max_num_reqs * self.decode_token_per_req) /
|
||||
(FIA_SEQ_LEN_LIMIT**2))
|
||||
if self.max_num_reqs < new_max_num_reqs:
|
||||
logger.warning(
|
||||
f"max_num_reqs is updated to {new_max_num_reqs}")
|
||||
self.max_num_reqs = new_max_num_reqs
|
||||
|
||||
def _init_mc2_tokens_capacity(self):
|
||||
# NOTE: To be clear, we need to make sure that during graph capture, the number of
|
||||
# tokens is less than or equal to mc2_tokens_capacity. According to _set_cudagraph_sizes,
|
||||
# the max number of tokens in graph is min(max_num_seqs * uniform_decode_query_len, 512).
|
||||
max_num_tokens = self.max_num_reqs * self.uniform_decode_query_len
|
||||
tp_size = self.parallel_config.tensor_parallel_size
|
||||
# Use integer arithmetic for ceiling division.
|
||||
max_graph_batch_size = self.calculate_new_torchair_graph_batch_size(
|
||||
max_num_tokens, tp_size)
|
||||
self.mc2_tokens_capacity = max_graph_batch_size
|
||||
|
||||
if get_ascend_device_type(
|
||||
) == AscendDeviceType._910_93 and self.mc2_tokens_capacity > 512:
|
||||
logger.error(
|
||||
f"A3: the max number of tokens must smaller then 512, but now is {self.mc2_tokens_capacity}"
|
||||
)
|
||||
if get_ascend_device_type(
|
||||
) == AscendDeviceType._910B and self.mc2_tokens_capacity > 256:
|
||||
logger.error(
|
||||
f"A2: the max number of tokens must smaller then 256, but now is {self.mc2_tokens_capacity}"
|
||||
)
|
||||
|
||||
def _sync_metadata_across_dp(
|
||||
self, num_tokens: int,
|
||||
with_prefill: bool) -> tuple[int, Optional[torch.Tensor], bool]:
|
||||
"""Override from NPUModelRunner to pad num_tokens"""
|
||||
if self.enable_shared_expert_dp:
|
||||
# Padding is not required for shared_expert_dp cases in eager mode.
|
||||
return num_tokens, None, with_prefill
|
||||
if self.dp_size == 1:
|
||||
if not with_prefill:
|
||||
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
|
||||
num_tokens)
|
||||
return maybe_padded_num_tokens, None, with_prefill
|
||||
return num_tokens, None, with_prefill
|
||||
|
||||
num_tokens_across_dp = torch.zeros(self.dp_size + 1,
|
||||
dtype=torch.int32,
|
||||
device="npu")
|
||||
num_tokens_across_dp[self.dp_rank] = num_tokens
|
||||
num_tokens_across_dp[-1] = int(with_prefill)
|
||||
dist.all_reduce(num_tokens_across_dp,
|
||||
group=get_dp_group().device_group)
|
||||
with_prefill = bool(num_tokens_across_dp[-1])
|
||||
num_tokens_across_dp = num_tokens_across_dp[:-1]
|
||||
|
||||
if not with_prefill:
|
||||
max_num_token = num_tokens_across_dp.max().item()
|
||||
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
|
||||
max_num_token)
|
||||
num_tokens_across_dp = torch.full((self.dp_size, ),
|
||||
maybe_padded_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device="npu")
|
||||
else:
|
||||
maybe_padded_num_tokens = num_tokens
|
||||
|
||||
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill
|
||||
|
||||
def _build_dummy_attn_metadata(
|
||||
self,
|
||||
with_prefill: bool,
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
max_query_len: int,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
|
||||
force_attention: bool = False,
|
||||
) -> Optional[dict[str, Any]]:
|
||||
# NOTE: If torchair graph mode and not with_prefill,
|
||||
# we can't skip_attn, it will cause graph recompile.
|
||||
if with_prefill or self.enable_shared_expert_dp:
|
||||
attn_metadata = super()._build_dummy_attn_metadata(
|
||||
with_prefill, num_reqs, num_tokens, max_query_len,
|
||||
num_scheduled_tokens, aclgraph_runtime_mode, force_attention)
|
||||
else:
|
||||
common_attn_metadata = TorchairCommonAttentionMetadata(
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=1,
|
||||
actual_seq_lengths_q=self.actual_seq_lengths_q,
|
||||
attn_mask=self.attn_mask,
|
||||
spec_attn_mask=self.spec_attn_mask,
|
||||
decode_token_per_req=self.decode_token_per_req,
|
||||
)
|
||||
attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy(
|
||||
common_attn_metadata)
|
||||
return attn_metadata
|
||||
|
||||
def _generate_dummy_run_hidden_states(self, with_prefill,
|
||||
is_torchair_compile, input_ids,
|
||||
positions, attn_metadata, num_tokens,
|
||||
intermediate_tensors, inputs_embeds):
|
||||
if with_prefill or self.enable_shared_expert_dp:
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND)
|
||||
hidden_states = super()._generate_dummy_run_hidden_states(
|
||||
with_prefill, is_torchair_compile, input_ids, positions,
|
||||
attn_metadata, num_tokens, intermediate_tensors, inputs_embeds)
|
||||
else:
|
||||
# Only mark static while compiling
|
||||
if is_torchair_compile:
|
||||
torch._dynamo.mark_static(input_ids)
|
||||
torch._dynamo.mark_static(positions)
|
||||
torch._dynamo.mark_static(attn_metadata.decode.block_table)
|
||||
torch._dynamo.mark_static(attn_metadata.decode.input_positions)
|
||||
torch._dynamo.mark_static(get_forward_context().mc2_mask)
|
||||
if hasattr(attn_metadata.decode, "sin"):
|
||||
torch._dynamo.mark_static(attn_metadata.decode.sin)
|
||||
torch._dynamo.mark_static(attn_metadata.decode.cos)
|
||||
torch._dynamo.mark_static(attn_metadata.slot_mapping)
|
||||
if self.speculative_config:
|
||||
torch._dynamo.mark_static(attn_metadata.decode.attn_mask)
|
||||
for kv in self.kv_caches:
|
||||
assert isinstance(kv, tuple), "kv_cache must be a tuple"
|
||||
torch._dynamo.mark_static(kv[0])
|
||||
torch._dynamo.mark_static(kv[1])
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
compiled_model = self._get_torchair_lazy_compiled_model(num_tokens)
|
||||
model_kwargs = {}
|
||||
model_kwargs["kv_caches"] = self.kv_caches
|
||||
model_kwargs["attn_metadata"] = attn_metadata
|
||||
hidden_states = compiled_model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=None,
|
||||
**model_kwargs,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def _convert_torch_format(self, kv_cache):
|
||||
if self.enable_shared_expert_dp:
|
||||
return super()._convert_torch_format(kv_cache)
|
||||
kv_cache = torch_npu.npu_format_cast(kv_cache, ACL_FORMAT_FRACTAL_ND)
|
||||
return kv_cache
|
||||
|
||||
def _compile_torchair_graph(self, torchair_graph_batch_sizes) -> None:
|
||||
# Trigger torchair graph capture for specific shapes.
|
||||
# Capture the large shapes first so that the smaller shapes
|
||||
# can reuse the memory pool allocated for the large shapes.
|
||||
for idx, num_tokens in enumerate(reversed(torchair_graph_batch_sizes)):
|
||||
for _ in range(self.vllm_config.compilation_config.
|
||||
cudagraph_num_of_warmups):
|
||||
self._dummy_run(num_tokens, is_torchair_compile=True)
|
||||
self._dummy_run(num_tokens, is_torchair_compile=True)
|
||||
logger.info("Batchsize %d is compiled successfully: %d/%d.",
|
||||
num_tokens, idx + 1, len(torchair_graph_batch_sizes))
|
||||
|
||||
def _capture_model(self):
|
||||
"""Override from NPUModelRunner to use torchair graph capture."""
|
||||
if self.enable_shared_expert_dp:
|
||||
return super()._capture_model()
|
||||
# TODO(NeverRaR): Calling graph_capture(device=self.device) in
|
||||
# torchair graph capture can cause some issues, so now we just
|
||||
# temporarily split the codepath for the two different graph patterns.
|
||||
torchair_graph_batch_sizes = self.torchair_graph_batch_sizes
|
||||
graph_num = len(torchair_graph_batch_sizes)
|
||||
|
||||
if self.use_cached_npu_graph and not check_torchair_cache_exist():
|
||||
# If caching is enabled but does not exist (either
|
||||
# use_cached_kv_cache_bytes is disabled or kv_cache_bytes are
|
||||
# different), we will compile the model twice. The first time is
|
||||
# used to generate the cache, and the second time is used to load the
|
||||
# cache to skip the overhead caused by Dynamo guard mechanism.
|
||||
logger.info(
|
||||
"Cache compilation for torchair graph is enabled. Now we compile graph to genetate"
|
||||
" torchair cache, this usually takes %.1f~%.1f mins.",
|
||||
0.5 * graph_num, 1.5 * graph_num)
|
||||
self._compile_torchair_graph(torchair_graph_batch_sizes)
|
||||
NPUPlatform.synchronize()
|
||||
# Note: We reset dynamo and reload the compiled torchair cached computation graph below
|
||||
# that was compiled above. This operation reduces graph launch time by 2-4ms and avoids
|
||||
# runtime errors caused by configuration mismatches in graph mode.
|
||||
torch._dynamo.reset()
|
||||
self.torchair_compiled_models.clear()
|
||||
if self.use_cached_npu_graph:
|
||||
logger.info(
|
||||
"Loading torchair graph cache, this usually takes %.1f~%.1f mins.",
|
||||
0.3 * graph_num, 0.5 * graph_num)
|
||||
self._compile_torchair_graph(torchair_graph_batch_sizes)
|
||||
else:
|
||||
logger.info(
|
||||
"Capturing torchair graph, this usually takes %.1f~%.1f mins.",
|
||||
0.5 * graph_num, 1.5 * graph_num)
|
||||
self._compile_torchair_graph(torchair_graph_batch_sizes)
|
||||
|
||||
if self.use_cached_kv_cache_bytes and self.new_kv_cache_bytes > 0:
|
||||
write_kv_cache_bytes_to_file(torch.distributed.get_rank(),
|
||||
self.new_kv_cache_bytes)
|
||||
|
||||
def _use_aclgraph(self) -> bool:
|
||||
if self.enable_shared_expert_dp:
|
||||
return super()._use_aclgraph()
|
||||
return False
|
||||
|
||||
def _check_batch_sizes_consistency(self) -> None:
|
||||
if not dist.is_initialized():
|
||||
return
|
||||
|
||||
local = torch.tensor(self.torchair_graph_batch_sizes,
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
gathered_graph_batch_size = local.clone()
|
||||
dist.all_reduce(gathered_graph_batch_size,
|
||||
group=get_dp_group().cpu_group)
|
||||
expected = local * self.dp_size
|
||||
|
||||
if not torch.equal(gathered_graph_batch_size, expected):
|
||||
diff_idxs = (gathered_graph_batch_size != expected).nonzero(
|
||||
as_tuple=False).flatten().tolist()
|
||||
raise AssertionError(
|
||||
f"[Graph BatchSize Mismatch] Found mismatches at indices {diff_idxs}.\n"
|
||||
f"Local (rank {self.dp_rank}): {local.tolist()}\n"
|
||||
f"Sum over ranks: {gathered_graph_batch_size.tolist()}\n"
|
||||
f"Expected if all equal: {[v * self.dp_size for v in local.tolist()]}"
|
||||
)
|
||||
|
||||
def _update_graph_pad_size(self, with_prefill, graph_pad_size):
|
||||
if with_prefill or self.enable_shared_expert_dp:
|
||||
super()._update_graph_pad_size(with_prefill, graph_pad_size)
|
||||
else:
|
||||
self.graph_pad_size = graph_pad_size
|
||||
|
||||
def _update_input_ids_and_positions(self, input_ids, positions,
|
||||
num_input_tokens, with_prefill,
|
||||
padded_num_tokens_across_dp):
|
||||
"""Override from NPUModelRunner to update input_ids and positions"""
|
||||
input_ids, positions = super()._update_input_ids_and_positions(
|
||||
input_ids, positions, num_input_tokens, with_prefill,
|
||||
padded_num_tokens_across_dp)
|
||||
|
||||
if with_prefill or self.enable_shared_expert_dp:
|
||||
return input_ids, positions
|
||||
else:
|
||||
input_ids = self.input_ids[:padded_num_tokens_across_dp]
|
||||
positions = self.positions[:padded_num_tokens_across_dp]
|
||||
return input_ids, positions
|
||||
|
||||
def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
|
||||
padded_num_tokens_across_dp,
|
||||
input_ids, positions,
|
||||
intermediate_tensors,
|
||||
inputs_embeds):
|
||||
if attn_metadata is not None and isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata['model.layers.0.self_attn.attn']
|
||||
|
||||
if self.enable_shared_expert_dp:
|
||||
return super()._generate_process_reqs_hidden_states(
|
||||
attn_metadata, with_prefill, padded_num_tokens_across_dp,
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds)
|
||||
model_kwargs = {
|
||||
"kv_caches": self.kv_caches,
|
||||
"attn_metadata": attn_metadata
|
||||
}
|
||||
if not with_prefill:
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_NZ)
|
||||
compiled_model = self._get_torchair_lazy_compiled_model(
|
||||
padded_num_tokens_across_dp)
|
||||
hidden_states = compiled_model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
assert self.model is not None
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**model_kwargs,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def _get_torchair_lazy_compiled_model(self, batch_size: int):
|
||||
if batch_size < 0 or batch_size > self.torchair_graph_batch_sizes[-1]:
|
||||
raise ValueError(
|
||||
f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.torchair_graph_batch_sizes[-1]}"
|
||||
)
|
||||
|
||||
compiled_model = self.torchair_compiled_models.get(
|
||||
batch_size
|
||||
) if self.use_cached_npu_graph else self.torchair_compiled_model
|
||||
|
||||
if compiled_model:
|
||||
return compiled_model
|
||||
|
||||
import torchair # type: ignore
|
||||
from torchair import patch_for_hcom # type: ignore
|
||||
|
||||
patch_for_hcom()
|
||||
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
# on 300I Duo platform, we need to patch broadcast. however, this patch will be
|
||||
# overwritten by patch_for_hcom in torchair. so we need to re-patch it here.
|
||||
from vllm_ascend.patch.platform.patch_distributed import \
|
||||
communication_adaptation_310p
|
||||
communication_adaptation_310p()
|
||||
|
||||
config = torchair.CompilerConfig()
|
||||
if self.ascend_config.torchair_graph_config.mode:
|
||||
config.mode = self.ascend_config.torchair_graph_config.mode
|
||||
config.experimental_config.frozen_parameter = \
|
||||
self.ascend_config.torchair_graph_config.enable_frozen_parameter
|
||||
# enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to
|
||||
# disable it on 300I Duo platform now.
|
||||
config.experimental_config.tiling_schedule_optimize = get_ascend_device_type(
|
||||
) != AscendDeviceType._310P
|
||||
config.experimental_config.enable_view_optimize = \
|
||||
self.ascend_config.torchair_graph_config.enable_view_optimize
|
||||
torch.npu.set_compile_mode(jit_compile=False)
|
||||
if not self.use_cached_npu_graph:
|
||||
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
||||
self.torchair_compiled_model = torch.compile(
|
||||
self.model,
|
||||
dynamic=not self.use_sparse,
|
||||
fullgraph=True,
|
||||
backend=npu_backend)
|
||||
return self.torchair_compiled_model
|
||||
else:
|
||||
# Generate a new forward proxy code object to prevent the invalidation of
|
||||
# compilation cache caused by dynamo retracing
|
||||
forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}"
|
||||
forward_fn = self.model.forward
|
||||
code = forward_fn.__code__
|
||||
# Mark code object with a new proxy name
|
||||
modified_code = code.replace(co_name=forward_proxy_name, )
|
||||
|
||||
modified_func = types.FunctionType(modified_code,
|
||||
forward_fn.__globals__,
|
||||
name=forward_proxy_name,
|
||||
argdefs=forward_fn.__defaults__)
|
||||
|
||||
self.model.__dict__[forward_proxy_name] = modified_func.__get__(
|
||||
self.model, nn.Module)
|
||||
self.torchair_compiled_models[
|
||||
batch_size] = torchair.inference.cache_compile(
|
||||
self.model.__dict__[forward_proxy_name],
|
||||
dynamic=not self.use_sparse,
|
||||
fullgraph=True,
|
||||
cache_dir=TORCHAIR_CACHE_DIR,
|
||||
config=config,
|
||||
ge_cache=False)
|
||||
return self.torchair_compiled_models[batch_size]
|
||||
|
||||
def init_torchair_graph_batch_sizes(self):
|
||||
start_graph_batch_size = 4
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
# NOTE: When use all2all | mc2, We need to slice the `num_tokens` dimension into `tp_size` blocks
|
||||
start_graph_batch_size = max(start_graph_batch_size, tp_size)
|
||||
|
||||
while (start_graph_batch_size <= self.max_num_reqs):
|
||||
self.torchair_graph_batch_sizes.append(start_graph_batch_size)
|
||||
start_graph_batch_size *= 2
|
||||
|
||||
def calculate_new_torchair_graph_batch_size(self, old_graph_batch_size,
|
||||
tp_size):
|
||||
cur_graph_batch_size = (old_graph_batch_size + tp_size -
|
||||
1) // tp_size * tp_size
|
||||
# MTP > 1: Cal LCMLeast Common Multiple with graph_batch_size and tp_size,
|
||||
# Both adapter multi-dp and FIA operator
|
||||
if self.speculative_config is not None and self.speculative_config.num_speculative_tokens > 1:
|
||||
cur_graph_batch_size = (tp_size * old_graph_batch_size) \
|
||||
// math.gcd(tp_size, old_graph_batch_size)
|
||||
return cur_graph_batch_size
|
||||
|
||||
def select_torchair_padded_batch_size(self, batch_size: int):
|
||||
for padded_batch_size in self.torchair_graph_batch_sizes:
|
||||
if batch_size <= padded_batch_size:
|
||||
# we treat batch_size as num of requests
|
||||
return padded_batch_size
|
||||
raise ValueError(
|
||||
f"cur batch_size is invalid, torchair_graph_batch_sizes is "
|
||||
f"{self.torchair_graph_batch_sizes}, but cur batch_size is {batch_size}."
|
||||
)
|
||||
|
||||
def update_torchair_graph_batch_sizes(self):
|
||||
# return graph_batch_sizes according to the max number of tokens
|
||||
# first pad according to the number of requests
|
||||
if self.is_kv_consumer and self.speculative_config and self.speculative_config.method == 'mtp':
|
||||
# pd disaggregation scenario may incorrectly calculate the batch in mtp scenario, so we force set it to max_num_reqs
|
||||
self.torchair_graph_batch_sizes = [self.max_num_reqs]
|
||||
logger.warning(
|
||||
f"is kv_consumer, torch_graph_batch_sizes sets to [max_num_seqs] {[self.max_num_reqs]}"
|
||||
)
|
||||
elif len(self.torchair_graph_batch_sizes) == 0:
|
||||
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
|
||||
else:
|
||||
self.torchair_graph_batch_sizes = sorted(
|
||||
self.torchair_graph_batch_sizes)
|
||||
while self.torchair_graph_batch_sizes[-1] > self.max_num_reqs:
|
||||
self.torchair_graph_batch_sizes.pop()
|
||||
if len(self.torchair_graph_batch_sizes) == 0:
|
||||
logger.warning(
|
||||
"torch_graph_batch_sizes is invalid, reset it to [1, max_num_seqs]"
|
||||
)
|
||||
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
|
||||
if self.torchair_graph_batch_sizes[-1] < self.max_num_reqs:
|
||||
self.torchair_graph_batch_sizes.append(self.max_num_reqs)
|
||||
|
||||
# padded max number tokens = max_num_req * decode_token_per_req
|
||||
self.torchair_graph_batch_sizes = [
|
||||
graph_batch_size * self.decode_token_per_req
|
||||
for graph_batch_size in self.torchair_graph_batch_sizes
|
||||
]
|
||||
|
||||
# NOTE: when enable_expert_parallel on A3, we need to check if `graph_batch_size` is divisible by `tp_size`
|
||||
# Because we use x_active_mask for dispatch/combine op on A3, which requires that input shape should be same
|
||||
# on all EP ranks
|
||||
if get_ascend_device_type(
|
||||
) == AscendDeviceType._910_93 and self.parallel_config.enable_expert_parallel:
|
||||
self._align_graph_size_divisible_by_tp_size()
|
||||
|
||||
def _align_graph_size_divisible_by_tp_size(self):
|
||||
tp_size = self.parallel_config.tensor_parallel_size
|
||||
new_graph_batch_sizes = []
|
||||
for graph_batch_size in self.torchair_graph_batch_sizes:
|
||||
cur_graph_batch_size = self.calculate_new_torchair_graph_batch_size(
|
||||
graph_batch_size, tp_size)
|
||||
if cur_graph_batch_size not in new_graph_batch_sizes and \
|
||||
cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens:
|
||||
new_graph_batch_sizes.append(cur_graph_batch_size)
|
||||
elif cur_graph_batch_size > self.scheduler_config.max_num_batched_tokens \
|
||||
and self.decode_token_per_req > 1:
|
||||
logger.warning(
|
||||
f"torchair_graph_batch_sizes {cur_graph_batch_size} is bigger than max_num_batched_tokens",
|
||||
f"{self.scheduler_config.max_num_batched_tokens} will skip this batch size."
|
||||
)
|
||||
new_max_num_reqs = math.ceil(
|
||||
max(new_graph_batch_sizes) / self.decode_token_per_req)
|
||||
if self.max_num_reqs != new_max_num_reqs:
|
||||
logger.warning(f"max_num_reqs is updated to {new_max_num_reqs}")
|
||||
self.max_num_reqs = new_max_num_reqs
|
||||
if not (self.decode_token_per_req > 1 and self.is_kv_consumer):
|
||||
# Do not update scheduler_config.max_num_seqs in KV consumer + MTP
|
||||
# Since FIA need extra space for padding
|
||||
# Enforce self.max_num_seqs > self.scheduler_config.max_num_seqs in KV consumer + MTP
|
||||
self.scheduler_config.max_num_seqs = new_max_num_reqs
|
||||
|
||||
if new_graph_batch_sizes != self.torchair_graph_batch_sizes:
|
||||
logger.warning(
|
||||
f"torchair_graph_batch_sizes are updated to {new_graph_batch_sizes}."
|
||||
)
|
||||
self.torchair_graph_batch_sizes = new_graph_batch_sizes
|
||||
|
||||
def _build_drafter_prepare_inputs_torchair_param(self):
|
||||
if self.enable_shared_expert_dp:
|
||||
return super()._build_drafter_prepare_inputs_torchair_param()
|
||||
else:
|
||||
return True
|
||||
@@ -1,543 +0,0 @@
|
||||
import types
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchair
|
||||
from torchair import patch_for_hcom
|
||||
from vllm.config import (CUDAGraphMode, VllmConfig,
|
||||
get_layers_from_vllm_config, set_current_vllm_config)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.model_executor.model_loader.utils import \
|
||||
process_weights_after_loading
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.spec_decode import MtpProposer
|
||||
from vllm_ascend.torchair.models.torchair_deepseek_mtp import \
|
||||
TorchairDeepSeekMTP
|
||||
from vllm_ascend.torchair.utils import (TORCHAIR_CACHE_DIR,
|
||||
TorchairCommonAttentionMetadata)
|
||||
from vllm_ascend.utils import ProfileExecuteDuration, lmhead_tp_enable
|
||||
|
||||
PADDING_SLOT_ID = -1
|
||||
|
||||
|
||||
class TorchairMtpProposer(MtpProposer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device,
|
||||
runner,
|
||||
):
|
||||
super().__init__(vllm_config, device, runner)
|
||||
self.torchair_compiled_model = None # type: ignore
|
||||
self.torchair_compiled_models = {} # type: ignore
|
||||
|
||||
def load_model(self, model) -> None:
|
||||
loader = get_model_loader(self.vllm_config.load_config)
|
||||
|
||||
target_attn_layer_names = set(
|
||||
get_layers_from_vllm_config(self.vllm_config,
|
||||
AttentionLayerBase).keys())
|
||||
draft_model_config = \
|
||||
self.vllm_config.speculative_config.draft_model_config
|
||||
target_device = self.vllm_config.device_config.device
|
||||
|
||||
with set_default_torch_dtype(
|
||||
draft_model_config.dtype), set_current_vllm_config(
|
||||
self.vllm_config):
|
||||
|
||||
self.model = TorchairDeepSeekMTP(
|
||||
vllm_config=self.vllm_config).to(target_device)
|
||||
|
||||
draft_attn_layer_names = (get_layers_from_vllm_config(
|
||||
self.vllm_config, AttentionLayerBase).keys() -
|
||||
target_attn_layer_names)
|
||||
|
||||
assert len(draft_attn_layer_names) == 1
|
||||
self.attn_layer_name = list(draft_attn_layer_names)
|
||||
|
||||
self.model.load_weights(
|
||||
loader.get_all_weights(
|
||||
self.vllm_config.speculative_config.draft_model_config,
|
||||
self.model))
|
||||
process_weights_after_loading(self.model, draft_model_config,
|
||||
target_device)
|
||||
|
||||
@torch.inference_mode()
|
||||
def dummy_run(self,
|
||||
num_tokens: int,
|
||||
with_prefill: bool = False,
|
||||
skip_attn: bool = False,
|
||||
num_reqs: int = 0,
|
||||
num_tokens_across_dp=None,
|
||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
batch_descriptor=None,
|
||||
dummy_compute_logits=lambda hidden_states: None) -> None:
|
||||
moe_comm_type = self.runner._select_moe_comm_method(num_tokens)
|
||||
|
||||
if not with_prefill:
|
||||
skip_attn = False
|
||||
if skip_attn:
|
||||
attn_metadata = None
|
||||
else:
|
||||
common_attn_metadata = TorchairCommonAttentionMetadata(
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=1,
|
||||
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
||||
attn_mask=self.runner.attn_mask,
|
||||
spec_attn_mask=self.runner.spec_attn_mask,
|
||||
decode_token_per_req=self.runner.decode_token_per_req,
|
||||
)
|
||||
attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy(
|
||||
common_attn_metadata)
|
||||
|
||||
input_ids = self.input_ids[:num_tokens]
|
||||
positions = self.positions[:num_tokens]
|
||||
previous_hidden_states = self.hidden_states[:num_tokens]
|
||||
for _ in range(self.num_speculative_tokens):
|
||||
with set_ascend_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
with_prefill=with_prefill,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
reserved_mc2_mask=self.runner.reserved_mc2_mask,
|
||||
moe_comm_type=moe_comm_type,
|
||||
in_profile_run=self.runner.in_profile_run,
|
||||
num_actual_tokens=0):
|
||||
if not with_prefill:
|
||||
assert attn_metadata is not None
|
||||
torch._dynamo.mark_static(input_ids)
|
||||
torch._dynamo.mark_static(positions)
|
||||
torch._dynamo.mark_static(previous_hidden_states)
|
||||
torch._dynamo.mark_static(attn_metadata.decode.block_table)
|
||||
torch._dynamo.mark_static(
|
||||
attn_metadata.decode.input_positions)
|
||||
if hasattr(attn_metadata.decode, "sin"):
|
||||
torch._dynamo.mark_static(attn_metadata.decode.sin)
|
||||
torch._dynamo.mark_static(attn_metadata.decode.cos)
|
||||
torch._dynamo.mark_static(get_forward_context().mc2_mask)
|
||||
torch._dynamo.mark_static(attn_metadata.slot_mapping)
|
||||
torch._dynamo.mark_static(attn_metadata.decode.attn_mask)
|
||||
torchair_compiled_model = self._get_torchair_lazy_compiled_model(
|
||||
num_tokens)
|
||||
torchair_compiled_model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
hidden_states=previous_hidden_states,
|
||||
inputs_embeds=None,
|
||||
intermediate_tensors=None,
|
||||
attn_metadata=attn_metadata,
|
||||
kv_caches=self.runner.kv_caches[-1:],
|
||||
spec_step_idx=0)
|
||||
else:
|
||||
self.model(input_ids=input_ids,
|
||||
positions=positions,
|
||||
hidden_states=previous_hidden_states)
|
||||
dummy_compute_logits(previous_hidden_states)
|
||||
if with_prefill:
|
||||
break
|
||||
|
||||
def generate_token_ids(self,
|
||||
valid_sampled_token_ids: list[list[int]],
|
||||
sampling_metadata: SamplingMetadata = None,
|
||||
scheduler_output: SchedulerOutput = None,
|
||||
spec_decode_metadata: SpecDecodeMetadata = None,
|
||||
positions: torch.Tensor = None,
|
||||
num_scheduled_tokens: int = 0,
|
||||
hidden_states: torch.Tensor = None,
|
||||
attn_metadata=None,
|
||||
aux_hidden_states: torch.Tensor = None):
|
||||
if attn_metadata is not None and isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata['model.layers.0.self_attn.attn']
|
||||
next_token_ids: list[int] = []
|
||||
for i, token_ids in enumerate(valid_sampled_token_ids):
|
||||
if token_ids:
|
||||
# Common case.
|
||||
next_token_id = token_ids[-1]
|
||||
else:
|
||||
# Partial prefill (rare case).
|
||||
# Get the next token id from the request state.
|
||||
req_id = self.runner.input_batch.req_ids[i]
|
||||
req_state = self.runner.requests[req_id]
|
||||
seq_len = (req_state.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
next_token_id = req_state.get_token_id(seq_len)
|
||||
next_token_ids.append(next_token_id)
|
||||
next_token_ids = torch.tensor(next_token_ids,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
accepted_token_indices = None
|
||||
if spec_decode_metadata is None:
|
||||
# input_ids can be None for multimodal models.
|
||||
target_token_ids = self.runner.input_ids[:num_scheduled_tokens]
|
||||
target_positions = positions[:num_scheduled_tokens]
|
||||
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
target_slot_mapping = attn_metadata.slot_mapping
|
||||
cu_num_tokens = attn_metadata.query_start_loc
|
||||
else:
|
||||
# TODO(woosuk): Refactor this.
|
||||
num_draft_tokens = spec_decode_metadata.num_draft_tokens
|
||||
num_rejected_tokens = [
|
||||
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
|
||||
for i, n in enumerate(num_draft_tokens)
|
||||
]
|
||||
num_rejected_tokens = torch.tensor(
|
||||
num_rejected_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
cu_num_tokens, accepted_token_indices, target_token_ids, \
|
||||
target_positions, target_hidden_states, target_slot_mapping = self._torchair_prepare_inputs(
|
||||
attn_metadata.query_start_loc,
|
||||
num_rejected_tokens,
|
||||
self.runner.input_ids[:num_scheduled_tokens],
|
||||
positions[:num_scheduled_tokens],
|
||||
hidden_states[:num_scheduled_tokens],
|
||||
attn_metadata.slot_mapping[:num_scheduled_tokens],
|
||||
)
|
||||
|
||||
draft_token_ids = self._propose_torchair(
|
||||
target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
target_slot_mapping=target_slot_mapping,
|
||||
next_token_ids=next_token_ids,
|
||||
cu_num_tokens=cu_num_tokens,
|
||||
block_table=attn_metadata.block_tables,
|
||||
sampling_metadata=sampling_metadata,
|
||||
token_indices=accepted_token_indices)
|
||||
spec_token_ids = draft_token_ids.tolist()
|
||||
return spec_token_ids
|
||||
|
||||
def _torchair_prepare_inputs(
|
||||
self,
|
||||
# [batch_size + 1]
|
||||
cu_target_query_lens: torch.Tensor,
|
||||
# [batch_size]
|
||||
num_rejected_tokens: torch.Tensor,
|
||||
token_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
|
||||
torch.Tensor, torch.Tensor]:
|
||||
# cu_target_query_lens: [0, a, a + b, a + b + c]
|
||||
# num_rejected_tokens: [n1, n2, n3]
|
||||
# num_tokens_per_req: [a - n1, b - n2, c - n3]
|
||||
# cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
|
||||
# token_indices: [0, 1, ..., a - n1 - 1,
|
||||
# a, a + 1, ..., a + b - n2 - 1,
|
||||
# a + b, a + b + 1, ..., a + b + c - n3 - 1]
|
||||
# [0, a, a + b, a + b + c] -> [a, b, c]
|
||||
query_len_per_req = (cu_target_query_lens[1:] -
|
||||
cu_target_query_lens[:-1])
|
||||
# [a, b, c] -> [a - n1, b - n2, c - n3]
|
||||
|
||||
cu_num_tokens = cu_target_query_lens
|
||||
relative_index = query_len_per_req - num_rejected_tokens - 1
|
||||
token_indices = cu_num_tokens[:-1] + relative_index
|
||||
# the seq len of each bath is padded to 1+num_speculative_tokens, thus input is same as the main model
|
||||
target_token_ids = token_ids
|
||||
target_positions = positions
|
||||
target_hidden_states = hidden_states
|
||||
target_slot_mapping = slot_mapping
|
||||
|
||||
return cu_num_tokens, token_indices, target_token_ids, target_positions, target_hidden_states, target_slot_mapping
|
||||
|
||||
def _propose_torchair(
|
||||
self,
|
||||
# [num_tokens]
|
||||
target_token_ids: torch.Tensor,
|
||||
# [num_tokens]
|
||||
target_positions: torch.Tensor,
|
||||
# [num_tokens, hidden_size]
|
||||
target_hidden_states: torch.Tensor,
|
||||
# [num_tokens]
|
||||
target_slot_mapping: torch.Tensor,
|
||||
# [batch_size]
|
||||
next_token_ids: torch.Tensor,
|
||||
# [batch_size + 1] starting with 0
|
||||
cu_num_tokens: torch.Tensor,
|
||||
# [batch_size, max_num_blocks_per_req]
|
||||
block_table: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
token_indices=None) -> torch.Tensor:
|
||||
num_tokens = target_token_ids.shape[0]
|
||||
batch_size = next_token_ids.shape[0]
|
||||
last_token_indices = cu_num_tokens[1:] - 1
|
||||
|
||||
# Shift the input ids by one token.
|
||||
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
|
||||
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
|
||||
# Replace the last token with the next token.
|
||||
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
|
||||
if token_indices is not None:
|
||||
last_token_indices = token_indices
|
||||
|
||||
self.input_ids[last_token_indices] = next_token_ids
|
||||
|
||||
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
|
||||
max_query_len = query_lens.max().item()
|
||||
|
||||
# FIXME: reorder_batch() needs to be called before build()
|
||||
# because fields of attn_metadata_builder needs to be updated.
|
||||
# However, currently reorder_batch() takes input_batch and
|
||||
# scheduler_output as arguments, we should probably refactor
|
||||
# the method to use new data structures which are independent
|
||||
# from input_batch and scheduler_output.
|
||||
# self.runner.attn_metadata_builder.reorder_batch(
|
||||
# input_batch=self.runner.input_batch,
|
||||
# scheduler_output=self.runner.scheduler_output,
|
||||
# )
|
||||
|
||||
if not self.runner.with_prefill:
|
||||
# Torchair graph mode, padding is same as the main model
|
||||
num_input_tokens = self.runner.graph_pad_size
|
||||
elif (self.runner.use_aclgraph
|
||||
and num_tokens <= self.runner.aclgraph_batch_sizes[-1]):
|
||||
# Acl graph mode, add padding to the batch size
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
||||
else:
|
||||
# Eager mode, no padding needed
|
||||
num_input_tokens = num_tokens
|
||||
|
||||
seq_lens = target_positions[last_token_indices] + 1
|
||||
seq_lens = seq_lens.int()
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=cu_num_tokens[:batch_size + 1],
|
||||
query_start_loc_cpu=cu_num_tokens[:batch_size + 1].cpu(),
|
||||
seq_lens_cpu=seq_lens.cpu(),
|
||||
num_reqs=batch_size,
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=max_query_len,
|
||||
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
||||
block_table_tensor=self.runner.input_batch.block_table[0].
|
||||
get_device_tensor(),
|
||||
slot_mapping=target_slot_mapping,
|
||||
positions=target_positions,
|
||||
attn_mask=self.runner.attn_mask,
|
||||
spec_attn_mask=self.runner.spec_attn_mask,
|
||||
attn_state=self.runner.attn_state,
|
||||
graph_pad_size=self.runner.graph_pad_size,
|
||||
decode_token_per_req=self.runner.decode_token_per_req,
|
||||
num_computed_tokens_cpu=None,
|
||||
seq_lens=None)
|
||||
|
||||
attn_metadata = self.runner.attn_metadata_builder.build(
|
||||
0, common_attn_metadata, self.runner.get_model())
|
||||
|
||||
self.positions[:num_tokens] = target_positions
|
||||
self.hidden_states[:num_tokens] = target_hidden_states
|
||||
|
||||
# torchair mode can reuse self.runner.num_tokens_across_dp
|
||||
num_tokens_across_dp = self.runner.num_tokens_across_dp
|
||||
with_prefill = self.runner.with_prefill
|
||||
moe_comm_type = self.runner._select_moe_comm_method(num_input_tokens)
|
||||
|
||||
for step in range(self.num_speculative_tokens):
|
||||
with set_ascend_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
with_prefill=with_prefill,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
reserved_mc2_mask=self.runner.reserved_mc2_mask,
|
||||
moe_comm_type=moe_comm_type,
|
||||
in_profile_run=self.runner.in_profile_run,
|
||||
num_actual_tokens=num_tokens):
|
||||
with ProfileExecuteDuration().capture_async('mtp_forward'):
|
||||
model_kwargs = {}
|
||||
model_kwargs["attn_metadata"] = attn_metadata
|
||||
|
||||
model_kwargs["kv_caches"] = self.runner.kv_caches[-1:]
|
||||
if not self.runner.with_prefill:
|
||||
torchair_compiled_model = self._get_torchair_lazy_compiled_model(
|
||||
num_input_tokens)
|
||||
hidden_states = torchair_compiled_model(
|
||||
input_ids=self.input_ids[:num_input_tokens],
|
||||
positions=self.positions[:num_input_tokens],
|
||||
hidden_states=self.
|
||||
hidden_states[:num_input_tokens],
|
||||
inputs_embeds=None,
|
||||
intermediate_tensors=None,
|
||||
spec_step_idx=0,
|
||||
**model_kwargs)
|
||||
else:
|
||||
hidden_states = self.model(
|
||||
input_ids=self.input_ids[:num_input_tokens],
|
||||
positions=self.positions[:num_input_tokens],
|
||||
hidden_states=self.hidden_states[:num_input_tokens]
|
||||
)
|
||||
|
||||
num_indices = last_token_indices.shape[0]
|
||||
if lmhead_tp_enable():
|
||||
if not self.runner.with_prefill:
|
||||
max_num_reqs_across_dp = num_input_tokens
|
||||
else:
|
||||
max_num_reqs_across_dp = self.vllm_config.scheduler_config.max_num_seqs
|
||||
last_token_indices = nn.functional.pad(
|
||||
last_token_indices,
|
||||
(0, max_num_reqs_across_dp - num_indices))
|
||||
|
||||
sample_hidden_states = hidden_states[last_token_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
if lmhead_tp_enable() and num_indices < logits.shape[0]:
|
||||
logits = logits[:num_indices]
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
|
||||
if self.num_speculative_tokens == 1:
|
||||
# [batch_size, 1]
|
||||
return draft_token_ids.view(-1, 1)
|
||||
|
||||
if step == 0:
|
||||
draft_token_ids_list = [draft_token_ids]
|
||||
else:
|
||||
draft_token_ids_list.append(draft_token_ids)
|
||||
|
||||
# prepare next mtp inputs
|
||||
# mtp>1: prefill skip or decode skip last loop
|
||||
if with_prefill:
|
||||
for _ in range(self.num_speculative_tokens - 1):
|
||||
draft_token_ids_list.append(draft_token_ids)
|
||||
if step == self.num_speculative_tokens - 1 or with_prefill:
|
||||
break
|
||||
|
||||
attn_metadata_i = attn_metadata
|
||||
|
||||
if step == 0:
|
||||
positions = target_positions[last_token_indices]
|
||||
hidden_states = hidden_states[last_token_indices]
|
||||
slot_mapping = attn_metadata_i.slot_mapping[last_token_indices]
|
||||
attn_metadata_i.slot_mapping.fill_(-1)
|
||||
attn_metadata_i.query_start_loc = self.arange[:batch_size + 1]
|
||||
last_token_indices = self.arange[:batch_size]
|
||||
if attn_metadata_i.num_decode_tokens != 0:
|
||||
attn_metadata_i.num_decode_tokens = batch_size
|
||||
if not self.runner.with_prefill:
|
||||
attn_metadata_i.num_actual_tokens = batch_size
|
||||
attn_metadata_i.query_lens = [1] * batch_size
|
||||
|
||||
input_ids = draft_token_ids_list[-1].int()
|
||||
positions += 1
|
||||
|
||||
# NOTE(woosuk): We should handle the case where the draft model
|
||||
# generates tokens beyond the max model length. Since it is complex
|
||||
# to remove such requests from the batch, we keep them in the batch
|
||||
# but adjust the position ids and slot mappings to avoid the
|
||||
# out-of-range access during the model execution. The draft tokens
|
||||
# generated with this adjustment should be ignored.
|
||||
exceeds_max_model_len = positions >= self.runner.model_config.max_model_len
|
||||
# Mask out the position ids that exceed the max model length.
|
||||
# Otherwise, we may get out-of-range error in RoPE.
|
||||
clamped_positions = torch.where(exceeds_max_model_len, 0,
|
||||
positions)
|
||||
# Increment the sequence lengths.
|
||||
attn_metadata_i.seq_lens[:batch_size] += 1
|
||||
# For the requests that exceed the max model length, we set the
|
||||
# sequence length to 1 to minimize their overheads in attention.
|
||||
exceeds_max_model_len_cpu = exceeds_max_model_len.to(
|
||||
attn_metadata_i.seq_lens.device, non_blocking=True)
|
||||
attn_metadata_i.seq_lens[:batch_size].masked_fill_(
|
||||
exceeds_max_model_len_cpu, 1)
|
||||
# Mask out the slot mappings that exceed the max model length.
|
||||
# Otherwise, the KV cache will be inadvertently updated with the
|
||||
# padding tokens.
|
||||
slot_mapping += 1
|
||||
slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID)
|
||||
|
||||
# copy inputs to buffer for cudagraph
|
||||
self.input_ids[:batch_size] = input_ids
|
||||
self.positions[:batch_size] = clamped_positions
|
||||
self.hidden_states[:hidden_states.shape[0]] = hidden_states
|
||||
attn_metadata_i.slot_mapping[:batch_size] = slot_mapping
|
||||
|
||||
if attn_metadata_i.prefill is not None:
|
||||
attn_metadata_i.prefill.seq_lens = attn_metadata_i.seq_lens
|
||||
attn_metadata_i.prefill.seq_lens_list = attn_metadata_i.prefill.seq_lens.tolist(
|
||||
)
|
||||
attn_metadata_i.prefill.context_lens = attn_metadata_i.seq_lens
|
||||
attn_metadata_i.prefill.input_positions = self.positions[:
|
||||
num_input_tokens]
|
||||
attn_metadata_i.prefill.max_seq_lens += 1
|
||||
attn_metadata_i.prefill.max_seq_lens = min(
|
||||
attn_metadata_i.prefill.max_seq_lens,
|
||||
self.runner.model_config.max_model_len)
|
||||
if attn_metadata_i.decode is not None:
|
||||
attn_metadata_i.decode.seq_lens = attn_metadata_i.seq_lens
|
||||
attn_metadata_i.decode.seq_lens_list = attn_metadata_i.decode.seq_lens.tolist(
|
||||
)
|
||||
attn_metadata_i.decode.input_positions = self.positions[:
|
||||
num_input_tokens]
|
||||
attn_metadata_i.decode.max_seq_lens += 1
|
||||
attn_metadata_i.decode.max_seq_lens = min(
|
||||
attn_metadata_i.decode.max_seq_lens,
|
||||
self.runner.model_config.max_model_len)
|
||||
|
||||
# mtp>1: [batch_size, k]
|
||||
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
|
||||
return draft_token_ids
|
||||
|
||||
def _get_torchair_lazy_compiled_model(self, batch_size: int):
|
||||
if batch_size < 0 or batch_size > self.runner.torchair_graph_batch_sizes[
|
||||
-1]:
|
||||
raise ValueError(
|
||||
f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.runner.torchair_graph_batch_sizes[-1]}"
|
||||
)
|
||||
|
||||
compiled_model = self.torchair_compiled_models.get(
|
||||
batch_size
|
||||
) if self.runner.use_cached_npu_graph else self.torchair_compiled_model
|
||||
|
||||
if compiled_model:
|
||||
return compiled_model
|
||||
|
||||
patch_for_hcom()
|
||||
config = torchair.CompilerConfig()
|
||||
config.experimental_config.frozen_parameter = True
|
||||
config.experimental_config.tiling_schedule_optimize = True
|
||||
config.experimental_config.enable_view_optimize = \
|
||||
get_ascend_config().torchair_graph_config.enable_view_optimize
|
||||
torch.npu.set_compile_mode(jit_compile=False)
|
||||
if not self.runner.use_cached_npu_graph:
|
||||
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
||||
self.torchair_compiled_model = torch.compile(
|
||||
self.model,
|
||||
dynamic=not self.use_sparse,
|
||||
fullgraph=True,
|
||||
backend=npu_backend)
|
||||
return self.torchair_compiled_model
|
||||
else:
|
||||
# Generate a new forward proxy code object to prevent the invalidation of
|
||||
# compilation cache caused by dynamo retracing
|
||||
forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}"
|
||||
forward_fn = self.model.forward
|
||||
code = forward_fn.__code__
|
||||
# Mark code object with a new proxy name
|
||||
modified_code = code.replace(co_name=forward_proxy_name, )
|
||||
|
||||
modified_func = types.FunctionType(modified_code,
|
||||
forward_fn.__globals__,
|
||||
name=forward_proxy_name,
|
||||
argdefs=forward_fn.__defaults__)
|
||||
|
||||
self.model.__dict__[forward_proxy_name] = modified_func.__get__(
|
||||
self.model, nn.Module)
|
||||
self.torchair_compiled_models[
|
||||
batch_size] = torchair.inference.cache_compile(
|
||||
self.model.__dict__[forward_proxy_name],
|
||||
dynamic=not self.use_sparse,
|
||||
fullgraph=True,
|
||||
cache_dir=TORCHAIR_CACHE_DIR,
|
||||
config=config,
|
||||
ge_cache=False)
|
||||
return self.torchair_compiled_models[batch_size]
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,63 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
from vllm.logger import logger
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.torchair.torchair_model_runner import NPUTorchairModelRunner
|
||||
from vllm_ascend.torchair.utils import (check_kv_cache_bytes_cache_exist,
|
||||
delete_torchair_cache_file,
|
||||
read_kv_cache_bytes_from_file)
|
||||
from vllm_ascend.worker.worker_v1 import NPUWorker
|
||||
|
||||
|
||||
class NPUTorchairWorker(NPUWorker):
|
||||
"""Torchair worker bases on NPUWorker. Only torchair specified code should be added in this class."""
|
||||
|
||||
def determine_available_memory(self) -> int:
|
||||
"""Override determine_available_memory to use cached torchair kv_cache_bytes."""
|
||||
|
||||
available_kv_cache_memory = super().determine_available_memory()
|
||||
ascend_config = get_ascend_config()
|
||||
if ascend_config.enable_shared_expert_dp:
|
||||
return available_kv_cache_memory
|
||||
if ascend_config.torchair_graph_config.use_cached_kv_cache_bytes:
|
||||
if check_kv_cache_bytes_cache_exist():
|
||||
old_kv_cache_bytes = read_kv_cache_bytes_from_file(
|
||||
torch.distributed.get_rank())
|
||||
if 0 < old_kv_cache_bytes <= available_kv_cache_memory:
|
||||
logger.info(
|
||||
f"Use cached torchair kv_cache_bytes: {old_kv_cache_bytes}"
|
||||
)
|
||||
self.model_runner.new_kv_cache_bytes = old_kv_cache_bytes
|
||||
return old_kv_cache_bytes
|
||||
else:
|
||||
logger.info(
|
||||
"Cached torchair kv_cache_bytes is too big, invalidate old torchair_cache"
|
||||
)
|
||||
delete_torchair_cache_file()
|
||||
bytes_floating_tolerance = 1024 * 1024 * envs_ascend.VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE
|
||||
available_kv_cache_memory -= bytes_floating_tolerance
|
||||
logger.info(f"Use new kv_cache_bytes: {available_kv_cache_memory}")
|
||||
self.model_runner.new_kv_cache_bytes = available_kv_cache_memory
|
||||
return available_kv_cache_memory
|
||||
|
||||
def init_device(self):
|
||||
"""Override init_device to init torchair model runner"""
|
||||
device = self._init_device()
|
||||
# Init ModelRunner here, so that we have access to self.device.
|
||||
self.model_runner = NPUTorchairModelRunner(self.vllm_config, device)
|
||||
@@ -1,275 +0,0 @@
|
||||
import fcntl
|
||||
import os
|
||||
import shutil
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from torchair.scope import super_kernel as _super_kernel
|
||||
|
||||
try:
|
||||
# Recent release of torchair has moved these ops to `.scope`.
|
||||
from torchair.scope import npu_stream_switch as _npu_stream_switch
|
||||
from torchair.scope import npu_wait_tensor as _npu_wait_tensor
|
||||
except ImportError:
|
||||
from torchair.ops import NpuStreamSwitch as _npu_stream_switch
|
||||
from torchair.ops import npu_wait_tensor as _npu_wait_tensor
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
|
||||
|
||||
KV_CACHE_BYTES_CACHE_PATH_NAME = ".kv_cache_bytes"
|
||||
KV_CACHE_BYTES_CACHE_FILE_NAME = "kv_cache_bytes"
|
||||
TORCHAIR_CACHE_PATH_NAME = ".torchair_cache"
|
||||
TORCHAIR_CACHE_DIR = os.path.join(
|
||||
os.getenv('TORCHAIR_CACHE_HOME', os.getcwd()), TORCHAIR_CACHE_PATH_NAME)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TorchairCommonAttentionMetadata:
|
||||
"""
|
||||
Per-batch attention metadata, shared across layers and backends.
|
||||
AttentionMetadataBuilder instances use it to construct per-layer metadata.
|
||||
|
||||
For many of the tensors we keep both GPU and CPU versions.
|
||||
"""
|
||||
|
||||
num_reqs: int
|
||||
"""Number of requests"""
|
||||
|
||||
num_actual_tokens: int
|
||||
"""Total number of tokens in batch"""
|
||||
|
||||
decode_token_per_req: int
|
||||
|
||||
actual_seq_lengths_q: list[int]
|
||||
|
||||
attn_mask: torch.Tensor = None
|
||||
|
||||
spec_attn_mask: torch.Tensor = None
|
||||
|
||||
graph_pad_size: int = -1
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _file_lock(file_descriptor, lock_type):
|
||||
fcntl.flock(file_descriptor, lock_type)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
fcntl.flock(file_descriptor, fcntl.LOCK_UN)
|
||||
|
||||
|
||||
def _get_torchair_current_work_dir(file_name=None):
|
||||
if file_name is None:
|
||||
return TORCHAIR_CACHE_DIR
|
||||
return os.path.join(TORCHAIR_CACHE_DIR, file_name)
|
||||
|
||||
|
||||
def check_torchair_cache_exist():
|
||||
res = False
|
||||
torch_air_abs_path = _get_torchair_current_work_dir()
|
||||
if os.path.exists(torch_air_abs_path):
|
||||
file_list = os.listdir(torch_air_abs_path)
|
||||
if len(file_list) != 0:
|
||||
res = True
|
||||
return res
|
||||
|
||||
|
||||
def check_kv_cache_bytes_cache_exist():
|
||||
res = False
|
||||
kv_cache_bytes_cache_abs_path = _get_torchair_current_work_dir(
|
||||
KV_CACHE_BYTES_CACHE_PATH_NAME)
|
||||
if os.path.exists(kv_cache_bytes_cache_abs_path):
|
||||
file_list = os.listdir(kv_cache_bytes_cache_abs_path)
|
||||
if len(file_list) != 0:
|
||||
res = True
|
||||
return res
|
||||
|
||||
|
||||
def read_kv_cache_bytes_from_file(rank) -> int:
|
||||
kv_cache_bytes = -1
|
||||
kv_cache_bytes_cache_abs_path = _get_torchair_current_work_dir(
|
||||
KV_CACHE_BYTES_CACHE_PATH_NAME)
|
||||
kv_cache_bytes_file = os.path.join(
|
||||
kv_cache_bytes_cache_abs_path,
|
||||
f"{rank}_{KV_CACHE_BYTES_CACHE_FILE_NAME}")
|
||||
with open(kv_cache_bytes_file, "r", encoding="utf-8") as f:
|
||||
with _file_lock(f, fcntl.LOCK_SH):
|
||||
kv_cache_bytes = int(f.readline())
|
||||
return kv_cache_bytes
|
||||
|
||||
|
||||
def write_kv_cache_bytes_to_file(rank, kv_cache_bytes):
|
||||
kv_cache_bytes_cache_abs_path = _get_torchair_current_work_dir(
|
||||
KV_CACHE_BYTES_CACHE_PATH_NAME)
|
||||
os.makedirs(kv_cache_bytes_cache_abs_path, exist_ok=True)
|
||||
kv_cache_bytes_file = os.path.join(
|
||||
kv_cache_bytes_cache_abs_path,
|
||||
f"{rank}_{KV_CACHE_BYTES_CACHE_FILE_NAME}")
|
||||
with open(kv_cache_bytes_file, "w", encoding="utf-8") as f:
|
||||
with _file_lock(f, fcntl.LOCK_EX):
|
||||
f.write(f"{kv_cache_bytes}")
|
||||
|
||||
|
||||
def delete_torchair_cache_file():
|
||||
torch_air_abs_path = _get_torchair_current_work_dir()
|
||||
try:
|
||||
shutil.rmtree(torch_air_abs_path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
def npu_stream_switch(tag: str, priority: int, *, enabled: bool = True):
|
||||
return _npu_stream_switch(tag, priority) if enabled else nullcontext()
|
||||
|
||||
|
||||
def npu_wait_tensor(self: torch.Tensor,
|
||||
dependency: torch.Tensor,
|
||||
*,
|
||||
enabled: bool = True):
|
||||
return _npu_wait_tensor(self, dependency) if enabled else self
|
||||
|
||||
|
||||
def converting_weight_acl_format(model, format):
|
||||
# currently, there are some operations which do not support ACL_FORMAT_FRACTAL_NZ
|
||||
# in eager mode but support it in torchair graph mode. since ACL_FORMAT_FRACTAL_NZ
|
||||
# is much more preferred than ACL_FORMAT_FRACTAL_ND on 300I Duo, we add this
|
||||
# conversion when using torchair graph mode on 300I Duo platform.
|
||||
# TODO: we will remove this conversion if npu_quant_grouped_matmul_dequant
|
||||
# accepts weight format of ACL_FORMAT_FRACTAL_NZ in eager mode.
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, FusedMoE):
|
||||
if torch_npu.get_npu_format(module.w13_weight.data) == format:
|
||||
return
|
||||
if format == ACL_FORMAT_FRACTAL_NZ \
|
||||
and not is_enable_nz():
|
||||
return
|
||||
module.w13_weight.data = torch_npu.npu_format_cast(
|
||||
module.w13_weight.data, format)
|
||||
module.w2_weight.data = torch_npu.npu_format_cast(
|
||||
module.w2_weight.data, format)
|
||||
|
||||
|
||||
def register_torchair_model():
|
||||
from vllm import ModelRegistry
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"DeepSeekMTPModel",
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_mtp:TorchairDeepSeekMTP"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"DeepseekV2ForCausalLM",
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_v2:TorchairDeepseekV2ForCausalLM"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"DeepseekV3ForCausalLM",
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"DeepseekV32ForCausalLM",
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen2ForCausalLM",
|
||||
"vllm_ascend.torchair.models.qwen2:CustomQwen2ForCausalLM")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen3MoeForCausalLM",
|
||||
"vllm_ascend.torchair.models.qwen3_moe:CustomQwen3MoeForCausalLM")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"PanguProMoEForCausalLM",
|
||||
"vllm_ascend.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM"
|
||||
)
|
||||
|
||||
|
||||
def torchair_quant_method_register():
|
||||
from vllm_ascend.quantization.utils import ASCEND_QUANTIZATION_METHOD_MAP
|
||||
from vllm_ascend.torchair.quantization.torchair_w4a8_dynamic import (
|
||||
TorchairAscendW4A8DynamicFusedMoEMethod,
|
||||
TorchairAscendW4A8DynamicLinearMethod)
|
||||
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import (
|
||||
TorchairAscendW8A8DynamicFusedMoEMethod,
|
||||
TorchairAscendW8A8DynamicLinearMethod)
|
||||
|
||||
ASCEND_QUANTIZATION_METHOD_MAP["W8A8_DYNAMIC"][
|
||||
"linear"] = TorchairAscendW8A8DynamicLinearMethod
|
||||
ASCEND_QUANTIZATION_METHOD_MAP["W8A8_DYNAMIC"][
|
||||
"moe"] = TorchairAscendW8A8DynamicFusedMoEMethod
|
||||
ASCEND_QUANTIZATION_METHOD_MAP["W4A8_DYNAMIC"][
|
||||
"linear"] = TorchairAscendW4A8DynamicLinearMethod
|
||||
ASCEND_QUANTIZATION_METHOD_MAP["W4A8_DYNAMIC"][
|
||||
"moe"] = TorchairAscendW4A8DynamicFusedMoEMethod
|
||||
|
||||
|
||||
def torchair_ops_patch():
|
||||
from vllm_ascend.ops.activation import AscendSiluAndMul
|
||||
from vllm_ascend.ops.layernorm import AscendQuantRMSNorm, AscendRMSNorm
|
||||
from vllm_ascend.ops.rotary_embedding import (
|
||||
AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding)
|
||||
from vllm_ascend.ops.vocab_parallel_embedding import \
|
||||
AscendVocabParallelEmbedding
|
||||
from vllm_ascend.torchair.ops import (torchair_activation,
|
||||
torchair_layernorm)
|
||||
from vllm_ascend.torchair.ops.torchair_rotary_embedding import (
|
||||
deepseek_rope_init_func, native_rope_deepseek_forward,
|
||||
qwen_rope_init_func, rope_forward)
|
||||
from vllm_ascend.torchair.ops.torchair_vocab_parallel_embedding import \
|
||||
vocab_embedding_forward
|
||||
|
||||
AscendRotaryEmbedding.__init__ = qwen_rope_init_func # type: ignore[method-assign]
|
||||
AscendRotaryEmbedding.forward_oot = rope_forward # type: ignore[method-assign]
|
||||
|
||||
AscendDeepseekScalingRotaryEmbedding.__init__ = deepseek_rope_init_func # type: ignore[method-assign]
|
||||
AscendDeepseekScalingRotaryEmbedding.forward = native_rope_deepseek_forward # type: ignore[method-assign]
|
||||
|
||||
AscendRMSNorm.__init__ = torchair_layernorm.torchair_rmsnorm_init_ # type: ignore[method-assign]
|
||||
AscendRMSNorm.forward_oot = torchair_layernorm.torchair_rmsnorm_forward_oot # type: ignore[method-assign]
|
||||
|
||||
AscendQuantRMSNorm.__init__ = torchair_layernorm.torchair_rmsnorm_init_ # type: ignore[method-assign]
|
||||
AscendQuantRMSNorm.forward_oot = torchair_layernorm.torchair_rmsnorm_forward_oot # type: ignore[method-assign]
|
||||
|
||||
AscendSiluAndMul.forward_oot = torchair_activation.torchair_silu_and_mul_forward_oot # type: ignore[method-assign]
|
||||
AscendVocabParallelEmbedding.forward = vocab_embedding_forward # type: ignore[method-assign]
|
||||
|
||||
|
||||
def super_kernel(prefix: str, option: str, enabled: bool = True):
|
||||
return _super_kernel(prefix, option) if enabled else nullcontext()
|
||||
|
||||
|
||||
# TODO(ttanzhiqiang): rm_router_logits
|
||||
# dp>1 will trigger
|
||||
# In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors.
|
||||
def get_rm_router_logits_state(ep_size: int, dp_size: int,
|
||||
is_deepseek_v3_r1: bool):
|
||||
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
|
||||
# only supports deepseek v3/r1
|
||||
if dp_size > 1:
|
||||
if (envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
|
||||
and is_deepseek_v3_r1):
|
||||
return True
|
||||
elif ep_size == 1 and is_deepseek_v3_r1:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# TODO(ttanzhiqiang): all_reduce merge
|
||||
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
|
||||
# Currently, all_reduce_merge is enabled by default in the AllGather, AllGatherEP and NaiveMulticast scenarios of the deepseek model.
|
||||
def get_all_reduce_merge_state(ep_size: int, is_deepseek_v3_r1: bool):
|
||||
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
|
||||
# only supports deepseek v3/r1
|
||||
if (envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
|
||||
and is_deepseek_v3_r1):
|
||||
return True
|
||||
elif ep_size == 1 and is_deepseek_v3_r1:
|
||||
return True
|
||||
return False
|
||||
@@ -143,7 +143,6 @@ from vllm_ascend.spec_decode import get_spec_decode_method
|
||||
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
||||
from vllm_ascend.spec_decode.interface import SpecDcodeType
|
||||
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
|
||||
from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
AscendDeviceType, ProfileExecuteDuration,
|
||||
enable_sp, get_ascend_device_type, is_enable_nz,
|
||||
@@ -638,7 +637,6 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
# Set up speculative decoding.
|
||||
self.spec_attn_mask = None
|
||||
self.drafter: Optional[Union[NgramProposer, EagleProposer, MtpProposer,
|
||||
TorchairMtpProposer,
|
||||
SuffixDecodingProposer]] = None
|
||||
self.actual_seq_lengths_q: list[int] = []
|
||||
self.decode_token_per_req = 1
|
||||
@@ -2917,8 +2915,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def _generate_dummy_run_hidden_states(self, with_prefill,
|
||||
is_torchair_compile, input_ids,
|
||||
def _generate_dummy_run_hidden_states(self, with_prefill, input_ids,
|
||||
positions, attn_metadata, num_tokens,
|
||||
intermediate_tensors, inputs_embeds):
|
||||
hidden_states = self.model(input_ids=input_ids,
|
||||
@@ -2960,7 +2957,6 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
self,
|
||||
num_tokens: int,
|
||||
with_prefill: bool = False,
|
||||
is_torchair_compile: bool = False,
|
||||
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
|
||||
force_attention: bool = False,
|
||||
uniform_decode: bool = False,
|
||||
@@ -3136,9 +3132,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
model_instance=self.model,
|
||||
weight_prefetch_method=self.weight_prefetch_method):
|
||||
hidden_states = self._generate_dummy_run_hidden_states(
|
||||
with_prefill, is_torchair_compile, input_ids, positions,
|
||||
attn_metadata, num_tokens_padded, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
with_prefill, input_ids, positions, attn_metadata,
|
||||
num_tokens_padded, intermediate_tensors, inputs_embeds)
|
||||
dummy_compute_logits(hidden_states)
|
||||
|
||||
if self.drafter:
|
||||
@@ -4262,9 +4257,6 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
|
||||
return list(model.pooler.get_supported_tasks())
|
||||
|
||||
def _build_drafter_prepare_inputs_torchair_param(self):
|
||||
return False
|
||||
|
||||
def _update_tokens_for_pcp(self, tokens):
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
self.num_pcp_pads = self.num_pcp_pads[:num_reqs]
|
||||
|
||||
Reference in New Issue
Block a user